6464
6565
6666class LangchainCallbackHandler (LangchainBaseCallbackHandler ):
67- def __init__ (self , * , public_key : Optional [str ] = None ) -> None :
67+ def __init__ (
68+ self , * , public_key : Optional [str ] = None , update_trace : bool = False
69+ ) -> None :
70+ """Initialize the LangchainCallbackHandler.
71+
72+ Args:
73+ public_key: Optional Langfuse public key. If not provided, will use the default client configuration.
74+ update_trace: Whether to update the Langfuse trace with the chains input / output / metadata / name. Defaults to False.
75+ """
6876 self .client = get_client (public_key = public_key )
6977
7078 self .runs : Dict [
@@ -82,6 +90,7 @@ def __init__(self, *, public_key: Optional[str] = None) -> None:
8290 self .updated_completion_start_time_memo : Set [UUID ] = set ()
8391
8492 self .last_trace_id : Optional [str ] = None
93+ self .update_trace = update_trace
8594
8695 def on_llm_new_token (
8796 self ,
@@ -273,7 +282,19 @@ def on_chain_start(
273282 ),
274283 )
275284 span .update_trace (
276- ** self ._parse_langfuse_trace_attributes_from_metadata (metadata )
285+ ** (
286+ cast (
287+ Any ,
288+ {
289+ "input" : inputs ,
290+ "name" : span_name ,
291+ "metadata" : span_metadata ,
292+ },
293+ )
294+ if self .update_trace
295+ else {}
296+ ),
297+ ** self ._parse_langfuse_trace_attributes_from_metadata (metadata ),
277298 )
278299 self .runs [run_id ] = span
279300 else :
@@ -391,14 +412,21 @@ def on_chain_end(
391412 if run_id not in self .runs :
392413 raise Exception ("run not found" )
393414
394- self .runs [run_id ].update (
415+ span = self .runs [run_id ]
416+ span .update (
395417 output = outputs ,
396418 input = kwargs .get ("inputs" ),
397- ).end ()
419+ )
420+
421+ if parent_run_id is None and self .update_trace :
422+ span .update_trace (output = outputs , input = kwargs .get ("inputs" ))
423+
424+ span .end ()
398425
399426 del self .runs [run_id ]
400427
401428 self ._deregister_langfuse_prompt (run_id )
429+
402430 except Exception as e :
403431 langfuse_logger .exception (e )
404432
@@ -968,22 +996,41 @@ def _parse_usage_model(usage: typing.Union[pydantic.BaseModel, dict]) -> Any:
968996 usage_model = cast (Dict , usage .copy ()) # Copy all existing key-value pairs
969997
970998 # Skip OpenAI usage types as they are handled server side
971- if not all (
972- openai_key in usage_model
973- for openai_key in ["prompt_tokens" , "completion_tokens" , "total_tokens" ]
999+ if (
1000+ all (
1001+ openai_key in usage_model
1002+ for openai_key in [
1003+ "prompt_tokens" ,
1004+ "completion_tokens" ,
1005+ "total_tokens" ,
1006+ "prompt_tokens_details" ,
1007+ "completion_tokens_details" ,
1008+ ]
1009+ )
1010+ and len (usage_model .keys ()) == 5
1011+ ) or (
1012+ all (
1013+ openai_key in usage_model
1014+ for openai_key in [
1015+ "prompt_tokens" ,
1016+ "completion_tokens" ,
1017+ "total_tokens" ,
1018+ ]
1019+ )
1020+ and len (usage_model .keys ()) == 3
9741021 ):
975- for model_key , langfuse_key in conversion_list :
976- if model_key in usage_model :
977- captured_count = usage_model . pop ( model_key )
978- final_count = (
979- sum ( captured_count )
980- if isinstance ( captured_count , list )
981- else captured_count
982- ) # For Bedrock, the token count is a list when streamed
983-
984- usage_model [ langfuse_key ] = (
985- final_count # Translate key and keep the value
986- )
1022+ return usage_model
1023+
1024+ for model_key , langfuse_key in conversion_list :
1025+ if model_key in usage_model :
1026+ captured_count = usage_model . pop ( model_key )
1027+ final_count = (
1028+ sum ( captured_count )
1029+ if isinstance ( captured_count , list )
1030+ else captured_count
1031+ ) # For Bedrock, the token count is a list when streamed
1032+
1033+ usage_model [ langfuse_key ] = final_count # Translate key and keep the value
9871034
9881035 if isinstance (usage_model , dict ):
9891036 if "input_token_details" in usage_model :
@@ -1058,7 +1105,7 @@ def _parse_usage_model(usage: typing.Union[pydantic.BaseModel, dict]) -> Any:
10581105 if "input" in usage_model :
10591106 usage_model ["input" ] = max (0 , usage_model ["input" ] - value )
10601107
1061- usage_model = {k : v for k , v in usage_model .items () if not isinstance (v , str )}
1108+ usage_model = {k : v for k , v in usage_model .items () if isinstance (v , int )}
10621109
10631110 return usage_model if usage_model else None
10641111
0 commit comments