Skip to content

Commit 61d5d27

Browse files
committed
push
1 parent d9f0eb9 commit 61d5d27

2 files changed

Lines changed: 352 additions & 1 deletion

File tree

langfuse/_client/observe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from langfuse._client.environment_variables import (
2626
LANGFUSE_OBSERVE_DECORATOR_IO_CAPTURE_ENABLED,
2727
)
28-
from langfuse._client.get_client import get_client, _set_current_public_key
28+
from langfuse._client.get_client import _set_current_public_key, get_client
2929
from langfuse._client.span import LangfuseGeneration, LangfuseSpan
3030
from langfuse.types import TraceContext
3131

tests/test_decorators.py

Lines changed: 351 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1306,3 +1306,354 @@ def level_1_function(*args, **kwargs):
13061306
except Exception:
13071307
# Trace not found is also expected - tracing was completely disabled
13081308
pass
1309+
1310+
1311+
@pytest.mark.asyncio
1312+
async def test_multiproject_async_context_propagation_basic():
1313+
"""Test that nested async decorated functions inherit langfuse_public_key from parent in multi-project setup"""
1314+
LangfuseResourceManager.reset()
1315+
client1 = Langfuse() # Reads from environment
1316+
Langfuse(public_key="pk-test-project2", secret_key="sk-test-project2")
1317+
1318+
# Verify both instances are registered
1319+
assert len(LangfuseResourceManager._instances) == 2
1320+
1321+
mock_name = "test_multiproject_async_context_propagation_basic"
1322+
env_public_key = os.environ[LANGFUSE_PUBLIC_KEY]
1323+
langfuse = get_client(public_key=env_public_key)
1324+
mock_trace_id = langfuse.create_trace_id()
1325+
1326+
@observe(as_type="generation", capture_output=False)
1327+
async def async_level_3_function():
1328+
# This function should inherit the public key from level_1_function
1329+
# and NOT need langfuse_public_key parameter
1330+
await asyncio.sleep(0.01) # Simulate async work
1331+
langfuse_client = get_client()
1332+
langfuse_client.update_current_generation(
1333+
metadata={"level": "3", "async": True}
1334+
)
1335+
langfuse_client.update_current_trace(name=mock_name)
1336+
return "async_level_3"
1337+
1338+
@observe()
1339+
async def async_level_2_function():
1340+
# This function should also inherit the public key
1341+
result = await async_level_3_function()
1342+
langfuse_client = get_client()
1343+
langfuse_client.update_current_span(metadata={"level": "2", "async": True})
1344+
return result
1345+
1346+
@observe()
1347+
async def async_level_1_function(*args, **kwargs):
1348+
# Only this top-level function receives langfuse_public_key
1349+
result = await async_level_2_function()
1350+
langfuse_client = get_client()
1351+
langfuse_client.update_current_span(metadata={"level": "1", "async": True})
1352+
return result
1353+
1354+
result = await async_level_1_function(
1355+
*mock_args,
1356+
**mock_kwargs,
1357+
langfuse_trace_id=mock_trace_id,
1358+
langfuse_public_key=env_public_key, # Only provided to top-level function
1359+
)
1360+
1361+
# Use the correct client for flushing
1362+
client1.flush()
1363+
1364+
assert result == "async_level_3"
1365+
1366+
# Verify trace was created properly
1367+
trace_data = get_api().trace.get(mock_trace_id)
1368+
assert len(trace_data.observations) == 3
1369+
assert trace_data.name == mock_name
1370+
1371+
# Verify all observations have async metadata
1372+
async_flags = [
1373+
obs.metadata.get("async") for obs in trace_data.observations if obs.metadata
1374+
]
1375+
assert all(async_flags)
1376+
1377+
1378+
@pytest.mark.asyncio
1379+
async def test_multiproject_mixed_sync_async_context_propagation():
1380+
"""Test context propagation between sync and async decorated functions in multi-project setup"""
1381+
LangfuseResourceManager.reset()
1382+
client1 = Langfuse() # Reads from environment
1383+
Langfuse(public_key="pk-test-project2", secret_key="sk-test-project2")
1384+
1385+
# Verify both instances are registered
1386+
assert len(LangfuseResourceManager._instances) == 2
1387+
1388+
mock_name = "test_multiproject_mixed_sync_async_context_propagation"
1389+
env_public_key = os.environ[LANGFUSE_PUBLIC_KEY]
1390+
langfuse = get_client(public_key=env_public_key)
1391+
mock_trace_id = langfuse.create_trace_id()
1392+
1393+
@observe(as_type="generation")
1394+
def sync_level_4_function():
1395+
# Sync function called from async should inherit context
1396+
langfuse_client = get_client()
1397+
langfuse_client.update_current_generation(
1398+
metadata={"level": "4", "type": "sync"}
1399+
)
1400+
return "sync_level_4"
1401+
1402+
@observe()
1403+
async def async_level_3_function():
1404+
# Async function calls sync function
1405+
await asyncio.sleep(0.01)
1406+
result = sync_level_4_function()
1407+
langfuse_client = get_client()
1408+
langfuse_client.update_current_span(metadata={"level": "3", "type": "async"})
1409+
return result
1410+
1411+
@observe()
1412+
async def async_level_2_function():
1413+
# Changed to async to avoid event loop issues
1414+
result = await async_level_3_function()
1415+
langfuse_client = get_client()
1416+
langfuse_client.update_current_span(metadata={"level": "2", "type": "async"})
1417+
return result
1418+
1419+
@observe()
1420+
async def async_level_1_function(*args, **kwargs):
1421+
# Top-level async function
1422+
langfuse_client = get_client()
1423+
langfuse_client.update_current_trace(name=mock_name)
1424+
result = await async_level_2_function()
1425+
langfuse_client.update_current_span(metadata={"level": "1", "type": "async"})
1426+
return result
1427+
1428+
result = await async_level_1_function(
1429+
langfuse_trace_id=mock_trace_id, langfuse_public_key=env_public_key
1430+
)
1431+
client1.flush()
1432+
1433+
assert result == "sync_level_4"
1434+
1435+
trace_data = get_api().trace.get(mock_trace_id)
1436+
assert len(trace_data.observations) == 4
1437+
assert trace_data.name == mock_name
1438+
1439+
# Verify mixed sync/async execution
1440+
types = [
1441+
obs.metadata.get("type") for obs in trace_data.observations if obs.metadata
1442+
]
1443+
assert "sync" in types
1444+
assert "async" in types
1445+
1446+
1447+
@pytest.mark.asyncio
1448+
async def test_multiproject_concurrent_async_context_isolation():
1449+
"""Test that concurrent async executions don't interfere with each other's context in multi-project setup"""
1450+
LangfuseResourceManager.reset()
1451+
client1 = Langfuse() # Reads from environment
1452+
Langfuse(public_key="pk-test-project2", secret_key="sk-test-project2")
1453+
1454+
# Verify both instances are registered
1455+
assert len(LangfuseResourceManager._instances) == 2
1456+
1457+
mock_name = "test_multiproject_concurrent_async_context_isolation"
1458+
env_public_key = os.environ[LANGFUSE_PUBLIC_KEY]
1459+
langfuse = get_client(public_key=env_public_key)
1460+
1461+
trace_id_1 = langfuse.create_trace_id()
1462+
trace_id_2 = langfuse.create_trace_id()
1463+
1464+
# Use the same valid public key for both tasks to avoid credential issues
1465+
# The isolation test is about trace contexts, not different projects
1466+
public_key_1 = env_public_key
1467+
public_key_2 = env_public_key
1468+
1469+
@observe(as_type="generation")
1470+
async def async_level_3_function(task_id):
1471+
# Simulate work and ensure contexts don't leak
1472+
await asyncio.sleep(0.1) # Ensure concurrency overlap
1473+
langfuse_client = get_client()
1474+
langfuse_client.update_current_generation(
1475+
metadata={"task_id": task_id, "level": "3"}
1476+
)
1477+
return f"async_level_3_task_{task_id}"
1478+
1479+
@observe()
1480+
async def async_level_2_function(task_id):
1481+
result = await async_level_3_function(task_id)
1482+
langfuse_client = get_client()
1483+
langfuse_client.update_current_span(metadata={"task_id": task_id, "level": "2"})
1484+
return result
1485+
1486+
@observe()
1487+
async def async_level_1_function(task_id, *args, **kwargs):
1488+
langfuse_client = get_client()
1489+
langfuse_client.update_current_trace(name=f"{mock_name}_task_{task_id}")
1490+
result = await async_level_2_function(task_id)
1491+
langfuse_client.update_current_span(metadata={"task_id": task_id, "level": "1"})
1492+
return result
1493+
1494+
# Run two concurrent async tasks with the same public key but different trace contexts
1495+
task1 = async_level_1_function(
1496+
"1", langfuse_trace_id=trace_id_1, langfuse_public_key=public_key_1
1497+
)
1498+
task2 = async_level_1_function(
1499+
"2", langfuse_trace_id=trace_id_2, langfuse_public_key=public_key_2
1500+
)
1501+
1502+
result1, result2 = await asyncio.gather(task1, task2)
1503+
1504+
client1.flush()
1505+
1506+
assert result1 == "async_level_3_task_1"
1507+
assert result2 == "async_level_3_task_2"
1508+
1509+
# Verify both traces were created correctly and didn't interfere
1510+
trace_data_1 = get_api().trace.get(trace_id_1)
1511+
trace_data_2 = get_api().trace.get(trace_id_2)
1512+
1513+
assert trace_data_1.name == f"{mock_name}_task_1"
1514+
assert trace_data_2.name == f"{mock_name}_task_2"
1515+
1516+
# Verify that both traces have the expected number of observations (context propagation worked)
1517+
assert (
1518+
len(trace_data_1.observations) == 3
1519+
) # All 3 levels should be captured for task 1
1520+
assert (
1521+
len(trace_data_2.observations) == 3
1522+
) # All 3 levels should be captured for task 2
1523+
1524+
# Verify traces are properly isolated (no cross-contamination)
1525+
trace_1_names = [obs.name for obs in trace_data_1.observations]
1526+
trace_2_names = [obs.name for obs in trace_data_2.observations]
1527+
assert "async_level_1_function" in trace_1_names
1528+
assert "async_level_2_function" in trace_1_names
1529+
assert "async_level_3_function" in trace_1_names
1530+
assert "async_level_1_function" in trace_2_names
1531+
assert "async_level_2_function" in trace_2_names
1532+
assert "async_level_3_function" in trace_2_names
1533+
1534+
1535+
@pytest.mark.asyncio
1536+
async def test_multiproject_async_generator_context_propagation():
1537+
"""Test context propagation with async generators in multi-project setup"""
1538+
LangfuseResourceManager.reset()
1539+
client1 = Langfuse() # Reads from environment
1540+
Langfuse(public_key="pk-test-project2", secret_key="sk-test-project2")
1541+
1542+
# Verify both instances are registered
1543+
assert len(LangfuseResourceManager._instances) == 2
1544+
1545+
mock_name = "test_multiproject_async_generator_context_propagation"
1546+
env_public_key = os.environ[LANGFUSE_PUBLIC_KEY]
1547+
langfuse = get_client(public_key=env_public_key)
1548+
mock_trace_id = langfuse.create_trace_id()
1549+
1550+
@observe(capture_output=True)
1551+
async def async_generator_function():
1552+
# Async generator should inherit context from parent
1553+
await asyncio.sleep(0.01)
1554+
yield "Hello"
1555+
await asyncio.sleep(0.01)
1556+
yield ", "
1557+
await asyncio.sleep(0.01)
1558+
yield "Async"
1559+
await asyncio.sleep(0.01)
1560+
yield " World!"
1561+
1562+
@observe()
1563+
async def async_consumer_function():
1564+
langfuse_client = get_client()
1565+
langfuse_client.update_current_trace(name=mock_name)
1566+
1567+
result = ""
1568+
async for item in async_generator_function():
1569+
result += item
1570+
1571+
langfuse_client.update_current_span(
1572+
metadata={"type": "consumer", "result": result}
1573+
)
1574+
return result
1575+
1576+
result = await async_consumer_function(
1577+
langfuse_trace_id=mock_trace_id, langfuse_public_key=env_public_key
1578+
)
1579+
client1.flush()
1580+
1581+
assert result == "Hello, Async World!"
1582+
1583+
trace_data = get_api().trace.get(mock_trace_id)
1584+
assert len(trace_data.observations) == 2
1585+
assert trace_data.name == mock_name
1586+
1587+
# Verify both generator and consumer were captured by name (most reliable test)
1588+
observation_names = [obs.name for obs in trace_data.observations]
1589+
assert "async_generator_function" in observation_names
1590+
assert "async_consumer_function" in observation_names
1591+
1592+
# Verify that context propagation worked - both functions should be in the same trace
1593+
# This confirms that the async generator inherited the public key context
1594+
assert len(trace_data.observations) == 2
1595+
1596+
1597+
@pytest.mark.asyncio
1598+
async def test_multiproject_async_context_exception_handling():
1599+
"""Test that async context is properly restored even when exceptions occur in multi-project setup"""
1600+
LangfuseResourceManager.reset()
1601+
client1 = Langfuse() # Reads from environment
1602+
Langfuse(public_key="pk-test-project2", secret_key="sk-test-project2")
1603+
1604+
# Verify both instances are registered
1605+
assert len(LangfuseResourceManager._instances) == 2
1606+
1607+
mock_name = "test_multiproject_async_context_exception_handling"
1608+
env_public_key = os.environ[LANGFUSE_PUBLIC_KEY]
1609+
langfuse = get_client(public_key=env_public_key)
1610+
mock_trace_id = langfuse.create_trace_id()
1611+
1612+
@observe(as_type="generation")
1613+
async def async_failing_function():
1614+
# This function should inherit context but will raise an exception
1615+
await asyncio.sleep(0.01)
1616+
langfuse_client = get_client()
1617+
langfuse_client.update_current_generation(metadata={"will_fail": True})
1618+
langfuse_client.update_current_trace(name=mock_name)
1619+
raise ValueError("Async function failed")
1620+
1621+
@observe()
1622+
async def async_caller_function():
1623+
try:
1624+
await async_failing_function()
1625+
except ValueError:
1626+
# Context should still be available here
1627+
langfuse_client = get_client()
1628+
langfuse_client.update_current_span(metadata={"caught_exception": True})
1629+
return "exception_handled"
1630+
1631+
@observe()
1632+
async def async_root_function(*args, **kwargs):
1633+
result = await async_caller_function()
1634+
# Context should still be available after exception
1635+
langfuse_client = get_client()
1636+
langfuse_client.update_current_span(metadata={"root": True})
1637+
return result
1638+
1639+
result = await async_root_function(
1640+
langfuse_trace_id=mock_trace_id, langfuse_public_key=env_public_key
1641+
)
1642+
client1.flush()
1643+
1644+
assert result == "exception_handled"
1645+
1646+
trace_data = get_api().trace.get(mock_trace_id)
1647+
assert len(trace_data.observations) == 3
1648+
assert trace_data.name == mock_name
1649+
1650+
# Verify exception was properly handled and context maintained
1651+
exception_obs = next(obs for obs in trace_data.observations if obs.level == "ERROR")
1652+
assert exception_obs.status_message == "Async function failed"
1653+
1654+
caught_obs = next(
1655+
obs
1656+
for obs in trace_data.observations
1657+
if obs.metadata and obs.metadata.get("caught_exception")
1658+
)
1659+
assert caught_obs is not None

0 commit comments

Comments
 (0)