Skip to content

Commit e10e522

Browse files
committed
fix(langchain): pass trace_name to propagate_attributes in on_chain_start
When CallbackHandler.on_chain_start fires at the root of a chain (parent_run_id is None), propagate_attributes was called without a trace_name, so the trace name was determined by whichever internal node's on_chain_start happened to fire first. On LangGraph resume (e.g. after a human-in-the-loop interrupt) that node is often an internal subgraph whose name is "", which produces a blank trace name. The fix passes span_name — the name already computed from the serialized runnable and kwargs — as trace_name to propagate_attributes. This ensures the trace name is always pinned to the root chain's name regardless of execution order on resume. As a companion change, _parse_langfuse_trace_attributes now also reads a langfuse_trace_name key from LangChain metadata, consistent with the existing langfuse_session_id / langfuse_user_id / langfuse_tags pattern. When present, metadata langfuse_trace_name takes priority over the computed span_name. The key is also added to the strip-list in _strip_langfuse_keys_from_dict so it does not leak into observation metadata. Fixes #1602
1 parent 3e530af commit e10e522

2 files changed

Lines changed: 274 additions & 0 deletions

File tree

langfuse/langchain/CallbackHandler.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,11 @@ def _parse_langfuse_trace_attributes(
287287
):
288288
attributes["user_id"] = metadata["langfuse_user_id"]
289289

290+
if "langfuse_trace_name" in metadata and isinstance(
291+
metadata["langfuse_trace_name"], str
292+
):
293+
attributes["trace_name"] = metadata["langfuse_trace_name"]
294+
290295
if tags is not None or (
291296
"langfuse_tags" in metadata and isinstance(metadata["langfuse_tags"], list)
292297
):
@@ -365,6 +370,7 @@ def on_chain_start(
365370
)
366371

