Skip to content

Commit e857292

Browse files
madhav-dbclaude
andcommitted
Add token provider infrastructure
This introduces a flexible TokenProvider interface that allows custom authentication implementations: - TokenProvider interface with static, external function support - Token struct with expiration handling - Authenticator wrapper for integration with existing auth system - Connector functions: WithTokenProvider, WithExternalToken, WithStaticToken This foundation enables custom token management strategies without requiring changes to the core driver. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent d0ebd21 commit e857292

9 files changed

Lines changed: 781 additions & 0 deletions

File tree

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
package tokenprovider
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"net/http"
7+
8+
"github.com/databricks/databricks-sql-go/auth"
9+
"github.com/rs/zerolog/log"
10+
)
11+
12+
// TokenProviderAuthenticator implements auth.Authenticator using a TokenProvider
13+
type TokenProviderAuthenticator struct {
14+
provider TokenProvider
15+
}
16+
17+
// NewAuthenticator creates an authenticator from a token provider
18+
func NewAuthenticator(provider TokenProvider) auth.Authenticator {
19+
return &TokenProviderAuthenticator{
20+
provider: provider,
21+
}
22+
}
23+
24+
// Authenticate implements auth.Authenticator
25+
func (a *TokenProviderAuthenticator) Authenticate(r *http.Request) error {
26+
ctx := r.Context()
27+
if ctx == nil {
28+
ctx = context.Background()
29+
}
30+
31+
token, err := a.provider.GetToken(ctx)
32+
if err != nil {
33+
return fmt.Errorf("token provider authenticator: failed to get token: %w", err)
34+
}
35+
36+
if token.AccessToken == "" {
37+
return fmt.Errorf("token provider authenticator: empty access token")
38+
}
39+
40+
token.SetAuthHeader(r)
41+
log.Debug().Msgf("token provider authenticator: authenticated using provider %s", a.provider.Name())
42+
43+
return nil
44+
}
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
package tokenprovider
2+
3+
import (
4+
"context"
5+
"errors"
6+
"net/http"
7+
"testing"
8+
9+
"github.com/stretchr/testify/assert"
10+
"github.com/stretchr/testify/require"
11+
)
12+
13+
func TestTokenProviderAuthenticator(t *testing.T) {
14+
t.Run("successful_authentication", func(t *testing.T) {
15+
provider := NewStaticTokenProvider("test-token-123")
16+
authenticator := NewAuthenticator(provider)
17+
18+
req, _ := http.NewRequest("GET", "http://example.com", nil)
19+
err := authenticator.Authenticate(req)
20+
21+
require.NoError(t, err)
22+
assert.Equal(t, "Bearer test-token-123", req.Header.Get("Authorization"))
23+
})
24+
25+
t.Run("authentication_with_custom_token_type", func(t *testing.T) {
26+
provider := NewStaticTokenProviderWithType("test-token", "MAC")
27+
authenticator := NewAuthenticator(provider)
28+
29+
req, _ := http.NewRequest("GET", "http://example.com", nil)
30+
err := authenticator.Authenticate(req)
31+
32+
require.NoError(t, err)
33+
assert.Equal(t, "MAC test-token", req.Header.Get("Authorization"))
34+
})
35+
36+
t.Run("authentication_error_propagation", func(t *testing.T) {
37+
provider := &mockProvider{
38+
tokenFunc: func() (*Token, error) {
39+
return nil, errors.New("provider failed")
40+
},
41+
}
42+
authenticator := NewAuthenticator(provider)
43+
44+
req, _ := http.NewRequest("GET", "http://example.com", nil)
45+
err := authenticator.Authenticate(req)
46+
47+
assert.Error(t, err)
48+
assert.Contains(t, err.Error(), "provider failed")
49+
assert.Empty(t, req.Header.Get("Authorization"))
50+
})
51+
52+
t.Run("empty_token_error", func(t *testing.T) {
53+
provider := &mockProvider{
54+
tokenFunc: func() (*Token, error) {
55+
return &Token{
56+
AccessToken: "",
57+
TokenType: "Bearer",
58+
}, nil
59+
},
60+
}
61+
authenticator := NewAuthenticator(provider)
62+
63+
req, _ := http.NewRequest("GET", "http://example.com", nil)
64+
err := authenticator.Authenticate(req)
65+
66+
assert.Error(t, err)
67+
assert.Contains(t, err.Error(), "empty access token")
68+
assert.Empty(t, req.Header.Get("Authorization"))
69+
})
70+
71+
t.Run("uses_request_context", func(t *testing.T) {
72+
ctx, cancel := context.WithCancel(context.Background())
73+
cancel() // Cancel immediately
74+
75+
provider := &mockProvider{
76+
tokenFunc: func() (*Token, error) {
77+
// This would normally check context cancellation
78+
return &Token{
79+
AccessToken: "test-token",
80+
TokenType: "Bearer",
81+
}, nil
82+
},
83+
}
84+
authenticator := NewAuthenticator(provider)
85+
86+
req, _ := http.NewRequestWithContext(ctx, "GET", "http://example.com", nil)
87+
err := authenticator.Authenticate(req)
88+
89+
// Even with cancelled context, this should work as our mock doesn't check it
90+
require.NoError(t, err)
91+
assert.Equal(t, "Bearer test-token", req.Header.Get("Authorization"))
92+
})
93+
94+
t.Run("external_token_integration", func(t *testing.T) {
95+
tokenFunc := func() (string, error) {
96+
return "external-token-456", nil
97+
}
98+
provider := NewExternalTokenProvider(tokenFunc)
99+
authenticator := NewAuthenticator(provider)
100+
101+
req, _ := http.NewRequest("POST", "http://example.com/api", nil)
102+
err := authenticator.Authenticate(req)
103+
104+
require.NoError(t, err)
105+
assert.Equal(t, "Bearer external-token-456", req.Header.Get("Authorization"))
106+
})
107+
108+
t.Run("cached_provider_integration", func(t *testing.T) {
109+
callCount := 0
110+
baseProvider := &mockProvider{
111+
tokenFunc: func() (*Token, error) {
112+
callCount++
113+
return &Token{
114+
AccessToken: "cached-token",
115+
TokenType: "Bearer",
116+
}, nil
117+
},
118+
name: "test",
119+
}
120+
121+
cachedProvider := NewCachedTokenProvider(baseProvider)
122+
authenticator := NewAuthenticator(cachedProvider)
123+
124+
// Multiple authentication attempts
125+
for i := 0; i < 3; i++ {
126+
req, _ := http.NewRequest("GET", "http://example.com", nil)
127+
err := authenticator.Authenticate(req)
128+
require.NoError(t, err)
129+
assert.Equal(t, "Bearer cached-token", req.Header.Get("Authorization"))
130+
}
131+
132+
// Should only call base provider once due to caching
133+
assert.Equal(t, 1, callCount)
134+
})
135+
}

