1616import pydantic
1717from opentelemetry import context , trace
1818from opentelemetry .context import _RUNTIME_CONTEXT
19+ from opentelemetry .util ._decorator import _AgnosticContextManager
1920
21+ from langfuse import propagate_attributes
2022from langfuse ._client .attributes import LangfuseOtelSpanAttributes
2123from langfuse ._client .client import Langfuse
2224from langfuse ._client .get_client import get_client
3234from langfuse .langchain .utils import _extract_model_name
3335from langfuse .logger import langfuse_logger
3436from langfuse .types import TraceContext
37+ from langfuse .types import TraceContext
3538
3639try :
3740 import langchain
@@ -96,14 +99,12 @@ def __init__(
9699 self ,
97100 * ,
98101 public_key : Optional [str ] = None ,
99- update_trace : bool = False ,
100102 trace_context : Optional [TraceContext ] = None ,
101103 ) -> None :
102104 """Initialize the LangchainCallbackHandler.
103105
104106 Args:
105107 public_key: Optional Langfuse public key. If not provided, will use the default client configuration.
106- update_trace: Whether to update the Langfuse trace with the chains input / output / metadata / name. Defaults to False.
107108 trace_context: Optional context for connecting to an existing trace (distributed tracing) or
108109 setting a custom trace id for the root LangChain run. Pass a `TraceContext` dict, e.g.
109110 `{"trace_id": "<trace_id>"}` (and optionally `{"parent_span_id": "<span_id>"}`) to link
@@ -118,10 +119,8 @@ def __init__(
118119 handler = CallbackHandler(trace_context={"trace_id": "my-trace-id"})
119120 ```
120121 """
121- self .client = get_client (public_key = public_key )
122- self .run_inline = True
123-
124- self .runs : Dict [
122+ self ._langfuse_client = get_client (public_key = public_key )
123+ self ._runs : Dict [
125124 UUID ,
126125 Union [
127126 LangfuseSpan ,
@@ -132,14 +131,13 @@ def __init__(
132131 LangfuseRetriever ,
133132 ],
134133 ] = {}
135- self ._child_to_parent_run_id_map : Dict [UUID , Optional [UUID ]] = {}
136- self .context_tokens : Dict [UUID , Token ] = {}
137- self .prompt_to_parent_run_map : Dict [UUID , Any ] = {}
138- self .updated_completion_start_time_memo : Set [UUID ] = set ()
134+ self ._context_tokens : Dict [UUID , Token ] = {}
135+ self ._prompt_to_parent_run_map : Dict [UUID , Any ] = {}
136+ self ._updated_completion_start_time_memo : Set [UUID ] = set ()
137+ self ._propagation_context_manager : Optional [_AgnosticContextManager ] = None
138+ self ._trace_context = trace_context
139139
140140 self .last_trace_id : Optional [str ] = None
141- self .update_trace = update_trace
142- self .trace_context = trace_context
143141
144142 def on_llm_new_token (
145143 self ,
@@ -154,14 +152,14 @@ def on_llm_new_token(
154152 f"on llm new token: run_id: { run_id } parent_run_id: { parent_run_id } "
155153 )
156154 if (
157- run_id in self .runs
158- and isinstance (self .runs [run_id ], LangfuseGeneration )
159- and run_id not in self .updated_completion_start_time_memo
155+ run_id in self ._runs
156+ and isinstance (self ._runs [run_id ], LangfuseGeneration )
157+ and run_id not in self ._updated_completion_start_time_memo
160158 ):
161- current_generation = cast (LangfuseGeneration , self .runs [run_id ])
159+ current_generation = cast (LangfuseGeneration , self ._runs [run_id ])
162160 current_generation .update (completion_start_time = _get_timestamp ())
163161
164- self .updated_completion_start_time_memo .add (run_id )
162+ self ._updated_completion_start_time_memo .add (run_id )
165163
166164 def _get_observation_type_from_serialized (
167165 self , serialized : Optional [Dict [str , Any ]], callback_type : str , ** kwargs : Any
@@ -268,12 +266,14 @@ def on_retriever_error(
268266 except Exception as e :
269267 langfuse_logger .exception (e )
270268
271- def _parse_langfuse_trace_attributes_from_metadata (
272- self ,
273- metadata : Optional [Dict [str , Any ]],
269+ def _parse_langfuse_trace_attributes (
270+ self , * , metadata : Optional [Dict [str , Any ]], tags : Optional [List [str ]]
274271 ) -> Dict [str , Any ]:
275272 attributes : Dict [str , Any ] = {}
276273
274+ if metadata is None and tags is not None :
275+ return {"tags" : tags }
276+
277277 if metadata is None :
278278 return attributes
279279
@@ -287,8 +287,13 @@ def _parse_langfuse_trace_attributes_from_metadata(
287287 ):
288288 attributes ["user_id" ] = metadata ["langfuse_user_id" ]
289289
290- if "langfuse_tags" in metadata and isinstance (metadata ["langfuse_tags" ], list ):
291- attributes ["tags" ] = [str (tag ) for tag in metadata ["langfuse_tags" ]]
290+ if tags is not None or (
291+ "langfuse_tags" in metadata and isinstance (metadata ["langfuse_tags" ], list )
292+ ):
293+ merged_tags = list (set (metadata ["langfuse_tags" ]) | set (tags or []))
294+ attributes ["tags" ] = [str (tag ) for tag in set (merged_tags )]
295+
296+ attributes ["metadata" ] = _strip_langfuse_keys_from_dict (metadata , False )
292297
293298 return attributes
294299
@@ -321,10 +326,26 @@ def on_chain_start(
321326 serialized , "chain" , ** kwargs
322327 )
323328
329+ # Handle trace attribute propagation at the root of the chain
330+ if parent_run_id is None :
331+ parsed_trace_attributes = self ._parse_langfuse_trace_attributes (
332+ metadata = metadata , tags = tags
333+ )
334+
335+ self ._propagation_context_manager = propagate_attributes (
336+ trace_name = span_name ,
337+ user_id = parsed_trace_attributes .get ("user_id" , None ),
338+ session_id = parsed_trace_attributes .get ("session_id" , None ),
339+ tags = parsed_trace_attributes .get ("tags" , None ),
340+ metadata = parsed_trace_attributes .get ("metadata" , None ),
341+ )
342+
343+ self ._propagation_context_manager .__enter__ ()
344+
324345 obs = self ._get_parent_observation (parent_run_id )
325346 if isinstance (obs , Langfuse ):
326347 span = obs .start_observation (
327- trace_context = self .trace_context ,
348+ trace_context = self ._trace_context ,
328349 name = span_name ,
329350 as_type = observation_type ,
330351 metadata = span_metadata ,
@@ -348,24 +369,7 @@ def on_chain_start(
348369
349370 self ._attach_observation (run_id , span )
350371
351- if parent_run_id is None :
352- span .update_trace (
353- ** (
354- cast (
355- Any ,
356- {
357- "input" : inputs ,
358- "name" : span_name ,
359- "metadata" : span_metadata ,
360- },
361- )
362- if self .update_trace
363- else {}
364- ),
365- ** self ._parse_langfuse_trace_attributes_from_metadata (metadata ),
366- )
367-
368- self .last_trace_id = self .runs [run_id ].trace_id
372+ self .last_trace_id = self ._runs [run_id ].trace_id
369373
370374 except Exception as e :
371375 langfuse_logger .exception (e )
@@ -388,17 +392,17 @@ def _register_langfuse_prompt(
388392 langfuse_prompt = metadata and metadata .get ("langfuse_prompt" , None )
389393
390394 if langfuse_prompt :
391- self .prompt_to_parent_run_map [parent_run_id ] = langfuse_prompt
395+ self ._prompt_to_parent_run_map [parent_run_id ] = langfuse_prompt
392396
393397 # If we have a registered prompt that has not been linked to a generation yet, we need to allow _children_ of that chain to link to it.
394398 # Otherwise, we only allow generations on the same level of the prompt rendering to be linked, not if they are nested.
395- elif parent_run_id in self .prompt_to_parent_run_map :
396- registered_prompt = self .prompt_to_parent_run_map [parent_run_id ]
397- self .prompt_to_parent_run_map [run_id ] = registered_prompt
399+ elif parent_run_id in self ._prompt_to_parent_run_map :
400+ registered_prompt = self ._prompt_to_parent_run_map [parent_run_id ]
401+ self ._prompt_to_parent_run_map [run_id ] = registered_prompt
398402
399403 def _deregister_langfuse_prompt (self , run_id : Optional [UUID ]) -> None :
400- if run_id is not None and run_id in self .prompt_to_parent_run_map :
401- del self .prompt_to_parent_run_map [run_id ]
404+ if run_id is not None and run_id in self ._prompt_to_parent_run_map :
405+ del self ._prompt_to_parent_run_map [run_id ]
402406
403407 def _get_parent_observation (
404408 self , parent_run_id : Optional [UUID ]
@@ -411,10 +415,10 @@ def _get_parent_observation(
411415 LangfuseSpan ,
412416 LangfuseTool ,
413417 ]:
414- if parent_run_id and parent_run_id in self .runs :
415- return self .runs [parent_run_id ]
418+ if parent_run_id and parent_run_id in self ._runs :
419+ return self ._runs [parent_run_id ]
416420
417- return self .client
421+ return self ._langfuse_client
418422
419423 def _attach_observation (
420424 self ,
@@ -431,8 +435,8 @@ def _attach_observation(
431435 ctx = trace .set_span_in_context (observation ._otel_span )
432436 token = context .attach (ctx )
433437
434- self .runs [run_id ] = observation
435- self .context_tokens [run_id ] = token
438+ self ._runs [run_id ] = observation
439+ self ._context_tokens [run_id ] = token
436440
437441 def _detach_observation (
438442 self , run_id : UUID
@@ -446,7 +450,7 @@ def _detach_observation(
446450 LangfuseTool ,
447451 ]
448452 ]:
449- token = self .context_tokens .pop (run_id , None )
453+ token = self ._context_tokens .pop (run_id , None )
450454
451455 if token :
452456 try :
@@ -471,7 +475,7 @@ def _detach_observation(
471475 LangfuseSpan ,
472476 LangfuseTool ,
473477 ],
474- self .runs .pop (run_id , None ),
478+ self ._runs .pop (run_id , None ),
475479 )
476480
477481 def on_agent_action (
@@ -490,7 +494,7 @@ def on_agent_action(
490494 "on_agent_action" , run_id , parent_run_id , action = action
491495 )
492496
493- agent_run = self .runs .get (run_id , None )
497+ agent_run = self ._runs .get (run_id , None )
494498
495499 if agent_run is not None :
496500 agent_run ._otel_span .set_attribute (
@@ -519,7 +523,7 @@ def on_agent_finish(
519523 )
520524 # Langchain is sending same run ID for both agent finish and chain end
521525 # handle cleanup of observation in the chain end callback
522- agent_run = self .runs .get (run_id , None )
526+ agent_run = self ._runs .get (run_id , None )
523527
524528 if agent_run is not None :
525529 agent_run ._otel_span .set_attribute (
@@ -555,8 +559,11 @@ def on_chain_end(
555559 input = kwargs .get ("inputs" ),
556560 )
557561
558- if parent_run_id is None and self .update_trace :
559- span .update_trace (output = outputs , input = kwargs .get ("inputs" ))
562+ if (
563+ parent_run_id is None
564+ and self ._propagation_context_manager is not None
565+ ):
566+ self ._propagation_context_manager .__exit__ (None , None , None )
560567
561568 span .end ()
562569
@@ -836,15 +843,11 @@ def __on_llm_action(
836843 model_name = self ._parse_model_and_log_errors (
837844 serialized = serialized , metadata = metadata , kwargs = kwargs
838845 )
839-
840- registered_prompt = None
841- current_parent_run_id = parent_run_id
842-
843- # Check all parents for registered prompt
844- while current_parent_run_id is not None :
845- registered_prompt = self .prompt_to_parent_run_map .get (
846- current_parent_run_id
847- )
846+ registered_prompt = (
847+ self ._prompt_to_parent_run_map .get (parent_run_id )
848+ if parent_run_id is not None
849+ else None
850+ )
848851
849852 if registered_prompt :
850853 self ._deregister_langfuse_prompt (current_parent_run_id )
@@ -875,7 +878,7 @@ def __on_llm_action(
875878 ) # type: ignore
876879 self ._attach_observation (run_id , generation )
877880
878- self .last_trace_id = self .runs [run_id ].trace_id
881+ self .last_trace_id = self ._runs [run_id ].trace_id
879882
880883 except Exception as e :
881884 langfuse_logger .exception (e )
@@ -982,10 +985,7 @@ def on_llm_end(
982985 langfuse_logger .exception (e )
983986
984987 finally :
985- self .updated_completion_start_time_memo .discard (run_id )
986-
987- if parent_run_id is None :
988- self ._reset ()
988+ self ._updated_completion_start_time_memo .discard (run_id )
989989
990990 def on_llm_error (
991991 self ,
0 commit comments