1111 Dict ,
1212 Generator ,
1313 Iterable ,
14+ List ,
1415 Optional ,
1516 Tuple ,
1617 TypeVar ,
@@ -468,64 +469,11 @@ def _wrap_sync_generator_result(
468469 generator : Generator ,
469470 transform_to_string : Optional [Callable [[Iterable ], str ]] = None ,
470471 ) -> Any :
471- # Capture the current context while the span is still active
472472 preserved_context = contextvars .copy_context ()
473- items : list [Any ] = []
474-
475- class ContextPreservedSyncGeneratorWrapper :
476- """Sync generator wrapper that ensures each iteration runs in preserved context."""
477-
478- def __init__ (
479- self ,
480- generator : Generator ,
481- context : contextvars .Context ,
482- items : list [Any ],
483- span : Union [
484- LangfuseSpan ,
485- LangfuseGeneration ,
486- LangfuseAgent ,
487- LangfuseTool ,
488- LangfuseChain ,
489- LangfuseRetriever ,
490- LangfuseEvaluator ,
491- LangfuseEmbedding ,
492- LangfuseGuardrail ,
493- ],
494- transform_fn : Optional [Callable [[Iterable ], str ]],
495- ) -> None :
496- self .generator = generator
497- self .context = context
498- self .items = items
499- self .span = span
500- self .transform_fn = transform_fn
501-
502- def __iter__ (self ) -> "ContextPreservedSyncGeneratorWrapper" :
503- return self
504-
505- def __next__ (self ) -> Any :
506- try :
507- # Run the generator's __next__ in the preserved context
508- item = self .context .run (next , self .generator )
509- self .items .append (item )
510- return item
511-
512- except StopIteration :
513- # Handle output and span cleanup when generator is exhausted
514- output : Any = self .items
515-
516- if self .transform_fn is not None :
517- output = self .transform_fn (self .items )
518- elif all (isinstance (item , str ) for item in self .items ):
519- output = "" .join (self .items )
520-
521- self .span .update (output = output )
522- self .span .end ()
523- raise # Re-raise StopIteration
524-
525- return ContextPreservedSyncGeneratorWrapper (
473+
474+ return _ContextPreservedSyncGeneratorWrapper (
526475 generator ,
527476 preserved_context ,
528- items ,
529477 langfuse_span_or_generation ,
530478 transform_to_string ,
531479 )
@@ -546,75 +494,11 @@ def _wrap_async_generator_result(
546494 generator : AsyncGenerator ,
547495 transform_to_string : Optional [Callable [[Iterable ], str ]] = None ,
548496 ) -> Any :
549- import asyncio
550-
551- # Capture the current context while the span is still active
552497 preserved_context = contextvars .copy_context ()
553- items : list [Any ] = []
554-
555- class ContextPreservedAsyncGeneratorWrapper :
556- """Async generator wrapper that ensures each iteration runs in preserved context."""
557-
558- def __init__ (
559- self ,
560- generator : AsyncGenerator ,
561- context : contextvars .Context ,
562- items : list [Any ],
563- span : Union [
564- LangfuseSpan ,
565- LangfuseGeneration ,
566- LangfuseAgent ,
567- LangfuseTool ,
568- LangfuseChain ,
569- LangfuseRetriever ,
570- LangfuseEvaluator ,
571- LangfuseEmbedding ,
572- LangfuseGuardrail ,
573- ],
574- transform_fn : Optional [Callable [[Iterable ], str ]],
575- ) -> None :
576- self .generator = generator
577- self .context = context
578- self .items = items
579- self .span = span
580- self .transform_fn = transform_fn
581-
582- def __aiter__ (self ) -> "ContextPreservedAsyncGeneratorWrapper" :
583- return self
584-
585- async def __anext__ (self ) -> Any :
586- try :
587- # Run the generator's __anext__ in the preserved context
588- try :
589- # Python 3.10+ approach with context parameter
590- item = await asyncio .create_task (
591- self .generator .__anext__ (), # type: ignore
592- context = self .context ,
593- ) # type: ignore
594- except TypeError :
595- # Python < 3.10 fallback - context parameter not supported
596- item = await self .generator .__anext__ ()
597-
598- self .items .append (item )
599- return item
600-
601- except StopAsyncIteration :
602- # Handle output and span cleanup when generator is exhausted
603- output : Any = self .items
604-
605- if self .transform_fn is not None :
606- output = self .transform_fn (self .items )
607- elif all (isinstance (item , str ) for item in self .items ):
608- output = "" .join (self .items )
609-
610- self .span .update (output = output )
611- self .span .end ()
612- raise # Re-raise StopAsyncIteration
613-
614- return ContextPreservedAsyncGeneratorWrapper (
498+
499+ return _ContextPreservedAsyncGeneratorWrapper (
615500 generator ,
616501 preserved_context ,
617- items ,
618502 langfuse_span_or_generation ,
619503 transform_to_string ,
620504 )
@@ -623,3 +507,125 @@ async def __anext__(self) -> Any:
623507_decorator = LangfuseDecorator ()
624508
625509observe = _decorator .observe
510+
511+
512+ class _ContextPreservedSyncGeneratorWrapper :
513+ """Sync generator wrapper that ensures each iteration runs in preserved context."""
514+
515+ def __init__ (
516+ self ,
517+ generator : Generator ,
518+ context : contextvars .Context ,
519+ span : Union [
520+ LangfuseSpan ,
521+ LangfuseGeneration ,
522+ LangfuseAgent ,
523+ LangfuseTool ,
524+ LangfuseChain ,
525+ LangfuseRetriever ,
526+ LangfuseEvaluator ,
527+ LangfuseEmbedding ,
528+ LangfuseGuardrail ,
529+ ],
530+ transform_fn : Optional [Callable [[Iterable ], str ]],
531+ ) -> None :
532+ self .generator = generator
533+ self .context = context
534+ self .items : List [Any ] = []
535+ self .span = span
536+ self .transform_fn = transform_fn
537+
538+ def __iter__ (self ) -> "_ContextPreservedSyncGeneratorWrapper" :
539+ return self
540+
541+ def __next__ (self ) -> Any :
542+ try :
543+ # Run the generator's __next__ in the preserved context
544+ item = self .context .run (next , self .generator )
545+ self .items .append (item )
546+
547+ return item
548+
549+ except StopIteration :
550+ # Handle output and span cleanup when generator is exhausted
551+ output : Any = self .items
552+
553+ if self .transform_fn is not None :
554+ output = self .transform_fn (self .items )
555+
556+ elif all (isinstance (item , str ) for item in self .items ):
557+ output = "" .join (self .items )
558+
559+ self .span .update (output = output ).end ()
560+
561+ raise # Re-raise StopIteration
562+
563+ except Exception as e :
564+ self .span .update (level = "ERROR" , status_message = str (e )).end ()
565+
566+ raise e
567+
568+
569+ class _ContextPreservedAsyncGeneratorWrapper :
570+ """Async generator wrapper that ensures each iteration runs in preserved context."""
571+
572+ def __init__ (
573+ self ,
574+ generator : AsyncGenerator ,
575+ context : contextvars .Context ,
576+ span : Union [
577+ LangfuseSpan ,
578+ LangfuseGeneration ,
579+ LangfuseAgent ,
580+ LangfuseTool ,
581+ LangfuseChain ,
582+ LangfuseRetriever ,
583+ LangfuseEvaluator ,
584+ LangfuseEmbedding ,
585+ LangfuseGuardrail ,
586+ ],
587+ transform_fn : Optional [Callable [[Iterable ], str ]],
588+ ) -> None :
589+ self .generator = generator
590+ self .context = context
591+ self .items : List [Any ] = []
592+ self .span = span
593+ self .transform_fn = transform_fn
594+
595+ def __aiter__ (self ) -> "_ContextPreservedAsyncGeneratorWrapper" :
596+ return self
597+
598+ async def __anext__ (self ) -> Any :
599+ try :
600+ # Run the generator's __anext__ in the preserved context
601+ try :
602+ # Python 3.10+ approach with context parameter
603+ item = await asyncio .create_task (
604+ self .generator .__anext__ (), # type: ignore
605+ context = self .context ,
606+ ) # type: ignore
607+ except TypeError :
608+ # Python < 3.10 fallback - context parameter not supported
609+ item = await self .generator .__anext__ ()
610+
611+ self .items .append (item )
612+
613+ return item
614+
615+ except StopAsyncIteration :
616+ # Handle output and span cleanup when generator is exhausted
617+ output : Any = self .items
618+
619+ if self .transform_fn is not None :
620+ output = self .transform_fn (self .items )
621+
622+ elif all (isinstance (item , str ) for item in self .items ):
623+ output = "" .join (self .items )
624+
625+ self .span .update (output = output ).end ()
626+
627+ raise # Re-raise StopAsyncIteration
628+ except Exception as e :
629+ self .span .update (level = "ERROR" , status_message = str (e )).end ()
630+
631+ raise e
0 commit comments