Skip to content

Commit 6dd8400

Browse files
committed
mark LangChain roots in metadata
1 parent de09c32 commit 6dd8400

2 files changed

Lines changed: 146 additions & 6 deletions

File tree

langfuse/langchain/CallbackHandler.py

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,28 @@ def _parse_langfuse_trace_attributes(
303303

304304
return attributes
305305

306+
def _get_langchain_observation_metadata(
307+
self,
308+
*,
309+
parent_run_id: Optional[UUID],
310+
tags: Optional[List[str]] = None,
311+
metadata: Optional[Dict[str, Any]] = None,
312+
keep_langfuse_trace_attributes: bool = False,
313+
) -> Optional[Dict[str, Any]]:
314+
observation_metadata = self.__join_tags_and_metadata(
315+
tags=tags,
316+
metadata=metadata,
317+
keep_langfuse_trace_attributes=keep_langfuse_trace_attributes,
318+
)
319+
320+
if parent_run_id is not None:
321+
return observation_metadata
322+
323+
root_metadata = observation_metadata.copy() if observation_metadata else {}
324+
root_metadata["is_langchain_root"] = True
325+
326+
return root_metadata
327+
306328
def on_chain_start(
307329
self,
308330
serialized: Optional[Dict[str, Any]],
@@ -325,7 +347,11 @@ def on_chain_start(
325347
)
326348

327349
span_name = self.get_langchain_run_name(serialized, **kwargs)
328-
span_metadata = self.__join_tags_and_metadata(tags, metadata)
350+
span_metadata = self._get_langchain_observation_metadata(
351+
parent_run_id=parent_run_id,
352+
tags=tags,
353+
metadata=metadata,
354+
)
329355
span_level = "DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None
330356

331357
observation_type = self._get_observation_type_from_serialized(
@@ -690,7 +716,11 @@ def on_tool_start(
690716
"on_tool_start", run_id, parent_run_id, input_str=input_str
691717
)
692718

693-
meta = self.__join_tags_and_metadata(tags, metadata)
719+
meta = self._get_langchain_observation_metadata(
720+
parent_run_id=parent_run_id,
721+
tags=tags,
722+
metadata=metadata,
723+
)
694724

695725
if not meta:
696726
meta = {}
@@ -734,7 +764,11 @@ def on_retriever_start(
734764
"on_retriever_start", run_id, parent_run_id, query=query
735765
)
736766
span_name = self.get_langchain_run_name(serialized, **kwargs)
737-
span_metadata = self.__join_tags_and_metadata(tags, metadata)
767+
span_metadata = self._get_langchain_observation_metadata(
768+
parent_run_id=parent_run_id,
769+
tags=tags,
770+
metadata=metadata,
771+
)
738772
span_level = "DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None
739773

740774
observation_type = self._get_observation_type_from_serialized(
@@ -865,9 +899,10 @@ def __on_llm_action(
865899
content = {
866900
"name": self.get_langchain_run_name(serialized, **kwargs),
867901
"input": prompts,
868-
"metadata": self.__join_tags_and_metadata(
869-
tags,
870-
metadata,
902+
"metadata": self._get_langchain_observation_metadata(
903+
parent_run_id=parent_run_id,
904+
tags=tags,
905+
metadata=metadata,
871906
# If llm is run isolated and outside chain, keep trace attributes
872907
keep_langfuse_trace_attributes=True
873908
if parent_run_id is None

tests/test_langchain.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
import importlib
12
import random
23
import string
34
import time
45
from time import sleep
56
from typing import Any, Dict, Literal
7+
from uuid import uuid4
68

79
import pytest
810
from langchain.messages import HumanMessage, SystemMessage
@@ -14,13 +16,116 @@
1416
from langgraph.checkpoint.memory import MemorySaver
1517
from langgraph.graph import END, START, MessagesState, StateGraph
1618
from langgraph.prebuilt import ToolNode
19+
from opentelemetry import trace as otel_trace
1720
from pydantic import BaseModel, Field
1821

1922
from langfuse._client.client import Langfuse
2023
from langfuse.langchain import CallbackHandler
2124
from tests.utils import create_uuid, encode_file_to_base64, get_api
2225

2326

27+
class _FakeLangchainObservation:
28+
def __init__(self, recorder, **kwargs):
29+
self._recorder = recorder
30+
self._otel_span = otel_trace.NonRecordingSpan(otel_trace.INVALID_SPAN_CONTEXT)
31+
self.trace_id = "test-trace-id"
32+
self.metadata = kwargs.get("metadata")
33+
self.name = kwargs.get("name")
34+
self.input = kwargs.get("input")
35+
self.as_type = kwargs.get("as_type")
36+
self.updates = []
37+
self.ended = False
38+
39+
def start_observation(self, **kwargs):
40+
return self._recorder.start_observation(**kwargs)
41+
42+
def update(self, **kwargs):
43+
self.updates.append(kwargs)
44+
return self
45+
46+
def end(self):
47+
self.ended = True
48+
return self
49+
50+
51+
class _FakeLangchainClient:
52+
def __init__(self):
53+
self.started_observations = []
54+
55+
def start_observation(self, **kwargs):
56+
observation = _FakeLangchainObservation(self, **kwargs)
57+
self.started_observations.append(observation)
58+
return observation
59+
60+
61+
def _patch_langchain_client(monkeypatch, fake_client):
62+
callback_handler_module = importlib.import_module(
63+
"langfuse.langchain.CallbackHandler"
64+
)
65+
monkeypatch.setattr(
66+
callback_handler_module,
67+
"get_client",
68+
lambda public_key=None: fake_client,
69+
)
70+
71+
72+
def test_root_langchain_chain_sets_is_langchain_root_metadata(monkeypatch):
73+
fake_client = _FakeLangchainClient()
74+
_patch_langchain_client(monkeypatch, fake_client)
75+
handler = CallbackHandler()
76+
77+
root_run_id = uuid4()
78+
child_run_id = uuid4()
79+
80+
handler.on_chain_start(
81+
serialized={"name": "RootChain"},
82+
inputs={"question": "hello"},
83+
run_id=root_run_id,
84+
tags=["root-tag"],
85+
metadata={"foo": "bar"},
86+
)
87+
handler.on_chain_start(
88+
serialized={"name": "ChildChain"},
89+
inputs={"question": "child"},
90+
run_id=child_run_id,
91+
parent_run_id=root_run_id,
92+
metadata={"child": "metadata"},
93+
)
94+
handler.on_chain_end(
95+
{"answer": "child"},
96+
run_id=child_run_id,
97+
parent_run_id=root_run_id,
98+
)
99+
handler.on_chain_end({"answer": "root"}, run_id=root_run_id)
100+
101+
root_observation, child_observation = fake_client.started_observations
102+
103+
assert root_observation.metadata == {
104+
"foo": "bar",
105+
"tags": ["root-tag"],
106+
"is_langchain_root": True,
107+
}
108+
assert child_observation.metadata == {"child": "metadata"}
109+
110+
111+
def test_root_langchain_llm_sets_is_langchain_root_metadata(monkeypatch):
112+
fake_client = _FakeLangchainClient()
113+
_patch_langchain_client(monkeypatch, fake_client)
114+
handler = CallbackHandler()
115+
116+
root_run_id = uuid4()
117+
118+
handler.on_llm_start(
119+
serialized={"name": "ChatOpenAI", "id": ["langchain", "ChatOpenAI"]},
120+
prompts=["hello"],
121+
run_id=root_run_id,
122+
invocation_params={"model_name": "gpt-4o-mini"},
123+
)
124+
handler._detach_observation(root_run_id)
125+
126+
assert fake_client.started_observations[0].metadata == {"is_langchain_root": True}
127+
128+
24129
def test_callback_generated_from_trace_chat():
25130
langfuse = Langfuse()
26131

0 commit comments

Comments
 (0)