Skip to content

Commit c4aa64b

Browse files
committed
start
Signed-off-by: Bryan Qiu <bryan.qiu@databricks.com>
1 parent 1185281 commit c4aa64b

4 files changed

Lines changed: 277 additions & 29 deletions

File tree

integrations/langchain/src/databricks_langchain/chat_models.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1197,14 +1197,28 @@ def _get_tool_calls_from_ai_message(message: AIMessage) -> List[Dict]:
11971197
]
11981198

11991199
if tool_calls or invalid_tool_calls:
1200-
return tool_calls + invalid_tool_calls
1200+
# Merge thoughtSignature from additional_kwargs if present
1201+
all_tool_calls = tool_calls + invalid_tool_calls
1202+
additional_tool_calls = message.additional_kwargs.get("tool_calls", [])
1203+
if additional_tool_calls:
1204+
# Create a mapping of tool call IDs to their thoughtSignature
1205+
thought_signatures = {
1206+
tc.get("id"): tc.get("thoughtSignature")
1207+
for tc in additional_tool_calls
1208+
if tc.get("thoughtSignature")
1209+
}
1210+
# Add thoughtSignature to matching tool calls
1211+
for tc in all_tool_calls:
1212+
if tc["id"] in thought_signatures:
1213+
tc["thoughtSignature"] = thought_signatures[tc["id"]]
1214+
return all_tool_calls
12011215

