Skip to content

Commit 398e246

Browse files
authored
feat: add Gemini support. (#177)
- Add Gemini support - Add `gemini` package - Add `gemini/func.go` file - Add `gemini/gemini.go` file - Add `gemini/options.go` file - Update `go.mod` file - Add `Float32Ptr` and `Int32Ptr` functions to `util/util.go` file Signed-off-by: appleboy <appleboy.tw@gmail.com>
1 parent f208110 commit 398e246

7 files changed

Lines changed: 470 additions & 2 deletions

File tree

cmd/openai.go

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"errors"
55

66
"github.com/appleboy/CodeGPT/core"
7+
"github.com/appleboy/CodeGPT/gemini"
78
"github.com/appleboy/CodeGPT/openai"
89

910
"github.com/spf13/viper"
@@ -30,11 +31,22 @@ func NewOpenAI() (*openai.Client, error) {
3031
)
3132
}
3233

34+
// NewGemini returns a new Gemini client
35+
func NewGemini() (*gemini.Client, error) {
36+
return gemini.New(
37+
gemini.WithToken(viper.GetString("openai.api_key")),
38+
gemini.WithModel(viper.GetString("openai.model")),
39+
gemini.WithMaxTokens(viper.GetInt("openai.max_tokens")),
40+
gemini.WithTemperature(float32(viper.GetFloat64("openai.temperature"))),
41+
gemini.WithTopP(float32(viper.GetFloat64("openai.top_p"))),
42+
)
43+
}
44+
3345
// GetClient returns the generative client based on the platform
3446
func GetClient(p core.Platform) (core.Generative, error) {
3547
switch p {
3648
case core.Gemini:
37-
// TODO: implement Gemini
49+
return NewGemini()
3850
case core.OpenAI, core.Azure:
3951
return NewOpenAI()
4052
}

gemini/func.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
package gemini
2+
3+
import "github.com/google/generative-ai-go/genai"
4+
5+
var summaryPrefixFunc = &genai.Tool{
6+
FunctionDeclarations: []*genai.FunctionDeclaration{{
7+
Name: "get_summary_prefix",
8+
Description: "Get a summary prefix using function call",
9+
Parameters: &genai.Schema{
10+
Type: genai.TypeObject,
11+
Properties: map[string]*genai.Schema{
12+
"prefix": {
13+
Type: genai.TypeString,
14+
Description: "The prefix to use for the summary",
15+
Enum: []string{
16+
"build", "chore", "ci",
17+
"docs", "feat", "fix",
18+
"perf", "refactor", "style",
19+
"test",
20+
},
21+
},
22+
},
23+
Required: []string{"prefix"},
24+
},
25+
}},
26+
}

gemini/gemini.go

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
package gemini
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"strings"
7+
8+
"github.com/appleboy/CodeGPT/core"
9+
"github.com/appleboy/CodeGPT/util"
10+
11+
"github.com/google/generative-ai-go/genai"
12+
"google.golang.org/api/option"
13+
)
14+
15+
type Client struct {
16+
client *genai.GenerativeModel
17+
model string
18+
maxTokens int
19+
temperature float32
20+
topP float32
21+
debug bool
22+
}
23+
24+
// Completion is a method on the Client struct that takes a context.Context and a string argument
25+
func (c *Client) Completion(ctx context.Context, content string) (*core.Response, error) {
26+
resp, err := c.client.GenerateContent(ctx, genai.Text(content))
27+
if err != nil {
28+
return nil, err
29+
}
30+
31+
var ret string
32+
33+
for _, cand := range resp.Candidates {
34+
for _, part := range cand.Content.Parts {
35+
ret += fmt.Sprintf("%v", part)
36+
}
37+
}
38+
39+
return &core.Response{
40+
Content: ret,
41+
Usage: core.Usage{
42+
PromptTokens: int(resp.UsageMetadata.PromptTokenCount),
43+
CompletionTokens: int(resp.UsageMetadata.CandidatesTokenCount),
44+
TotalTokens: int(resp.UsageMetadata.TotalTokenCount),
45+
},
46+
}, nil
47+
}
48+
49+
// GetSummaryPrefix is an API call to get a summary prefix using function call.
50+
func (c *Client) GetSummaryPrefix(ctx context.Context, content string) (*core.Response, error) {
51+
c.client.Tools = []*genai.Tool{summaryPrefixFunc}
52+
53+
// Start new chat session.
54+
session := c.client.StartChat()
55+
56+
// Send the message to the generative model.
57+
resp, err := session.SendMessage(ctx, genai.Text(content))
58+
if err != nil {
59+
return nil, err
60+
}
61+
62+
part := resp.Candidates[0].Content.Parts[0]
63+
64+
r := &core.Response{
65+
Content: strings.TrimSpace(strings.TrimSuffix(fmt.Sprintf("%v", part), "\n")),
66+
Usage: core.Usage{
67+
PromptTokens: int(resp.UsageMetadata.PromptTokenCount),
68+
CompletionTokens: int(resp.UsageMetadata.CandidatesTokenCount),
69+
TotalTokens: int(resp.UsageMetadata.TotalTokenCount),
70+
},
71+
}
72+
73+
if c.debug {
74+
// Check that you got the expected function call back.
75+
funcall, ok := part.(genai.FunctionCall)
76+
if !ok {
77+
return nil, fmt.Errorf("expected type FunctionCall, got %T", part)
78+
}
79+
if g, e := funcall.Name, summaryPrefixFunc.FunctionDeclarations[0].Name; g != e {
80+
return nil, fmt.Errorf("expected FunctionCall.Name %q, got %q", e, g)
81+
}
82+
}
83+
84+
return r, nil
85+
}
86+
87+
func New(opts ...Option) (c *Client, err error) {
88+
// Create a new config object with the given options.
89+
cfg := newConfig(opts...)
90+
91+
// Validate the config object, returning an error if it is invalid.
92+
if err := cfg.valid(); err != nil {
93+
return nil, err
94+
}
95+
96+
// Create a new client instance with the necessary fields.
97+
engine := &Client{
98+
model: cfg.model,
99+
maxTokens: cfg.maxTokens,
100+
temperature: cfg.temperature,
101+
}
102+
103+
client, err := genai.NewClient(context.Background(), option.WithAPIKey(cfg.token))
104+
if err != nil {
105+
return nil, err
106+
}
107+
108+
engine.client = client.GenerativeModel(engine.model)
109+
engine.client.MaxOutputTokens = util.Int32Ptr(int32(engine.maxTokens))
110+
engine.client.Temperature = util.Float32Ptr(engine.temperature)
111+
engine.client.TopP = util.Float32Ptr(engine.topP)
112+
113+
return engine, nil
114+
}

