Skip to content

Commit 96d2d9c

Browse files
committed
feat(decorator): require langfuse_public_key only in top most decorated func
1 parent eba5f7f commit 96d2d9c

1 file changed

Lines changed: 219 additions & 1 deletion

File tree

tests/test_decorators.py

Lines changed: 219 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import os
23
from collections import defaultdict
34
from concurrent.futures import ThreadPoolExecutor
45
from time import sleep
@@ -8,7 +9,9 @@
89
from langchain.prompts import ChatPromptTemplate
910
from langchain_openai import ChatOpenAI
1011

11-
from langfuse import get_client, observe
12+
from langfuse import Langfuse, get_client, observe
13+
from langfuse._client.environment_variables import LANGFUSE_PUBLIC_KEY
14+
from langfuse._client.resource_manager import LangfuseResourceManager
1215
from langfuse.langchain import CallbackHandler
1316
from langfuse.media import LangfuseMedia
1417
from tests.utils import get_api
@@ -1081,3 +1084,218 @@ def main():
10811084
assert trace_data.metadata["key2"] == "value2"
10821085

10831086
assert trace_data.tags == ["tag1", "tag2"]
1087+
1088+
1089+
# Multi-project context propagation tests
1090+
def test_multiproject_context_propagation_basic():
1091+
"""Test that nested decorated functions inherit langfuse_public_key from parent in multi-project setup"""
1092+
client1 = Langfuse() # Reads from environment
1093+
Langfuse(public_key="pk-test-project2", secret_key="sk-test-project2")
1094+
1095+
# Verify both instances are registered
1096+
assert len(LangfuseResourceManager._instances) == 2
1097+
1098+
mock_name = "test_multiproject_context_propagation_basic"
1099+
# Use known public key from environment
1100+
env_public_key = os.environ[LANGFUSE_PUBLIC_KEY]
1101+
# In multi-project setup, must specify which client to use
1102+
langfuse = get_client(public_key=env_public_key)
1103+
mock_trace_id = langfuse.create_trace_id()
1104+
1105+
@observe(as_type="generation", capture_output=False)
1106+
def level_3_function():
1107+
# This function should inherit the public key from level_1_function
1108+
# and NOT need langfuse_public_key parameter
1109+
langfuse_client = get_client()
1110+
langfuse_client.update_current_generation(metadata={"level": "3"})
1111+
langfuse_client.update_current_trace(name=mock_name)
1112+
return "level_3"
1113+
1114+
@observe()
1115+
def level_2_function():
1116+
# This function should also inherit the public key
1117+
level_3_function()
1118+
langfuse_client = get_client()
1119+
langfuse_client.update_current_span(metadata={"level": "2"})
1120+
return "level_2"
1121+
1122+
@observe()
1123+
def level_1_function(*args, **kwargs):
1124+
# Only this top-level function receives langfuse_public_key
1125+
level_2_function()
1126+
langfuse_client = get_client()
1127+
langfuse_client.update_current_span(metadata={"level": "1"})
1128+
return "level_1"
1129+
1130+
result = level_1_function(
1131+
*mock_args,
1132+
**mock_kwargs,
1133+
langfuse_trace_id=mock_trace_id,
1134+
langfuse_public_key=env_public_key, # Only provided to top-level function
1135+
)
1136+
1137+
# Use the correct client for flushing
1138+
client1.flush()
1139+
1140+
assert result == "level_1"
1141+
1142+
# Verify trace was created properly
1143+
trace_data = get_api().trace.get(mock_trace_id)
1144+
assert len(trace_data.observations) == 3
1145+
assert trace_data.name == mock_name
1146+
1147+
1148+
def test_multiproject_context_propagation_deep_nesting():
1149+
client1 = Langfuse() # Reads from environment
1150+
Langfuse(public_key="pk-test-project2", secret_key="sk-test-project2")
1151+
1152+
# Verify both instances are registered
1153+
assert len(LangfuseResourceManager._instances) == 2
1154+
1155+
mock_name = "test_multiproject_context_propagation_deep_nesting"
1156+
env_public_key = os.environ[LANGFUSE_PUBLIC_KEY]
1157+
langfuse = get_client(public_key=env_public_key)
1158+
mock_trace_id = langfuse.create_trace_id()
1159+
1160+
@observe(as_type="generation")
1161+
def level_4_function():
1162+
langfuse_client = get_client()
1163+
langfuse_client.update_current_generation(metadata={"level": "4"})
1164+
return "level_4"
1165+
1166+
@observe()
1167+
def level_3_function():
1168+
result = level_4_function()
1169+
langfuse_client = get_client()
1170+
langfuse_client.update_current_span(metadata={"level": "3"})
1171+
return result
1172+
1173+
@observe()
1174+
def level_2_function():
1175+
result = level_3_function()
1176+
langfuse_client = get_client()
1177+
langfuse_client.update_current_span(metadata={"level": "2"})
1178+
return result
1179+
1180+
@observe()
1181+
def level_1_function(*args, **kwargs):
1182+
langfuse_client = get_client()
1183+
langfuse_client.update_current_trace(name=mock_name)
1184+
result = level_2_function()
1185+
langfuse_client.update_current_span(metadata={"level": "1"})
1186+
return result
1187+
1188+
result = level_1_function(
1189+
langfuse_trace_id=mock_trace_id, langfuse_public_key=env_public_key
1190+
)
1191+
client1.flush()
1192+
1193+
assert result == "level_4"
1194+
1195+
trace_data = get_api().trace.get(mock_trace_id)
1196+
assert len(trace_data.observations) == 4
1197+
assert trace_data.name == mock_name
1198+
1199+
# Verify all levels were captured
1200+
levels = [
1201+
obs.metadata.get("level") for obs in trace_data.observations if obs.metadata
1202+
]
1203+
assert set(levels) == {"1", "2", "3", "4"}
1204+
1205+
1206+
def test_multiproject_context_propagation_override():
1207+
# Initialize two separate Langfuse instances
1208+
client1 = Langfuse() # Reads from environment
1209+
client2 = Langfuse(public_key="pk-test-project2", secret_key="sk-test-project2")
1210+
1211+
# Verify both instances are registered
1212+
assert len(LangfuseResourceManager._instances) == 2
1213+
1214+
mock_name = "test_multiproject_context_propagation_override"
1215+
env_public_key = os.environ[LANGFUSE_PUBLIC_KEY]
1216+
langfuse = get_client(public_key=env_public_key)
1217+
mock_trace_id = langfuse.create_trace_id()
1218+
1219+
primary_public_key = env_public_key
1220+
override_public_key = "pk-test-project2"
1221+
1222+
@observe(as_type="generation")
1223+
def level_3_function():
1224+
# This function explicitly overrides the inherited public key
1225+
langfuse_client = get_client(public_key=override_public_key)
1226+
langfuse_client.update_current_generation(metadata={"used_override": "true"})
1227+
return "level_3"
1228+
1229+
@observe()
1230+
def level_2_function():
1231+
# This function should use the overridden key when calling level_3
1232+
level_3_function(langfuse_public_key=override_public_key)
1233+
langfuse_client = get_client(public_key=primary_public_key)
1234+
langfuse_client.update_current_span(metadata={"level": "2"})
1235+
return "level_2"
1236+
1237+
@observe()
1238+
def level_1_function(*args, **kwargs):
1239+
langfuse_client = get_client(public_key=primary_public_key)
1240+
langfuse_client.update_current_trace(name=mock_name)
1241+
level_2_function()
1242+
return "level_1"
1243+
1244+
result = level_1_function(
1245+
langfuse_trace_id=mock_trace_id, langfuse_public_key=primary_public_key
1246+
)
1247+
client1.flush()
1248+
client2.flush()
1249+
1250+
assert result == "level_1"
1251+
1252+
trace_data = get_api().trace.get(mock_trace_id)
1253+
assert len(trace_data.observations) == 2
1254+
assert trace_data.name == mock_name
1255+
1256+
1257+
def test_multiproject_context_propagation_no_public_key():
1258+
# Initialize two separate Langfuse instances
1259+
client1 = Langfuse() # Reads from environment
1260+
Langfuse(public_key="pk-test-project2", secret_key="sk-test-project2")
1261+
1262+
# Verify both instances are registered
1263+
assert len(LangfuseResourceManager._instances) == 2
1264+
1265+
mock_name = "test_multiproject_context_propagation_no_public_key"
1266+
env_public_key = os.environ[LANGFUSE_PUBLIC_KEY]
1267+
langfuse = get_client(public_key=env_public_key)
1268+
mock_trace_id = langfuse.create_trace_id()
1269+
1270+
@observe(as_type="generation")
1271+
def level_3_function():
1272+
# Should use default client since no public key provided
1273+
langfuse_client = get_client()
1274+
langfuse_client.update_current_generation(metadata={"level": "3"})
1275+
return "level_3"
1276+
1277+
@observe()
1278+
def level_2_function():
1279+
result = level_3_function()
1280+
langfuse_client = get_client()
1281+
langfuse_client.update_current_span(metadata={"level": "2"})
1282+
return result
1283+
1284+
@observe()
1285+
def level_1_function(*args, **kwargs):
1286+
langfuse_client = get_client()
1287+
langfuse_client.update_current_trace(name=mock_name)
1288+
result = level_2_function()
1289+
langfuse_client.update_current_span(metadata={"level": "1"})
1290+
return result
1291+
1292+
# No langfuse_public_key provided - should use default client
1293+
result = level_1_function(langfuse_trace_id=mock_trace_id)
1294+
client1.flush()
1295+
1296+
assert result == "level_3"
1297+
1298+
# Should still work with default client
1299+
trace_data = get_api().trace.get(mock_trace_id)
1300+
assert len(trace_data.observations) == 0
1301+
assert trace_data.name == mock_name

0 commit comments

Comments
 (0)