12021216
# Get tool calls from additional kwargs if present.
12031217
return [
12041218
{
12051219
k: v
12061220
for k, v in tool_call.items() # type: ignore[union-attr]
1207-
if k in {"id", "type", "function"}
1221+
if k in {"id", "type", "function", "thoughtSignature"}
12081222
}
12091223
for tool_call in message.additional_kwargs.get("tool_calls", [])
12101224
]
Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
from typing import Annotated, Any, Generator, Optional, Sequence, TypedDict, Union
2+
3+
import mlflow
4+
from langchain.messages import AIMessage, AIMessageChunk, AnyMessage
5+
from langchain_core.runnables import RunnableConfig, RunnableLambda
6+
from langchain_core.tools import BaseTool
7+
from langgraph.graph import END, StateGraph
8+
from langgraph.graph.message import add_messages
9+
from langgraph.prebuilt.tool_node import ToolNode
10+
from mlflow.pyfunc import ResponsesAgent
11+
from mlflow.types.responses import (
12+
ResponsesAgentRequest,
13+
ResponsesAgentResponse,
14+
ResponsesAgentStreamEvent,
15+
output_to_responses_items_stream,
16+
to_chat_completions_input,
17+
)
18+
19+
from databricks_langchain import (
20+
ChatDatabricks,
21+
UCFunctionToolkit,
22+
)
23+
24+
############################################
25+
# Define your LLM endpoint and system prompt
26+
############################################
27+
# TODO: Replace with your model serving endpoint
28+
# LLM_ENDPOINT_NAME = "databricks-claude-sonnet-4-5"
29+
LLM_ENDPOINT_NAME = "databricks-gemini-3-pro"
30+
llm = ChatDatabricks(endpoint=LLM_ENDPOINT_NAME)
31+
32+
# TODO: Update with your system prompt
33+
system_prompt = "You are a helpful assistant that can run Python code."
34+
35+
###############################################################################
36+
## Define tools for your agent, enabling it to retrieve data or take actions
37+
## beyond text generation
38+
## To create and see usage examples of more tools, see
39+
## https://docs.databricks.com/en/generative-ai/agent-framework/agent-tool.html
40+
###############################################################################
41+
tools = []
42+
43+
# You can use UDFs in Unity Catalog as agent tools
44+
# Below, we add the `system.ai.python_exec` UDF, which provides
45+
# a python code interpreter tool to our agent
46+
# You can also add local LangChain python tools. See https://python.langchain.com/docs/concepts/tools
47+
48+
# TODO: Add additional tools
49+
UC_TOOL_NAMES = ["system.ai.python_exec"]
50+
uc_toolkit = UCFunctionToolkit(function_names=UC_TOOL_NAMES)
51+
tools.extend(uc_toolkit.tools)
52+
53+
# Use Databricks vector search indexes as tools
54+
# See https://docs.databricks.com/en/generative-ai/agent-framework/unstructured-retrieval-tools.html#locally-develop-vector-search-retriever-tools-with-ai-bridge
55+
# List to store vector search tool instances for unstructured retrieval.
56+
VECTOR_SEARCH_TOOLS = []
57+
58+
# To add vector search retriever tools,
59+
# use VectorSearchRetrieverTool and create_tool_info,
60+
# then append the result to TOOL_INFOS.
61+
# Example:
62+
# VECTOR_SEARCH_TOOLS.append(
63+
# VectorSearchRetrieverTool(
64+
# index_name="",
65+
# # filters="..."
66+
# )
67+
# )
68+
69+
tools.extend(VECTOR_SEARCH_TOOLS)
70+
71+
#####################
72+
## Define agent logic
73+
#####################
74+
75+
76+
class AgentState(TypedDict):
77+
messages: Annotated[Sequence[AnyMessage], add_messages]
78+
custom_inputs: Optional[dict[str, Any]]
79+
custom_outputs: Optional[dict[str, Any]]
80+
81+
82+
def create_tool_calling_agent(
83+
model: ChatDatabricks,
84+
tools: Union[ToolNode, Sequence[BaseTool]],
85+
system_prompt: Optional[str] = None,
86+
):
87+
model = model.bind_tools(tools)
88+
89+
# Define the function that determines which node to go to
90+
def should_continue(state: AgentState):
91+
messages = state["messages"]
92+
last_message = messages[-1]
93+
# If there are function calls, continue. else, end
94+
if isinstance(last_message, AIMessage) and last_message.tool_calls:
95+
return "continue"
96+
else:
97+
return "end"
98+
99+
if system_prompt:
100+
preprocessor = RunnableLambda(
101+
lambda state: [{"role": "system", "content": system_prompt}] + state["messages"]
102+
)
103+
else:
104+
preprocessor = RunnableLambda(lambda state: state["messages"])
105+
model_runnable = preprocessor | model
106+
107+
def call_model(
108+
state: AgentState,
109+
config: RunnableConfig,
110+
):
111+
response = model_runnable.invoke(state, config)
112+
113+
return {"messages": [response]}
114+
115+
workflow = StateGraph(AgentState)
116+
117+
workflow.add_node("agent", RunnableLambda(call_model))
118+
workflow.add_node("tools", ToolNode(tools))
119+
120+
workflow.set_entry_point("agent")
121+
workflow.add_conditional_edges(
122+
"agent",
123+
should_continue,
124+
{
125+
"continue": "tools",
126+
"end": END,
127+
},
128+
)
129+
workflow.add_edge("tools", "agent")
130+
131+
return workflow.compile()
132+
133+
134+
class LangGraphResponsesAgent(ResponsesAgent):
135+
def __init__(self, agent):
136+
self.agent = agent
137+
138+
def predict(self, request: ResponsesAgentRequest) -> ResponsesAgentResponse:
139+
session_id = None
140+
if request.custom_inputs and "session_id" in request.custom_inputs:
141+
session_id = request.custom_inputs.get("session_id")
142+
elif request.context and request.context.conversation_id:
143+
session_id = request.context.conversation_id
144+
145+
if session_id:
146+
mlflow.update_current_trace(
147+
metadata={
148+
"mlflow.trace.session": session_id,
149+
}
150+
)
151+
152+
outputs = [
153+
event.item
154+
for event in self.predict_stream(request)
155+
if event.type == "response.output_item.done"
156+
]
157+
return ResponsesAgentResponse(output=outputs, custom_outputs=request.custom_inputs)
158+
159+
def predict_stream(
160+
self,
161+
request: ResponsesAgentRequest,
162+
) -> Generator[ResponsesAgentStreamEvent, None, None]:
163+
session_id = None
164+
if request.custom_inputs and "session_id" in request.custom_inputs:
165+
session_id = request.custom_inputs.get("session_id")
166+
elif request.context and request.context.conversation_id:
167+
session_id = request.context.conversation_id
168+
169+
if session_id:
170+
mlflow.update_current_trace(
171+
metadata={
172+
"mlflow.trace.session": session_id,
173+
}
174+
)
175+
176+
cc_msgs = to_chat_completions_input([i.model_dump() for i in request.input])
177+
178+
for event in self.agent.stream({"messages": cc_msgs}, stream_mode=["updates", "messages"]):
179+
if event[0] == "updates":
180+
for node_data in event[1].values():
181+
if len(node_data.get("messages", [])) > 0:
182+
yield from output_to_responses_items_stream(node_data["messages"])
183+
# filter the streamed messages to just the generated text messages
184+
elif event[0] == "messages":
185+
try:
186+
chunk = event[1][0]
187+
if isinstance(chunk, AIMessageChunk) and (content := chunk.content):
188+
yield ResponsesAgentStreamEvent(
189+
**self.create_text_delta(delta=content, item_id=chunk.id),
190+
)
191+
except Exception as e:
192+
print(e)
193+
194+
195+
# Create the agent object, and specify it as the agent object to use when
196+
# loading the agent back for inference via mlflow.models.set_model()
197+
mlflow.langchain.autolog()
198+
agent = create_tool_calling_agent(llm, tools, system_prompt)
199+
AGENT = LangGraphResponsesAgent(agent)
200+
mlflow.models.set_model(AGENT)

