Skip to content

Commit 1185281

Browse files
authored
Databricks MCP: throw a better error for incompatible workspace clients (#263)
Signed-off-by: Bryan Qiu <bryan.qiu@databricks.com>
1 parent 6d3d287 commit 1185281

5 files changed

Lines changed: 171 additions & 14 deletions

File tree

databricks_mcp/src/databricks_mcp/mcp.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,28 @@
2323

2424
logger = logging.getLogger(__name__)
2525

26+
27+
def _is_databricks_apps_url(url: str) -> bool:
28+
"""Check if the URL is hosted on Databricks Apps."""
29+
parsed = urlparse(url)
30+
return parsed.netloc.endswith(".databricksapps.com")
31+
32+
33+
def _is_oauth_auth(workspace_client: WorkspaceClient) -> bool:
34+
"""Check if the workspace client is using OAuth authentication.
35+
36+
Uses the SDK's oauth_token() method to determine if OAuth is available.
37+
This is more resilient than checking auth_type directly, as it handles
38+
various non-OAuth auth types (pat, runtime, etc.).
39+
"""
40+
try:
41+
workspace_client.config.oauth_token()
42+
return True
43+
except ValueError:
44+
# oauth_token() raises ValueError when not using OAuth-based auth
45+
return False
46+
47+
2648
# MCP URL types
2749
UC_FUNCTIONS_MCP = "uc_functions_mcp"
2850
VECTOR_SEARCH_MCP = "vector_search_mcp"
@@ -123,6 +145,15 @@ def __init__(self, server_url: str, workspace_client: Optional[WorkspaceClient]
123145
self.client = workspace_client or WorkspaceClient()
124146
self.server_url = server_url
125147

148+
# Early detection: error if using non-OAuth auth with Databricks Apps
149+
if _is_databricks_apps_url(server_url) and not _is_oauth_auth(self.client):
150+
raise ValueError(
151+
"OAuth authentication is required for MCP servers hosted on Databricks Apps. "
152+
"Your current authentication method is not supported. "
153+
"Please use OAuth authentication instead. "
154+
"For more information: https://docs.databricks.com/aws/en/generative-ai/mcp/custom-mcp"
155+
)
156+
126157
def _get_databricks_managed_mcp_url_type(self) -> str:
127158
"""Determine the MCP URL type based on the path."""
128159
path = urlparse(self.server_url).path

databricks_mcp/src/databricks_mcp/oauth_provider.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from databricks.sdk import WorkspaceClient
2+
from databricks.sdk.errors.platform import PermissionDenied
23
from mcp.client.auth import OAuthClientProvider, TokenStorage
34
from mcp.shared.auth import OAuthToken
45

@@ -47,6 +48,17 @@ class DatabricksOAuthClientProvider(OAuthClientProvider):
4748

4849
def __init__(self, workspace_client: WorkspaceClient):
4950
self.workspace_client = workspace_client
51+
52+
# Pre-flight check: verify the workspace client has basic permissions
53+
try:
54+
workspace_client.current_user.me()
55+
except PermissionDenied as e:
56+
raise PermissionError(
57+
f"The workspace client does not have permission to access the Databricks workspace. "
58+
f"Please ensure the service principal or user has the required permissions to call Databricks APIs. "
59+
f"Original error: {e}"
60+
) from e
61+
5062
self.databricks_token_storage = DatabricksTokenStorage(workspace_client)
5163

5264
super().__init__(
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from databricks.sdk import WorkspaceClient
2+
3+
from databricks_mcp import DatabricksMCPClient
4+
5+
# Replace with your deployed app URL
6+
mcp_server_url = "https://mcp-chloe-test-6051921418418893.staging.aws.databricksapps.com/mcp"
7+
8+
workspace_client = WorkspaceClient(
9+
host="https://e2-dogfood.staging.cloud.databricks.com",
10+
token="<token>",
11+
)
12+
13+
# print(workspace_client.current_user.me())
14+
15+
mcp_client = DatabricksMCPClient(server_url=mcp_server_url, workspace_client=workspace_client)
16+
17+
# List available tools
18+
tools = mcp_client.list_tools()
19+
# print(f"Available tools: {tools}")
20+
21+
from databricks.sdk import WorkspaceClient
22+
23+
from databricks_mcp import DatabricksMCPClient
24+
25+
ws_client = WorkspaceClient(
26+
host="https://e2-dogfood.staging.cloud.databricks.com/",
27+
client_id="<client-id>",
28+
client_secret="<client-secret>",
29+
)
30+
31+
mcp_client = DatabricksMCPClient(
32+
server_url=mcp_server_url,
33+
workspace_client=ws_client,
34+
)
35+
# print(mcp_client.list_tools())

databricks_mcp/tests/unit_tests/test_mcp.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
UC_FUNCTIONS_MCP,
1212
VECTOR_SEARCH_MCP,
1313
DatabricksMCPClient,
14+
_is_databricks_apps_url,
15+
_is_oauth_auth,
1416
)
1517

1618

@@ -131,6 +133,7 @@ async def test_get_tools_async(self):
131133
with (
132134
patch("databricks_mcp.mcp.streamablehttp_client") as mock_client,
133135
patch("databricks_mcp.mcp.ClientSession") as mock_session_class,
136+
patch("databricks_mcp.mcp.DatabricksOAuthClientProvider"),
134137
):
135138
mock_client.return_value.__aenter__.return_value = (AsyncMock(), AsyncMock(), None)
136139
mock_session_class.return_value.__aenter__.return_value = mock_session
@@ -156,6 +159,7 @@ async def test_call_tools_async(self):
156159
with (
157160
patch("databricks_mcp.mcp.streamablehttp_client") as mock_client,
158161
patch("databricks_mcp.mcp.ClientSession") as mock_session_class,
162+
patch("databricks_mcp.mcp.DatabricksOAuthClientProvider"),
159163
):
160164
mock_client.return_value.__aenter__.return_value = (AsyncMock(), AsyncMock(), None)
161165
mock_session_class.return_value.__aenter__.return_value = mock_session
@@ -434,6 +438,7 @@ def test_error_decorator_managed_server_reraises_original(self):
434438
client, "_get_databricks_managed_mcp_url_type", return_value=UC_FUNCTIONS_MCP
435439
),
436440
patch("databricks_mcp.mcp.streamablehttp_client") as mock_client,
441+
patch("databricks_mcp.mcp.DatabricksOAuthClientProvider"),
437442
patch("requests.request") as mock_request,
438443
):
439444
mock_client.side_effect = original_error
@@ -442,3 +447,57 @@ def test_error_decorator_managed_server_reraises_original(self):
442447
client.list_tools()
443448