auth/tokenprovider/external.go

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
package tokenprovider
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"time"
7+
)
8+
9+
// ExternalTokenProvider provides tokens from an external source (passthrough)
10+
type ExternalTokenProvider struct {
11+
tokenFunc func() (string, error)
12+
tokenType string
13+
}
14+
15+
// NewExternalTokenProvider creates a provider that gets tokens from an external function
16+
func NewExternalTokenProvider(tokenFunc func() (string, error)) *ExternalTokenProvider {
17+
return &ExternalTokenProvider{
18+
tokenFunc: tokenFunc,
19+
tokenType: "Bearer",
20+
}
21+
}
22+
23+
// NewExternalTokenProviderWithType creates a provider with a custom token type
24+
func NewExternalTokenProviderWithType(tokenFunc func() (string, error), tokenType string) *ExternalTokenProvider {
25+
return &ExternalTokenProvider{
26+
tokenFunc: tokenFunc,
27+
tokenType: tokenType,
28+
}
29+
}
30+
31+
// GetToken retrieves the token from the external source
32+
func (p *ExternalTokenProvider) GetToken(ctx context.Context) (*Token, error) {
33+
if p.tokenFunc == nil {
34+
return nil, fmt.Errorf("external token provider: token function is nil")
35+
}
36+
37+
accessToken, err := p.tokenFunc()
38+
if err != nil {
39+
return nil, fmt.Errorf("external token provider: failed to get token: %w", err)
40+
}
41+
42+
if accessToken == "" {
43+
return nil, fmt.Errorf("external token provider: empty token returned")
44+
}
45+
46+
return &Token{
47+
AccessToken: accessToken,
48+
TokenType: p.tokenType,
49+
ExpiresAt: time.Time{}, // External tokens don't provide expiry info
50+
}, nil
51+
}
52+
53+
// Name returns the provider name
54+
func (p *ExternalTokenProvider) Name() string {
55+
return "external"
56+
}

auth/tokenprovider/provider.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
package tokenprovider
2+
3+
import (
4+
"context"
5+
"net/http"
6+
"time"
7+
)
8+
9+
// TokenProvider is the interface for providing tokens from various sources
10+
type TokenProvider interface {
11+
// GetToken retrieves a valid access token
12+
GetToken(ctx context.Context) (*Token, error)
13+
14+
// Name returns the provider name for logging/debugging
15+
Name() string
16+
}
17+
18+
// Token represents an access token with metadata
19+
type Token struct {
20+
AccessToken string
21+
TokenType string
22+
ExpiresAt time.Time
23+
RefreshToken string
24+
Scopes []string
25+
}
26+
27+
// IsExpired checks if the token has expired
28+
func (t *Token) IsExpired() bool {
29+
if t.ExpiresAt.IsZero() {
30+
return false // No expiry means token doesn't expire
31+
}
32+
// Consider token expired 5 minutes before actual expiry for safety
33+
return time.Now().Add(5 * time.Minute).After(t.ExpiresAt)
34+
}
35+
36+
// SetAuthHeader sets the Authorization header on an HTTP request
37+
func (t *Token) SetAuthHeader(r *http.Request) {
38+
tokenType := t.TokenType
39+
if tokenType == "" {
40+
tokenType = "Bearer"
41+
}
42+
r.Header.Set("Authorization", tokenType+" "+t.AccessToken)
43+
}

0 commit comments

Comments
 (0)