Skip to content

Commit 678a420

Browse files
committed
fix tests
1 parent 5841a59 commit 678a420

2 files changed

Lines changed: 129 additions & 138 deletions

File tree

langfuse/_client/observe.py

Lines changed: 127 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
Dict,
1212
Generator,
1313
Iterable,
14+
List,
1415
Optional,
1516
Tuple,
1617
TypeVar,
@@ -468,64 +469,11 @@ def _wrap_sync_generator_result(
468469
generator: Generator,
469470
transform_to_string: Optional[Callable[[Iterable], str]] = None,
470471
) -> Any:
471-
# Capture the current context while the span is still active
472472
preserved_context = contextvars.copy_context()
473-
items: list[Any] = []
474-
475-
class ContextPreservedSyncGeneratorWrapper:
476-
"""Sync generator wrapper that ensures each iteration runs in preserved context."""
477-
478-
def __init__(
479-
self,
480-
generator: Generator,
481-
context: contextvars.Context,
482-
items: list[Any],
483-
span: Union[
484-
LangfuseSpan,
485-
LangfuseGeneration,
486-
LangfuseAgent,
487-
LangfuseTool,
488-
LangfuseChain,
489-
LangfuseRetriever,
490-
LangfuseEvaluator,
491-
LangfuseEmbedding,
492-
LangfuseGuardrail,
493-
],
494-
transform_fn: Optional[Callable[[Iterable], str]],
495-
) -> None:
496-
self.generator = generator
497-
self.context = context
498-
self.items = items
499-
self.span = span
500-
self.transform_fn = transform_fn
501-
502-
def __iter__(self) -> "ContextPreservedSyncGeneratorWrapper":
503-
return self
504-
505-
def __next__(self) -> Any:
506-
try:
507-
# Run the generator's __next__ in the preserved context
508-
item = self.context.run(next, self.generator)
509-
self.items.append(item)
510-
return item
511-
512-
except StopIteration:
513-
# Handle output and span cleanup when generator is exhausted
514-
output: Any = self.items
515-
516-
if self.transform_fn is not None:
517-
output = self.transform_fn(self.items)
518-
elif all(isinstance(item, str) for item in self.items):
519-
output = "".join(self.items)
520-
521-
self.span.update(output=output)
522-
self.span.end()
523-
raise # Re-raise StopIteration
524-
525-
return ContextPreservedSyncGeneratorWrapper(
473+
474+
return _ContextPreservedSyncGeneratorWrapper(
526475
generator,
527476
preserved_context,
528-
items,
529477
langfuse_span_or_generation,
530478
transform_to_string,
531479
)
@@ -546,75 +494,11 @@ def _wrap_async_generator_result(
546494
generator: AsyncGenerator,
547495
transform_to_string: Optional[Callable[[Iterable], str]] = None,
548496
) -> Any:
549-
import asyncio
550-
551-
# Capture the current context while the span is still active
552497
preserved_context = contextvars.copy_context()
553-
items: list[Any] = []
554-
555-
class ContextPreservedAsyncGeneratorWrapper:
556-
"""Async generator wrapper that ensures each iteration runs in preserved context."""
557-
558-
def __init__(
559-
self,
560-
generator: AsyncGenerator,
561-
context: contextvars.Context,
562-
items: list[Any],
563-
span: Union[
564-
LangfuseSpan,
565-
LangfuseGeneration,
566-
LangfuseAgent,
567-
LangfuseTool,
568-
LangfuseChain,
569-
LangfuseRetriever,
570-
LangfuseEvaluator,
571-
LangfuseEmbedding,
572-
LangfuseGuardrail,
573-
],
574-
transform_fn: Optional[Callable[[Iterable], str]],
575-
) -> None:
576-
self.generator = generator
577-
self.context = context
578-
self.items = items
579-
self.span = span
580-
self.transform_fn = transform_fn
581-
582-
def __aiter__(self) -> "ContextPreservedAsyncGeneratorWrapper":
583-
return self
584-
585-
async def __anext__(self) -> Any:
586-
try:
587-
# Run the generator's __anext__ in the preserved context
588-
try:
589-
# Python 3.10+ approach with context parameter
590-
item = await asyncio.create_task(
591-
self.generator.__anext__(), # type: ignore
592-
context=self.context,
593-
) # type: ignore
594-
except TypeError:
595-
# Python < 3.10 fallback - context parameter not supported
596-
item = await self.generator.__anext__()
597-
598-
self.items.append(item)
599-
return item
600-
601-
except StopAsyncIteration:
602-
# Handle output and span cleanup when generator is exhausted
603-
output: Any = self.items
604-
605-
if self.transform_fn is not None:
606-
output = self.transform_fn(self.items)
607-
elif all(isinstance(item, str) for item in self.items):
608-
output = "".join(self.items)
609-
610-
self.span.update(output=output)
611-
self.span.end()
612-
raise # Re-raise StopAsyncIteration
613-
614-
return ContextPreservedAsyncGeneratorWrapper(
498+
499+
return _ContextPreservedAsyncGeneratorWrapper(
615500
generator,
616501
preserved_context,
617-
items,
618502
langfuse_span_or_generation,
619503
transform_to_string,
620504
)
@@ -623,3 +507,125 @@ async def __anext__(self) -> Any:
623507
_decorator = LangfuseDecorator()
624508

625509
observe = _decorator.observe
510+
511+
512+
class _ContextPreservedSyncGeneratorWrapper:
513+
"""Sync generator wrapper that ensures each iteration runs in preserved context."""
514+
515+
def __init__(
516+
self,
517+
generator: Generator,
518+
context: contextvars.Context,
519+
span: Union[
520+
LangfuseSpan,
521+
LangfuseGeneration,
522+
LangfuseAgent,
523+
LangfuseTool,
524+
LangfuseChain,
525+
LangfuseRetriever,
526+
LangfuseEvaluator,
527+
LangfuseEmbedding,
528+
LangfuseGuardrail,
529+
],
530+
transform_fn: Optional[Callable[[Iterable], str]],
531+
) -> None:
532+
self.generator = generator
533+
self.context = context
534+
self.items: List[Any] = []
535+
self.span = span
536+
self.transform_fn = transform_fn
537+
538+
def __iter__(self) -> "_ContextPreservedSyncGeneratorWrapper":
539+
return self
540+
541+
def __next__(self) -> Any:
542+
try:
543+
# Run the generator's __next__ in the preserved context
544+
item = self.context.run(next, self.generator)
545+
self.items.append(item)
546+
547+
return item
548+
549+
except StopIteration:
550+
# Handle output and span cleanup when generator is exhausted
551+
output: Any = self.items
552+
553+
if self.transform_fn is not None:
554+
output = self.transform_fn(self.items)
555+
556+
elif all(isinstance(item, str) for item in self.items):
557+
output = "".join(self.items)
558+
559+
self.span.update(output=output).end()
560+
561+
raise # Re-raise StopIteration
562+
563+
except Exception as e:
564+
self.span.update(level="ERROR", status_message=str(e)).end()
565+
566+
raise e
567+
568+
569+
class _ContextPreservedAsyncGeneratorWrapper:
570+
"""Async generator wrapper that ensures each iteration runs in preserved context."""
571+
572+
def __init__(
573+
self,
574+
generator: AsyncGenerator,
575+
context: contextvars.Context,
576+
span: Union[
577+
LangfuseSpan,
578+
LangfuseGeneration,
579+
LangfuseAgent,
580+
LangfuseTool,
581+
LangfuseChain,
582+
LangfuseRetriever,
583+
LangfuseEvaluator,
584+
LangfuseEmbedding,
585+
LangfuseGuardrail,
586+
],
587+
transform_fn: Optional[Callable[[Iterable], str]],
588+
) -> None:
589+
self.generator = generator
590+
self.context = context
591+
self.items: List[Any] = []
592+
self.span = span
593+
self.transform_fn = transform_fn
594+
595+
def __aiter__(self) -> "_ContextPreservedAsyncGeneratorWrapper":
596+
return self
597+
598+
async def __anext__(self) -> Any:
599+
try:
600+
# Run the generator's __anext__ in the preserved context
601+
try:
602+
# Python 3.10+ approach with context parameter
603+
item = await asyncio.create_task(
604+
self.generator.__anext__(), # type: ignore
605+
context=self.context,
606+
) # type: ignore
607+
except TypeError:
608+
# Python < 3.10 fallback - context parameter not supported
609+
item = await self.generator.__anext__()
610+
611+
self.items.append(item)
612+
613+
return item
614+
615+
except StopAsyncIteration:
616+
# Handle output and span cleanup when generator is exhausted
617+
output: Any = self.items
618+
619+
if self.transform_fn is not None:
620+
output = self.transform_fn(self.items)
621+
622+
elif all(isinstance(item, str) for item in self.items):
623+
output = "".join(self.items)
624+
625+
self.span.update(output=output).end()
626+
627+
raise # Re-raise StopAsyncIteration
628+
except Exception as e:
629+
self.span.update(level="ERROR", status_message=str(e)).end()
630+
631+
raise e

tests/test_decorators.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import pytest
99
from langchain.prompts import ChatPromptTemplate
1010
from langchain_openai import ChatOpenAI
11+
from opentelemetry import trace
1112

1213
from langfuse import Langfuse, get_client, observe
1314
from langfuse._client.environment_variables import LANGFUSE_PUBLIC_KEY
@@ -1698,8 +1699,6 @@ def test_sync_generator_context_preservation():
16981699

16991700
@observe(name="sync_generator")
17001701
def create_generator():
1701-
from opentelemetry import trace
1702-
17031702
current_span = trace.get_current_span()
17041703
span_info["generator_span_id"] = trace.format_span_id(
17051704
current_span.get_span_context().span_id
@@ -1710,8 +1709,6 @@ def create_generator():
17101709

17111710
@observe(name="root")
17121711
def root_function():
1713-
from opentelemetry import trace
1714-
17151712
current_span = trace.get_current_span()
17161713
span_info["root_span_id"] = trace.format_span_id(
17171714
current_span.get_span_context().span_id
@@ -1764,8 +1761,6 @@ async def test_async_generator_context_preservation():
17641761

17651762
@observe(name="async_generator")
17661763
async def create_async_generator():
1767-
from opentelemetry import trace
1768-
17691764
current_span = trace.get_current_span()
17701765
span_info["generator_span_id"] = trace.format_span_id(
17711766
current_span.get_span_context().span_id
@@ -1777,8 +1772,6 @@ async def create_async_generator():
17771772

17781773
@observe(name="root")
17791774
async def root_function():
1780-
from opentelemetry import trace
1781-
17821775
current_span = trace.get_current_span()
17831776
span_info["root_span_id"] = trace.format_span_id(
17841777
current_span.get_span_context().span_id
@@ -1833,8 +1826,6 @@ async def test_async_generator_context_preservation_with_trace_hierarchy():
18331826

18341827
@observe(name="child_stream")
18351828
async def child_generator():
1836-
from opentelemetry import trace
1837-
18381829
current_span = trace.get_current_span()
18391830
span_context = current_span.get_span_context()
18401831
span_info["child_span_id"] = trace.format_span_id(span_context.span_id)
@@ -1846,8 +1837,6 @@ async def child_generator():
18461837

18471838
@observe(name="parent_root")
18481839
async def parent_function():
1849-
from opentelemetry import trace
1850-
18511840
current_span = trace.get_current_span()
18521841
span_context = current_span.get_span_context()
18531842
span_info["parent_span_id"] = trace.format_span_id(span_context.span_id)
@@ -1896,8 +1885,6 @@ async def test_async_generator_exception_handling_with_context():
18961885

18971886
@observe(name="failing_generator")
18981887
async def failing_generator():
1899-
from opentelemetry import trace
1900-
19011888
current_span = trace.get_current_span()
19021889
# Verify we have valid context even when exception occurs
19031890
assert (
@@ -1946,8 +1933,6 @@ def test_sync_generator_empty_context_preservation():
19461933

19471934
@observe(name="empty_generator")
19481935
def empty_generator():
1949-
from opentelemetry import trace
1950-
19511936
current_span = trace.get_current_span()
19521937
# Should have valid context even for empty generator
19531938
assert (
@@ -1977,4 +1962,4 @@ def root_function():
19771962
empty_obs = next(
19781963
obs for obs in trace_data.observations if obs.name == "empty_generator"
19791964
)
1980-
assert empty_obs.output == ""
1965+
assert empty_obs.output is None

0 commit comments

Comments
 (0)