Skip to content

Commit 32cebf4

Browse files
authored
Add support for AI Gateway native provider endpoints with databricks_openai (#390)
1 parent 8d72cae commit 32cebf4

3 files changed

Lines changed: 218 additions & 10 deletions

File tree

integrations/openai/src/databricks_openai/utils/clients.py

Lines changed: 65 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -109,15 +109,15 @@ def _fix_empty_assistant_content_in_messages(messages: Any) -> None:
109109
message["content"] = " "
110110

111111

112-
def _get_ai_gateway_base_url(
112+
def _discover_ai_gateway_host(
113113
http_client: Client,
114114
host: str,
115115
) -> str | None:
116-
"""Check if AI Gateway V2 is enabled and return its base URL.
116+
"""Discover the AI Gateway host URL (scheme + netloc only).
117117
118118
Calls GET /api/ai-gateway/v2/endpoints. If successful and endpoints exist,
119-
extracts the ai_gateway_url from the first endpoint response.
120-
Returns None if gateway is not available.
119+
extracts the ai_gateway_url from the first endpoint response and returns
120+
just the scheme + netloc (no path). Returns None if gateway is not available.
121121
"""
122122
try:
123123
response = http_client.get(f"{host}/api/ai-gateway/v2/endpoints")
@@ -131,24 +131,53 @@ def _get_ai_gateway_base_url(
131131
if not gateway_url:
132132
return None
133133
parsed = urlparse(gateway_url)
134-
return f"{parsed.scheme}://{parsed.netloc}/mlflow/v1"
134+
return f"{parsed.scheme}://{parsed.netloc}"
135135
except Exception:
136136
return None
137137

138138

139+
def _get_ai_gateway_base_url(
140+
http_client: Client,
141+
host: str,
142+
) -> str | None:
143+
"""Check if AI Gateway V2 is enabled and return its MLflow base URL.
144+
145+
Returns the AI Gateway base URL with /mlflow/v1 path appended, or None if
146+
the gateway is not available.
147+
"""
148+
gateway_host = _discover_ai_gateway_host(http_client, host)
149+
return f"{gateway_host}/mlflow/v1" if gateway_host else None
150+
151+
139152
def _resolve_base_url(
140153
workspace_client: WorkspaceClient,
141154
base_url: str | None,
142155
use_ai_gateway: bool,
143156
http_client: Client,
157+
use_ai_gateway_native_api: bool,
144158
) -> str:
145159
"""Resolve the target base URL for the OpenAI client."""
160+
if use_ai_gateway_native_api and base_url is not None:
161+
raise ValueError("Cannot specify both 'use_ai_gateway_native_api' and 'base_url'.")
162+
if use_ai_gateway_native_api and use_ai_gateway:
163+
raise ValueError("Cannot specify both 'use_ai_gateway_native_api' and 'use_ai_gateway'.")
164+
146165
if base_url is not None:
147166
if _DATABRICKS_APPS_DOMAIN in base_url:
148167
_validate_oauth_for_apps(workspace_client)
149168
return base_url
150169

151-
# Prioritize using AI Gateway endpoints
170+
# Native provider API via AI Gateway (e.g. OpenAI-compatible /openai path)
171+
if use_ai_gateway_native_api:
172+
gateway_host = _discover_ai_gateway_host(http_client, workspace_client.config.host)
173+
if gateway_host:
174+
return f"{gateway_host}/openai/v1"
175+
raise ValueError(
176+
"Please ensure AI Gateway V2 is enabled for the workspace "
177+
"when use_ai_gateway_native_api is set to True."
178+
)
179+
180+
# MLflow-format AI Gateway endpoints
152181
if use_ai_gateway:
153182
gateway_url = _get_ai_gateway_base_url(http_client, workspace_client.config.host)
154183
if gateway_url:
@@ -362,8 +391,12 @@ class DatabricksOpenAI(OpenAI):
362391
base_url: Optional base URL to override the default serving endpoints URL. When the URL
363392
points to a Databricks App (contains "databricksapps"), OAuth authentication is
364393
required.
394+
use_ai_gateway_native_api: If True, auto-detect AI Gateway V2 and route requests through
395+
its native OpenAI-compatible API (``<ai_gateway_url>/openai/v1``). This allows use of
396+
provider-native features not available through the MLflow API. Cannot be combined
397+
with ``base_url`` or ``use_ai_gateway``. Defaults to False.
365398
use_ai_gateway: If True, auto-detect AI Gateway V2 availability and route
366-
requests through it. Defaults to False.
399+
requests through it using the MLflow API. Defaults to False.
367400
368401
Example - Query a serving or AI gateway endpoint:
369402
>>> client = DatabricksOpenAI()
@@ -372,6 +405,13 @@ class DatabricksOpenAI(OpenAI):
372405
... messages=[{"role": "user", "content": "Hello!"}],
373406
... )
374407
408+
Example - Query AI Gateway endpoints via the native OpenAI-compatible API:
409+
>>> client = DatabricksOpenAI(use_ai_gateway_native_api=True)
410+
>>> response = client.chat.completions.create(
411+
... model="databricks-meta-llama-3-1-70b-instruct",
412+
... messages=[{"role": "user", "content": "Hello!"}],
413+
... )
414+
375415
Example - Query a Databricks App directly by URL (requires OAuth):
376416
>>> # WorkspaceClient must be configured with OAuth authentication
377417
>>> # See: https://docs.databricks.com/aws/en/dev-tools/auth/oauth-u2m.html
@@ -397,6 +437,7 @@ def __init__(
397437
self,
398438
workspace_client: WorkspaceClient | None = None,
399439
base_url: str | None = None,
440+
use_ai_gateway_native_api: bool = False,
400441
use_ai_gateway: bool = False,
401442
):
402443
if workspace_client is None:
@@ -405,7 +446,9 @@ def __init__(
405446
self._workspace_client = workspace_client
406447

407448
http_client = _get_authorized_http_client(workspace_client)
408-
target_base_url = _resolve_base_url(workspace_client, base_url, use_ai_gateway, http_client)
449+
target_base_url = _resolve_base_url(
450+
workspace_client, base_url, use_ai_gateway, http_client, use_ai_gateway_native_api
451+
)
409452

410453
# Authentication is handled via http_client, not api_key
411454
super().__init__(
@@ -510,8 +553,12 @@ class AsyncDatabricksOpenAI(AsyncOpenAI):
510553
base_url: Optional base URL to override the default serving endpoints URL. When the URL
511554
points to a Databricks App (contains "databricksapps"), OAuth authentication is
512555
required.
556+
use_ai_gateway_native_api: If True, auto-detect AI Gateway V2 and route requests through
557+
its native OpenAI-compatible API (``<ai_gateway_url>/openai/v1``). This allows use of
558+
provider-native features not available through the MLflow API. Cannot be combined
559+
with ``base_url`` or ``use_ai_gateway``. Defaults to False.
513560
use_ai_gateway: If True, auto-detect AI Gateway V2 availability and route
514-
requests through it. Defaults to False.
561+
requests through it using the MLflow API. Defaults to False.
515562
516563
Example - Query a serving or AI gateway endpoint:
517564
>>> client = AsyncDatabricksOpenAI()
@@ -520,6 +567,13 @@ class AsyncDatabricksOpenAI(AsyncOpenAI):
520567
... messages=[{"role": "user", "content": "Hello!"}],
521568
... )
522569
570+
Example - Query AI Gateway endpoints via the native OpenAI-compatible API:
571+
>>> client = AsyncDatabricksOpenAI(use_ai_gateway_native_api=True)
572+
>>> response = await client.chat.completions.create(
573+
... model="databricks-meta-llama-3-1-70b-instruct",
574+
... messages=[{"role": "user", "content": "Hello!"}],
575+
... )
576+
523577
Example - Query a Databricks App directly by URL (requires OAuth):
524578
>>> # WorkspaceClient must be configured with OAuth authentication
525579
>>> # See: https://docs.databricks.com/aws/en/dev-tools/auth/oauth-u2m.html
@@ -545,6 +599,7 @@ def __init__(
545599
self,
546600
workspace_client: WorkspaceClient | None = None,
547601
base_url: str | None = None,
602+
use_ai_gateway_native_api: bool = False,
548603
use_ai_gateway: bool = False,
549604
):
550605
if workspace_client is None:
@@ -554,7 +609,7 @@ def __init__(
554609

555610
sync_http_client = _get_authorized_http_client(workspace_client)
556611
target_base_url = _resolve_base_url(
557-
workspace_client, base_url, use_ai_gateway, sync_http_client
612+
workspace_client, base_url, use_ai_gateway, sync_http_client, use_ai_gateway_native_api
558613
)
559614

560615
# Authentication is handled via http_client, not api_key
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
"""
2+
Integration tests for DatabricksOpenAI with use_ai_gateway_native_api=True.
3+
4+
Prerequisites:
5+
- AI Gateway V2 must be enabled on the test workspace.
6+
- Set DATABRICKS_CONFIG_PROFILE (or use default credentials).
7+
8+
Run with:
9+
RUN_AI_GATEWAY_NATIVE_API_TESTS=1 python -m pytest \
10+
tests/integration_tests/test_ai_gateway_native_api.py -v
11+
"""
12+
13+
from __future__ import annotations
14+
15+
import os
16+
17+
import pytest
18+
import pytest_asyncio
19+
from databricks.sdk import WorkspaceClient
20+
21+
from databricks_openai import AsyncDatabricksOpenAI, DatabricksOpenAI
22+
23+
pytestmark = pytest.mark.skipif(
24+
os.environ.get("RUN_AI_GATEWAY_NATIVE_API_TESTS") != "1",
25+
reason="AI Gateway native API tests disabled. Set RUN_AI_GATEWAY_NATIVE_API_TESTS=1 to enable.",
26+
)
27+
28+
_TEST_MODEL = os.environ.get("AI_GATEWAY_NATIVE_API_MODEL", "databricks-gpt-5-4")
29+
_TEST_INPUT = [{"role": "user", "content": "Reply with exactly the word PONG and nothing else."}]
30+
31+
32+
@pytest.fixture(scope="module")
33+
def workspace_client():
34+
return WorkspaceClient()
35+
36+
37+
@pytest.fixture(scope="module")
38+
def sync_client(workspace_client):
39+
return DatabricksOpenAI(workspace_client=workspace_client, use_ai_gateway_native_api=True)
40+
41+
42+
@pytest_asyncio.fixture(scope="module")
43+
async def async_client(workspace_client):
44+
return AsyncDatabricksOpenAI(workspace_client=workspace_client, use_ai_gateway_native_api=True)
45+
46+
47+
class TestAIGatewayNativeAPISync:
48+
def test_base_url_uses_openai_path(self, sync_client):
49+
assert "/openai/v1" in str(sync_client.base_url)
50+
assert "ai-gateway" in str(sync_client.base_url)
51+
52+
def test_responses(self, sync_client):
53+
response = sync_client.responses.create(
54+
model=_TEST_MODEL,
55+
input=_TEST_INPUT,
56+
max_output_tokens=50,
57+
)
58+
assert response.output_text is not None
59+
60+
61+
@pytest.mark.asyncio
62+
class TestAIGatewayNativeAPIAsync:
63+
async def test_base_url_uses_openai_path(self, async_client):
64+
assert "/openai/v1" in str(async_client.base_url)
65+
assert "ai-gateway" in str(async_client.base_url)
66+
67+
async def test_responses(self, async_client):
68+
response = await async_client.responses.create(
69+
model=_TEST_MODEL,
70+
input=_TEST_INPUT,
71+
max_output_tokens=50,
72+
)
73+
assert response.output_text is not None

integrations/openai/tests/unit_tests/test_clients.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -893,3 +893,83 @@ def test_explicit_base_url_skips_gateway_check(self, client_cls_name, mock_works
893893
)
894894
mock_gateway.assert_not_called()
895895
assert "custom.example.com" in str(client.base_url)
896+
897+
898+
class TestAIGatewayNativeAPI:
899+
"""Tests for use_ai_gateway_native_api parameter in DatabricksOpenAI and AsyncDatabricksOpenAI."""
900+
901+
@pytest.mark.parametrize("client_cls_name", ["DatabricksOpenAI", "AsyncDatabricksOpenAI"])
902+
def test_native_api_uses_openai_base_path(self, client_cls_name, mock_workspace_client):
903+
client_cls = (
904+
DatabricksOpenAI if client_cls_name == "DatabricksOpenAI" else AsyncDatabricksOpenAI
905+
)
906+
with patch(
907+
"databricks_openai.utils.clients._discover_ai_gateway_host",
908+
return_value="https://12345.ai-gateway.cloud.databricks.com",
909+
):
910+
client = client_cls(
911+
workspace_client=mock_workspace_client,
912+
use_ai_gateway_native_api=True,
913+
)
914+
assert "12345.ai-gateway.cloud.databricks.com" in str(client.base_url)
915+
assert "/openai/v1" in str(client.base_url)
916+
917+
@pytest.mark.parametrize("client_cls_name", ["DatabricksOpenAI", "AsyncDatabricksOpenAI"])
918+
def test_native_api_unavailable_raises_error(self, client_cls_name, mock_workspace_client):
919+
client_cls = (
920+
DatabricksOpenAI if client_cls_name == "DatabricksOpenAI" else AsyncDatabricksOpenAI
921+
)
922+
with patch(
923+
"databricks_openai.utils.clients._discover_ai_gateway_host",
924+
return_value=None,
925+
):
926+
with pytest.raises(ValueError, match="Please ensure AI Gateway V2 is enabled"):
927+
client_cls(
928+
workspace_client=mock_workspace_client,
929+
use_ai_gateway_native_api=True,
930+
)
931+
932+
@pytest.mark.parametrize("client_cls_name", ["DatabricksOpenAI", "AsyncDatabricksOpenAI"])
933+
def test_native_api_and_base_url_raises(self, client_cls_name, mock_workspace_client):
934+
client_cls = (
935+
DatabricksOpenAI if client_cls_name == "DatabricksOpenAI" else AsyncDatabricksOpenAI
936+
)
937+
with pytest.raises(
938+
ValueError, match="Cannot specify both 'use_ai_gateway_native_api' and 'base_url'"
939+
):
940+
client_cls(
941+
workspace_client=mock_workspace_client,
942+
use_ai_gateway_native_api=True,
943+
base_url="https://custom.example.com/v1",
944+
)
945+
946+
@pytest.mark.parametrize("client_cls_name", ["DatabricksOpenAI", "AsyncDatabricksOpenAI"])
947+
def test_native_api_and_use_ai_gateway_raises(self, client_cls_name, mock_workspace_client):
948+
client_cls = (
949+
DatabricksOpenAI if client_cls_name == "DatabricksOpenAI" else AsyncDatabricksOpenAI
950+
)
951+
with pytest.raises(
952+
ValueError,
953+
match="Cannot specify both 'use_ai_gateway_native_api' and 'use_ai_gateway'",
954+
):
955+
client_cls(
956+
workspace_client=mock_workspace_client,
957+
use_ai_gateway_native_api=True,
958+
use_ai_gateway=True,
959+
)
960+
961+
def test_discover_ai_gateway_host_strips_path(self):
962+
"""_discover_ai_gateway_host returns only scheme+netloc, stripping any path."""
963+
from unittest.mock import MagicMock
964+
965+
from databricks_openai.utils.clients import _discover_ai_gateway_host
966+
967+
mock_http = MagicMock()
968+
mock_http.get.return_value.status_code = 200
969+
mock_http.get.return_value.json.return_value = {
970+
"endpoints": [
971+
{"ai_gateway_url": "https://12345.ai-gateway.cloud.databricks.com/some/path"}
972+
]
973+
}
974+
result = _discover_ai_gateway_host(mock_http, "https://test.databricks.com")
975+
assert result == "https://12345.ai-gateway.cloud.databricks.com"

0 commit comments

Comments
 (0)