|
1 | | -from unittest.mock import MagicMock, patch |
| 1 | +from unittest.mock import patch |
2 | 2 |
|
3 | 3 | import pytest |
4 | 4 | from databricks.sdk.service.dashboards import GenieSpace |
5 | 5 | from databricks_ai_bridge.genie import Genie, GenieResponse |
6 | 6 | from langchain_core.messages import AIMessage |
7 | | -from mcp.types import CallToolResult |
8 | 7 |
|
9 | 8 | from databricks_langchain.genie import ( |
10 | 9 | GenieAgent, |
|
13 | 12 | ) |
14 | 13 |
|
15 | 14 |
|
16 | | -@pytest.fixture(autouse=True) |
17 | | -def mock_databricks_oauth_provider(): |
18 | | - """Auto-mock DatabricksOAuthClientProvider for all tests to avoid OAuth validation errors.""" |
19 | | - with patch("databricks_mcp.mcp.DatabricksOAuthClientProvider") as mock_auth: |
20 | | - # Return a MagicMock instance that won't try to get OAuth tokens |
21 | | - mock_auth_instance = MagicMock() |
22 | | - mock_auth.return_value = mock_auth_instance |
23 | | - yield mock_auth |
24 | | - |
25 | | - |
26 | 15 | def test_concat_messages_array(): |
27 | 16 | # Test a simple case with multiple messages |
28 | 17 | messages = [ |
@@ -74,8 +63,8 @@ def test_query_genie_as_agent(MockWorkspaceClient): |
74 | 63 | input_data = {"messages": [{"role": "user", "content": "What is the weather?"}]} |
75 | 64 | genie = Genie("space-id", MockWorkspaceClient) |
76 | 65 |
|
77 | | - # Mock the ask_question method at the module level to avoid mlflow tracing issues |
78 | | - with patch("databricks_ai_bridge.genie.Genie.ask_question", return_value=mock_genie_response): |
| 66 | + # Mock the ask_question method to return our mock response |
| 67 | + with patch.object(genie, "ask_question", return_value=mock_genie_response): |
79 | 68 | # Test with include_context=False (default) |
80 | 69 | result = _query_genie_as_agent(input_data, genie, "Genie") |
81 | 70 | expected_message = { |
@@ -152,8 +141,8 @@ def test_query_genie_with_client(mock_workspace_client): |
152 | 141 | input_data = {"messages": [{"role": "user", "content": "What is the weather?"}]} |
153 | 142 | genie = Genie("space-id", mock_workspace_client) |
154 | 143 |
|
155 | | - # Mock the ask_question method at the module level to avoid mlflow tracing issues |
156 | | - with patch("databricks_ai_bridge.genie.Genie.ask_question", return_value=mock_genie_response): |
| 144 | + # Mock the ask_question method to return our mock response |
| 145 | + with patch.object(genie, "ask_question", return_value=mock_genie_response): |
157 | 146 | result = _query_genie_as_agent(input_data, genie, "Genie") |
158 | 147 | expected_message = { |
159 | 148 | "messages": [AIMessage(content="It is sunny.", name="query_result")], |
|
0 commit comments