Skip to content

Commit a0c0d9c

Browse files
committed
.
Signed-off-by: Bryan Qiu <bryan.qiu@databricks.com>
1 parent c4aa64b commit a0c0d9c

2 files changed

Lines changed: 38 additions & 1 deletion

File tree

integrations/langchain/src/databricks_langchain/chat_models.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1055,6 +1055,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
10551055
return {"role": "user", **message_dict}
10561056
elif isinstance(message, AIMessage):
10571057
if tool_calls := _get_tool_calls_from_ai_message(message):
1058+
print(tool_calls)
10581059
message_dict["tool_calls"] = tool_calls # type: ignore[assignment]
10591060
# If tool calls present, content null value should be None not empty string.
10601061
message_dict["content"] = message_dict["content"] or None # type: ignore[assignment]
@@ -1196,6 +1197,14 @@ def _get_tool_calls_from_ai_message(message: AIMessage) -> List[Dict]:
11961197
for tc in message.invalid_tool_calls
11971198
]
11981199

1200+
"""
1201+
thought signature encodes model reasoning
1202+
it is required for each tool call to gemini 3 pro - https://arc.net/l/quote/jhoeoqbl
1203+
this means we need to encode this info in the responses events in order to fix this bug, in addition to the work on this PR
1204+
1205+
will have to change _langchain_message_stream_to_responses_stream
1206+
"""
1207+
11991208
if tool_calls or invalid_tool_calls:
12001209
# Merge thoughtSignature from additional_kwargs if present
12011210
all_tool_calls = tool_calls + invalid_tool_calls

integrations/langchain/tests/integration_tests/test_chat_models.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -877,6 +877,34 @@ def test_chat_databricks_with_gemini():
877877
os.environ["DATABRICKS_CONFIG_PROFILE"] = "dogfood"
878878
from .agent import AGENT
879879

880-
result = AGENT.predict({"input": [{"role": "user", "content": "What is 6*7 in Python?"}]})
880+
result = AGENT.predict(
881+
{
882+
"input": [
883+
{"role": "user", "content": "What is 6*7 in Python?"},
884+
{
885+
"type": "function_call",
886+
"id": "lc_run--e58dec26-ce5d-4597-b4f8-28e6db62cd49",
887+
"call_id": "system__ai__python_exec",
888+
"name": "system__ai__python_exec",
889+
"arguments": '{"code": "print(6 * 7)"}',
890+
},
891+
{
892+
"type": "function_call_output",
893+
"call_id": "system__ai__python_exec",
894+
"output": '{"format": "SCALAR", "value": "42\\n"}',
895+
},
896+
{
897+
"type": "message",
898+
"id": "lc_run--dd658def-dfdc-4bc7-b0d9-b6e25d1ecc48",
899+
"content": [
900+
{"text": "The result of `6 * 7` in Python is 42.", "type": "output_text"}
901+
],
902+
"role": "assistant",
903+
},
904+
]
905+
}
906+
)
881907
assert result is not None
882908
assert result.output is not None
909+
print(result.model_dump())
910+
assert False

0 commit comments

Comments
 (0)