Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions backend/examples/cli_research.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,29 @@ def main() -> None:
messages = result.get("messages", [])
if messages:
print(messages[-1].content)

token_records = result.get("token_usage_records", [])
if token_records:
print("\n" + "=" * 80)
print("TOKEN USAGE SUMMARY")
print("=" * 80)

total_input = 0
total_output = 0

for record in token_records:
print(f"\n{record['node_name'].upper():<20} ({record['model']})")
print(f" Input tokens: {record['input_tokens']:,}")
print(f" Output tokens: {record['output_tokens']:,}")
total_input += record['input_tokens']
total_output += record['output_tokens']

print("\n" + "-" * 80)
print(f"{'TOTAL':<20}")
print(f" Input tokens: {total_input:,}")
print(f" Output tokens: {total_output:,}")
print(f" Total tokens: {(total_input + total_output):,}")
print("=" * 80)


if __name__ == "__main__":
Expand Down
8 changes: 8 additions & 0 deletions backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@ dev = ["mypy>=1.11.1", "ruff>=0.6.1"]
requires = ["setuptools>=73.0.0", "wheel"]
build-backend = "setuptools.build_meta"

# Configure setuptools to find packages in src/ directory to avoid import conflicts with site-packages
# Fixes: ImportError: cannot import name 'extract_token_usage_from_langchain' from 'agent.utils'
[tool.setuptools]
package-dir = {"" = "src"}

[tool.setuptools.packages.find]
where = ["src"]

[tool.ruff]
lint.select = [
"E", # pycodestyle
Expand Down
7 changes: 7 additions & 0 deletions backend/src/agent/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@ class Configuration(BaseModel):
metadata={"description": "The maximum number of research loops to perform."},
)

track_token_usage: bool = Field(
default=True,
metadata={
"description": "Enable token usage tracking for cost monitoring and optimization."
},
)

@classmethod
def from_runnable_config(
cls, config: Optional[RunnableConfig] = None
Expand Down
62 changes: 58 additions & 4 deletions backend/src/agent/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
get_research_topic,
insert_citation_markers,
resolve_urls,
extract_token_usage_from_langchain,
extract_token_usage_from_genai_client,
)

load_dotenv()
Expand Down Expand Up @@ -78,7 +80,20 @@ def generate_query(state: OverallState, config: RunnableConfig) -> QueryGenerati
)
# Generate the search queries
result = structured_llm.invoke(formatted_prompt)
return {"search_query": result.query}

update = {"search_query": result.query}
if configurable.track_token_usage:
token_usage = extract_token_usage_from_langchain(result)
update["token_usage_records"] = [
{
"node_name": "generate_query",
"input_tokens": token_usage["input_tokens"],
"output_tokens": token_usage["output_tokens"],
"model": configurable.query_generator_model,
}
]

Comment thread
AbdulTawabJuly marked this conversation as resolved.
return update


def continue_to_web_research(state: QueryGenerationState):
Expand Down Expand Up @@ -129,11 +144,24 @@ def web_research(state: WebSearchState, config: RunnableConfig) -> OverallState:
modified_text = insert_citation_markers(response.text, citations)
sources_gathered = [item for citation in citations for item in citation["segments"]]

return {
update = {
"sources_gathered": sources_gathered,
"search_query": [state["search_query"]],
"web_research_result": [modified_text],
}

if configurable.track_token_usage:
token_usage = extract_token_usage_from_genai_client(response)
update["token_usage_records"] = [
{
"node_name": "web_research",
"input_tokens": token_usage["input_tokens"],
"output_tokens": token_usage["output_tokens"],
"model": configurable.query_generator_model,
}
]

return update


def reflection(state: OverallState, config: RunnableConfig) -> ReflectionState:
Expand Down Expand Up @@ -171,13 +199,26 @@ def reflection(state: OverallState, config: RunnableConfig) -> ReflectionState:
)
result = llm.with_structured_output(Reflection).invoke(formatted_prompt)

