|
1 | 1 | import asyncio |
| 2 | +import os |
2 | 3 | from collections import defaultdict |
3 | 4 | from concurrent.futures import ThreadPoolExecutor |
4 | 5 | from time import sleep |
|
8 | 9 | from langchain.prompts import ChatPromptTemplate |
9 | 10 | from langchain_openai import ChatOpenAI |
10 | 11 |
|
11 | | -from langfuse import get_client, observe |
| 12 | +from langfuse import Langfuse, get_client, observe |
| 13 | +from langfuse._client.environment_variables import LANGFUSE_PUBLIC_KEY |
| 14 | +from langfuse._client.resource_manager import LangfuseResourceManager |
12 | 15 | from langfuse.langchain import CallbackHandler |
13 | 16 | from langfuse.media import LangfuseMedia |
14 | 17 | from tests.utils import get_api |
@@ -1081,3 +1084,218 @@ def main(): |
1081 | 1084 | assert trace_data.metadata["key2"] == "value2" |
1082 | 1085 |
|
1083 | 1086 | assert trace_data.tags == ["tag1", "tag2"] |
| 1087 | + |
| 1088 | + |
| 1089 | +# Multi-project context propagation tests |
| 1090 | +def test_multiproject_context_propagation_basic(): |
| 1091 | + """Test that nested decorated functions inherit langfuse_public_key from parent in multi-project setup""" |
| 1092 | + client1 = Langfuse() # Reads from environment |
| 1093 | + Langfuse(public_key="pk-test-project2", secret_key="sk-test-project2") |
| 1094 | + |
| 1095 | + # Verify both instances are registered |
| 1096 | + assert len(LangfuseResourceManager._instances) == 2 |
| 1097 | + |
| 1098 | + mock_name = "test_multiproject_context_propagation_basic" |
| 1099 | + # Use known public key from environment |
| 1100 | + env_public_key = os.environ[LANGFUSE_PUBLIC_KEY] |
| 1101 | + # In multi-project setup, must specify which client to use |
| 1102 | + langfuse = get_client(public_key=env_public_key) |
| 1103 | + mock_trace_id = langfuse.create_trace_id() |
| 1104 | + |
| 1105 | + @observe(as_type="generation", capture_output=False) |
| 1106 | + def level_3_function(): |
| 1107 | + # This function should inherit the public key from level_1_function |
| 1108 | + # and NOT need langfuse_public_key parameter |
| 1109 | + langfuse_client = get_client() |
| 1110 | + langfuse_client.update_current_generation(metadata={"level": "3"}) |
| 1111 | + langfuse_client.update_current_trace(name=mock_name) |
| 1112 | + return "level_3" |
| 1113 | + |
| 1114 | + @observe() |
| 1115 | + def level_2_function(): |
| 1116 | + # This function should also inherit the public key |
| 1117 | + level_3_function() |
| 1118 | + langfuse_client = get_client() |
| 1119 | + langfuse_client.update_current_span(metadata={"level": "2"}) |
| 1120 | + return "level_2" |
| 1121 | + |
| 1122 | + @observe() |
| 1123 | + def level_1_function(*args, **kwargs): |
| 1124 | + # Only this top-level function receives langfuse_public_key |
| 1125 | + level_2_function() |
| 1126 | + langfuse_client = get_client() |
| 1127 | + langfuse_client.update_current_span(metadata={"level": "1"}) |
| 1128 | + return "level_1" |
| 1129 | + |
| 1130 | + result = level_1_function( |
| 1131 | + *mock_args, |
| 1132 | + **mock_kwargs, |
| 1133 | + langfuse_trace_id=mock_trace_id, |
| 1134 | + langfuse_public_key=env_public_key, # Only provided to top-level function |
| 1135 | + ) |
| 1136 | + |
| 1137 | + # Use the correct client for flushing |
| 1138 | + client1.flush() |
| 1139 | + |
| 1140 | + assert result == "level_1" |
| 1141 | + |
| 1142 | + # Verify trace was created properly |
| 1143 | + trace_data = get_api().trace.get(mock_trace_id) |
| 1144 | + assert len(trace_data.observations) == 3 |
| 1145 | + assert trace_data.name == mock_name |
| 1146 | + |
| 1147 | + |
| 1148 | +def test_multiproject_context_propagation_deep_nesting(): |
| 1149 | + client1 = Langfuse() # Reads from environment |
| 1150 | + Langfuse(public_key="pk-test-project2", secret_key="sk-test-project2") |
| 1151 | + |
| 1152 | + # Verify both instances are registered |
| 1153 | + assert len(LangfuseResourceManager._instances) == 2 |
| 1154 | + |
| 1155 | + mock_name = "test_multiproject_context_propagation_deep_nesting" |
| 1156 | + env_public_key = os.environ[LANGFUSE_PUBLIC_KEY] |
| 1157 | + langfuse = get_client(public_key=env_public_key) |
| 1158 | + mock_trace_id = langfuse.create_trace_id() |
| 1159 | + |
| 1160 | + @observe(as_type="generation") |
| 1161 | + def level_4_function(): |
| 1162 | + langfuse_client = get_client() |
| 1163 | + langfuse_client.update_current_generation(metadata={"level": "4"}) |
| 1164 | + return "level_4" |
| 1165 | + |
| 1166 | + @observe() |
| 1167 | + def level_3_function(): |
| 1168 | + result = level_4_function() |
| 1169 | + langfuse_client = get_client() |
| 1170 | + langfuse_client.update_current_span(metadata={"level": "3"}) |
| 1171 | + return result |
| 1172 | + |
| 1173 | + @observe() |
| 1174 | + def level_2_function(): |
| 1175 | + result = level_3_function() |
| 1176 | + langfuse_client = get_client() |
| 1177 | + langfuse_client.update_current_span(metadata={"level": "2"}) |
| 1178 | + return result |
| 1179 | + |
| 1180 | + @observe() |
| 1181 | + def level_1_function(*args, **kwargs): |
| 1182 | + langfuse_client = get_client() |
| 1183 | + langfuse_client.update_current_trace(name=mock_name) |
| 1184 | + result = level_2_function() |
| 1185 | + langfuse_client.update_current_span(metadata={"level": "1"}) |
| 1186 | + return result |
| 1187 | + |
| 1188 | + result = level_1_function( |
| 1189 | + langfuse_trace_id=mock_trace_id, langfuse_public_key=env_public_key |
| 1190 | + ) |
| 1191 | + client1.flush() |
| 1192 | + |
| 1193 | + assert result == "level_4" |
| 1194 | + |
| 1195 | + trace_data = get_api().trace.get(mock_trace_id) |
| 1196 | + assert len(trace_data.observations) == 4 |
| 1197 | + assert trace_data.name == mock_name |
| 1198 | + |
| 1199 | + # Verify all levels were captured |
| 1200 | + levels = [ |
| 1201 | + obs.metadata.get("level") for obs in trace_data.observations if obs.metadata |
| 1202 | + ] |
| 1203 | + assert set(levels) == {"1", "2", "3", "4"} |
| 1204 | + |
| 1205 | + |
| 1206 | +def test_multiproject_context_propagation_override(): |
| 1207 | + # Initialize two separate Langfuse instances |
| 1208 | + client1 = Langfuse() # Reads from environment |
| 1209 | + client2 = Langfuse(public_key="pk-test-project2", secret_key="sk-test-project2") |
| 1210 | + |
| 1211 | + # Verify both instances are registered |
| 1212 | + assert len(LangfuseResourceManager._instances) == 2 |
| 1213 | + |
| 1214 | + mock_name = "test_multiproject_context_propagation_override" |
| 1215 | + env_public_key = os.environ[LANGFUSE_PUBLIC_KEY] |
| 1216 | + langfuse = get_client(public_key=env_public_key) |
| 1217 | + mock_trace_id = langfuse.create_trace_id() |
| 1218 | + |
| 1219 | + primary_public_key = env_public_key |
| 1220 | + override_public_key = "pk-test-project2" |
| 1221 | + |
| 1222 | + @observe(as_type="generation") |
| 1223 | + def level_3_function(): |
| 1224 | + # This function explicitly overrides the inherited public key |
| 1225 | + langfuse_client = get_client(public_key=override_public_key) |
| 1226 | + langfuse_client.update_current_generation(metadata={"used_override": "true"}) |
| 1227 | + return "level_3" |
| 1228 | + |
| 1229 | + @observe() |
| 1230 | + def level_2_function(): |
| 1231 | + # This function should use the overridden key when calling level_3 |
| 1232 | + level_3_function(langfuse_public_key=override_public_key) |
| 1233 | + langfuse_client = get_client(public_key=primary_public_key) |
| 1234 | + langfuse_client.update_current_span(metadata={"level": "2"}) |
| 1235 | + return "level_2" |
| 1236 | + |
| 1237 | + @observe() |
| 1238 | + def level_1_function(*args, **kwargs): |
| 1239 | + langfuse_client = get_client(public_key=primary_public_key) |
| 1240 | + langfuse_client.update_current_trace(name=mock_name) |
| 1241 | + level_2_function() |
| 1242 | + return "level_1" |
| 1243 | + |
| 1244 | + result = level_1_function( |
| 1245 | + langfuse_trace_id=mock_trace_id, langfuse_public_key=primary_public_key |
| 1246 | + ) |
| 1247 | + client1.flush() |
| 1248 | + client2.flush() |
| 1249 | + |
| 1250 | + assert result == "level_1" |
| 1251 | + |
| 1252 | + trace_data = get_api().trace.get(mock_trace_id) |
| 1253 | + assert len(trace_data.observations) == 2 |
| 1254 | + assert trace_data.name == mock_name |
| 1255 | + |
| 1256 | + |
| 1257 | +def test_multiproject_context_propagation_no_public_key(): |
| 1258 | + # Initialize two separate Langfuse instances |
| 1259 | + client1 = Langfuse() # Reads from environment |
| 1260 | + Langfuse(public_key="pk-test-project2", secret_key="sk-test-project2") |
| 1261 | + |
| 1262 | + # Verify both instances are registered |
| 1263 | + assert len(LangfuseResourceManager._instances) == 2 |
| 1264 | + |
| 1265 | + mock_name = "test_multiproject_context_propagation_no_public_key" |
| 1266 | + env_public_key = os.environ[LANGFUSE_PUBLIC_KEY] |
| 1267 | + langfuse = get_client(public_key=env_public_key) |
| 1268 | + mock_trace_id = langfuse.create_trace_id() |
| 1269 | + |
| 1270 | + @observe(as_type="generation") |
| 1271 | + def level_3_function(): |
| 1272 | + # Should use default client since no public key provided |
| 1273 | + langfuse_client = get_client() |
| 1274 | + langfuse_client.update_current_generation(metadata={"level": "3"}) |
| 1275 | + return "level_3" |
| 1276 | + |
| 1277 | + @observe() |
| 1278 | + def level_2_function(): |
| 1279 | + result = level_3_function() |
| 1280 | + langfuse_client = get_client() |
| 1281 | + langfuse_client.update_current_span(metadata={"level": "2"}) |
| 1282 | + return result |
| 1283 | + |
| 1284 | + @observe() |
| 1285 | + def level_1_function(*args, **kwargs): |
| 1286 | + langfuse_client = get_client() |
| 1287 | + langfuse_client.update_current_trace(name=mock_name) |
| 1288 | + result = level_2_function() |
| 1289 | + langfuse_client.update_current_span(metadata={"level": "1"}) |
| 1290 | + return result |
| 1291 | + |
| 1292 | + # No langfuse_public_key provided - should use default client |
| 1293 | + result = level_1_function(langfuse_trace_id=mock_trace_id) |
| 1294 | + client1.flush() |
| 1295 | + |
| 1296 | + assert result == "level_3" |
| 1297 | + |
| 1298 | + # Should still work with default client |
| 1299 | + trace_data = get_api().trace.get(mock_trace_id) |
| 1300 | + assert len(trace_data.observations) == 0 |
| 1301 | + assert trace_data.name == mock_name |
0 commit comments