Skip to content

Commit a749d4f

Browse files
authored
databricks-openai - Remove injection of 'strict' kwarg from Langchain create_agent (#372)
1 parent d6862fa commit a749d4f

2 files changed

Lines changed: 88 additions & 0 deletions

File tree

integrations/openai/src/databricks_openai/utils/clients.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,12 @@ def _strip_strict_from_tools(tools: Any) -> Any:
5151
return tools
5252

5353

54+
def _strip_strict_from_kwargs(kwargs: dict) -> dict:
55+
"""Strip 'strict' from top-level kwargs which causes issues for GPT models."""
56+
kwargs.pop("strict", None) # Remove top-level strict if present
57+
return kwargs
58+
59+
5460
def _should_strip_strict(model: str | None) -> bool:
5561
"""Determine if strict should be stripped based on model name.
5662
@@ -253,6 +259,7 @@ def create(self, **kwargs):
253259
_strip_strict_from_tools(kwargs.get("tools"))
254260
if _is_claude_model(model):
255261
_fix_empty_assistant_content_in_messages(kwargs.get("messages"))
262+
kwargs = _strip_strict_from_kwargs(kwargs)
256263
return super().create(**kwargs)
257264

258265

@@ -436,6 +443,7 @@ async def create(self, **kwargs):
436443
_strip_strict_from_tools(kwargs.get("tools"))
437444
if _is_claude_model(model):
438445
_fix_empty_assistant_content_in_messages(kwargs.get("messages"))
446+
kwargs = _strip_strict_from_kwargs(kwargs)
439447
return await super().create(**kwargs)
440448

441449

integrations/openai/tests/unit_tests/test_clients.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
_get_authorized_http_client,
2020
_get_openai_api_key,
2121
_should_strip_strict,
22+
_strip_strict_from_kwargs,
2223
_strip_strict_from_tools,
2324
_validate_oauth_for_apps,
2425
_wrap_app_error,
@@ -171,6 +172,35 @@ def test_bearer_auth_flow(self, mock_workspace_client):
171172
class TestStrictFieldStripping:
172173
"""Tests for strict field stripping helper functions."""
173174

175+
def test_strip_strict_from_kwargs_removes_top_level_strict_only(self):
176+
kwargs = {
177+
"strict": True,
178+
"model": "databricks-claude-3-7-sonnet",
179+
"temperature": 0.2,
180+
"tools": [{"type": "function", "function": {"name": "test", "strict": True}}],
181+
}
182+
183+
result = _strip_strict_from_kwargs(kwargs)
184+
185+
assert "strict" not in result
186+
assert result["model"] == "databricks-claude-3-7-sonnet"
187+
assert result["temperature"] == 0.2
188+
assert result["tools"][0]["function"]["strict"] is True
189+
190+
def test_strip_strict_from_kwargs_is_noop_when_strict_absent(self):
191+
kwargs = {
192+
"model": "databricks-gpt-4o",
193+
"temperature": 0.2,
194+
"tools": [{"type": "function", "function": {"name": "test", "strict": True}}],
195+
}
196+
197+
expected = kwargs.copy()
198+
result = _strip_strict_from_kwargs(kwargs)
199+
200+
assert result is kwargs
201+
assert result == expected
202+
assert result["tools"][0]["function"]["strict"] is True
203+
174204
def test_strip_strict_from_tools_removes_strict(self):
175205
tools = [
176206
{"type": "function", "function": {"name": "test", "strict": True, "parameters": {}}}
@@ -310,6 +340,31 @@ def test_chat_completions_works_without_tools(self):
310340
)
311341
mock_create.assert_called_once()
312342

343+
def test_chat_completions_strips_top_level_strict_kwarg(self):
344+
with patch("databricks_openai.utils.clients.WorkspaceClient") as mock_ws:
345+
mock_client = MagicMock(spec=WorkspaceClient)
346+
mock_client.config.host = "https://test.databricks.com"
347+
mock_client.config.authenticate.return_value = {"Authorization": "Bearer token"}
348+
mock_ws.return_value = mock_client
349+
350+
client = DatabricksOpenAI()
351+
352+
with patch.object(Completions, "create") as mock_create:
353+
mock_create.return_value = MagicMock()
354+
request_kwargs = cast(
355+
Any,
356+
{
357+
"model": "databricks-gpt-4o",
358+
"messages": [{"role": "user", "content": "hi"}],
359+
"strict": True,
360+
},
361+
)
362+
client.chat.completions.create(**request_kwargs)
363+
364+
call_kwargs = mock_create.call_args.kwargs
365+
assert "strict" not in call_kwargs
366+
assert call_kwargs["model"] == "databricks-gpt-4o"
367+
313368

314369
class TestAsyncDatabricksOpenAIStrictStripping:
315370
"""Tests for strict stripping in AsyncDatabricksOpenAI."""
@@ -360,6 +415,31 @@ async def test_chat_completions_preserves_strict_for_gpt(self):
360415
call_kwargs = mock_create.call_args.kwargs
361416
assert call_kwargs["tools"][0]["function"]["strict"] is True
362417

418+
@pytest.mark.asyncio
419+
async def test_chat_completions_strips_top_level_strict_kwarg(self):
420+
with patch("databricks_openai.utils.clients.WorkspaceClient") as mock_ws:
421+
mock_client = MagicMock(spec=WorkspaceClient)
422+
mock_client.config.host = "https://test.databricks.com"
423+
mock_client.config.authenticate.return_value = {"Authorization": "Bearer token"}
424+
mock_ws.return_value = mock_client
425+
426+
client = AsyncDatabricksOpenAI()
427+
428+
with patch.object(AsyncCompletions, "create", new_callable=AsyncMock) as mock_create:
429+
request_kwargs = cast(
430+
Any,
431+
{
432+
"model": "databricks-gpt-4o",
433+
"messages": [{"role": "user", "content": "hi"}],
434+
"strict": True,
435+
},
436+
)
437+
await client.chat.completions.create(**request_kwargs)
438+
439+
call_kwargs = mock_create.call_args.kwargs
440+
assert "strict" not in call_kwargs
441+
assert call_kwargs["model"] == "databricks-gpt-4o"
442+
363443

364444
class TestDatabricksAppsSupport:
365445
"""Tests for Databricks Apps support."""

0 commit comments

Comments
 (0)