444449
mock_request.assert_not_called()
450+
451+
452+
class TestIsDatabricksAppsUrl:
453+
"""Test cases for _is_databricks_apps_url helper function."""
454+
455+
@pytest.mark.parametrize(
456+
"url,expected",
457+
[
458+
("https://my-app.staging.aws.databricksapps.com/mcp", True),
459+
("https://my-app.prod.azure.databricksapps.com/mcp", True),
460+
("https://my-app.databricksapps.com", True),
461+
("https://test.cloud.databricks.com/api/2.0/mcp/functions/a/b", False),
462+
("https://custom-server.example.com/mcp", False),
463+
("https://databricksapps.com.evil.com/mcp", False),
464+
("https://notdatabricksapps.com/mcp", False),
465+
],
466+
)
467+
def test_is_databricks_apps_url(self, url, expected):
468+
assert _is_databricks_apps_url(url) == expected
469+
470+
471+
class TestIsOauthAuth:
472+
@pytest.mark.parametrize(
473+
"side_effect,expected",
474+
[
475+
(None, True), # oauth_token succeeds
476+
(ValueError("not available"), False), # oauth_token raises
477+
],
478+
)
479+
def test_is_oauth_auth(self, side_effect, expected):
480+
mock_client = MagicMock(spec=WorkspaceClient)
481+
if side_effect:
482+
mock_client.config.oauth_token.side_effect = side_effect
483+
assert _is_oauth_auth(mock_client) is expected
484+
485+
486+
class TestDatabricksMCPClientOAuthValidation:
487+
@pytest.mark.parametrize("auth_type", ["pat", "runtime"])
488+
def test_raises_error_for_non_oauth_with_databricks_apps(self, auth_type):
489+
mock_client = MagicMock(spec=WorkspaceClient)
490+
mock_client.config.oauth_token.side_effect = ValueError(f"not available for {auth_type}")
491+
with pytest.raises(ValueError, match="OAuth authentication is required"):
492+
DatabricksMCPClient("https://my-app.databricksapps.com/mcp", mock_client)
493+
494+
def test_allows_oauth_with_databricks_apps(self):
495+
mock_client = MagicMock(spec=WorkspaceClient)
496+
client = DatabricksMCPClient("https://my-app.databricksapps.com/mcp", mock_client)
497+
assert client.server_url == "https://my-app.databricksapps.com/mcp"
498+
499+
def test_allows_non_oauth_with_non_databricks_apps(self):
500+
mock_client = MagicMock(spec=WorkspaceClient)
501+
mock_client.config.oauth_token.side_effect = ValueError("not available")
502+
client = DatabricksMCPClient("https://test.com/api/2.0/mcp/functions/a/b", mock_client)
503+
assert client.server_url == "https://test.com/api/2.0/mcp/functions/a/b"
Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,54 @@
1-
from unittest.mock import patch
1+
from unittest.mock import MagicMock, patch
22