return {
update = {
"is_sufficient": result.is_sufficient,
"knowledge_gap": result.knowledge_gap,
"follow_up_queries": result.follow_up_queries,
"research_loop_count": state["research_loop_count"],
"number_of_ran_queries": len(state["search_query"]),
}

if configurable.track_token_usage:
token_usage = extract_token_usage_from_langchain(result)
update["token_usage_records"] = [
{
"node_name": "reflection",
"input_tokens": token_usage["input_tokens"],
"output_tokens": token_usage["output_tokens"],
"model": reasoning_model,
}
]

return update


def evaluate_research(
Expand Down Expand Up @@ -259,10 +300,23 @@ def finalize_answer(state: OverallState, config: RunnableConfig):
)
unique_sources.append(source)

return {
update = {
"messages": [AIMessage(content=result.content)],
"sources_gathered": unique_sources,
}

if configurable.track_token_usage:
token_usage = extract_token_usage_from_langchain(result)
update["token_usage_records"] = [
{
"node_name": "finalize_answer",
"input_tokens": token_usage["input_tokens"],
"output_tokens": token_usage["output_tokens"],
"model": reasoning_model,
}
]

return update


# Create our Agent Graph
Expand Down
10 changes: 10 additions & 0 deletions backend/src/agent/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,21 @@
import operator


class TokenUsageRecord(TypedDict):
"""Record of token usage for a single node execution."""

node_name: str
input_tokens: int
output_tokens: int
model: str


class OverallState(TypedDict):
messages: Annotated[list, add_messages]
search_query: Annotated[list, operator.add]
web_research_result: Annotated[list, operator.add]
sources_gathered: Annotated[list, operator.add]
token_usage_records: Annotated[list, operator.add]
initial_search_query_count: int
max_research_loops: int
research_loop_count: int
Expand Down
37 changes: 37 additions & 0 deletions backend/src/agent/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,43 @@
from langchain_core.messages import AnyMessage, AIMessage, HumanMessage


def extract_token_usage_from_langchain(response: Any) -> Dict[str, int]:
"""
Extract token usage from LangChain ChatGoogleGenerativeAI response.

Args:
response: The response object from LangChain's ChatGoogleGenerativeAI

Returns:
Dictionary with 'input_tokens' and 'output_tokens' keys
"""
if hasattr(response, "response_metadata"):
usage = response.response_metadata.get("usage_metadata", {})
return {
"input_tokens": usage.get("prompt_token_count", 0),
"output_tokens": usage.get("candidates_token_count", 0),
}
return {"input_tokens": 0, "output_tokens": 0}


def extract_token_usage_from_genai_client(response: Any) -> Dict[str, int]:
"""
Extract token usage from native google.genai.Client response.

Args:
response: The response object from google.genai.Client

Returns:
Dictionary with 'input_tokens' and 'output_tokens' keys
"""
if hasattr(response, "usage_metadata"):
return {
"input_tokens": response.usage_metadata.prompt_token_count,
"output_tokens": response.usage_metadata.candidates_token_count,
}
return {"input_tokens": 0, "output_tokens": 0}


def get_research_topic(messages: List[AnyMessage]) -> str:
"""
Get the research topic from the messages.
Expand Down
27 changes: 26 additions & 1 deletion frontend/src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { useStream } from "@langchain/langgraph-sdk/react";
import type { Message } from "@langchain/langgraph-sdk";
import { useState, useEffect, useRef, useCallback } from "react";
import { ProcessedEvent } from "@/components/ActivityTimeline";
import { TokenUsageRecord } from "@/components/TokenUsageDisplay";
import { WelcomeScreen } from "@/components/WelcomeScreen";
import { ChatMessagesView } from "@/components/ChatMessagesView";
import { Button } from "@/components/ui/button";
Expand All @@ -13,6 +14,12 @@ export default function App() {
const [historicalActivities, setHistoricalActivities] = useState<
Record<string, ProcessedEvent[]>
>({});
const [tokenUsageTimeline, setTokenUsageTimeline] = useState<
TokenUsageRecord[]
>([]);
const [historicalTokenUsage, setHistoricalTokenUsage] = useState<
Record<string, TokenUsageRecord[]>
>({});
const scrollAreaRef = useRef<HTMLDivElement>(null);
const hasFinalizeEventOccurredRef = useRef(false);
const [error, setError] = useState<string | null>(null);
Expand All @@ -21,6 +28,7 @@ export default function App() {
initial_search_query_count: number;
max_research_loops: number;
reasoning_model: string;
token_usage_records?: TokenUsageRecord[];
}>({
apiUrl: import.meta.env.DEV
? "http://localhost:2024"
Expand Down Expand Up @@ -65,6 +73,16 @@ export default function App() {
processedEvent!,
]);
}

const nodeNames = ['generate_query', 'web_research', 'reflection', 'finalize_answer'];
Comment thread
AbdulTawabJuly marked this conversation as resolved.
Outdated
for (const nodeName of nodeNames) {
if (event[nodeName]?.token_usage_records) {
const newRecords = event[nodeName].token_usage_records;
if (Array.isArray(newRecords) && newRecords.length > 0) {
setTokenUsageTimeline((prev) => [...prev, ...newRecords]);
}
}
}
},
onError: (error: any) => {
setError(error.message);
Expand Down Expand Up @@ -94,15 +112,20 @@ export default function App() {
...prev,
[lastMessage.id!]: [...processedEventsTimeline],
}));
setHistoricalTokenUsage((prev) => ({
...prev,
[lastMessage.id!]: [...tokenUsageTimeline],
}));
}
hasFinalizeEventOccurredRef.current = false;
}
}, [thread.messages, thread.isLoading, processedEventsTimeline]);
}, [thread.messages, thread.isLoading, processedEventsTimeline, tokenUsageTimeline]);

const handleSubmit = useCallback(
(submittedInputValue: string, effort: string, model: string) => {
if (!submittedInputValue.trim()) return;
setProcessedEventsTimeline([]);
setTokenUsageTimeline([]);
hasFinalizeEventOccurredRef.current = false;

// convert effort to, initial_search_query_count and max_research_loops
Expand Down Expand Up @@ -181,6 +204,8 @@ export default function App() {
onCancel={handleCancel}
liveActivityEvents={processedEventsTimeline}
historicalActivities={historicalActivities}
liveTokenUsage={tokenUsageTimeline}
historicalTokenUsage={historicalTokenUsage}
/>
)}
</main>
Expand Down
Loading