integrations/langchain/tests/integration_tests/test_chat_models.py

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -243,9 +243,9 @@ class GetWeather(BaseModel):
243243
return
244244

245245
# Models should make at least one tool call when tool_choice is not "none"
246-
assert len(response.tool_calls) >= 1, (
247-
f"Expected at least 1 tool call, got {len(response.tool_calls)}"
248-
)
246+
assert (
247+
len(response.tool_calls) >= 1
248+
), f"Expected at least 1 tool call, got {len(response.tool_calls)}"
249249

250250
# The first tool call should be for GetWeather
251251
first_call = response.tool_calls[0]
@@ -267,9 +267,9 @@ class GetWeather(BaseModel):
267267
]
268268
)
269269
# Should call GetWeather tool for the followup question
270-
assert len(response.tool_calls) >= 1, (
271-
f"Expected at least 1 tool call, got {len(response.tool_calls)}"
272-
)
270+
assert (
271+
len(response.tool_calls) >= 1
272+
), f"Expected at least 1 tool call, got {len(response.tool_calls)}"
273273
tool_call = response.tool_calls[0]
274274
assert tool_call["name"] == "GetWeather", f"Expected GetWeather tool, got {tool_call['name']}"
275275
assert "location" in tool_call["args"], f"Expected location in args, got {tool_call['args']}"
@@ -559,12 +559,8 @@ def test_chat_databricks_chatagent_invoke():
559559
):
560560
python_tool_used = True
561561

562-
assert has_tool_calls, (
563-
f"Expected ChatAgent to use tool calls for fibonacci computation. Content: {response.content}"
564-
)
565-
assert python_tool_used, (
566-
f"Expected ChatAgent to use python execution tool for fibonacci computation. Content: {response.content}"
567-
)
562+
assert has_tool_calls, f"Expected ChatAgent to use tool calls for fibonacci computation. Content: {response.content}"
563+
assert python_tool_used, f"Expected ChatAgent to use python execution tool for fibonacci computation. Content: {response.content}"
568564

569565

570566
@pytest.mark.st_endpoints
@@ -847,9 +843,9 @@ def test_chat_databricks_gpt5_stream_with_usage():
847843
]
848844

849845
# Should have exactly ONE usage chunk from the final usage-only chunk
850-
assert len(usage_chunks) == 1, (
851-
f"Expected exactly 1 usage chunk from GPT-5 final chunk, got {len(usage_chunks)}"
852-
)
846+
assert (
847+
len(usage_chunks) == 1
848+
), f"Expected exactly 1 usage chunk from GPT-5 final chunk, got {len(usage_chunks)}"
853849

854850
# Verify usage chunk has correct metadata structure
855851
usage_chunk = usage_chunks[0]
@@ -860,12 +856,12 @@ def test_chat_databricks_gpt5_stream_with_usage():
860856
assert "total_tokens" in usage_chunk.usage_metadata
861857

