2828 LangfuseSpan ,
2929 LangfuseTool ,
3030)
31- from langfuse .types import TraceContext
3231from langfuse ._utils import _get_timestamp
3332from langfuse .langchain .utils import _extract_model_name
3433from langfuse .logger import langfuse_logger
34+ from langfuse .types import TraceContext
3535
3636try :
3737 import langchain
@@ -132,6 +132,7 @@ def __init__(
132132 LangfuseRetriever ,
133133 ],
134134 ] = {}
135+ self ._child_to_parent_run_id_map : Dict [UUID , Optional [UUID ]] = {}
135136 self .context_tokens : Dict [UUID , Token ] = {}
136137 self .prompt_to_parent_run_map : Dict [UUID , Any ] = {}
137138 self .updated_completion_start_time_memo : Set [UUID ] = set ()
@@ -302,6 +303,8 @@ def on_chain_start(
302303 metadata : Optional [Dict [str , Any ]] = None ,
303304 ** kwargs : Any ,
304305 ) -> Any :
306+ self ._child_to_parent_run_id_map [run_id ] = parent_run_id
307+
305308 try :
306309 self ._log_debug_event (
307310 "on_chain_start" , run_id , parent_run_id , inputs = inputs
@@ -480,6 +483,8 @@ def on_agent_action(
480483 ** kwargs : Any ,
481484 ) -> Any :
482485 """Run on agent action."""
486+ self ._child_to_parent_run_id_map [run_id ] = parent_run_id
487+
483488 try :
484489 self ._log_debug_event (
485490 "on_agent_action" , run_id , parent_run_id , action = action
@@ -560,6 +565,10 @@ def on_chain_end(
560565 except Exception as e :
561566 langfuse_logger .exception (e )
562567
568+ finally :
569+ if parent_run_id is None :
570+ self ._reset ()
571+
563572 def on_chain_error (
564573 self ,
565574 error : BaseException ,
@@ -603,6 +612,8 @@ def on_chat_model_start(
603612 metadata : Optional [Dict [str , Any ]] = None ,
604613 ** kwargs : Any ,
605614 ) -> Any :
615+ self ._child_to_parent_run_id_map [run_id ] = parent_run_id
616+
606617 try :
607618 self ._log_debug_event (
608619 "on_chat_model_start" , run_id , parent_run_id , messages = messages
@@ -635,6 +646,8 @@ def on_llm_start(
635646 metadata : Optional [Dict [str , Any ]] = None ,
636647 ** kwargs : Any ,
637648 ) -> Any :
649+ self ._child_to_parent_run_id_map [run_id ] = parent_run_id
650+
638651 try :
639652 self ._log_debug_event (
640653 "on_llm_start" , run_id , parent_run_id , prompts = prompts
@@ -662,6 +675,8 @@ def on_tool_start(
662675 metadata : Optional [Dict [str , Any ]] = None ,
663676 ** kwargs : Any ,
664677 ) -> Any :
678+ self ._child_to_parent_run_id_map [run_id ] = parent_run_id
679+
665680 try :
666681 self ._log_debug_event (
667682 "on_tool_start" , run_id , parent_run_id , input_str = input_str
@@ -704,6 +719,8 @@ def on_retriever_start(
704719 metadata : Optional [Dict [str , Any ]] = None ,
705720 ** kwargs : Any ,
706721 ) -> Any :
722+ self ._child_to_parent_run_id_map [run_id ] = parent_run_id
723+
707724 try :
708725 self ._log_debug_event (
709726 "on_retriever_start" , run_id , parent_run_id , query = query
@@ -809,6 +826,8 @@ def __on_llm_action(
809826 metadata : Optional [Dict [str , Any ]] = None ,
810827 ** kwargs : Any ,
811828 ) -> None :
829+ self ._child_to_parent_run_id_map [run_id ] = parent_run_id
830+
812831 try :
813832 tools = kwargs .get ("invocation_params" , {}).get ("tools" , None )
814833 if tools and isinstance (tools , list ):
@@ -817,14 +836,23 @@ def __on_llm_action(
817836 model_name = self ._parse_model_and_log_errors (
818837 serialized = serialized , metadata = metadata , kwargs = kwargs
819838 )
820- registered_prompt = (
821- self .prompt_to_parent_run_map .get (parent_run_id )
822- if parent_run_id is not None
823- else None
824- )
825839
826- if registered_prompt :
827- self ._deregister_langfuse_prompt (parent_run_id )
840+ registered_prompt = None
841+ current_parent_run_id = parent_run_id
842+
843+ # Check all parents for registered prompt
844+ while current_parent_run_id is not None :
845+ registered_prompt = self .prompt_to_parent_run_map .get (
846+ current_parent_run_id
847+ )
848+
849+ if registered_prompt :
850+ self ._deregister_langfuse_prompt (current_parent_run_id )
851+ break
852+ else :
853+ current_parent_run_id = self ._child_to_parent_run_id_map .get (
854+ current_parent_run_id , None
855+ )
828856
829857 content = {
830858 "name" : self .get_langchain_run_name (serialized , ** kwargs ),
@@ -956,6 +984,9 @@ def on_llm_end(
956984 finally :
957985 self .updated_completion_start_time_memo .discard (run_id )
958986
987+ if parent_run_id is None :
988+ self ._reset ()
989+
959990 def on_llm_error (
960991 self ,
961992 error : BaseException ,
@@ -980,6 +1011,9 @@ def on_llm_error(
9801011 except Exception as e :
9811012 langfuse_logger .exception (e )
9821013
1014+ def _reset (self ) -> None :
1015+ self ._child_to_parent_run_id_map = {}
1016+
9831017 def __join_tags_and_metadata (
9841018 self ,
9851019 tags : Optional [List [str ]] = None ,
@@ -1047,7 +1081,7 @@ def _log_debug_event(
10471081 ** kwargs : Any ,
10481082 ) -> None :
10491083 langfuse_logger .debug (
1050- f"Event: { event_name } , run_id: { str ( run_id )[: 5 ] } , parent_run_id: { str ( parent_run_id )[: 5 ] } "
1084+ f"Event: { event_name } , run_id: { run_id } , parent_run_id: { parent_run_id } "
10511085 )
10521086
10531087
@@ -1210,7 +1244,9 @@ def _parse_usage_model(usage: Union[pydantic.BaseModel, dict]) -> Any:
12101244 usage_model ["input" ] = max (0 , usage_model ["input" ] - value )
12111245
12121246 if f"input_modality_{ item ['modality' ]} " in usage_model :
1213- usage_model [f"input_modality_{ item ['modality' ]} " ] = max (0 , usage_model [f"input_modality_{ item ['modality' ]} " ] - value )
1247+ usage_model [f"input_modality_{ item ['modality' ]} " ] = max (
1248+ 0 , usage_model [f"input_modality_{ item ['modality' ]} " ] - value
1249+ )
12141250
12151251 usage_model = {k : v for k , v in usage_model .items () if isinstance (v , int )}
12161252
0 commit comments