gemini/options.go

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
package gemini
2+
3+
import (
4+
"errors"
5+
)
6+
7+
var (
8+
errorsMissingToken = errors.New("missing gemini api key")
9+
errorsMissingModel = errors.New("missing model")
10+
)
11+
12+
const (
13+
defaultMaxTokens = 300
14+
defaultModel = "gemini-1.5-flash-latest"
15+
defaultTemperature = 1.0
16+
defaultTopP = 1.0
17+
)
18+
19+
// Option is an interface that specifies instrumentation configuration options.
20+
type Option interface {
21+
apply(*config)
22+
}
23+
24+
// optionFunc is a type of function that can be used to implement the Option interface.
25+
// It takes a pointer to a config struct and modifies it.
26+
type optionFunc func(*config)
27+
28+
// Ensure that optionFunc satisfies the Option interface.
29+
var _ Option = (*optionFunc)(nil)
30+
31+
// The apply method of optionFunc type is implemented here to modify the config struct based on the function passed.
32+
func (o optionFunc) apply(c *config) {
33+
o(c)
34+
}
35+
36+
// WithToken is a function that returns an Option, which sets the token field of the config struct.
37+
func WithToken(val string) Option {
38+
return optionFunc(func(c *config) {
39+
c.token = val
40+
})
41+
}
42+
43+
// WithModel is a function that returns an Option, which sets the model field of the config struct.
44+
func WithModel(val string) Option {
45+
return optionFunc(func(c *config) {
46+
c.model = val
47+
})
48+
}
49+
50+
// WithMaxTokens returns a new Option that sets the max tokens for the client configuration.
51+
// The maximum number of tokens to generate in the chat completion.
52+
// The total length of input tokens and generated tokens is limited by the model's context length.
53+
func WithMaxTokens(val int) Option {
54+
if val <= 0 {
55+
val = defaultMaxTokens
56+
}
57+
return optionFunc(func(c *config) {
58+
c.maxTokens = val
59+
})
60+
}
61+
62+
// WithTemperature returns a new Option that sets the temperature for the client configuration.
63+
// What sampling temperature to use, between 0 and 2.
64+
// Higher values like 0.8 will make the output more random,
65+
// while lower values like 0.2 will make it more focused and deterministic.
66+
func WithTemperature(val float32) Option {
67+
if val <= 0 {
68+
val = defaultTemperature
69+
}
70+
return optionFunc(func(c *config) {
71+
c.temperature = val
72+
})
73+
}
74+
75+
// WithTopP returns a new Option that sets the topP for the client configuration.
76+
func WithTopP(val float32) Option {
77+
return optionFunc(func(c *config) {
78+
c.topP = val
79+
})
80+
}
81+
82+
// config is a struct that stores configuration options for the instrumentation.
83+
type config struct {
84+
token string
85+
model string
86+
maxTokens int
87+
temperature float32
88+
topP float32
89+
}
90+
91+
// valid checks whether a config object is valid, returning an error if it is not.
92+
func (cfg *config) valid() error {
93+
// Check that the token is not empty.
94+
if cfg.token == "" {
95+
return errorsMissingToken
96+
}
97+
98+
if cfg.model == "" {
99+
return errorsMissingModel
100+
}
101+
102+
// If all checks pass, return nil (no error).
103+
return nil
104+
}
105+
106+
// newConfig creates a new config object with default values, and applies the given options.
107+
func newConfig(opts ...Option) *config {
108+
// Create a new config object with default values.
109+
c := &config{
110+
model: defaultModel,
111+
maxTokens: defaultMaxTokens,
112+
temperature: defaultTemperature,
113+
topP: defaultTopP,
114+
}
115+
116+
// Apply each of the given options to the config object.
117+
for _, opt := range opts {
118+
opt.apply(c)
119+
}
120+
121+
// Return the resulting config object.
122+
return c
123+
}

