|
19 | 19 | _get_authorized_http_client, |
20 | 20 | _get_openai_api_key, |
21 | 21 | _should_strip_strict, |
| 22 | + _strip_strict_from_kwargs, |
22 | 23 | _strip_strict_from_tools, |
23 | 24 | _validate_oauth_for_apps, |
24 | 25 | _wrap_app_error, |
@@ -171,6 +172,35 @@ def test_bearer_auth_flow(self, mock_workspace_client): |
171 | 172 | class TestStrictFieldStripping: |
172 | 173 | """Tests for strict field stripping helper functions.""" |
173 | 174 |
|
| 175 | + def test_strip_strict_from_kwargs_removes_top_level_strict_only(self): |
| 176 | + kwargs = { |
| 177 | + "strict": True, |
| 178 | + "model": "databricks-claude-3-7-sonnet", |
| 179 | + "temperature": 0.2, |
| 180 | + "tools": [{"type": "function", "function": {"name": "test", "strict": True}}], |
| 181 | + } |
| 182 | + |
| 183 | + result = _strip_strict_from_kwargs(kwargs) |
| 184 | + |
| 185 | + assert "strict" not in result |
| 186 | + assert result["model"] == "databricks-claude-3-7-sonnet" |
| 187 | + assert result["temperature"] == 0.2 |
| 188 | + assert result["tools"][0]["function"]["strict"] is True |
| 189 | + |
| 190 | + def test_strip_strict_from_kwargs_is_noop_when_strict_absent(self): |
| 191 | + kwargs = { |
| 192 | + "model": "databricks-gpt-4o", |
| 193 | + "temperature": 0.2, |
| 194 | + "tools": [{"type": "function", "function": {"name": "test", "strict": True}}], |
| 195 | + } |
| 196 | + |
| 197 | + expected = kwargs.copy() |
| 198 | + result = _strip_strict_from_kwargs(kwargs) |
| 199 | + |
| 200 | + assert result is kwargs |
| 201 | + assert result == expected |
| 202 | + assert result["tools"][0]["function"]["strict"] is True |
| 203 | + |
174 | 204 | def test_strip_strict_from_tools_removes_strict(self): |
175 | 205 | tools = [ |
176 | 206 | {"type": "function", "function": {"name": "test", "strict": True, "parameters": {}}} |
@@ -310,6 +340,31 @@ def test_chat_completions_works_without_tools(self): |
310 | 340 | ) |
311 | 341 | mock_create.assert_called_once() |
312 | 342 |
|
| 343 | + def test_chat_completions_strips_top_level_strict_kwarg(self): |
| 344 | + with patch("databricks_openai.utils.clients.WorkspaceClient") as mock_ws: |
| 345 | + mock_client = MagicMock(spec=WorkspaceClient) |
| 346 | + mock_client.config.host = "https://test.databricks.com" |
| 347 | + mock_client.config.authenticate.return_value = {"Authorization": "Bearer token"} |
| 348 | + mock_ws.return_value = mock_client |
| 349 | + |
| 350 | + client = DatabricksOpenAI() |
| 351 | + |
| 352 | + with patch.object(Completions, "create") as mock_create: |
| 353 | + mock_create.return_value = MagicMock() |
| 354 | + request_kwargs = cast( |
| 355 | + Any, |
| 356 | + { |
| 357 | + "model": "databricks-gpt-4o", |
| 358 | + "messages": [{"role": "user", "content": "hi"}], |
| 359 | + "strict": True, |
| 360 | + }, |
| 361 | + ) |
| 362 | + client.chat.completions.create(**request_kwargs) |
| 363 | + |
| 364 | + call_kwargs = mock_create.call_args.kwargs |
| 365 | + assert "strict" not in call_kwargs |
| 366 | + assert call_kwargs["model"] == "databricks-gpt-4o" |
| 367 | + |
313 | 368 |
|
314 | 369 | class TestAsyncDatabricksOpenAIStrictStripping: |
315 | 370 | """Tests for strict stripping in AsyncDatabricksOpenAI.""" |
@@ -360,6 +415,31 @@ async def test_chat_completions_preserves_strict_for_gpt(self): |
360 | 415 | call_kwargs = mock_create.call_args.kwargs |
361 | 416 | assert call_kwargs["tools"][0]["function"]["strict"] is True |
362 | 417 |
|
| 418 | + @pytest.mark.asyncio |
| 419 | + async def test_chat_completions_strips_top_level_strict_kwarg(self): |
| 420 | + with patch("databricks_openai.utils.clients.WorkspaceClient") as mock_ws: |
| 421 | + mock_client = MagicMock(spec=WorkspaceClient) |
| 422 | + mock_client.config.host = "https://test.databricks.com" |
| 423 | + mock_client.config.authenticate.return_value = {"Authorization": "Bearer token"} |
| 424 | + mock_ws.return_value = mock_client |
| 425 | + |
| 426 | + client = AsyncDatabricksOpenAI() |
| 427 | + |
| 428 | + with patch.object(AsyncCompletions, "create", new_callable=AsyncMock) as mock_create: |
| 429 | + request_kwargs = cast( |
| 430 | + Any, |
| 431 | + { |
| 432 | + "model": "databricks-gpt-4o", |
| 433 | + "messages": [{"role": "user", "content": "hi"}], |
| 434 | + "strict": True, |
| 435 | + }, |
| 436 | + ) |
| 437 | + await client.chat.completions.create(**request_kwargs) |
| 438 | + |
| 439 | + call_kwargs = mock_create.call_args.kwargs |
| 440 | + assert "strict" not in call_kwargs |
| 441 | + assert call_kwargs["model"] == "databricks-gpt-4o" |
| 442 | + |
363 | 443 |
|
364 | 444 | class TestDatabricksAppsSupport: |
365 | 445 | """Tests for Databricks Apps support.""" |
|
0 commit comments