862858
# Verify token counts are positive
863-
assert usage_chunk.usage_metadata["input_tokens"] > 0, (
864-
f"Expected positive input_tokens, got {usage_chunk.usage_metadata['input_tokens']}"
865-
)
866-
assert usage_chunk.usage_metadata["output_tokens"] > 0, (
867-
f"Expected positive output_tokens, got {usage_chunk.usage_metadata['output_tokens']}"
868-
)
859+
assert (
860+
usage_chunk.usage_metadata["input_tokens"] > 0
861+
), f"Expected positive input_tokens, got {usage_chunk.usage_metadata['input_tokens']}"
862+
assert (
863+
usage_chunk.usage_metadata["output_tokens"] > 0
864+
), f"Expected positive output_tokens, got {usage_chunk.usage_metadata['output_tokens']}"
869865

870866
# Verify total_tokens equals sum of input and output
871867
expected_total = (
@@ -875,3 +871,12 @@ def test_chat_databricks_gpt5_stream_with_usage():
875871
f"Expected total_tokens ({usage_chunk.usage_metadata['total_tokens']}) "
876872
f"to equal input_tokens + output_tokens ({expected_total})"
877873
)
874+
875+
876+
def test_chat_databricks_with_gemini():
877+
os.environ["DATABRICKS_CONFIG_PROFILE"] = "dogfood"
878+
from .agent import AGENT
879+
880+
result = AGENT.predict({"input": [{"role": "user", "content": "What is 6*7 in Python?"}]})
881+
assert result is not None
882+
assert result.output is not None

integrations/langchain/tests/unit_tests/test_chat_models.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -479,9 +479,9 @@ def test_chat_model_stream_usage_only_chunk_missing_tokens():
479479
usage_chunks = [
480480
chunk for chunk in chunks if chunk.content == "" and chunk.usage_metadata is not None
481481
]
482-
assert len(usage_chunks) == 0, (
483-
f"Expected 0 usage chunks when tokens are missing, got {len(usage_chunks)}"
484-
)
482+
assert (
483+
len(usage_chunks) == 0
484+
), f"Expected 0 usage chunks when tokens are missing, got {len(usage_chunks)}"
485485

486486

487487
def test_chat_model_stream_usage_only_chunk_stream_usage_false():
@@ -532,9 +532,9 @@ def test_chat_model_stream_usage_only_chunk_stream_usage_false():
532532
usage_chunks = [
533533
chunk for chunk in chunks if chunk.content == "" and chunk.usage_metadata is not None
534534
]
535-
assert len(usage_chunks) == 0, (
536-
f"Expected 0 usage chunks when stream_usage=False, got {len(usage_chunks)}"
537-
)
535+
assert (
536+
len(usage_chunks) == 0
537+
), f"Expected 0 usage chunks when stream_usage=False, got {len(usage_chunks)}"
538538

539539

540540
class GetWeather(BaseModel):
@@ -713,6 +713,35 @@ def test_convert_message_with_tool_calls() -> None:
713713
assert dict_result == message_with_tools
714714

715715

716+
def test_convert_message_with_tool_calls_and_thought_signature() -> None:
717+
ID = "system__ai__python_exec"
718+
THOUGHT_SIG = "CikBjz1rXxsXPO9F7LWvkXdS3Fkl7lMvmk9yp2iIuuTv0vWI2wRd0vHm5QpZAY89a1"
719+
tool_calls = [
720+
{
721+
"id": ID,
722+
"type": "function",
723+
"function": {
724+
"name": "system__ai__python_exec",
725+
"arguments": '{"code":"print(6 * 7)"}',
726+
},
727+
"thoughtSignature": THOUGHT_SIG,
728+
}
729+
]
730+
message = AIMessage(
731+
content="",
732+
additional_kwargs={"tool_calls": tool_calls},
733+
)
734+
735+
dict_result = _convert_message_to_dict(message)
736+
737+
assert "tool_calls" in dict_result
738+
assert len(dict_result["tool_calls"]) == 1
739+
assert dict_result["tool_calls"][0]["id"] == ID
740+
assert dict_result["tool_calls"][0]["type"] == "function"
741+
assert dict_result["tool_calls"][0]["function"]["name"] == "system__ai__python_exec"
742+
assert dict_result["tool_calls"][0]["thoughtSignature"] == THOUGHT_SIG
743+
744+
716745
def test_convert_tool_message() -> None:
717746
tool_message = ToolMessage(content="result", tool_call_id="call_123")
718747
result = _convert_message_to_dict(tool_message)

0 commit comments

Comments
 (0)