367372
self._propagation_context_manager = propagate_attributes(
373+
trace_name=parsed_trace_attributes.get("trace_name") or span_name,
368374
user_id=parsed_trace_attributes.get("user_id", None),
369375
session_id=parsed_trace_attributes.get("session_id", None),
370376
tags=parsed_trace_attributes.get("tags", None),
@@ -1403,6 +1409,7 @@ def _strip_langfuse_keys_from_dict(
14031409
"langfuse_session_id",
14041410
"langfuse_user_id",
14051411
"langfuse_tags",
1412+
"langfuse_trace_name",
14061413
]
14071414

14081415
metadata_copy = metadata.copy()
Lines changed: 267 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,267 @@
1+
"""Unit tests for CallbackHandler trace-name propagation.
2+
3+
These tests cover the fix for on_chain_start not passing trace_name to
4+
propagate_attributes, which caused non-deterministic trace names on LangGraph
5+
resume (e.g. after a human-in-the-loop interrupt).
6+
7+
No real API calls are made — propagate_attributes and get_client are mocked.
8+
"""
9+
10+
import uuid
11+
from contextlib import contextmanager
12+
from typing import Any, Dict, Optional
13+
from unittest.mock import MagicMock, call, patch
14+
15+
import pytest
16+
17+
from langfuse.langchain import CallbackHandler
18+
from langfuse.langchain.CallbackHandler import (
19+
LangchainCallbackHandler,
20+
_strip_langfuse_keys_from_dict,
21+
)
22+
23+
24+
# ---------------------------------------------------------------------------
25+
# Helpers
26+
# ---------------------------------------------------------------------------
27+
28+
29+
def _make_handler() -> CallbackHandler:
30+
"""Return a CallbackHandler with Langfuse SDK calls mocked out."""
31+
with patch("langfuse.langchain.CallbackHandler.get_client") as mock_get_client:
32+
mock_client = MagicMock()
33+
mock_client.start_observation.return_value = MagicMock(trace_id="trace-123")
34+
mock_get_client.return_value = mock_client
35+
handler = CallbackHandler()
36+
# Keep a reference so tests can inspect it
37+
handler._langfuse_client = MagicMock()
38+
handler._langfuse_client.start_observation.return_value = MagicMock(
39+
trace_id="trace-123"
40+
)
41+
return handler
42+
43+
44+
def _make_run_id() -> uuid.UUID:
45+
return uuid.uuid4()
46+
47+
48+
# ---------------------------------------------------------------------------
49+
# Tests: _parse_langfuse_trace_attributes
50+
# ---------------------------------------------------------------------------
51+
52+
53+
class TestParseLangfuseTraceAttributes:
54+
def _parse(self, handler, metadata=None, tags=None):
55+
return handler._parse_langfuse_trace_attributes(
56+
metadata=metadata, tags=tags
57+
)
58+
59+
def test_extracts_trace_name_from_metadata(self):
60+
handler = _make_handler()
61+
result = self._parse(
62+
handler, metadata={"langfuse_trace_name": "my-agent"}
63+
)
64+
assert result["trace_name"] == "my-agent"
65+
66+
def test_ignores_non_string_trace_name(self):
67+
handler = _make_handler()
68+
result = self._parse(handler, metadata={"langfuse_trace_name": 42})
69+
assert "trace_name" not in result
70+
71+
def test_does_not_set_trace_name_when_absent(self):
72+
handler = _make_handler()
73+
result = self._parse(handler, metadata={"langfuse_session_id": "s1"})
74+
assert "trace_name" not in result
75+
76+
def test_extracts_all_attributes_together(self):
77+
handler = _make_handler()
78+
result = self._parse(
79+
handler,
80+
metadata={
81+
"langfuse_trace_name": "agent",
82+
"langfuse_session_id": "sess-1",
83+
"langfuse_user_id": "user-1",
84+
},
85+
)
86+
assert result["trace_name"] == "agent"
87+
assert result["session_id"] == "sess-1"
88+
assert result["user_id"] == "user-1"
89+
90+
91+
# ---------------------------------------------------------------------------
92+
# Tests: on_chain_start passes trace_name to propagate_attributes
93+
# ---------------------------------------------------------------------------
94+
95+
96+
class TestOnChainStartTraceNamePropagation:
97+
"""Verify that on_chain_start forwards trace_name to propagate_attributes."""
98+
99+
def _run_on_chain_start(
100+
self,
101+
handler: CallbackHandler,
102+
serialized: Optional[Dict[str, Any]] = None,
103+
name: Optional[str] = None,
104+
parent_run_id: Optional[uuid.UUID] = None,
105+
metadata: Optional[Dict[str, Any]] = None,
106+
) -> uuid.UUID:
107+
run_id = _make_run_id()
108+
kwargs: Dict[str, Any] = {}
109+
if name is not None:
110+
kwargs["name"] = name
111+
handler.on_chain_start(
112+
serialized=serialized or {},
113+
inputs={},
114+
run_id=run_id,
115+
parent_run_id=parent_run_id,
116+
metadata=metadata,
117+
**kwargs,
118+
)
119+
return run_id
120+
121+
def test_trace_name_passed_to_propagate_attributes(self):
122+
"""span_name derived from serialized['name'] is forwarded as trace_name."""
123+
handler = _make_handler()
124+
125+
@contextmanager
126+
def _noop_ctx(*args, **kwargs):
127+
yield
128+
129+
with patch(
130+
"langfuse.langchain.CallbackHandler.propagate_attributes"
131+
) as mock_pa:
132+
mock_pa.return_value = MagicMock(
133+
__enter__=MagicMock(return_value=None),
134+
__exit__=MagicMock(return_value=False),
135+
)
136+
self._run_on_chain_start(
137+
handler,
138+
serialized={"name": "my-agent"},
139+
parent_run_id=None,
140+
)
141+
142+
mock_pa.assert_called_once()
143+
_, kwargs = mock_pa.call_args
144+
assert kwargs.get("trace_name") == "my-agent"
145+
146+
def test_trace_name_uses_kwargs_name_over_serialized(self):
147+
"""The 'name' kwarg takes priority over serialized dict (LangChain convention)."""
148+
handler = _make_handler()
149+
150+
with patch(
151+
"langfuse.langchain.CallbackHandler.propagate_attributes"
152+
) as mock_pa:
153+
mock_pa.return_value = MagicMock(
154+
__enter__=MagicMock(return_value=None),
155+
__exit__=MagicMock(return_value=False),
156+
)
157+
self._run_on_chain_start(
158+
handler,
159+
serialized={"name": "fallback-name"},
160+
name="explicit-name",
161+
parent_run_id=None,
162+
)
163+
164+
_, kwargs = mock_pa.call_args
165+
assert kwargs.get("trace_name") == "explicit-name"
166+
167+
def test_metadata_langfuse_trace_name_overrides_span_name(self):
168+
"""langfuse_trace_name in metadata takes priority over computed span_name."""
169+
handler = _make_handler()
170+
171+
with patch(
172+
"langfuse.langchain.CallbackHandler.propagate_attributes"
173+
) as mock_pa:
174+
mock_pa.return_value = MagicMock(
175+
__enter__=MagicMock(return_value=None),
176+
__exit__=MagicMock(return_value=False),
177+
)
178+
self._run_on_chain_start(
179+
handler,
180+
serialized={"name": "computed-name"},
181+
metadata={"langfuse_trace_name": "override-name"},
182+
parent_run_id=None,
183+
)
184+
185+
_, kwargs = mock_pa.call_args
186+
assert kwargs.get("trace_name") == "override-name"
187+
188+
def test_propagate_attributes_not_called_for_child_runs(self):
189+
"""propagate_attributes must only be called at the root (parent_run_id=None)."""
190+
handler = _make_handler()
191+
root_run_id = _make_run_id()
192+
handler._child_to_parent_run_id_map[root_run_id] = None
193+
194+
with patch(
195+
"langfuse.langchain.CallbackHandler.propagate_attributes"
196+
) as mock_pa:
197+
mock_pa.return_value = MagicMock(
198+
__enter__=MagicMock(return_value=None),
199+
__exit__=MagicMock(return_value=False),
200+
)
201+
# Child run — parent_run_id is set
202+
self._run_on_chain_start(
203+
handler,
204+
serialized={"name": "child-node"},
205+
parent_run_id=root_run_id,
206+
)
207+
208+
mock_pa.assert_not_called()
209+
210+
def test_empty_span_name_still_propagated(self):
211+
"""Even when span_name resolves to '', it should still be forwarded."""
212+
handler = _make_handler()
213+
214+
with patch(
215+
"langfuse.langchain.CallbackHandler.propagate_attributes"
216+
) as mock_pa:
217+
mock_pa.return_value = MagicMock(
218+
__enter__=MagicMock(return_value=None),
219+
__exit__=MagicMock(return_value=False),
220+
)
221+
# Empty serialized — span_name will be '<unknown>'
222+
self._run_on_chain_start(
223+
handler,
224+
serialized=None,
225+
parent_run_id=None,
226+
)
227+
228+
_, kwargs = mock_pa.call_args
229+
# '<unknown>' is what get_langchain_run_name returns for None serialized
230+
assert kwargs.get("trace_name") == "<unknown>"
231+
232+
233+
# ---------------------------------------------------------------------------
234+
# Tests: _strip_langfuse_keys_from_dict strips langfuse_trace_name
235+
# ---------------------------------------------------------------------------
236+
237+
238+
class TestStripLangfuseKeys:
239+
def test_strips_langfuse_trace_name(self):
240+
metadata = {
241+
"langfuse_trace_name": "my-agent",
242+
"other_key": "value",
243+
}
244+
result = _strip_langfuse_keys_from_dict(metadata, keep_langfuse_trace_attributes=False)
245+
assert "langfuse_trace_name" not in result
246+
assert result["other_key"] == "value"
247+
248+
def test_keeps_langfuse_trace_name_when_flag_set(self):
249+
metadata = {
250+
"langfuse_trace_name": "my-agent",
251+
"other_key": "value",
252+
}
253+
result = _strip_langfuse_keys_from_dict(metadata, keep_langfuse_trace_attributes=True)
254+
assert result["langfuse_trace_name"] == "my-agent"
255+
256+
def test_strips_all_trace_attribute_keys_together(self):
257+
metadata = {
258+
"langfuse_trace_name": "n",
259+
"langfuse_session_id": "s",
260+
"langfuse_user_id": "u",
261+
"langfuse_tags": ["t"],
262+
"keep_me": 1,
263+
}
264+
result = _strip_langfuse_keys_from_dict(metadata, keep_langfuse_trace_attributes=False)
265+
for key in ("langfuse_trace_name", "langfuse_session_id", "langfuse_user_id", "langfuse_tags"):
266+
assert key not in result
267+
assert result["keep_me"] == 1

0 commit comments

Comments
 (0)