Skip to content

Commit 6fd4dde

Browse files
committed
stabilize live-provider langchain assertions
1 parent a98487e commit 6fd4dde

1 file changed

Lines changed: 32 additions & 20 deletions

File tree

tests/live_provider/test_langchain.py

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,16 @@ def test_callback_generated_from_trace_chat():
5151

5252
assert trace.id == trace_id
5353

54-
assert len(trace.observations) == 3
54+
assert len(trace.observations) >= 2
55+
assert any(observation.name == "parent" for observation in trace.observations)
5556

56-
langchain_generation_span = list(
57-
filter(
58-
lambda o: o.type == "GENERATION" and o.name == "ChatOpenAI",
59-
trace.observations,
60-
)
61-
)[0]
57+
generation_observations = [
58+
observation
59+
for observation in trace.observations
60+
if observation.type == "GENERATION" and observation.name == "ChatOpenAI"
61+
]
62+
assert len(generation_observations) == 1
63+
langchain_generation_span = generation_observations[0]
6264

6365
assert langchain_generation_span.usage_details["input"] > 0
6466
assert langchain_generation_span.usage_details["output"] > 0
@@ -294,19 +296,26 @@ def test_openai_instruct_usage():
294296

295297
observations = get_api().trace.get(trace_id).observations
296298

297-
# Add 1 to account for the wrapping span
298-
assert len(observations) == 4
299+
assert len(observations) >= 3
300+
assert any(
301+
observation.name == "openai_instruct_usage_test" and observation.type == "SPAN"
302+
for observation in observations
303+
)
299304

300-
for observation in observations:
301-
if observation.type == "GENERATION":
302-
assert observation.output is not None
303-
assert observation.output != ""
304-
assert observation.input is not None
305-
assert observation.input != ""
306-
assert observation.usage is not None
307-
assert observation.usage_details["input"] is not None
308-
assert observation.usage_details["output"] is not None
309-
assert observation.usage_details["total"] is not None
305+
generation_observations = [
306+
observation for observation in observations if observation.type == "GENERATION"
307+
]
308+
assert len(generation_observations) == len(input_list)
309+
310+
for observation in generation_observations:
311+
assert observation.output is not None
312+
assert observation.output != ""
313+
assert observation.input is not None
314+
assert observation.input != ""
315+
assert observation.usage is not None
316+
assert observation.usage_details["input"] is not None
317+
assert observation.usage_details["output"] is not None
318+
assert observation.usage_details["total"] is not None
310319

311320

312321
def test_get_langchain_prompt_with_jinja2():
@@ -869,7 +878,10 @@ def test_multimodal():
869878

870879
trace = get_api().trace.get(trace_id=trace_id)
871880

872-
assert len(trace.observations) == 3
881+
assert len(trace.observations) >= 2
882+
assert any(
883+
observation.name == "test_multimodal" for observation in trace.observations
884+
)
873885
# Filter for the observation with type GENERATION
874886
generation_observation = next(
875887
(obs for obs in trace.observations if obs.type == "GENERATION"), None

0 commit comments

Comments
 (0)