Skip to content

Commit c1c0084

Browse files
committed
implements referenced documents on /query
1 parent 2cc494c commit c1c0084

3 files changed

Lines changed: 220 additions & 25 deletions

File tree

src/app/endpoints/query.py

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
"""Handler for REST API call to provide answer to query."""
22

3+
import ast
34
from datetime import datetime, UTC
45
import json
56
import logging
67
import os
78
from pathlib import Path
9+
import re
810
from typing import Annotated, Any
911

1012
from llama_stack_client import APIConnectionError
@@ -41,6 +43,8 @@
4143
router = APIRouter(tags=["query"])
4244
auth_dependency = get_auth_dependency()
4345

46+
METADATA_PATTERN = re.compile(r"\nMetadata: (\{.+})\n")
47+
4448
query_response: dict[int | str, dict[str, Any]] = {
4549
200: {
4650
"conversation_id": "123e4567-e89b-12d3-a456-426614174000",
@@ -189,7 +193,7 @@ async def query_endpoint_handler(
189193
user_conversation=user_conversation, query_request=query_request
190194
),
191195
)
192-
response, conversation_id = await retrieve_response(
196+
response, conversation_id, referenced_documents = await retrieve_response(
193197
client,
194198
llama_stack_model_id,
195199
query_request,
@@ -223,7 +227,11 @@ async def query_endpoint_handler(
223227
provider_id=provider_id,
224228
)
225229

226-
return QueryResponse(conversation_id=conversation_id, response=response)
230+
return QueryResponse(
231+
conversation_id=conversation_id,
232+
response=response,
233+
referenced_documents=referenced_documents
234+
)
227235

228236
# connection to Llama Stack server
229237
except APIConnectionError as e:
@@ -322,7 +330,7 @@ async def retrieve_response( # pylint: disable=too-many-locals
322330
query_request: QueryRequest,
323331
token: str,
324332
mcp_headers: dict[str, dict[str, str]] | None = None,
325-
) -> tuple[str, str]:
333+
) -> tuple[str, str, list[dict[str, str]]]:
326334
"""Retrieve response from LLMs and agents."""
327335
available_input_shields = [
328336
shield.identifier
@@ -402,15 +410,42 @@ async def retrieve_response( # pylint: disable=too-many-locals
402410
toolgroups=toolgroups,
403411
)
404412

405-
# Check for validation errors in the response
413+
# Collect metadata from tool responses to extract referenced documents
414+
metadata_map: dict[str, dict[str, Any]] = {}
406415
steps = getattr(response, "steps", [])
407416
for step in steps:
408417
if step.step_type == "shield_call" and step.violation:
409418
# Metric for LLM validation errors
410419
metrics.llm_calls_validation_errors_total.inc()
411-
break
420+
elif step.step_type == "tool_execution" and hasattr(step, "tool_responses"):
421+
for tool_response in step.tool_responses:
422+
if tool_response.tool_name == "knowledge_search" and tool_response.content:
423+
for text_content_item in tool_response.content:
424+
if hasattr(text_content_item, 'text'):
425+
for match in METADATA_PATTERN.findall(text_content_item.text):
426+
try:
427+
meta = ast.literal_eval(match)
428+
if "document_id" in meta:
429+
metadata_map[meta["document_id"]] = meta
430+
except Exception: # pylint: disable=broad-except
431+
logger.debug(
432+
"An exception was thrown in processing %s",
433+
match,
434+
)
435+
436+
# Extract referenced documents from metadata
437+
referenced_documents = [
438+
{
439+
"doc_url": v["docs_url"],
440+
"doc_title": v["title"],
441+
}
442+
for v in filter(
443+
lambda v: ("docs_url" in v) and ("title" in v),
444+
metadata_map.values(),
445+
)
446+
]
412447

413-
return str(response.output_message.content), conversation_id # type: ignore[union-attr]
448+
return str(response.output_message.content), conversation_id, referenced_documents # type: ignore[union-attr]
414449

415450

416451
def validate_attachments_metadata(attachments: list[Attachment]) -> None:

src/models/responses.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,6 @@ class ModelsResponse(BaseModel):
3636

3737
# TODO(lucasagomes): a lot of fields to add to QueryResponse. For now
3838
# we are keeping it simple. The missing fields are:
39-
# - referenced_documents: The optional URLs and titles for the documents used
40-
# to generate the response.
4139
# - truncated: Set to True if conversation history was truncated to be within context window.
4240
# - input_tokens: Number of tokens sent to LLM
4341
# - output_tokens: Number of tokens received from LLM
@@ -51,6 +49,8 @@ class QueryResponse(BaseModel):
5149
Attributes:
5250
conversation_id: The optional conversation ID (UUID).
5351
response: The response.
52+
referenced_documents: The optional URLs and titles for the documents used
53+
to generate the response.
5454
"""
5555

5656
conversation_id: Optional[str] = Field(
@@ -65,6 +65,19 @@ class QueryResponse(BaseModel):
6565
"Kubernetes is an open-source container orchestration system for automating ..."
6666
],
6767
)
68+
69+
referenced_documents: list[dict[str, str]] = Field(
70+
default_factory=list,
71+
description="List of documents referenced in generating the response",
72+
examples=[
73+
[
74+
{
75+
"doc_url": "https://docs.openshift.com/container-platform/4.15/operators/olm/index.html",
76+
"doc_title": "Operator Lifecycle Manager (OLM)"
77+
}
78+
]
79+
],
80+
)
6881

6982
# provides examples for /docs endpoint
7083
model_config = {
@@ -73,6 +86,12 @@ class QueryResponse(BaseModel):
7386
{
7487
"conversation_id": "123e4567-e89b-12d3-a456-426614174000",
7588
"response": "Operator Lifecycle Manager (OLM) helps users install...",
89+
"referenced_documents": [
90+
{
91+
"doc_url": "https://docs.openshift.com/container-platform/4.15/operators/olm/index.html",
92+
"doc_title": "Operator Lifecycle Manager (OLM)"
93+
}
94+
]
7695
}
7796
]
7897
}

0 commit comments

Comments
 (0)