|
1 | 1 | """Handler for REST API call to provide answer to query.""" |
2 | 2 |
|
| 3 | +import ast |
3 | 4 | from datetime import datetime, UTC |
4 | 5 | import json |
5 | 6 | import logging |
6 | 7 | import os |
7 | 8 | from pathlib import Path |
| 9 | +import re |
8 | 10 | from typing import Annotated, Any |
9 | 11 |
|
10 | 12 | from llama_stack_client import APIConnectionError |
|
41 | 43 | router = APIRouter(tags=["query"]) |
42 | 44 | auth_dependency = get_auth_dependency() |
43 | 45 |
|
| 46 | +METADATA_PATTERN = re.compile(r"\nMetadata: (\{.+})\n") |
| 47 | + |
44 | 48 | query_response: dict[int | str, dict[str, Any]] = { |
45 | 49 | 200: { |
46 | 50 | "conversation_id": "123e4567-e89b-12d3-a456-426614174000", |
@@ -189,7 +193,7 @@ async def query_endpoint_handler( |
189 | 193 | user_conversation=user_conversation, query_request=query_request |
190 | 194 | ), |
191 | 195 | ) |
192 | | - response, conversation_id = await retrieve_response( |
| 196 | + response, conversation_id, referenced_documents = await retrieve_response( |
193 | 197 | client, |
194 | 198 | llama_stack_model_id, |
195 | 199 | query_request, |
@@ -223,7 +227,11 @@ async def query_endpoint_handler( |
223 | 227 | provider_id=provider_id, |
224 | 228 | ) |
225 | 229 |
|
226 | | - return QueryResponse(conversation_id=conversation_id, response=response) |
| 230 | + return QueryResponse( |
| 231 | + conversation_id=conversation_id, |
| 232 | + response=response, |
| 233 | + referenced_documents=referenced_documents |
| 234 | + ) |
227 | 235 |
|
228 | 236 | # connection to Llama Stack server |
229 | 237 | except APIConnectionError as e: |
@@ -322,7 +330,7 @@ async def retrieve_response( # pylint: disable=too-many-locals |
322 | 330 | query_request: QueryRequest, |
323 | 331 | token: str, |
324 | 332 | mcp_headers: dict[str, dict[str, str]] | None = None, |
325 | | -) -> tuple[str, str]: |
| 333 | +) -> tuple[str, str, list[dict[str, str]]]: |
326 | 334 | """Retrieve response from LLMs and agents.""" |
327 | 335 | available_input_shields = [ |
328 | 336 | shield.identifier |
@@ -402,15 +410,42 @@ async def retrieve_response( # pylint: disable=too-many-locals |
402 | 410 | toolgroups=toolgroups, |
403 | 411 | ) |
404 | 412 |
|
405 | | - # Check for validation errors in the response |
| 413 | + # Collect metadata from tool responses to extract referenced documents |
| 414 | + metadata_map: dict[str, dict[str, Any]] = {} |
406 | 415 | steps = getattr(response, "steps", []) |
407 | 416 | for step in steps: |
408 | 417 | if step.step_type == "shield_call" and step.violation: |
409 | 418 | # Metric for LLM validation errors |
410 | 419 | metrics.llm_calls_validation_errors_total.inc() |
411 | | - break |
| 420 | + elif step.step_type == "tool_execution" and hasattr(step, "tool_responses"): |
| 421 | + for tool_response in step.tool_responses: |
| 422 | + if tool_response.tool_name == "knowledge_search" and tool_response.content: |
| 423 | + for text_content_item in tool_response.content: |
| 424 | + if hasattr(text_content_item, 'text'): |
| 425 | + for match in METADATA_PATTERN.findall(text_content_item.text): |
| 426 | + try: |
| 427 | + meta = ast.literal_eval(match) |
| 428 | + if "document_id" in meta: |
| 429 | + metadata_map[meta["document_id"]] = meta |
| 430 | + except Exception: # pylint: disable=broad-except |
| 431 | + logger.debug( |
| 432 | + "An exception was thrown in processing %s", |
| 433 | + match, |
| 434 | + ) |
| 435 | + |
| 436 | + # Extract referenced documents from metadata |
| 437 | + referenced_documents = [ |
| 438 | + { |
| 439 | + "doc_url": v["docs_url"], |
| 440 | + "doc_title": v["title"], |
| 441 | + } |
| 442 | + for v in filter( |
| 443 | + lambda v: ("docs_url" in v) and ("title" in v), |
| 444 | + metadata_map.values(), |
| 445 | + ) |
| 446 | + ] |
412 | 447 |
|
413 | | - return str(response.output_message.content), conversation_id # type: ignore[union-attr] |
| 448 | + return str(response.output_message.content), conversation_id, referenced_documents # type: ignore[union-attr] |
414 | 449 |
|
415 | 450 |
|
416 | 451 | def validate_attachments_metadata(attachments: list[Attachment]) -> None: |
|
0 commit comments