|
| 1 | +from typing import Annotated, Any, Generator, Optional, Sequence, TypedDict, Union |
| 2 | + |
| 3 | +import mlflow |
| 4 | +from langchain.messages import AIMessage, AIMessageChunk, AnyMessage |
| 5 | +from langchain_core.runnables import RunnableConfig, RunnableLambda |
| 6 | +from langchain_core.tools import BaseTool |
| 7 | +from langgraph.graph import END, StateGraph |
| 8 | +from langgraph.graph.message import add_messages |
| 9 | +from langgraph.prebuilt.tool_node import ToolNode |
| 10 | +from mlflow.pyfunc import ResponsesAgent |
| 11 | +from mlflow.types.responses import ( |
| 12 | + ResponsesAgentRequest, |
| 13 | + ResponsesAgentResponse, |
| 14 | + ResponsesAgentStreamEvent, |
| 15 | + output_to_responses_items_stream, |
| 16 | + to_chat_completions_input, |
| 17 | +) |
| 18 | + |
| 19 | +from databricks_langchain import ( |
| 20 | + ChatDatabricks, |
| 21 | + UCFunctionToolkit, |
| 22 | +) |
| 23 | + |
| 24 | +############################################ |
| 25 | +# Define your LLM endpoint and system prompt |
| 26 | +############################################ |
| 27 | +# TODO: Replace with your model serving endpoint |
| 28 | +# LLM_ENDPOINT_NAME = "databricks-claude-sonnet-4-5" |
| 29 | +LLM_ENDPOINT_NAME = "databricks-gemini-3-pro" |
| 30 | +llm = ChatDatabricks(endpoint=LLM_ENDPOINT_NAME) |
| 31 | + |
| 32 | +# TODO: Update with your system prompt |
| 33 | +system_prompt = "You are a helpful assistant that can run Python code." |
| 34 | + |
| 35 | +############################################################################### |
| 36 | +## Define tools for your agent, enabling it to retrieve data or take actions |
| 37 | +## beyond text generation |
| 38 | +## To create and see usage examples of more tools, see |
| 39 | +## https://docs.databricks.com/en/generative-ai/agent-framework/agent-tool.html |
| 40 | +############################################################################### |
| 41 | +tools = [] |
| 42 | + |
| 43 | +# You can use UDFs in Unity Catalog as agent tools |
| 44 | +# Below, we add the `system.ai.python_exec` UDF, which provides |
| 45 | +# a python code interpreter tool to our agent |
| 46 | +# You can also add local LangChain python tools. See https://python.langchain.com/docs/concepts/tools |
| 47 | + |
| 48 | +# TODO: Add additional tools |
| 49 | +UC_TOOL_NAMES = ["system.ai.python_exec"] |
| 50 | +uc_toolkit = UCFunctionToolkit(function_names=UC_TOOL_NAMES) |
| 51 | +tools.extend(uc_toolkit.tools) |
| 52 | + |
| 53 | +# Use Databricks vector search indexes as tools |
| 54 | +# See https://docs.databricks.com/en/generative-ai/agent-framework/unstructured-retrieval-tools.html#locally-develop-vector-search-retriever-tools-with-ai-bridge |
| 55 | +# List to store vector search tool instances for unstructured retrieval. |
| 56 | +VECTOR_SEARCH_TOOLS = [] |
| 57 | + |
| 58 | +# To add vector search retriever tools, |
| 59 | +# use VectorSearchRetrieverTool and create_tool_info, |
| 60 | +# then append the result to TOOL_INFOS. |
| 61 | +# Example: |
| 62 | +# VECTOR_SEARCH_TOOLS.append( |
| 63 | +# VectorSearchRetrieverTool( |
| 64 | +# index_name="", |
| 65 | +# # filters="..." |
| 66 | +# ) |
| 67 | +# ) |
| 68 | + |
| 69 | +tools.extend(VECTOR_SEARCH_TOOLS) |
| 70 | + |
| 71 | +##################### |
| 72 | +## Define agent logic |
| 73 | +##################### |
| 74 | + |
| 75 | + |
| 76 | +class AgentState(TypedDict): |
| 77 | + messages: Annotated[Sequence[AnyMessage], add_messages] |
| 78 | + custom_inputs: Optional[dict[str, Any]] |
| 79 | + custom_outputs: Optional[dict[str, Any]] |
| 80 | + |
| 81 | + |
| 82 | +def create_tool_calling_agent( |
| 83 | + model: ChatDatabricks, |
| 84 | + tools: Union[ToolNode, Sequence[BaseTool]], |
| 85 | + system_prompt: Optional[str] = None, |
| 86 | +): |
| 87 | + model = model.bind_tools(tools) |
| 88 | + |
| 89 | + # Define the function that determines which node to go to |
| 90 | + def should_continue(state: AgentState): |
| 91 | + messages = state["messages"] |
| 92 | + last_message = messages[-1] |
| 93 | + # If there are function calls, continue. else, end |
| 94 | + if isinstance(last_message, AIMessage) and last_message.tool_calls: |
| 95 | + return "continue" |
| 96 | + else: |
| 97 | + return "end" |
| 98 | + |
| 99 | + if system_prompt: |
| 100 | + preprocessor = RunnableLambda( |
| 101 | + lambda state: [{"role": "system", "content": system_prompt}] + state["messages"] |
| 102 | + ) |
| 103 | + else: |
| 104 | + preprocessor = RunnableLambda(lambda state: state["messages"]) |
| 105 | + model_runnable = preprocessor | model |
| 106 | + |
| 107 | + def call_model( |
| 108 | + state: AgentState, |
| 109 | + config: RunnableConfig, |
| 110 | + ): |
| 111 | + response = model_runnable.invoke(state, config) |
| 112 | + |
| 113 | + return {"messages": [response]} |
| 114 | + |
| 115 | + workflow = StateGraph(AgentState) |
| 116 | + |
| 117 | + workflow.add_node("agent", RunnableLambda(call_model)) |
| 118 | + workflow.add_node("tools", ToolNode(tools)) |
| 119 | + |
| 120 | + workflow.set_entry_point("agent") |
| 121 | + workflow.add_conditional_edges( |
| 122 | + "agent", |
| 123 | + should_continue, |
| 124 | + { |
| 125 | + "continue": "tools", |
| 126 | + "end": END, |
| 127 | + }, |
| 128 | + ) |
| 129 | + workflow.add_edge("tools", "agent") |
| 130 | + |
| 131 | + return workflow.compile() |
| 132 | + |
| 133 | + |
| 134 | +class LangGraphResponsesAgent(ResponsesAgent): |
| 135 | + def __init__(self, agent): |
| 136 | + self.agent = agent |
| 137 | + |
| 138 | + def predict(self, request: ResponsesAgentRequest) -> ResponsesAgentResponse: |
| 139 | + session_id = None |
| 140 | + if request.custom_inputs and "session_id" in request.custom_inputs: |
| 141 | + session_id = request.custom_inputs.get("session_id") |
| 142 | + elif request.context and request.context.conversation_id: |
| 143 | + session_id = request.context.conversation_id |
| 144 | + |
| 145 | + if session_id: |
| 146 | + mlflow.update_current_trace( |
| 147 | + metadata={ |
| 148 | + "mlflow.trace.session": session_id, |
| 149 | + } |
| 150 | + ) |
| 151 | + |
| 152 | + outputs = [ |
| 153 | + event.item |
| 154 | + for event in self.predict_stream(request) |
| 155 | + if event.type == "response.output_item.done" |
| 156 | + ] |
| 157 | + return ResponsesAgentResponse(output=outputs, custom_outputs=request.custom_inputs) |
| 158 | + |
| 159 | + def predict_stream( |
| 160 | + self, |
| 161 | + request: ResponsesAgentRequest, |
| 162 | + ) -> Generator[ResponsesAgentStreamEvent, None, None]: |
| 163 | + session_id = None |
| 164 | + if request.custom_inputs and "session_id" in request.custom_inputs: |
| 165 | + session_id = request.custom_inputs.get("session_id") |
| 166 | + elif request.context and request.context.conversation_id: |
| 167 | + session_id = request.context.conversation_id |
| 168 | + |
| 169 | + if session_id: |
| 170 | + mlflow.update_current_trace( |
| 171 | + metadata={ |
| 172 | + "mlflow.trace.session": session_id, |
| 173 | + } |
| 174 | + ) |
| 175 | + |
| 176 | + cc_msgs = to_chat_completions_input([i.model_dump() for i in request.input]) |
| 177 | + |
| 178 | + for event in self.agent.stream({"messages": cc_msgs}, stream_mode=["updates", "messages"]): |
| 179 | + if event[0] == "updates": |
| 180 | + for node_data in event[1].values(): |
| 181 | + if len(node_data.get("messages", [])) > 0: |
| 182 | + yield from output_to_responses_items_stream(node_data["messages"]) |
| 183 | + # filter the streamed messages to just the generated text messages |
| 184 | + elif event[0] == "messages": |
| 185 | + try: |
| 186 | + chunk = event[1][0] |
| 187 | + if isinstance(chunk, AIMessageChunk) and (content := chunk.content): |
| 188 | + yield ResponsesAgentStreamEvent( |
| 189 | + **self.create_text_delta(delta=content, item_id=chunk.id), |
| 190 | + ) |
| 191 | + except Exception as e: |
| 192 | + print(e) |
| 193 | + |
| 194 | + |
| 195 | +# Create the agent object, and specify it as the agent object to use when |
| 196 | +# loading the agent back for inference via mlflow.models.set_model() |
| 197 | +mlflow.langchain.autolog() |
| 198 | +agent = create_tool_calling_agent(llm, tools, system_prompt) |
| 199 | +AGENT = LangGraphResponsesAgent(agent) |
| 200 | +mlflow.models.set_model(AGENT) |
0 commit comments