Skip to content

Commit f4a0d7a

Browse files
authored
refactor(model): remove model_name parameter (#165)
- Remove custom model checking logic from `commit.go` and `review.go` - Simplify model validation logic in `hepler.go` - Delete the `groq/model.go` file and its references - Remove the model mapping logic from `openai.go` - Simplify model assignment in `openai.go` - Remove the `OPENROUTER` provider and related custom model logic from `options.go` - Remove the `modelName` field and related validation logic from `options.go` - Adjust tests in `options_test.go` to reflect removal of custom model logic and update expected error to `errorsMissingModel` Signed-off-by: Bo-Yi Wu <appleboy.tw@gmail.com>
1 parent 7b8f448 commit f4a0d7a

7 files changed

Lines changed: 12 additions & 128 deletions

File tree

cmd/commit.go

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,6 @@ var commitCmd = &cobra.Command{
8585
}
8686

8787
currentModel := viper.GetString("openai.model")
88-
if openai.Provider(viper.GetString("openai.provider")).IsCustomModel() {
89-
currentModel = viper.GetString("openai.model_name")
90-
}
91-
9288
color.Green("Summarize the commit message use " + currentModel + " model")
9389
client, err := openai.New(
9490
openai.WithToken(viper.GetString("openai.api_key")),
@@ -101,7 +97,6 @@ var commitCmd = &cobra.Command{
10197
openai.WithMaxTokens(viper.GetInt("openai.max_tokens")),
10298
openai.WithTemperature(float32(viper.GetFloat64("openai.temperature"))),
10399
openai.WithProvider(viper.GetString("openai.provider")),
104-
openai.WithModelName(viper.GetString("openai.model_name")),
105100
openai.WithSkipVerify(viper.GetBool("openai.skip_verify")),
106101
openai.WithHeaders(viper.GetStringSlice("openai.headers")),
107102
openai.WithApiVersion(viper.GetString("openai.api_version")),

cmd/hepler.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ func check() error {
3131
viper.Set("output.lang", commitLang)
3232
}
3333

34-
if openai.GetModel(commitModel) != openai.DefaultModel {
34+
if commitModel != openai.DefaultModel {
3535
viper.Set("openai.model", commitModel)
3636
}
3737

cmd/review.go

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,6 @@ var reviewCmd = &cobra.Command{
5252
}
5353

5454
currentModel := viper.GetString("openai.model")
55-
if openai.Provider(viper.GetString("openai.provider")).IsCustomModel() {
56-
currentModel = viper.GetString("openai.model_name")
57-
}
58-
5955
color.Green("Code review your changes using " + currentModel + " model")
6056
client, err := openai.New(
6157
openai.WithToken(viper.GetString("openai.api_key")),
@@ -68,7 +64,6 @@ var reviewCmd = &cobra.Command{
6864
openai.WithMaxTokens(viper.GetInt("openai.max_tokens")),
6965
openai.WithTemperature(float32(viper.GetFloat64("openai.temperature"))),
7066
openai.WithProvider(viper.GetString("openai.provider")),
71-
openai.WithModelName(viper.GetString("openai.model_name")),
7267
openai.WithSkipVerify(viper.GetBool("openai.skip_verify")),
7368
openai.WithHeaders(viper.GetStringSlice("openai.headers")),
7469
openai.WithApiVersion(viper.GetString("openai.api_version")),

groq/model.go

Lines changed: 0 additions & 23 deletions
This file was deleted.

openai/openai.go

Lines changed: 2 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -7,61 +7,13 @@ import (
77
"net/http"
88
"net/url"
99

10-
"github.com/appleboy/CodeGPT/groq"
11-
1210
openai "github.com/sashabaranov/go-openai"
1311
"golang.org/x/net/proxy"
1412
)
1513

1614
// DefaultModel is the default OpenAI model to use if one is not provided.
1715
var DefaultModel = openai.GPT3Dot5Turbo
1816

19-
// modelMaps maps model names to their corresponding model ID strings.
20-
var modelMaps = map[string]string{
21-
"gpt-4-32k-0613": openai.GPT432K0613,
22-
"gpt-4-32k-0314": openai.GPT432K0314,
23-
"gpt-4-32k": openai.GPT432K,
24-
"gpt-4-0613": openai.GPT40613,
25-
"gpt-4-0314": openai.GPT40314,
26-
"gpt-4-turbo": openai.GPT4Turbo,
27-
"gpt-4-turbo-2024-04-09": openai.GPT4Turbo20240409,
28-
"gpt-4-0125-preview": openai.GPT4Turbo0125,
29-
"gpt-4-1106-preview": openai.GPT4Turbo1106,
30-
"gpt-4-turbo-preview": openai.GPT4TurboPreview,
31-
"gpt-4-vision-preview": openai.GPT4VisionPreview,
32-
"gpt-4": openai.GPT4,
33-
"gpt-3.5-turbo-0125": openai.GPT3Dot5Turbo0125,
34-
"gpt-3.5-turbo-1106": openai.GPT3Dot5Turbo1106,
35-
"gpt-3.5-turbo-0613": openai.GPT3Dot5Turbo0613,
36-
"gpt-3.5-turbo-0301": openai.GPT3Dot5Turbo0301,
37-
"gpt-3.5-turbo-16k": openai.GPT3Dot5Turbo16K,
38-
"gpt-3.5-turbo-16k-0613": openai.GPT3Dot5Turbo16K0613,
39-
"gpt-3.5-turbo": openai.GPT3Dot5Turbo,
40-
"gpt-3.5-turbo-instruct": openai.GPT3Dot5TurboInstruct,
41-
"davinci": openai.GPT3Davinci,
42-
"davinci-002": openai.GPT3Davinci002,
43-
"curie": openai.GPT3Curie,
44-
"curie-002": openai.GPT3Curie002,
45-
"ada": openai.GPT3Ada,
46-
"ada-002": openai.GPT3Ada002,
47-
"babbage": openai.GPT3Babbage,
48-
"babbage-002": openai.GPT3Babbage002,
49-
groq.LLaMA38b.String(): groq.LLaMA38b.String(),
50-
groq.LLaMA370b.String(): groq.LLaMA370b.String(),
51-
groq.Mixtral8x7b.String(): groq.Mixtral8x7b.String(),
52-
groq.Gemma7b.String(): groq.Gemma7b.String(),
53-
}
54-
55-
// GetModel returns the model ID corresponding to the given model name.
56-
// If the model name is not recognized, it returns the default model ID.
57-
func GetModel(model string) string {
58-
v, ok := modelMaps[model]
59-
if !ok {
60-
return DefaultModel
61-
}
62-
return v
63-
}
64-
6517
// Client is a struct that represents an OpenAI client.
6618
type Client struct {
6719
client *openai.Client
@@ -183,7 +135,7 @@ func New(opts ...Option) (*Client, error) {
183135

184136
// Create a new client instance with the necessary fields.
185137
engine := &Client{
186-
model: modelMaps[cfg.model],
138+
model: cfg.model,
187139
maxTokens: cfg.maxTokens,
188140
temperature: cfg.temperature,
189141
}
@@ -229,7 +181,7 @@ func New(opts ...Option) (*Client, error) {
229181
if cfg.provider == AZURE {
230182
defaultAzureConfig := openai.DefaultAzureConfig(cfg.token, cfg.baseURL)
231183
defaultAzureConfig.AzureModelMapperFunc = func(model string) string {
232-
return cfg.modelName
184+
return cfg.model
233185
}
234186
// Set the API version to the one with the specified options.
235187
if cfg.apiVersion != "" {
@@ -247,9 +199,6 @@ func New(opts ...Option) (*Client, error) {
247199
c.APIVersion = cfg.apiVersion
248200
}
249201

250-
if cfg.provider.IsCustomModel() {
251-
engine.model = cfg.modelName
252-
}
253202
engine.client = openai.NewClientWithConfig(c)
254203
}
255204

openai/options.go

Lines changed: 6 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,8 @@ import (
88
)
99

1010
var (
11-
errorsMissingToken = errors.New("please set OPENAI_API_KEY environment variable")
12-
errorsMissingModel = errors.New("missing model")
13-
errorsMissingCustomModel = errors.New("missing custom model name")
11+
errorsMissingToken = errors.New("please set OPENAI_API_KEY environment variable")
12+
errorsMissingModel = errors.New("missing model")
1413
)
1514

1615
type Provider string
@@ -21,21 +20,16 @@ func (p Provider) String() string {
2120

2221
func (p Provider) IsValid() bool {
2322
switch p {
24-
case OPENAI, AZURE, OPENROUTER:
23+
case OPENAI, AZURE:
2524
return true
2625
default:
2726
return false
2827
}
2928
}
3029

31-
func (p Provider) IsCustomModel() bool {
32-
return p != OPENAI
33-
}
34-
3530
var (
36-
OPENAI Provider = "openai"
37-
AZURE Provider = "azure"
38-
OPENROUTER Provider = "openrouter"
31+
OPENAI Provider = "openai"
32+
AZURE Provider = "azure"
3933
)
4034

4135
const (
@@ -152,15 +146,6 @@ func WithProvider(val string) Option {
152146
})
153147
}
154148

155-
// WithModelName sets the `modelName` variable to the provided `val` parameter.
156-
// This function returns an `Option` object.
157-
func WithModelName(val string) Option {
158-
// Return an `optionFunc` object with `c.modelName` set to `val`.
159-
return optionFunc(func(c *config) {
160-
c.modelName = val
161-
})
162-
}
163-
164149
// WithSkipVerify returns a new Option that sets the skipVerify for the client configuration.
165150
func WithSkipVerify(val bool) Option {
166151
return optionFunc(func(c *config) {
@@ -220,7 +205,6 @@ type config struct {
220205
frequencyPenalty float32
221206

222207
provider Provider
223-
modelName string
224208
skipVerify bool
225209
headers []string
226210
apiVersion string
@@ -233,17 +217,10 @@ func (cfg *config) valid() error {
233217
return errorsMissingToken
234218
}
235219

236-
// Check that the model exists in the model maps.
237-
modelExists := modelMaps[cfg.model] != ""
238-
if !modelExists {
220+
if cfg.model == "" {
239221
return errorsMissingModel
240222
}
241223

242-
// If the provider is Azure, check that the model name is not empty.
243-
if cfg.provider.IsCustomModel() && cfg.modelName == "" {
244-
return errorsMissingCustomModel
245-
}
246-
247224
// If all checks pass, return nil (no error).
248225
return nil
249226
}

openai/options_test.go

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ func Test_config_valid(t *testing.T) {
3030
name: "missing model",
3131
cfg: newConfig(
3232
WithToken("test"),
33-
WithModel("test"),
33+
WithModel(""),
3434
WithProvider(OPENAI.String()),
3535
),
3636
wantErr: errorsMissingModel,
@@ -39,19 +39,10 @@ func Test_config_valid(t *testing.T) {
3939
name: "missing Azure deployment model",
4040
cfg: newConfig(
4141
WithToken("test"),
42-
WithModel(openai.GPT3Dot5Turbo),
42+
WithModel(""),
4343
WithProvider(AZURE.String()),
4444
),
45-
wantErr: errorsMissingCustomModel,
46-
},
47-
{
48-
name: "missing OpenRouter Custom model",
49-
cfg: newConfig(
50-
WithToken("test"),
51-
WithModel(openai.GPT3Dot5Turbo),
52-
WithProvider(OPENROUTER.String()),
53-
),
54-
wantErr: errorsMissingCustomModel,
45+
wantErr: errorsMissingModel,
5546
},
5647
}
5748
for _, tt := range tests {

0 commit comments

Comments
 (0)