Skip to content

Commit a7fd1da

Browse files
authored
refactor(ai): use official provider SDKs (#5845)
1 parent f394e94 commit a7fd1da

12 files changed

Lines changed: 466 additions & 494 deletions

File tree

go.mod

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ require (
1919
github.com/lib/pq v1.11.2
2020
github.com/lithammer/shortuuid/v4 v4.2.0
2121
github.com/mark3labs/mcp-go v0.45.0
22+
github.com/openai/openai-go/v3 v3.31.0
2223
github.com/pkg/errors v0.9.1
2324
github.com/spf13/cobra v1.10.2
2425
github.com/spf13/viper v1.21.0
@@ -32,6 +33,7 @@ require (
3233
golang.org/x/net v0.52.0
3334
golang.org/x/oauth2 v0.36.0
3435
golang.org/x/sync v0.20.0
36+
google.golang.org/genai v1.54.0
3537
google.golang.org/genproto v0.0.0-20260316180232-0b37fe3546d5
3638
google.golang.org/genproto/googleapis/api v0.0.0-20260316172706-e463d84ca32d
3739
google.golang.org/grpc v1.79.2
@@ -40,6 +42,9 @@ require (
4042

4143
require (
4244
cel.dev/expr v0.25.1 // indirect
45+
cloud.google.com/go v0.116.0 // indirect
46+
cloud.google.com/go/auth v0.9.3 // indirect
47+
cloud.google.com/go/compute/metadata v0.9.0 // indirect
4348
dario.cat/mergo v1.0.2 // indirect
4449
filippo.io/edwards25519 v1.1.0 // indirect
4550
github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c // indirect
@@ -66,6 +71,11 @@ require (
6671
github.com/go-logr/stdr v1.2.2 // indirect
6772
github.com/go-ole/go-ole v1.2.6 // indirect
6873
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
74+
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
75+
github.com/google/go-cmp v0.7.0 // indirect
76+
github.com/google/s2a-go v0.1.8 // indirect
77+
github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect
78+
github.com/gorilla/websocket v1.5.3 // indirect
6979
github.com/inconshreveable/mousetrap v1.1.0 // indirect
7080
github.com/invopop/jsonschema v0.13.0 // indirect
7181
github.com/klauspost/compress v1.18.2 // indirect
@@ -94,11 +104,16 @@ require (
94104
github.com/spf13/cast v1.10.0 // indirect
95105
github.com/spf13/pflag v1.0.10 // indirect
96106
github.com/subosito/gotenv v1.6.0 // indirect
107+
github.com/tidwall/gjson v1.18.0 // indirect
108+
github.com/tidwall/match v1.1.1 // indirect
109+
github.com/tidwall/pretty v1.2.1 // indirect
110+
github.com/tidwall/sjson v1.2.5 // indirect
97111
github.com/tklauser/go-sysconf v0.3.16 // indirect
98112
github.com/tklauser/numcpus v0.11.0 // indirect
99113
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
100114
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
101115
github.com/yusufpapurcu/wmi v1.2.4 // indirect
116+
go.opencensus.io v0.24.0 // indirect
102117
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
103118
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 // indirect
104119
go.opentelemetry.io/otel v1.41.0 // indirect

go.sum

Lines changed: 111 additions & 0 deletions
Large diffs are not rendered by default.

internal/ai/client.go

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
package ai
2+
3+
import (
4+
"net/http"
5+
"net/url"
6+
"strings"
7+
"time"
8+
9+
"github.com/pkg/errors"
10+
)
11+
12+
const defaultHTTPTimeout = 2 * time.Minute
13+
14+
type transcriberOptions struct {
15+
httpClient *http.Client
16+
}
17+
18+
// TranscriberOption configures a transcriber.
19+
type TranscriberOption func(*transcriberOptions)
20+
21+
// WithHTTPClient sets the HTTP client used by a transcriber.
22+
func WithHTTPClient(client *http.Client) TranscriberOption {
23+
return func(options *transcriberOptions) {
24+
if client != nil {
25+
options.httpClient = client
26+
}
27+
}
28+
}
29+
30+
// NewTranscriber creates a transcriber for a provider.
31+
func NewTranscriber(config ProviderConfig, options ...TranscriberOption) (Transcriber, error) {
32+
transcriberOptions := transcriberOptions{
33+
httpClient: &http.Client{Timeout: defaultHTTPTimeout},
34+
}
35+
for _, applyOption := range options {
36+
applyOption(&transcriberOptions)
37+
}
38+
39+
switch config.Type {
40+
case ProviderOpenAI:
41+
return newOpenAITranscriber(config, transcriberOptions)
42+
case ProviderGemini:
43+
return newGeminiTranscriber(config, transcriberOptions)
44+
default:
45+
return nil, errors.Wrapf(ErrCapabilityUnsupported, "provider type %q", config.Type)
46+
}
47+
}
48+
49+
func normalizeEndpoint(endpoint string, defaultEndpoint string, providerName string) (string, error) {
50+
endpoint = strings.TrimSpace(endpoint)
51+
if endpoint == "" {
52+
endpoint = defaultEndpoint
53+
}
54+
if _, err := url.ParseRequestURI(endpoint); err != nil {
55+
return "", errors.Wrapf(err, "invalid %s endpoint", providerName)
56+
}
57+
return strings.TrimRight(endpoint, "/"), nil
58+
}
59+
60+
func requireAPIKey(apiKey string, providerName string) error {
61+
if apiKey == "" {
62+
return errors.Errorf("%s API key is required", providerName)
63+
}
64+
return nil
65+
}

internal/ai/gemini.go

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
package ai
2+
3+
import (
4+
"context"
5+
"io"
6+
"mime"
7+
"net/url"
8+
"strings"
9+
10+
"github.com/pkg/errors"
11+
"google.golang.org/genai"
12+
)
13+
14+
const (
15+
defaultGeminiEndpoint = "https://generativelanguage.googleapis.com/v1beta"
16+
geminiTranscriptionPrompt = `Transcribe the audio accurately. Return only the transcript text. Do not summarize, explain, or add content that is not spoken.`
17+
maxGeminiInlineAudioSize = 14 * 1024 * 1024
18+
defaultGeminiAPIVersion = "v1beta"
19+
geminiProviderDisplayName = "Gemini"
20+
geminiDefaultTemperature = float32(0)
21+
)
22+
23+
var geminiSupportedContentTypes = map[string]string{
24+
"audio/wav": "audio/wav",
25+
"audio/x-wav": "audio/wav",
26+
"audio/mp3": "audio/mp3",
27+
"audio/mpeg": "audio/mp3",
28+
"audio/aiff": "audio/aiff",
29+
"audio/aac": "audio/aac",
30+
"audio/ogg": "audio/ogg",
31+
"audio/flac": "audio/flac",
32+
"audio/x-flac": "audio/flac",
33+
}
34+
35+
type geminiTranscriber struct {
36+
client *genai.Client
37+
}
38+
39+
func newGeminiTranscriber(config ProviderConfig, options transcriberOptions) (*geminiTranscriber, error) {
40+
endpoint, err := normalizeEndpoint(config.Endpoint, defaultGeminiEndpoint, geminiProviderDisplayName)
41+
if err != nil {
42+
return nil, err
43+
}
44+
if err := requireAPIKey(config.APIKey, geminiProviderDisplayName); err != nil {
45+
return nil, err
46+
}
47+
baseURL, apiVersion, err := normalizeGeminiEndpoint(endpoint)
48+
if err != nil {
49+
return nil, err
50+
}
51+
httpOptions := genai.HTTPOptions{
52+
BaseURL: baseURL,
53+
APIVersion: apiVersion,
54+
}
55+
if options.httpClient.Timeout > 0 {
56+
timeout := options.httpClient.Timeout
57+
httpOptions.Timeout = &timeout
58+
}
59+
60+
client, err := genai.NewClient(context.Background(), &genai.ClientConfig{
61+
APIKey: config.APIKey,
62+
Backend: genai.BackendGeminiAPI,
63+
HTTPClient: options.httpClient,
64+
HTTPOptions: httpOptions,
65+
})
66+
if err != nil {
67+
return nil, errors.Wrap(err, "failed to create Gemini client")
68+
}
69+
return &geminiTranscriber{client: client}, nil
70+
}
71+
72+
// Transcribe transcribes audio with Gemini generateContent.
73+
func (t *geminiTranscriber) Transcribe(ctx context.Context, request TranscribeRequest) (*TranscribeResponse, error) {
74+
if strings.TrimSpace(request.Model) == "" {
75+
return nil, errors.New("model is required")
76+
}
77+
if request.Audio == nil {
78+
return nil, errors.New("audio is required")
79+
}
80+
audio, err := io.ReadAll(request.Audio)
81+
if err != nil {
82+
return nil, errors.Wrap(err, "failed to read audio")
83+
}
84+
if len(audio) == 0 {
85+
return nil, errors.New("audio is required")
86+
}
87+
if len(audio) > maxGeminiInlineAudioSize {
88+
return nil, errors.Errorf("audio is too large for Gemini inline transcription; maximum size is %d bytes", maxGeminiInlineAudioSize)
89+
}
90+
91+
contentType, err := normalizeGeminiContentType(request.ContentType)
92+
if err != nil {
93+
return nil, err
94+
}
95+
prompt := buildGeminiTranscriptionPrompt(request.Prompt, request.Language)
96+
temperature := geminiDefaultTemperature
97+
response, err := t.client.Models.GenerateContent(ctx, normalizeGeminiModelName(request.Model), []*genai.Content{
98+
genai.NewContentFromParts([]*genai.Part{
99+
genai.NewPartFromBytes(audio, contentType),
100+
genai.NewPartFromText(prompt),
101+
}, genai.RoleUser),
102+
}, &genai.GenerateContentConfig{
103+
Temperature: &temperature,
104+
})
105+
if err != nil {
106+
return nil, errors.Wrap(err, "failed to send Gemini transcription request")
107+
}
108+
text := strings.TrimSpace(response.Text())
109+
if text == "" {
110+
return nil, errors.New("Gemini transcription response did not include text")
111+
}
112+
return &TranscribeResponse{
113+
Text: text,
114+
}, nil
115+
}
116+
117+
func normalizeGeminiEndpoint(endpoint string) (string, string, error) {
118+
parsed, err := url.Parse(endpoint)
119+
if err != nil {
120+
return "", "", errors.Wrap(err, "invalid Gemini endpoint")
121+
}
122+
path := strings.TrimRight(parsed.Path, "/")
123+
apiVersion := defaultGeminiAPIVersion
124+
for _, supportedVersion := range []string{"v1alpha", "v1beta", "v1"} {
125+
if path == "/"+supportedVersion || strings.HasSuffix(path, "/"+supportedVersion) {
126+
apiVersion = supportedVersion
127+
parsed.Path = strings.TrimSuffix(path, "/"+supportedVersion)
128+
break
129+
}
130+
}
131+
return strings.TrimRight(parsed.String(), "/"), apiVersion, nil
132+
}
133+
134+
func normalizeGeminiContentType(contentType string) (string, error) {
135+
mediaType, _, err := mime.ParseMediaType(strings.TrimSpace(contentType))
136+
if err != nil {
137+
return "", errors.Wrap(err, "invalid audio content type")
138+
}
139+
mediaType = strings.ToLower(mediaType)
140+
normalized, ok := geminiSupportedContentTypes[mediaType]
141+
if !ok {
142+
return "", errors.Errorf("audio content type %q is not supported by Gemini", mediaType)
143+
}
144+
return normalized, nil
145+
}
146+
147+
func buildGeminiTranscriptionPrompt(prompt string, language string) string {
148+
parts := []string{geminiTranscriptionPrompt}
149+
language = strings.TrimSpace(language)
150+
if language != "" {
151+
parts = append(parts, "The input language is "+language+".")
152+
}
153+
prompt = strings.TrimSpace(prompt)
154+
if prompt != "" {
155+
parts = append(parts, "Context and spelling hints:\n"+prompt)
156+
}
157+
return strings.Join(parts, "\n\n")
158+
}
159+
160+
func normalizeGeminiModelName(model string) string {
161+
return strings.TrimPrefix(strings.TrimSpace(model), "models/")
162+
}

internal/ai/gemini/client.go

Lines changed: 0 additions & 59 deletions
This file was deleted.

0 commit comments

Comments
 (0)