33
import pytest
44
from databricks.sdk import WorkspaceClient
5+
from databricks.sdk.errors.platform import PermissionDenied
56

67
from databricks_mcp import DatabricksOAuthClientProvider
78

89

910
@pytest.mark.asyncio
1011
async def test_oauth_provider():
1112
workspace_client = WorkspaceClient(host="https://test-databricks.com", token="test-token")
12-
provider = DatabricksOAuthClientProvider(workspace_client=workspace_client)
13-
oauth_token = await provider.context.storage.get_tokens()
14-
assert oauth_token.access_token == "test-token"
15-
assert oauth_token.expires_in == 60
16-
assert oauth_token.token_type.lower() == "bearer"
13+
with patch.object(workspace_client.current_user, "me", return_value=MagicMock()):
14+
provider = DatabricksOAuthClientProvider(workspace_client=workspace_client)
15+
oauth_token = await provider.context.storage.get_tokens()
16+
assert oauth_token.access_token == "test-token"
17+
assert oauth_token.expires_in == 60
18+
assert oauth_token.token_type.lower() == "bearer"
1719

1820

1921
@pytest.mark.asyncio
2022
async def test_authenticate_raises_exception():
2123
workspace_client = WorkspaceClient(host="https://test-databricks.com", token="test-token")
2224

25+
with patch.object(workspace_client.current_user, "me", return_value=MagicMock()):
26+
with patch.object(
27+
workspace_client.config, "authenticate", return_value={"Authorization": "Basic abc123"}
28+
):
29+
with pytest.raises(
30+
ValueError, match="Invalid authentication token format. Expected Bearer token."
31+
):
32+
provider = DatabricksOAuthClientProvider(workspace_client=workspace_client)
33+
34+
oauth_token = await provider.context.storage.get_tokens()
35+
assert oauth_token.access_token == "test-token"
36+
assert oauth_token.expires_in == 60
37+
assert oauth_token.token_type.lower() == "bearer"
38+
39+
40+
def test_preflight_check_raises_permission_denied():
41+
workspace_client = WorkspaceClient(host="https://test-databricks.com", token="test-token")
42+
2343
with patch.object(
24-
workspace_client.config, "authenticate", return_value={"Authorization": "Basic abc123"}
44+
workspace_client.current_user,
45+
"me",
46+
side_effect=PermissionDenied(
47+
"This API is disabled for users without the workspace-access entitlement."
48+
),
2549
):
2650
with pytest.raises(
27-
ValueError, match="Invalid authentication token format. Expected Bearer token."
51+
PermissionError,
52+
match="The workspace client does not have permission to access the Databricks workspace",
2853
):
29-
provider = DatabricksOAuthClientProvider(workspace_client=workspace_client)
30-
31-
oauth_token = await provider.context.storage.get_tokens()
32-
assert oauth_token.access_token == "test-token"
33-
assert oauth_token.expires_in == 60
34-
assert oauth_token.token_type.lower() == "bearer"
54+
DatabricksOAuthClientProvider(workspace_client=workspace_client)

0 commit comments

Comments
 (0)