11from unittest .mock import MagicMock , patch
22
33from django .db .models import Q
4+ from pgvector .django import L2Distance
45
5- from api .services .embedding_services import build_query , evaluate_query , log_usage
6+ from api .services .embedding_services import (
7+ build_query ,
8+ evaluate_query ,
9+ get_closest_embeddings ,
10+ log_usage ,
11+ )
612
713# ---------------------------------------------------------------------------
814# build_query tests
9- #
10- # build_query only constructs a lazy Django QuerySet — it never evaluates it
11- # (no iteration, .get(), .exists(), etc.), so no database is needed.
12- #
13- # We patch Embeddings.objects so every chained ORM call (.filter, .annotate,
14- # .order_by, __getitem__) returns a MagicMock instead of hitting the DB.
15- # All assertions inspect which methods were called with which arguments.
1615# ---------------------------------------------------------------------------
1716
17+ # All assertions inspect which methods and arguments were called on Embeddings.objects
18+
1819# Only forwarded to L2Distance
1920EMBEDDING_VECTOR = [0.1 , 0.2 , 0.3 ]
2021
@@ -48,12 +49,6 @@ def test_build_query_unauthenticated_uses_superuser_only_filter(mock_objects):
4849
4950# Test application of annotate and order_by
5051
51- # TODO: Strengthen test_build_query_annotates_and_orders_by_distance to also
52- # assert the *arguments* to annotate — specifically that it receives
53- # distance=L2Distance("embedding_sentence_transformers", EMBEDDING_VECTOR).
54- # Currently only the call count is checked, so a wrong field name or a
55- # dropped vector would go undetected.
56-
5752@patch ("api.services.embedding_services.Embeddings.objects" )
5853def test_build_query_annotates_and_orders_by_distance (mock_objects ):
5954 # Regardless of other arguments, annotate(distance=L2Distance(...)) and
@@ -67,6 +62,12 @@ def test_build_query_annotates_and_orders_by_distance(mock_objects):
6762 filtered_qs .annotate .assert_called_once ()
6863 filtered_qs .annotate .return_value .order_by .assert_called_once_with ("distance" )
6964
65+ # L2Distance is a Django Func subclass, which implements __eq__ by comparing
66+ # class and source expressions — so we can assert the exact field name and
67+ # vector without patching L2Distance itself.
68+ actual_distance_expr = filtered_qs .annotate .call_args .kwargs ["distance" ]
69+ assert actual_distance_expr == L2Distance ("embedding_sentence_transformers" , EMBEDDING_VECTOR )
70+
7071# Test guid-over-document precedence logic
7172
7273@patch ("api.services.embedding_services.Embeddings.objects" )
@@ -165,7 +166,10 @@ def test_build_query_returns_unevaluated_queryset(mock_objects):
165166# evaluate_query tests
166167# ---------------------------------------------------------------------------
167168
168- # TODO: Add test for empty queryset — evaluate_query([]) should return [].
169+ def test_evaluate_query_empty_queryset ():
170+ # An empty iterable should return an empty list, not raise an exception.
171+ assert evaluate_query ([]) == []
172+
169173
170174def test_evaluate_query_maps_fields ():
171175 # Verify that each Embeddings model attribute is mapped to the correct
@@ -193,8 +197,8 @@ def test_evaluate_query_maps_fields():
193197
194198
195199def test_evaluate_query_none_upload_file ():
196- # When upload_file is None (e.g. the FK was deleted) , file_id must be None
197- # rather than raising an AttributeError on None.guid.
200+ # When upload_file is None, file_id must be None rather than raising
201+ # an AttributeError on None.guid.
198202 obj = MagicMock ()
199203 obj .name = "doc.pdf"
200204 obj .text = "some text"
@@ -211,17 +215,71 @@ def test_evaluate_query_none_upload_file():
211215# log_usage tests
212216# ---------------------------------------------------------------------------
213217
214- # TODO: Add test for empty results list — log_usage([]) hits the else branch and
215- # should call SemanticSearchUsage.objects.create with num_results_returned=0
216- # and max_distance=None, median_distance=None, min_distance=None.
218+ @patch ("api.services.embedding_services.SemanticSearchUsage.objects.create" )
219+ def test_log_usage_empty_results (mock_create ):
220+ # Empty results hits the else branch. The record should still be created
221+ # with num_results_returned=0 and all distance fields set to None.
222+ user = MagicMock (is_authenticated = True )
223+
224+ log_usage (
225+ [],
226+ message_data = "test query" ,
227+ user = user ,
228+ guid = None ,
229+ document_name = None ,
230+ num_results = 10 ,
231+ encoding_time = 0.1 ,
232+ db_query_time = 0.2 ,
233+ )
234+
235+ mock_create .assert_called_once ()
236+ kwargs = mock_create .call_args .kwargs
237+ assert kwargs ["num_results_returned" ] == 0
238+ assert kwargs ["max_distance" ] is None
239+ assert kwargs ["median_distance" ] is None
240+ assert kwargs ["min_distance" ] is None
241+
242+
243+ @patch ("api.services.embedding_services.SemanticSearchUsage.objects.create" )
244+ def test_log_usage_unauthenticated_user_stored_as_none (mock_create ):
245+ # An unauthenticated user should be stored as None in the DB record, not as
246+ # the user object itself, so the FK constraint is not violated.
247+ user = MagicMock (is_authenticated = False )
248+
249+ log_usage (
250+ [{"distance" : 1.0 }],
251+ message_data = "test query" ,
252+ user = user ,
253+ guid = None ,
254+ document_name = None ,
255+ num_results = 10 ,
256+ encoding_time = 0.1 ,
257+ db_query_time = 0.2 ,
258+ )
259+
260+ kwargs = mock_create .call_args .kwargs
261+ assert kwargs ["user" ] is None
262+
263+
264+ @patch ("api.services.embedding_services.SemanticSearchUsage.objects.create" )
265+ def test_log_usage_none_user_stored_as_none (mock_create ):
266+ # Passing user=None directly (e.g. from an anonymous request) should also
267+ # store None — the expression `user if (user and user.is_authenticated)`
268+ # short-circuits on the falsy None before accessing .is_authenticated.
269+ log_usage (
270+ [{"distance" : 1.0 }],
271+ message_data = "test query" ,
272+ user = None ,
273+ guid = None ,
274+ document_name = None ,
275+ num_results = 10 ,
276+ encoding_time = 0.1 ,
277+ db_query_time = 0.2 ,
278+ )
217279
218- # TODO: Add test for unauthenticated user — user.is_authenticated=False should
219- # result in user=None being stored in the SemanticSearchUsage record.
280+ kwargs = mock_create . call_args . kwargs
281+ assert kwargs [ " user" ] is None
220282
221- # TODO: Add test for user=None — passing None directly as the user argument
222- # should also store user=None (the expression `user if (user and
223- # user.is_authenticated) else None` handles both cases, but only the
224- # authenticated path is currently exercised).
225283
226284@patch ("api.services.embedding_services.SemanticSearchUsage.objects.create" )
227285def test_log_usage_computes_distance_stats (mock_create ):
@@ -276,8 +334,37 @@ def test_log_usage_swallows_exceptions(mock_create):
276334# get_closest_embeddings tests
277335# ---------------------------------------------------------------------------
278336
279- # TODO: Add smoke test for get_closest_embeddings verifying the wiring between
280- # its three steps: encode → build_query → evaluate_query → log_usage.
281- # Patch TransformerModel.get_instance, build_query, evaluate_query, and
282- # log_usage. Assert that evaluate_query receives the queryset returned by
283- # build_query, and that the function returns evaluate_query's result.
337+ @patch ("api.services.embedding_services.log_usage" )
338+ @patch ("api.services.embedding_services.evaluate_query" )
339+ @patch ("api.services.embedding_services.build_query" )
340+ @patch ("api.services.embedding_services.TransformerModel" )
341+ def test_get_closest_embeddings_wiring (mock_transformer , mock_build , mock_evaluate , mock_log ):
342+ # Smoke test verifying that get_closest_embeddings correctly wires together
343+ # encode → build_query → evaluate_query → log_usage and returns the results.
344+ user = MagicMock (is_authenticated = True )
345+
346+ # Simulate the model encoding the message to a vector.
347+ fake_vector = [0.1 , 0.2 , 0.3 ]
348+ mock_transformer .get_instance .return_value .model .encode .return_value = fake_vector
349+
350+ # build_query returns a queryset; evaluate_query turns it into a results list.
351+ fake_queryset = MagicMock ()
352+ mock_build .return_value = fake_queryset
353+ fake_results = [{"name" : "doc.pdf" , "distance" : 0.5 }]
354+ mock_evaluate .return_value = fake_results
355+
356+ result = get_closest_embeddings (user , "some query" , document_name = "doc.pdf" , guid = None , num_results = 5 )
357+
358+ # The encoded vector must be forwarded to build_query.
359+ mock_build .assert_called_once_with (user , fake_vector , "doc.pdf" , None , 5 )
360+
361+ # evaluate_query must receive the queryset that build_query returned.
362+ mock_evaluate .assert_called_once_with (fake_queryset )
363+
364+ # log_usage must be called with the results and original parameters.
365+ mock_log .assert_called_once ()
366+ log_kwargs = mock_log .call_args .args
367+ assert log_kwargs [0 ] is fake_results
368+
369+ # The function must return evaluate_query's result unchanged.
370+ assert result is fake_results
0 commit comments