@@ -51,14 +51,16 @@ def test_callback_generated_from_trace_chat():
5151
5252 assert trace .id == trace_id
5353
54- assert len (trace .observations ) == 3
54+ assert len (trace .observations ) >= 2
55+ assert any (observation .name == "parent" for observation in trace .observations )
5556
56- langchain_generation_span = list (
57- filter (
58- lambda o : o .type == "GENERATION" and o .name == "ChatOpenAI" ,
59- trace .observations ,
60- )
61- )[0 ]
57+ generation_observations = [
58+ observation
59+ for observation in trace .observations
60+ if observation .type == "GENERATION" and observation .name == "ChatOpenAI"
61+ ]
62+ assert len (generation_observations ) == 1
63+ langchain_generation_span = generation_observations [0 ]
6264
6365 assert langchain_generation_span .usage_details ["input" ] > 0
6466 assert langchain_generation_span .usage_details ["output" ] > 0
@@ -294,19 +296,26 @@ def test_openai_instruct_usage():
294296
295297 observations = get_api ().trace .get (trace_id ).observations
296298
297- # Add 1 to account for the wrapping span
298- assert len (observations ) == 4
299+ assert len (observations ) >= 3
300+ assert any (
301+ observation .name == "openai_instruct_usage_test" and observation .type == "SPAN"
302+ for observation in observations
303+ )
299304
300- for observation in observations :
301- if observation .type == "GENERATION" :
302- assert observation .output is not None
303- assert observation .output != ""
304- assert observation .input is not None
305- assert observation .input != ""
306- assert observation .usage is not None
307- assert observation .usage_details ["input" ] is not None
308- assert observation .usage_details ["output" ] is not None
309- assert observation .usage_details ["total" ] is not None
305+ generation_observations = [
306+ observation for observation in observations if observation .type == "GENERATION"
307+ ]
308+ assert len (generation_observations ) == len (input_list )
309+
310+ for observation in generation_observations :
311+ assert observation .output is not None
312+ assert observation .output != ""
313+ assert observation .input is not None
314+ assert observation .input != ""
315+ assert observation .usage is not None
316+ assert observation .usage_details ["input" ] is not None
317+ assert observation .usage_details ["output" ] is not None
318+ assert observation .usage_details ["total" ] is not None
310319
311320
312321def test_get_langchain_prompt_with_jinja2 ():
@@ -869,7 +878,10 @@ def test_multimodal():
869878
870879 trace = get_api ().trace .get (trace_id = trace_id )
871880
872- assert len (trace .observations ) == 3
881+ assert len (trace .observations ) >= 2
882+ assert any (
883+ observation .name == "test_multimodal" for observation in trace .observations
884+ )
873885 # Filter for the observation with type GENERATION
874886 generation_observation = next (
875887 (obs for obs in trace .observations if obs .type == "GENERATION" ), None
0 commit comments