22import logging
33from statistics import median
44
5+ # Use Q objects to express OR conditions in Django queries
56from django .db .models import Q
67from pgvector .django import L2Distance
78
1112
1213logger = logging .getLogger (__name__ )
1314
14- def get_closest_embeddings (
15- user , message_data , document_name = None , guid = None , num_results = 10
16- ):
15+
16+ def build_query (user , embedding_vector , document_name = None , guid = None , num_results = 10 ):
1717 """
18- Find the closest embeddings to a given message for a specific user .
18+ Build an unevaluated QuerySet for the closest embeddings .
1919
2020 Parameters
2121 ----------
2222 user : User
2323 The user whose uploaded documents will be searched
24- message_data : str
25- The input message to find similar embeddings for
24+ embedding_vector : array-like
25+ Pre-computed embedding vector to compare against
2626 document_name : str, optional
2727 Filter results to a specific document name
2828 guid : str, optional
@@ -32,59 +32,52 @@ def get_closest_embeddings(
3232
3333 Returns
3434 -------
35- list[dict]
36- List of dictionaries containing embedding results with keys:
37- - name: document name
38- - text: embedded text content
39- - page_number: page number in source document
40- - chunk_number: chunk number within the document
41- - distance: L2 distance from query embedding
42- - file_id: GUID of the source file
35+ QuerySet
36+ Unevaluated Django QuerySet ordered by L2 distance, sliced to num_results
4337 """
44-
45- encoding_start = time .time ()
46- transformerModel = TransformerModel .get_instance ().model
47- embedding_message = transformerModel .encode (message_data )
48- encoding_time = time .time () - encoding_start
49-
50- db_query_start = time .time ()
51-
5238 # Django QuerySets are lazily evaluated
5339 if user .is_authenticated :
5440 # User sees their own files + files uploaded by superusers
55- closest_embeddings_query = (
56- Embeddings .objects .filter (
57- Q (upload_file__uploaded_by = user ) | Q (upload_file__uploaded_by__is_superuser = True )
58- )
59- .annotate (
60- distance = L2Distance ("embedding_sentence_transformers" , embedding_message )
61- )
62- .order_by ("distance" )
41+ queryset = Embeddings .objects .filter (
42+ Q (upload_file__uploaded_by = user ) | Q (upload_file__uploaded_by__is_superuser = True )
6343 )
6444 else :
6545 # Unauthenticated users only see superuser-uploaded files
66- closest_embeddings_query = (
67- Embeddings . objects . filter ( upload_file__uploaded_by__is_superuser = True )
68- . annotate (
69- distance = L2Distance ( "embedding_sentence_transformers" , embedding_message )
70- )
71- .order_by ("distance" )
72- )
46+ queryset = Embeddings . objects . filter ( upload_file__uploaded_by__is_superuser = True )
47+
48+ queryset = (
49+ queryset
50+ . annotate ( distance = L2Distance ( "embedding_sentence_transformers" , embedding_vector ) )
51+ .order_by ("distance" )
52+ )
7353
7454 # Filtering to a document GUID takes precedence over a document name
7555 if guid :
76- closest_embeddings_query = closest_embeddings_query .filter (
77- upload_file__guid = guid
78- )
56+ queryset = queryset .filter (upload_file__guid = guid )
7957 elif document_name :
80- closest_embeddings_query = closest_embeddings_query .filter (name = document_name )
58+ queryset = queryset .filter (name = document_name )
8159
8260 # Slicing is equivalent to SQL's LIMIT clause
83- closest_embeddings_query = closest_embeddings_query [:num_results ]
61+ return queryset [:num_results ]
62+
63+
64+ def evaluate_query (queryset ):
65+ """
66+ Evaluate a QuerySet and return a list of result dicts.
67+
68+ Parameters
69+ ----------
70+ queryset : iterable
71+ Iterable of Embeddings objects (or any objects with the expected attributes)
8472
73+ Returns
74+ -------
75+ list[dict]
76+ List of dicts with keys: name, text, page_number, chunk_number, distance, file_id
77+ """
8578 # Iterating evaluates the QuerySet and hits the database
8679 # TODO: Research improving the query evaluation performance
87- results = [
80+ return [
8881 {
8982 "name" : obj .name ,
9083 "text" : obj .text ,
@@ -93,13 +86,36 @@ def get_closest_embeddings(
9386 "distance" : obj .distance ,
9487 "file_id" : obj .upload_file .guid if obj .upload_file else None ,
9588 }
96- for obj in closest_embeddings_query
89+ for obj in queryset
9790 ]
9891
99- db_query_time = time .time () - db_query_start
10092
93+ def log_usage (
94+ results , message_data , user , guid , document_name , num_results , encoding_time , db_query_time
95+ ):
96+ """
97+ Create a SemanticSearchUsage record. Swallows exceptions so search isn't interrupted.
98+
99+ Parameters
100+ ----------
101+ results : list[dict]
102+ The search results, each containing a "distance" key
103+ message_data : str
104+ The original search query text
105+ user : User
106+ The user who performed the search
107+ guid : str or None
108+ Document GUID filter used in the search
109+ document_name : str or None
110+ Document name filter used in the search
111+ num_results : int
112+ Number of results requested
113+ encoding_time : float
114+ Time in seconds to encode the query
115+ db_query_time : float
116+ Time in seconds for the database query
117+ """
101118 try :
102- # Handle user having no uploaded docs or doc filtering returning no matches
103119 if results :
104120 distances = [r ["distance" ] for r in results ]
105121 SemanticSearchUsage .objects .create (
@@ -113,11 +129,10 @@ def get_closest_embeddings(
113129 num_results_returned = len (results ),
114130 max_distance = max (distances ),
115131 median_distance = median (distances ),
116- min_distance = min (distances )
132+ min_distance = min (distances ),
117133 )
118134 else :
119135 logger .warning ("Semantic search returned no results" )
120-
121136 SemanticSearchUsage .objects .create (
122137 query_text = message_data ,
123138 user = user if (user and user .is_authenticated ) else None ,
@@ -129,9 +144,58 @@ def get_closest_embeddings(
129144 num_results_returned = 0 ,
130145 max_distance = None ,
131146 median_distance = None ,
132- min_distance = None
147+ min_distance = None ,
133148 )
134- except Exception as e :
135- logger .error (f"Failed to create semantic search usage database record: { e } " )
149+ except Exception :
150+ logger .exception ("Failed to create semantic search usage database record" )
151+
152+
153+ def get_closest_embeddings (
154+ user , message_data , document_name = None , guid = None , num_results = 10
155+ ):
156+ """
157+ Find the closest embeddings to a given message for a specific user.
158+
159+ Parameters
160+ ----------
161+ user : User
162+ The user whose uploaded documents will be searched
163+ message_data : str
164+ The input message to find similar embeddings for
165+ document_name : str, optional
166+ Filter results to a specific document name
167+ guid : str, optional
168+ Filter results to a specific document GUID (takes precedence over document_name)
169+ num_results : int, default 10
170+ Maximum number of results to return
171+
172+ Returns
173+ -------
174+ list[dict]
175+ List of dictionaries containing embedding results with keys:
176+ - name: document name
177+ - text: embedded text content
178+ - page_number: page number in source document
179+ - chunk_number: chunk number within the document
180+ - distance: L2 distance from query embedding
181+ - file_id: GUID of the source file
182+
183+ Notes
184+ -----
185+ Creates a SemanticSearchUsage record. Swallows exceptions so search isn't interrupted.
186+ """
187+ encoding_start = time .time ()
188+ model = TransformerModel .get_instance ().model
189+ embedding_vector = model .encode (message_data )
190+ encoding_time = time .time () - encoding_start
191+
192+ db_query_start = time .time ()
193+ queryset = build_query (user , embedding_vector , document_name , guid , num_results )
194+ results = evaluate_query (queryset )
195+ db_query_time = time .time () - db_query_start
196+
197+ log_usage (
198+ results , message_data , user , guid , document_name , num_results , encoding_time , db_query_time
199+ )
136200
137201 return results
0 commit comments