Skip to content

Commit 6d3d287

Browse files
langchain: fix token usage from EM endpoints (#265)
Signed-off-by: Bryan Qiu <bryan.qiu@databricks.com> Co-authored-by: Jake Steelman <jake.steelman@insurica.com>
1 parent 9ddaef8 commit 6d3d287

3 files changed

Lines changed: 280 additions & 26 deletions

File tree

integrations/langchain/src/databricks_langchain/chat_models.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -690,6 +690,18 @@ def _stream(
690690
logprobs=generation_info.get("logprobs"),
691691
)
692692
yield generation_chunk
693+
elif chunk.usage and stream_usage:
694+
# Some models send a final chunk that does not have
695+
# a delta or choices, but does have usage info
696+
if not usage_chunk_emitted:
697+
input_tokens = getattr(chunk.usage, "prompt_tokens", None)
698+
output_tokens = getattr(chunk.usage, "completion_tokens", None)
699+
if input_tokens is not None and output_tokens is not None:
700+
final_usage = {
701+
"input_tokens": input_tokens,
702+
"output_tokens": output_tokens,
703+
"total_tokens": input_tokens + output_tokens,
704+
}
693705

694706
# Emit special usage chunk at end of stream
695707
if stream_usage and final_usage and not usage_chunk_emitted:

integrations/langchain/tests/integration_tests/test_chat_models.py

Lines changed: 81 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from typing import Annotated
1414

1515
import pytest
16-
from langchain.agents import AgentExecutor, create_tool_calling_agent
1716
from langchain_core.callbacks.base import BaseCallbackHandler
1817
from langchain_core.messages import (
1918
AIMessage,
@@ -378,30 +377,6 @@ def multiply(a: int, b: int) -> int:
378377
return a * b
379378

380379

381-
@pytest.mark.foundation_models
382-
@pytest.mark.parametrize("model", _FOUNDATION_MODELS)
383-
def test_chat_databricks_agent_executor(model):
384-
model = ChatDatabricks(
385-
model=model,
386-
temperature=0,
387-
max_tokens=100,
388-
)
389-
tools = [add, multiply]
390-
prompt = ChatPromptTemplate.from_messages(
391-
[
392-
("system", "You are a helpful assistant"),
393-
("human", "{input}"),
394-
("placeholder", "{agent_scratchpad}"),
395-
]
396-
)
397-
398-
agent = create_tool_calling_agent(model, tools, prompt)
399-
agent_executor = AgentExecutor(agent=agent, tools=tools)
400-
401-
response = agent_executor.invoke({"input": "What is (10 + 5) * 3?"})
402-
assert "45" in response["output"]
403-
404-
405380
@pytest.mark.foundation_models
406381
@pytest.mark.parametrize("model", _FOUNDATION_MODELS)
407382
def test_chat_databricks_langgraph(model):
@@ -784,7 +759,7 @@ def test_chat_databricks_custom_outputs():
784759

785760

786761
def test_chat_databricks_custom_outputs_stream():
787-
llm = ChatDatabricks(model="agents_ml-bbqiu-codegen", use_responses_api=True)
762+
llm = ChatDatabricks(model="agents_ml-bbqiu-mcp-openai", use_responses_api=True)
788763
response = llm.stream(
789764
"What is the 10th fibonacci number?",
790765
custom_inputs={"key": "value"},
@@ -820,3 +795,83 @@ def test_chat_databricks_token_count():
820795
last_chunk.usage_metadata["total_tokens"]
821796
== last_chunk.usage_metadata["input_tokens"] + last_chunk.usage_metadata["output_tokens"]
822797
)
798+
799+
800+
def test_chat_databricks_gpt5_stream_with_usage():
801+
"""
802+
Test GPT-5 streaming with usage metadata.
803+
804+
GPT-5 sends a final chunk with only usage data (no choices/delta).
805+
This test verifies that the usage metadata is correctly extracted from that final chunk.
806+
807+
Example final chunk from GPT-5:
808+
ChatCompletionChunk(
809+
id='chatcmpl-...',
810+
choices=[], # Empty choices array
811+
created=...,
812+
model='gpt-5-2025-08-07',
813+
object='chat.completion.chunk',
814+
usage=CompletionUsage(
815+
completion_tokens=267,
816+
prompt_tokens=4861,
817+
total_tokens=5128,
818+
...
819+
)
820+
)
821+
"""
822+
from databricks.sdk import WorkspaceClient
823+
824+
# Use dogfood profile to access GPT-5
825+
workspace_client = WorkspaceClient(profile=DATABRICKS_CLI_PROFILE)
826+
827+
llm = ChatDatabricks(
828+
endpoint="gpt-5",
829+
workspace_client=workspace_client,
830+
max_tokens=100,
831+
stream_usage=True,
832+
)
833+
834+
# Stream a simple query
835+
chunks = list(llm.stream("hello"))
836+
837+
# Verify we get chunks
838+
assert len(chunks) > 0, "Expected at least one chunk from GPT-5 streaming"
839+
840+
# Find content chunks (non-empty content)
841+
content_chunks = [chunk for chunk in chunks if chunk.content != ""]
842+
assert len(content_chunks) > 0, "Expected at least one content chunk"
843+
844+
# Find usage chunks (empty content with usage_metadata)
845+
usage_chunks = [
846+
chunk for chunk in chunks if chunk.content == "" and chunk.usage_metadata is not None
847+
]
848+
849+
# Should have exactly ONE usage chunk from the final usage-only chunk
850+
assert len(usage_chunks) == 1, (
851+
f"Expected exactly 1 usage chunk from GPT-5 final chunk, got {len(usage_chunks)}"
852+
)
853+
854+
# Verify usage chunk has correct metadata structure
855+
usage_chunk = usage_chunks[0]
856+
assert isinstance(usage_chunk, AIMessageChunk)
857+
assert usage_chunk.content == ""
858+
assert "input_tokens" in usage_chunk.usage_metadata
859+
assert "output_tokens" in usage_chunk.usage_metadata
860+
assert "total_tokens" in usage_chunk.usage_metadata
861+
862+
# Verify token counts are positive
863+
assert usage_chunk.usage_metadata["input_tokens"] > 0, (
864+
f"Expected positive input_tokens, got {usage_chunk.usage_metadata['input_tokens']}"
865+
)
866+
assert usage_chunk.usage_metadata["output_tokens"] > 0, (
867+
f"Expected positive output_tokens, got {usage_chunk.usage_metadata['output_tokens']}"
868+
)
869+
870+
# Verify total_tokens equals sum of input and output
871+
expected_total = (
872+
usage_chunk.usage_metadata["input_tokens"] + usage_chunk.usage_metadata["output_tokens"]
873+
)
874+
assert usage_chunk.usage_metadata["total_tokens"] == expected_total, (
875+
f"Expected total_tokens ({usage_chunk.usage_metadata['total_tokens']}) "
876+
f"to equal input_tokens + output_tokens ({expected_total})"
877+
)

integrations/langchain/tests/unit_tests/test_chat_models.py

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,193 @@ def test_chat_model_stream_no_duplicate_usage_chunks():
350350
assert len(usage_chunks) == 1, f"Expected exactly 1 usage chunk, got {len(usage_chunks)}"
351351

352352

353+
def test_chat_model_stream_usage_only_final_chunk():
354+
"""Test that a final chunk with only usage data (no choices) correctly emits usage metadata."""
355+
from unittest.mock import Mock, patch
356+
357+
mock_usage = Mock()
358+
mock_usage.prompt_tokens = 15
359+
mock_usage.completion_tokens = 10
360+
361+
# Simulate GPT-5 streaming behavior: content chunks followed by usage-only chunk
362+
mock_chunks = [
363+
Mock(
364+
choices=[
365+
Mock(
366+
delta=Mock(
367+
role="assistant",
368+
content="Hello",
369+
model_dump=Mock(return_value={"role": "assistant", "content": "Hello"}),
370+
),
371+
finish_reason=None,
372+
logprobs=None,
373+
)
374+
],
375+
usage=None,
376+
),
377+
Mock(
378+
choices=[
379+
Mock(
380+
delta=Mock(
381+
role="assistant",
382+
content=" world",
383+
model_dump=Mock(return_value={"role": "assistant", "content": " world"}),
384+
),
385+
finish_reason="stop",
386+
logprobs=None,
387+
)
388+
],
389+
usage=None,
390+
),
391+
# Final chunk with ONLY usage data, no choices/delta
392+
Mock(
393+
choices=[],
394+
usage=mock_usage,
395+
),
396+
]
397+
398+
# Verify mock structure matches GPT-5 behavior
399+
# Final chunk has empty choices list and usage data (no delta)
400+
assert len(mock_chunks[2].choices) == 0
401+
assert mock_chunks[2].usage is not None
402+
403+
with patch("databricks_langchain.chat_models.get_openai_client") as mock_get_client:
404+
mock_client = Mock()
405+
mock_get_client.return_value = mock_client
406+
mock_client.chat.completions.create.return_value = iter(mock_chunks)
407+
408+
llm = ChatDatabricks(model="test-model")
409+
messages = [HumanMessage(content="Hello")]
410+
411+
chunks = list(llm.stream(messages, stream_usage=True))
412+
413+
# Should get content chunks plus one usage chunk
414+
content_chunks = [chunk for chunk in chunks if chunk.content != ""]
415+
assert len(content_chunks) == 2
416+
assert content_chunks[0].content == "Hello"
417+
assert content_chunks[1].content == " world"
418+
419+
# Should emit exactly ONE usage chunk
420+
usage_chunks = [
421+
chunk for chunk in chunks if chunk.content == "" and chunk.usage_metadata is not None
422+
]
423+
assert len(usage_chunks) == 1, f"Expected exactly 1 usage chunk, got {len(usage_chunks)}"
424+
425+
# Verify usage chunk has correct metadata
426+
usage_chunk = usage_chunks[0]
427+
assert isinstance(usage_chunk, AIMessageChunk)
428+
assert usage_chunk.content == ""
429+
assert usage_chunk.usage_metadata["input_tokens"] == 15
430+
assert usage_chunk.usage_metadata["output_tokens"] == 10
431+
assert usage_chunk.usage_metadata["total_tokens"] == 25
432+
433+
434+
def test_chat_model_stream_usage_only_chunk_missing_tokens():
435+
"""Test that a usage-only chunk with missing token data doesn't emit usage metadata."""
436+
from unittest.mock import Mock, patch
437+
438+
mock_usage = Mock()
439+
mock_usage.prompt_tokens = None # Missing prompt_tokens
440+
mock_usage.completion_tokens = 10
441+
442+
mock_chunks = [
443+
Mock(
444+
choices=[
445+
Mock(
446+
delta=Mock(
447+
role="assistant",
448+
content="Hello",
449+
model_dump=Mock(return_value={"role": "assistant", "content": "Hello"}),
450+
),
451+
finish_reason="stop",
452+
logprobs=None,
453+
)
454+
],
455+
usage=None,
456+
),
457+
# Final chunk with usage data but missing prompt_tokens
458+
Mock(
459+
choices=[],
460+
usage=mock_usage,
461+
),
462+
]
463+
464+
with patch("databricks_langchain.chat_models.get_openai_client") as mock_get_client:
465+
mock_client = Mock()
466+
mock_get_client.return_value = mock_client
467+
mock_client.chat.completions.create.return_value = iter(mock_chunks)
468+
469+
llm = ChatDatabricks(model="test-model")
470+
messages = [HumanMessage(content="Hello")]
471+
472+
chunks = list(llm.stream(messages, stream_usage=True))
473+
474+
# Should get content chunks but NO usage chunk (due to missing tokens)
475+
content_chunks = [chunk for chunk in chunks if chunk.content != ""]
476+
assert len(content_chunks) == 1
477+
478+
# Should NOT emit a usage chunk when tokens are missing
479+
usage_chunks = [
480+
chunk for chunk in chunks if chunk.content == "" and chunk.usage_metadata is not None
481+
]
482+
assert len(usage_chunks) == 0, (
483+
f"Expected 0 usage chunks when tokens are missing, got {len(usage_chunks)}"
484+
)
485+
486+
487+
def test_chat_model_stream_usage_only_chunk_stream_usage_false():
488+
"""Test that a usage-only chunk is ignored when stream_usage=False."""
489+
from unittest.mock import Mock, patch
490+
491+
mock_usage = Mock()
492+
mock_usage.prompt_tokens = 15
493+
mock_usage.completion_tokens = 10
494+
495+
mock_chunks = [
496+
Mock(
497+
choices=[
498+
Mock(
499+
delta=Mock(
500+
role="assistant",
501+
content="Hello",
502+
model_dump=Mock(return_value={"role": "assistant", "content": "Hello"}),
503+
),
504+
finish_reason="stop",
505+
logprobs=None,
506+
)
507+
],
508+
usage=None,
509+
),
510+
# Final chunk with usage data
511+
Mock(
512+
choices=[],
513+
usage=mock_usage,
514+
),
515+
]
516+
517+
with patch("databricks_langchain.chat_models.get_openai_client") as mock_get_client:
518+
mock_client = Mock()
519+
mock_get_client.return_value = mock_client
520+
mock_client.chat.completions.create.return_value = iter(mock_chunks)
521+
522+
llm = ChatDatabricks(model="test-model")
523+
messages = [HumanMessage(content="Hello")]
524+
525+
chunks = list(llm.stream(messages, stream_usage=False))
526+
527+
# Should get content chunks only
528+
content_chunks = [chunk for chunk in chunks if chunk.content != ""]
529+
assert len(content_chunks) == 1
530+
531+
# Should NOT emit a usage chunk when stream_usage=False
532+
usage_chunks = [
533+
chunk for chunk in chunks if chunk.content == "" and chunk.usage_metadata is not None
534+
]
535+
assert len(usage_chunks) == 0, (
536+
f"Expected 0 usage chunks when stream_usage=False, got {len(usage_chunks)}"
537+
)
538+
539+
353540
class GetWeather(BaseModel):
354541
"""Get the current weather in a given location"""
355542

0 commit comments

Comments
 (0)