go.mod

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,40 @@
11
module github.com/appleboy/CodeGPT
22

3-
go 1.20
3+
go 1.21
4+
5+
toolchain go1.22.2
46

57
require (
68
github.com/appleboy/com v0.1.7
79
github.com/appleboy/graceful v1.1.1
810
github.com/fatih/color v1.17.0
11+
github.com/google/generative-ai-go v0.13.0
912
github.com/joho/godotenv v1.5.1
1013
github.com/rodaine/table v1.2.0
1114
github.com/sashabaranov/go-openai v1.24.0
1215
github.com/spf13/cobra v1.8.0
1316
github.com/spf13/viper v1.18.2
1417
golang.org/x/net v0.25.0
18+
google.golang.org/api v0.178.0
1519
)
1620

1721
require (
22+
cloud.google.com/go v0.113.0 // indirect
23+
cloud.google.com/go/ai v0.5.0 // indirect
24+
cloud.google.com/go/auth v0.4.0 // indirect
25+
cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect
26+
cloud.google.com/go/compute/metadata v0.3.0 // indirect
27+
cloud.google.com/go/longrunning v0.5.7 // indirect
28+
github.com/felixge/httpsnoop v1.0.4 // indirect
1829
github.com/fsnotify/fsnotify v1.7.0 // indirect
30+
github.com/go-logr/logr v1.4.1 // indirect
31+
github.com/go-logr/stdr v1.2.2 // indirect
32+
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
33+
github.com/golang/protobuf v1.5.4 // indirect
34+
github.com/google/s2a-go v0.1.7 // indirect
35+
github.com/google/uuid v1.6.0 // indirect
36+
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
37+
github.com/googleapis/gax-go/v2 v2.12.4 // indirect
1938
github.com/hashicorp/hcl v1.0.0 // indirect
2039
github.com/inconshreveable/mousetrap v1.1.0 // indirect
2140
github.com/magiconair/properties v1.8.7 // indirect
@@ -30,10 +49,25 @@ require (
3049
github.com/spf13/cast v1.6.0 // indirect
3150
github.com/spf13/pflag v1.0.5 // indirect
3251
github.com/subosito/gotenv v1.6.0 // indirect
52+
go.opencensus.io v0.24.0 // indirect
53+
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.51.0 // indirect
54+
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect
55+
go.opentelemetry.io/otel v1.26.0 // indirect
56+
go.opentelemetry.io/otel/metric v1.26.0 // indirect
57+
go.opentelemetry.io/otel/trace v1.26.0 // indirect
3358
go.uber.org/multierr v1.11.0 // indirect
59+
golang.org/x/crypto v0.23.0 // indirect
3460
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 // indirect
61+
golang.org/x/oauth2 v0.20.0 // indirect
62+
golang.org/x/sync v0.7.0 // indirect
3563
golang.org/x/sys v0.20.0 // indirect
3664
golang.org/x/text v0.15.0 // indirect
65+
golang.org/x/time v0.5.0 // indirect
66+
google.golang.org/genproto v0.0.0-20240401170217-c3f982113cda // indirect
67+
google.golang.org/genproto/googleapis/api v0.0.0-20240506185236-b8a5c65736ae // indirect
68+
google.golang.org/genproto/googleapis/rpc v0.0.0-20240506185236-b8a5c65736ae // indirect
69+
google.golang.org/grpc v1.63.2 // indirect
70+
google.golang.org/protobuf v1.34.1 // indirect
3771
gopkg.in/ini.v1 v1.67.0 // indirect
3872
gopkg.in/yaml.v3 v3.0.1 // indirect
3973
)

0 commit comments

Comments
 (0)