|
1 | 1 | package cmd |
2 | 2 |
|
3 | 3 | import ( |
4 | | - "errors" |
5 | | - "strings" |
6 | | - "time" |
7 | | - |
8 | | - "github.com/appleboy/com/array" |
9 | | - "github.com/fatih/color" |
10 | 4 | "github.com/spf13/cobra" |
11 | | - "github.com/spf13/viper" |
12 | 5 | ) |
13 | 6 |
|
14 | | -var availableKeys = []string{ |
15 | | - "git.diff_unified", |
16 | | - "git.exclude_list", |
17 | | - "git.template_file", |
18 | | - "git.template_string", |
19 | | - "openai.socks", |
20 | | - "openai.api_key", |
21 | | - "openai.model", |
22 | | - "openai.org_id", |
23 | | - "openai.proxy", |
24 | | - "output.lang", |
25 | | - "openai.base_url", |
26 | | - "openai.timeout", |
27 | | - "openai.max_tokens", |
28 | | - "openai.temperature", |
29 | | - "openai.provider", |
30 | | - "openai.model_name", |
31 | | - "openai.skip_verify", |
32 | | - "openai.headers", |
33 | | - "openai.api_version", |
34 | | - "openai.top_p", |
35 | | - "openai.frequency_penalty", |
36 | | - "openai.presence_penalty", |
37 | | -} |
38 | | - |
39 | | -func init() { |
40 | | - configCmd.PersistentFlags().StringP("base_url", "b", "", "what API base url to use.") |
41 | | - configCmd.PersistentFlags().StringP("api_key", "k", "", "openai api key") |
42 | | - configCmd.PersistentFlags().StringP("model", "m", "gpt-3.5-turbo", "openai model") |
43 | | - configCmd.PersistentFlags().StringP("lang", "l", "en", "summarizing language uses English by default") |
44 | | - configCmd.PersistentFlags().StringP("org_id", "o", "", "openai requesting organization") |
45 | | - configCmd.PersistentFlags().StringP("proxy", "", "", "http proxy") |
46 | | - configCmd.PersistentFlags().StringP("socks", "", "", "socks proxy") |
47 | | - configCmd.PersistentFlags().DurationP("timeout", "t", 10*time.Second, "http timeout") |
48 | | - configCmd.PersistentFlags().StringP("template_file", "", "", "git commit message file") |
49 | | - configCmd.PersistentFlags().StringP("template_string", "", "", "git commit message string") |
50 | | - configCmd.PersistentFlags().IntP("diff_unified", "", 3, "generate diffs with <n> lines of context, default is 3") |
51 | | - configCmd.PersistentFlags().IntP("max_tokens", "", 300, "the maximum number of tokens to generate in the chat completion.") |
52 | | - configCmd.PersistentFlags().Float32P("temperature", "", 1.0, "What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.") |
53 | | - configCmd.PersistentFlags().Float32P("top_p", "", 1.0, "An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.") |
54 | | - configCmd.PersistentFlags().Float32P("frequency_penalty", "", 0.0, "Number between 0.0 and 1.0 that penalizes new tokens based on their existing frequency in the text so far. Decreases the model's likelihood to repeat the same line verbatim.") |
55 | | - configCmd.PersistentFlags().Float32P("presence_penalty", "", 0.0, "Number between 0.0 and 1.0 that penalizes new tokens based on whether they appear in the text so far. Increases the model's likelihood to talk about new topics.") |
56 | | - configCmd.PersistentFlags().StringP("exclude_list", "", "", "exclude file from `git diff` command") |
57 | | - |
58 | | - configCmd.PersistentFlags().StringP("provider", "", "openai", "service provider, only support 'openai' or 'azure'") |
59 | | - configCmd.PersistentFlags().StringP("model_name", "", "", "model deployment name for Azure cognitive service") |
60 | | - configCmd.PersistentFlags().BoolP("skip_verify", "", false, "skip verify TLS certificate") |
61 | | - configCmd.PersistentFlags().StringP("headers", "", "", "custom headers for openai request") |
62 | | - configCmd.PersistentFlags().StringP("api_version", "", "", "openai api version") |
63 | | - |
64 | | - _ = viper.BindPFlag("openai.base_url", configCmd.PersistentFlags().Lookup("base_url")) |
65 | | - _ = viper.BindPFlag("openai.org_id", configCmd.PersistentFlags().Lookup("org_id")) |
66 | | - _ = viper.BindPFlag("openai.api_key", configCmd.PersistentFlags().Lookup("api_key")) |
67 | | - _ = viper.BindPFlag("openai.model", configCmd.PersistentFlags().Lookup("model")) |
68 | | - _ = viper.BindPFlag("openai.proxy", configCmd.PersistentFlags().Lookup("proxy")) |
69 | | - _ = viper.BindPFlag("openai.socks", configCmd.PersistentFlags().Lookup("socks")) |
70 | | - _ = viper.BindPFlag("openai.timeout", configCmd.PersistentFlags().Lookup("timeout")) |
71 | | - _ = viper.BindPFlag("openai.max_tokens", configCmd.PersistentFlags().Lookup("max_tokens")) |
72 | | - _ = viper.BindPFlag("openai.temperature", configCmd.PersistentFlags().Lookup("temperature")) |
73 | | - _ = viper.BindPFlag("openai.top_p", configCmd.PersistentFlags().Lookup("top_p")) |
74 | | - _ = viper.BindPFlag("openai.frequency_penalty", configCmd.PersistentFlags().Lookup("frequency_penalty")) |
75 | | - _ = viper.BindPFlag("openai.presence_penalty", configCmd.PersistentFlags().Lookup("presence_penalty")) |
76 | | - _ = viper.BindPFlag("output.lang", configCmd.PersistentFlags().Lookup("lang")) |
77 | | - _ = viper.BindPFlag("git.diff_unified", configCmd.PersistentFlags().Lookup("diff_unified")) |
78 | | - _ = viper.BindPFlag("git.exclude_list", configCmd.PersistentFlags().Lookup("exclude_list")) |
79 | | - _ = viper.BindPFlag("git.template_file", configCmd.PersistentFlags().Lookup("template_file")) |
80 | | - _ = viper.BindPFlag("git.template_string", configCmd.PersistentFlags().Lookup("template_string")) |
81 | | - |
82 | | - _ = viper.BindPFlag("openai.provider", configCmd.PersistentFlags().Lookup("provider")) |
83 | | - _ = viper.BindPFlag("openai.model_name", configCmd.PersistentFlags().Lookup("model_name")) |
84 | | - _ = viper.BindPFlag("openai.skip_verify", configCmd.PersistentFlags().Lookup("skip_verify")) |
85 | | - _ = viper.BindPFlag("openai.headers", configCmd.PersistentFlags().Lookup("headers")) |
86 | | - _ = viper.BindPFlag("openai.api_version", configCmd.PersistentFlags().Lookup("api_version")) |
87 | | -} |
88 | | - |
89 | 7 | var configCmd = &cobra.Command{ |
90 | 8 | Use: "config", |
91 | | - Short: "Add openai config (openai.api_key, openai.model ...)", |
92 | | - Args: cobra.MinimumNArgs(3), |
93 | | - RunE: func(cmd *cobra.Command, args []string) error { |
94 | | - // Check if command is 'set' |
95 | | - if args[0] != "set" { |
96 | | - return errors.New("config set key value. ex: config set openai.api_key sk-...") |
97 | | - } |
98 | | - |
99 | | - // Check if key is available |
100 | | - if !array.InSlice(args[1], availableKeys) { |
101 | | - return errors.New("available key list: " + strings.Join(availableKeys, ", ")) |
102 | | - } |
103 | | - |
104 | | - // Set config value in viper |
105 | | - if args[1] == "git.exclude_list" { |
106 | | - viper.Set(args[1], strings.Split(args[2], ",")) |
107 | | - } else { |
108 | | - viper.Set(args[1], args[2]) |
109 | | - } |
110 | | - |
111 | | - // Write config to file |
112 | | - if err := viper.WriteConfig(); err != nil { |
113 | | - return err |
114 | | - } |
115 | | - |
116 | | - // Print success message with config file location |
117 | | - color.Green("you can see the config file: %s", viper.ConfigFileUsed()) |
118 | | - return nil |
119 | | - }, |
| 9 | + Short: "custom config (openai.api_key, openai.model ...)", |
120 | 10 | } |
0 commit comments