Skip to content

Commit cc5d7ed

Browse files
committed
update callbackhandler
1 parent 6b1eedb commit cc5d7ed

1 file changed

Lines changed: 72 additions & 72 deletions

File tree

langfuse/langchain/CallbackHandler.py

Lines changed: 72 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
import pydantic
1717
from opentelemetry import context, trace
1818
from opentelemetry.context import _RUNTIME_CONTEXT
19+
from opentelemetry.util._decorator import _AgnosticContextManager
1920

21+
from langfuse import propagate_attributes
2022
from langfuse._client.attributes import LangfuseOtelSpanAttributes
2123
from langfuse._client.client import Langfuse
2224
from langfuse._client.get_client import get_client
@@ -32,6 +34,7 @@
3234
from langfuse.langchain.utils import _extract_model_name
3335
from langfuse.logger import langfuse_logger
3436
from langfuse.types import TraceContext
37+
from langfuse.types import TraceContext
3538

3639
try:
3740
import langchain
@@ -96,14 +99,12 @@ def __init__(
9699
self,
97100
*,
98101
public_key: Optional[str] = None,
99-
update_trace: bool = False,
100102
trace_context: Optional[TraceContext] = None,
101103
) -> None:
102104
"""Initialize the LangchainCallbackHandler.
103105
104106
Args:
105107
public_key: Optional Langfuse public key. If not provided, will use the default client configuration.
106-
update_trace: Whether to update the Langfuse trace with the chains input / output / metadata / name. Defaults to False.
107108
trace_context: Optional context for connecting to an existing trace (distributed tracing) or
108109
setting a custom trace id for the root LangChain run. Pass a `TraceContext` dict, e.g.
109110
`{"trace_id": "<trace_id>"}` (and optionally `{"parent_span_id": "<span_id>"}`) to link
@@ -118,10 +119,8 @@ def __init__(
118119
handler = CallbackHandler(trace_context={"trace_id": "my-trace-id"})
119120
```
120121
"""
121-
self.client = get_client(public_key=public_key)
122-
self.run_inline = True
123-
124-
self.runs: Dict[
122+
self._langfuse_client = get_client(public_key=public_key)
123+
self._runs: Dict[
125124
UUID,
126125
Union[
127126
LangfuseSpan,
@@ -132,14 +131,13 @@ def __init__(
132131
LangfuseRetriever,
133132
],
134133
] = {}
135-
self._child_to_parent_run_id_map: Dict[UUID, Optional[UUID]] = {}
136-
self.context_tokens: Dict[UUID, Token] = {}
137-
self.prompt_to_parent_run_map: Dict[UUID, Any] = {}
138-
self.updated_completion_start_time_memo: Set[UUID] = set()
134+
self._context_tokens: Dict[UUID, Token] = {}
135+
self._prompt_to_parent_run_map: Dict[UUID, Any] = {}
136+
self._updated_completion_start_time_memo: Set[UUID] = set()
137+
self._propagation_context_manager: Optional[_AgnosticContextManager] = None
138+
self._trace_context = trace_context
139139

140140
self.last_trace_id: Optional[str] = None
141-
self.update_trace = update_trace
142-
self.trace_context = trace_context
143141

144142
def on_llm_new_token(
145143
self,
@@ -154,14 +152,14 @@ def on_llm_new_token(
154152
f"on llm new token: run_id: {run_id} parent_run_id: {parent_run_id}"
155153
)
156154
if (
157-
run_id in self.runs
158-
and isinstance(self.runs[run_id], LangfuseGeneration)
159-
and run_id not in self.updated_completion_start_time_memo
155+
run_id in self._runs
156+
and isinstance(self._runs[run_id], LangfuseGeneration)
157+
and run_id not in self._updated_completion_start_time_memo
160158
):
161-
current_generation = cast(LangfuseGeneration, self.runs[run_id])
159+
current_generation = cast(LangfuseGeneration, self._runs[run_id])
162160
current_generation.update(completion_start_time=_get_timestamp())
163161

164-
self.updated_completion_start_time_memo.add(run_id)
162+
self._updated_completion_start_time_memo.add(run_id)
165163

166164
def _get_observation_type_from_serialized(
167165
self, serialized: Optional[Dict[str, Any]], callback_type: str, **kwargs: Any
@@ -268,12 +266,14 @@ def on_retriever_error(
268266
except Exception as e:
269267
langfuse_logger.exception(e)
270268

271-
def _parse_langfuse_trace_attributes_from_metadata(
272-
self,
273-
metadata: Optional[Dict[str, Any]],
269+
def _parse_langfuse_trace_attributes(
270+
self, *, metadata: Optional[Dict[str, Any]], tags: Optional[List[str]]
274271
) -> Dict[str, Any]:
275272
attributes: Dict[str, Any] = {}
276273

274+
if metadata is None and tags is not None:
275+
return {"tags": tags}
276+
277277
if metadata is None:
278278
return attributes
279279

@@ -287,8 +287,13 @@ def _parse_langfuse_trace_attributes_from_metadata(
287287
):
288288
attributes["user_id"] = metadata["langfuse_user_id"]
289289

290-
if "langfuse_tags" in metadata and isinstance(metadata["langfuse_tags"], list):
291-
attributes["tags"] = [str(tag) for tag in metadata["langfuse_tags"]]
290+
if tags is not None or (
291+
"langfuse_tags" in metadata and isinstance(metadata["langfuse_tags"], list)
292+
):
293+
merged_tags = list(set(metadata["langfuse_tags"]) | set(tags or []))
294+
attributes["tags"] = [str(tag) for tag in set(merged_tags)]
295+
296+
attributes["metadata"] = _strip_langfuse_keys_from_dict(metadata, False)
292297

293298
return attributes
294299

@@ -321,10 +326,26 @@ def on_chain_start(
321326
serialized, "chain", **kwargs
322327
)
323328

329+
# Handle trace attribute propagation at the root of the chain
330+
if parent_run_id is None:
331+
parsed_trace_attributes = self._parse_langfuse_trace_attributes(
332+
metadata=metadata, tags=tags
333+
)
334+
335+
self._propagation_context_manager = propagate_attributes(
336+
trace_name=span_name,
337+
user_id=parsed_trace_attributes.get("user_id", None),
338+
session_id=parsed_trace_attributes.get("session_id", None),
339+
tags=parsed_trace_attributes.get("tags", None),
340+
metadata=parsed_trace_attributes.get("metadata", None),
341+
)
342+
343+
self._propagation_context_manager.__enter__()
344+
324345
obs = self._get_parent_observation(parent_run_id)
325346
if isinstance(obs, Langfuse):
326347
span = obs.start_observation(
327-
trace_context=self.trace_context,
348+
trace_context=self._trace_context,
328349
name=span_name,
329350
as_type=observation_type,
330351
metadata=span_metadata,
@@ -348,24 +369,7 @@ def on_chain_start(
348369

349370
self._attach_observation(run_id, span)
350371

351-
if parent_run_id is None:
352-
span.update_trace(
353-
**(
354-
cast(
355-
Any,
356-
{
357-
"input": inputs,
358-
"name": span_name,
359-
"metadata": span_metadata,
360-
},
361-
)
362-
if self.update_trace
363-
else {}
364-
),
365-
**self._parse_langfuse_trace_attributes_from_metadata(metadata),
366-
)
367-
368-
self.last_trace_id = self.runs[run_id].trace_id
372+
self.last_trace_id = self._runs[run_id].trace_id
369373

370374
except Exception as e:
371375
langfuse_logger.exception(e)
@@ -388,17 +392,17 @@ def _register_langfuse_prompt(
388392
langfuse_prompt = metadata and metadata.get("langfuse_prompt", None)
389393

390394
if langfuse_prompt:
391-
self.prompt_to_parent_run_map[parent_run_id] = langfuse_prompt
395+
self._prompt_to_parent_run_map[parent_run_id] = langfuse_prompt
392396

393397
# If we have a registered prompt that has not been linked to a generation yet, we need to allow _children_ of that chain to link to it.
394398
# Otherwise, we only allow generations on the same level of the prompt rendering to be linked, not if they are nested.
395-
elif parent_run_id in self.prompt_to_parent_run_map:
396-
registered_prompt = self.prompt_to_parent_run_map[parent_run_id]
397-
self.prompt_to_parent_run_map[run_id] = registered_prompt
399+
elif parent_run_id in self._prompt_to_parent_run_map:
400+
registered_prompt = self._prompt_to_parent_run_map[parent_run_id]
401+
self._prompt_to_parent_run_map[run_id] = registered_prompt
398402

399403
def _deregister_langfuse_prompt(self, run_id: Optional[UUID]) -> None:
400-
if run_id is not None and run_id in self.prompt_to_parent_run_map:
401-
del self.prompt_to_parent_run_map[run_id]
404+
if run_id is not None and run_id in self._prompt_to_parent_run_map:
405+
del self._prompt_to_parent_run_map[run_id]
402406

403407
def _get_parent_observation(
404408
self, parent_run_id: Optional[UUID]
@@ -411,10 +415,10 @@ def _get_parent_observation(
411415
LangfuseSpan,
412416
LangfuseTool,
413417
]:
414-
if parent_run_id and parent_run_id in self.runs:
415-
return self.runs[parent_run_id]
418+
if parent_run_id and parent_run_id in self._runs:
419+
return self._runs[parent_run_id]
416420

417-
return self.client
421+
return self._langfuse_client
418422

419423
def _attach_observation(
420424
self,
@@ -431,8 +435,8 @@ def _attach_observation(
431435
ctx = trace.set_span_in_context(observation._otel_span)
432436
token = context.attach(ctx)
433437

434-
self.runs[run_id] = observation
435-
self.context_tokens[run_id] = token
438+
self._runs[run_id] = observation
439+
self._context_tokens[run_id] = token
436440

437441
def _detach_observation(
438442
self, run_id: UUID
@@ -446,7 +450,7 @@ def _detach_observation(
446450
LangfuseTool,
447451
]
448452
]:
449-
token = self.context_tokens.pop(run_id, None)
453+
token = self._context_tokens.pop(run_id, None)
450454

451455
if token:
452456
try:
@@ -471,7 +475,7 @@ def _detach_observation(
471475
LangfuseSpan,
472476
LangfuseTool,
473477
],
474-
self.runs.pop(run_id, None),
478+
self._runs.pop(run_id, None),
475479
)
476480

477481
def on_agent_action(
@@ -490,7 +494,7 @@ def on_agent_action(
490494
"on_agent_action", run_id, parent_run_id, action=action
491495
)
492496

493-
agent_run = self.runs.get(run_id, None)
497+
agent_run = self._runs.get(run_id, None)
494498

495499
if agent_run is not None:
496500
agent_run._otel_span.set_attribute(
@@ -519,7 +523,7 @@ def on_agent_finish(
519523
)
520524
# Langchain is sending same run ID for both agent finish and chain end
521525
# handle cleanup of observation in the chain end callback
522-
agent_run = self.runs.get(run_id, None)
526+
agent_run = self._runs.get(run_id, None)
523527

524528
if agent_run is not None:
525529
agent_run._otel_span.set_attribute(
@@ -555,8 +559,11 @@ def on_chain_end(
555559
input=kwargs.get("inputs"),
556560
)
557561

558-
if parent_run_id is None and self.update_trace:
559-
span.update_trace(output=outputs, input=kwargs.get("inputs"))
562+
if (
563+
parent_run_id is None
564+
and self._propagation_context_manager is not None
565+
):
566+
self._propagation_context_manager.__exit__(None, None, None)
560567

561568
span.end()
562569

@@ -836,15 +843,11 @@ def __on_llm_action(
836843
model_name = self._parse_model_and_log_errors(
837844
serialized=serialized, metadata=metadata, kwargs=kwargs
838845
)
839-
840-
registered_prompt = None
841-
current_parent_run_id = parent_run_id
842-
843-
# Check all parents for registered prompt
844-
while current_parent_run_id is not None:
845-
registered_prompt = self.prompt_to_parent_run_map.get(
846-
current_parent_run_id
847-
)
846+
registered_prompt = (
847+
self._prompt_to_parent_run_map.get(parent_run_id)
848+
if parent_run_id is not None
849+
else None
850+
)
848851

849852
if registered_prompt:
850853
self._deregister_langfuse_prompt(current_parent_run_id)
@@ -875,7 +878,7 @@ def __on_llm_action(
875878
) # type: ignore
876879
self._attach_observation(run_id, generation)
877880

878-
self.last_trace_id = self.runs[run_id].trace_id
881+
self.last_trace_id = self._runs[run_id].trace_id
879882

880883
except Exception as e:
881884
langfuse_logger.exception(e)
@@ -982,10 +985,7 @@ def on_llm_end(
982985
langfuse_logger.exception(e)
983986

984987
finally:
985-
self.updated_completion_start_time_memo.discard(run_id)
986-
987-
if parent_run_id is None:
988-
self._reset()
988+
self._updated_completion_start_time_memo.discard(run_id)
989989

990990
def on_llm_error(
991991
self,

0 commit comments

Comments
 (0)