-
Notifications
You must be signed in to change notification settings - Fork 140
449 lines (374 loc) · 21.3 KB
/
token-federation-test.yml
File metadata and controls
449 lines (374 loc) · 21.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
name: Token Federation Test
# This workflow tests token federation functionality with GitHub Actions OIDC tokens
# in the databricks-sql-python connector to ensure CI/CD functionality
on:
# Manual trigger with required inputs
workflow_dispatch:
inputs:
databricks_host:
description: 'Databricks host URL (e.g., example.cloud.databricks.com)'
required: true
databricks_http_path:
description: 'Databricks HTTP path (e.g., /sql/1.0/warehouses/abc123)'
required: true
identity_federation_client_id:
description: 'Identity federation client ID'
required: true
# Automatically run on PR that changes token federation files
pull_request:
branches:
- main
# Run on push to main that affects token federation
push:
paths:
- 'src/databricks/sql/auth/token_federation.py'
- 'src/databricks/sql/auth/auth.py'
- 'examples/token_federation_*.py'
branches:
- main
permissions:
# Required for GitHub OIDC token
id-token: write
contents: read
jobs:
test-token-federation:
runs-on:
group: databricks-protected-runner-group
labels: linux-ubuntu-latest
steps:
- name: Debug OIDC Claims
uses: github/actions-oidc-debugger@main
with:
audience: '${{ github.server_url }}/${{ github.repository_owner }}'
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Python 3.9
uses: actions/setup-python@v5
with:
python-version: '3.9'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e .
pip install pyarrow
- name: Create debugging patch script
run: |
cat > patch_for_debugging.py << 'EOF'
#!/usr/bin/env python3
def patch_code():
with open('src/databricks/sql/auth/token_federation.py', 'r') as f:
content = f.read()
# Add token debugging
modified = content.replace(
'def _exchange_token(self, token, force_refresh=False):',
'def _exchange_token(self, token, force_refresh=False):\n # Debug token info\n import jwt\n try:\n decoded = jwt.decode(token, options={"verify_signature": False})\n print(f"Token issuer: {decoded.get(\'iss\')}")\n print(f"Token subject: {decoded.get(\'sub\')}")\n print(f"Token audience: {decoded.get(\'aud\') if isinstance(decoded.get(\'aud\'), str) else decoded.get(\'aud\', [])[0] if decoded.get(\'aud\') else \'\'}")\n except Exception as e:\n print(f"Unable to decode token: {str(e)}")'
)
# Add verbose request debugging
modified = modified.replace(
'try:\n # Make the token exchange request',
'try:\n import urllib.parse\n # Debug full request\n print(f"Connecting to Databricks at {self.host}")\n print(f"Token endpoint: {self.token_endpoint}")\n print(f"Request parameters: {urllib.parse.urlencode(params)}")\n print(f"Request headers: {headers}")\n # Make the token exchange request'
)
# Add verbose response debugging
modified = modified.replace(
'response = requests.post(self.token_endpoint, data=params, headers=headers)',
'response = requests.post(self.token_endpoint, data=params, headers=headers)\n print(f"Response status: {response.status_code}")\n print(f"Response headers: {dict(response.headers)}")\n print(f"Response body: {response.text}")'
)
# Improve error handling
modified = modified.replace(
'except RequestException as e:',
'except RequestException as e:\n print(f"Failed to perform token exchange: {str(e)}")\n if hasattr(e, "response") and e.response:\n print(f"Error response status: {e.response.status_code}")\n print(f"Error response headers: {dict(e.response.headers)}")\n print(f"Error response text: {e.response.text}")'
)
with open('src/databricks/sql/auth/token_federation.py', 'w') as f:
f.write(modified)
if __name__ == "__main__":
patch_code()
EOF
chmod +x patch_for_debugging.py
- name: Install PyJWT for token debugging
run: pip install pyjwt
- name: Apply debugging patches to token_federation.py
run: python patch_for_debugging.py
- name: Create audience fix patch script
run: |
cat > patch_for_audience_fix.py << 'EOF'
#!/usr/bin/env python3
def patch_code():
with open('src/databricks/sql/auth/token_federation.py', 'r') as f:
content = f.read()
# Fix audience handling
modified = content.replace(
'def _exchange_token(self, token, force_refresh=False):',
'def _exchange_token(self, token, force_refresh=False):\\n # Additional handling for different audience formats\\n import jwt\\n try:\\n # Try both standard and alternative audience formats\\n audience_tried = False\\n \\n def try_with_audience(token, audience):\\n nonlocal audience_tried\\n if audience_tried:\\n return None\\n \\n audience_tried = True\\n decoded = jwt.decode(token, options={\"verify_signature\": False})\\n aud = decoded.get(\"aud\")\\n \\n # Check if aud is a list and convert to string if needed\\n if isinstance(aud, list) and len(aud) > 0:\\n aud = aud[0]\\n \\n # Print audience for debugging\\n print(f\"Original token audience: {aud}\")\\n \\n if aud != audience:\\n print(f\"WARNING: Token audience \\\'{aud}\\\' doesn\\\'t match expected audience \\\'{audience}\\\'\")\\n # We won\\\'t modify the token as that would invalidate the signature\\n \\n return None\\n \\n # We\\\'re just collecting debugging info, not modifying the token\\n try_with_audience(token, \"https://github.com/databricks\")\\n \\n except Exception as e:\\n print(f\"Audience debug error: {str(e)}\")'
)
with open('src/databricks/sql/auth/token_federation.py', 'w') as f:
f.write(modified)
if __name__ == "__main__":
patch_code()
EOF
chmod +x patch_for_audience_fix.py
- name: Apply audience fix patches
run: python patch_for_audience_fix.py
- name: Get GitHub OIDC token
id: get-id-token
uses: actions/github-script@v7
with:
script: |
const token = await core.getIDToken('https://github.com/databricks')
core.setSecret(token)
core.setOutput('token', token)
- name: Decode and display OIDC token claims
env:
OIDC_TOKEN: ${{ steps.get-id-token.outputs.token }}
run: |
echo "Decoding GitHub OIDC token claims..."
python -c '
import sys, base64, json
token = """$OIDC_TOKEN"""
# Parse the token
try:
header, payload, signature = token.split(".")
# Add padding if needed
payload_padding = payload + "=" * (-len(payload) % 4)
# Decode the payload
decoded_payload = base64.b64decode(payload_padding).decode("utf-8")
claims = json.loads(decoded_payload)
# Print important claims
print("\n=== GITHUB OIDC TOKEN CLAIMS ===")
print(f"Issuer (iss): {claims.get('iss')}")
print(f"Subject (sub): {claims.get('sub')}")
print(f"Audience (aud): {claims.get('aud')}")
print(f"Repository: {claims.get('repository')}")
print(f"Repository owner: {claims.get('repository_owner')}")
print(f"Event name: {claims.get('event_name')}")
print(f"Ref: {claims.get('ref')}")
print(f"Workflow ref: {claims.get('workflow_ref')}")
print("\n=== FULL CLAIMS ===")
print(json.dumps(claims, indent=2))
print("===========================\n")
except Exception as e:
print(f"Failed to decode token: {str(e)}")
'
- name: Debug token exchange with curl
env:
DATABRICKS_HOST: ${{ github.event_name == 'workflow_dispatch' && inputs.databricks_host || secrets.DATABRICKS_HOST_FOR_TF }}
IDENTITY_FEDERATION_CLIENT_ID: ${{ github.event_name == 'workflow_dispatch' && inputs.identity_federation_client_id || secrets.IDENTITY_FEDERATION_CLIENT_ID_FOR_TF }}
OIDC_TOKEN: ${{ steps.get-id-token.outputs.token }}
run: |
echo "Attempting direct token exchange with curl..."
echo "Host: $DATABRICKS_HOST"
echo "Client ID: $IDENTITY_FEDERATION_CLIENT_ID"
# Debug token claims before making the request
echo "Token claims:"
python3 -c "
import base64, json, sys
token = \"$OIDC_TOKEN\"
parts = token.split(\".\")
if len(parts) >= 2:
padding = \"=\" * (4 - len(parts[1]) % 4)
decoded_bytes = base64.b64decode(parts[1] + padding)
decoded_str = decoded_bytes.decode(\"utf-8\")
claims = json.loads(decoded_str)
print(f\"Token issuer: {claims.get('iss', 'unknown')}\")
print(f\"Token subject: {claims.get('sub', 'unknown')}\")
print(f\"Token audience: {claims.get('aud', 'unknown')}\")
else:
print(\"Invalid token format\")
"
# Create a properly URL-encoded request
echo "Creating token exchange request..."
curl_data=$(cat << 'EOF'
client_id=$IDENTITY_FEDERATION_CLIENT_ID&\
subject_token=$OIDC_TOKEN&\
subject_token_type=urn:ietf:params:oauth:token-type:jwt&\
grant_type=urn:ietf:params:oauth:grant-type:token-exchange&\
scope=sql
EOF
)
# Substitute environment variables in the curl data
curl_data=$(eval echo "$curl_data")
# Print request details (except the token)
echo "Request URL: https://$DATABRICKS_HOST/oidc/v1/token"
echo "Request data: $(echo "$curl_data" | sed 's/subject_token=.*&/subject_token=REDACTED&/')"
# Make the request with detailed info
echo "Sending request..."
response=$(curl -v -s -X POST "https://$DATABRICKS_HOST/oidc/v1/token" \
--data-raw "$curl_data" \
-H "Content-Type: application/x-www-form-urlencoded" \
-H "Accept: application/json" \
2>&1)
# Extract and display results
echo "Response:"
echo "$response"
# Extract HTTP status if possible
status_code=$(echo "$response" | grep -o "< HTTP/[0-9.]* [0-9]*" | grep -o "[0-9]*$" || echo "unknown")
echo "HTTP Status Code: $status_code"
# Don't fail the workflow if curl fails
exit 0
- name: Create test script
run: |
cat > test_github_token_federation.py << 'EOF'
#!/usr/bin/env python3
"""
Test script for Databricks SQL token federation with GitHub Actions OIDC tokens.
This script demonstrates how to use the Databricks SQL connector with token federation
using a GitHub Actions OIDC token. It connects to a Databricks SQL warehouse,
runs a simple query, and shows the connected user.
"""
import os
import sys
import json
import base64
import requests
from databricks import sql
import time
def decode_jwt(token):
"""Decode and return the claims from a JWT token."""
try:
parts = token.split(".")
if len(parts) != 3:
raise ValueError("Invalid JWT format")
payload = parts[1]
# Add padding if needed
padding = '=' * (4 - len(payload) % 4)
payload += padding
decoded = base64.b64decode(payload)
return json.loads(decoded)
except Exception as e:
print(f"Failed to decode token: {str(e)}")
return None
def test_direct_token_exchange(host, token, client_id, audience=None):
"""Directly test token exchange with the Databricks API."""
try:
url = f"https://{host}/oidc/v1/token"
data = {
"client_id": client_id,
"subject_token": token,
"subject_token_type": "urn:ietf:params:oauth:token-type:jwt",
"grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
"scope": "sql",
"return_original_token_if_authenticated": "true"
}
headers = {
"Content-Type": "application/x-www-form-urlencoded",
"Accept": "application/json"
}
print(f"Testing direct token exchange with {url}")
print(f"Request parameters: {data}")
# Add debugging info
claims = decode_jwt(token)
if claims:
print(f"Token issuer: {claims.get('iss', 'unknown')}")
print(f"Token subject: {claims.get('sub', 'unknown')}")
print(f"Token audience: {claims.get('aud', 'unknown')}")
# If audience was specified in policy but doesn't match token
if audience and audience != claims.get('aud'):
print("WARNING: Expected audience and token audience don't match")
print(f"Expected: {audience}")
print(f"Actual: {claims.get('aud')}")
response = requests.post(url, data=data, headers=headers)
print(f"Status code: {response.status_code}")
print(f"Response headers: {dict(response.headers)}")
print(f"Response content: {response.text}")
if response.status_code == 200:
try:
return json.loads(response.text).get("access_token")
except json.JSONDecodeError:
print("Failed to parse response JSON")
return None
return None
except Exception as e:
print(f"Direct token exchange failed: {str(e)}")
return None
def main():
# Get GitHub OIDC token
github_token = os.environ.get("OIDC_TOKEN")
if not github_token:
print("GitHub OIDC token not available")
sys.exit(1)
# Get Databricks connection parameters
host = os.environ.get("DATABRICKS_HOST_FOR_TF")
http_path = os.environ.get("DATABRICKS_HTTP_PATH_FOR_TF")
identity_federation_client_id = os.environ.get("IDENTITY_FEDERATION_CLIENT_ID_FOR_TF")
if not host or not http_path:
print("Missing Databricks connection parameters")
sys.exit(1)
claims = decode_jwt(github_token)
if claims:
print("\n=== GitHub OIDC Token Claims ===")
print(f"Token issuer: {claims.get('iss')}")
print(f"Token subject: {claims.get('sub')}")
print(f"Token audience: {claims.get('aud')}")
print(f"Token expiration: {claims.get('exp', 'unknown')}")
print(f"Repository: {claims.get('repository', 'unknown')}")
print(f"Workflow ref: {claims.get('workflow_ref', 'unknown')}")
print(f"Event name: {claims.get('event_name', 'unknown')}")
print("===============================\n")
# Try token exchange with several possible audience values
audience_values = [
"https://github.com/databricks", # Standard audience for GitHub tokens
"https://github.com", # Alternative audience
None # No audience
]
# Direct token exchange test
access_token = None
for audience in audience_values:
print(f"\n=== Testing Direct Token Exchange (audience={audience}) ===")
result = test_direct_token_exchange(host, github_token, identity_federation_client_id, audience)
if result:
print("Direct token exchange successful!")
access_token = result
token_claims = decode_jwt(result)
if token_claims:
print(f"Databricks token subject: {token_claims.get('sub', 'unknown')}")
break
print(f"Token exchange failed with audience={audience}")
# Add a small delay between attempts
time.sleep(1)
if not access_token:
print("All token exchange attempts failed")
print("=====================================\n")
else:
print("=====================================\n")
try:
# Connect to Databricks using token federation
print(f"\n=== Testing Connection via Connector ===")
print(f"Connecting to Databricks at {host}{http_path}")
print(f"Using client ID: {identity_federation_client_id}")
connection_params = {
"server_hostname": host,
"http_path": http_path,
"access_token": github_token,
"auth_type": "token-federation",
"identity_federation_client_id": identity_federation_client_id,
}
print("Connection parameters:")
print(json.dumps({k: v if k != 'access_token' else '***' for k, v in connection_params.items()}, indent=2))
with sql.connect(**connection_params) as connection:
print("Connection established successfully")
# Execute a simple query
cursor = connection.cursor()
cursor.execute("SELECT 1 + 1 as result")
result = cursor.fetchall()
print(f"Query result: {result[0][0]}")
# Show current user
cursor.execute("SELECT current_user() as user")
result = cursor.fetchall()
print(f"Connected as user: {result[0][0]}")
print("Token federation test successful!")
return True
except Exception as e:
print(f"Error connecting to Databricks: {str(e)}")
print("===================================\n")
sys.exit(1)
if __name__ == "__main__":
main()
EOF
chmod +x test_github_token_federation.py
- name: Test token federation with GitHub OIDC token
env:
DATABRICKS_HOST_FOR_TF: ${{ github.event_name == 'workflow_dispatch' && inputs.databricks_host || secrets.DATABRICKS_HOST_FOR_TF }}
DATABRICKS_HTTP_PATH_FOR_TF: ${{ github.event_name == 'workflow_dispatch' && inputs.databricks_http_path || secrets.DATABRICKS_HTTP_PATH_FOR_TF }}
IDENTITY_FEDERATION_CLIENT_ID: ${{ github.event_name == 'workflow_dispatch' && inputs.identity_federation_client_id || secrets.IDENTITY_FEDERATION_CLIENT_ID }}
OIDC_TOKEN: ${{ steps.get-id-token.outputs.token }}
run: |
python test_github_token_federation.py