Skip to content

Commit 64a19ef

Browse files
committed
Fill test gaps in test_embedding_services
1 parent 46e9969 commit 64a19ef

2 files changed

Lines changed: 119 additions & 32 deletions

File tree

server/api/services/embedding_services.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import logging
33
from statistics import median
44

5-
# filter() only does ADD logic
5+
# Django filter() only does ADD logic
66
from django.db.models import Q
77
from pgvector.django import L2Distance
88

server/api/services/test_embedding_services.py

Lines changed: 118 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,21 @@
11
from unittest.mock import MagicMock, patch
22

33
from django.db.models import Q
4+
from pgvector.django import L2Distance
45

5-
from api.services.embedding_services import build_query, evaluate_query, log_usage
6+
from api.services.embedding_services import (
7+
build_query,
8+
evaluate_query,
9+
get_closest_embeddings,
10+
log_usage,
11+
)
612

713
# ---------------------------------------------------------------------------
814
# build_query tests
9-
#
10-
# build_query only constructs a lazy Django QuerySet — it never evaluates it
11-
# (no iteration, .get(), .exists(), etc.), so no database is needed.
12-
#
13-
# We patch Embeddings.objects so every chained ORM call (.filter, .annotate,
14-
# .order_by, __getitem__) returns a MagicMock instead of hitting the DB.
15-
# All assertions inspect which methods were called with which arguments.
1615
# ---------------------------------------------------------------------------
1716

17+
# All assertions inspect which methods and arguments were called on Embeddings.objects
18+
1819
# Only forwarded to L2Distance
1920
EMBEDDING_VECTOR = [0.1, 0.2, 0.3]
2021

@@ -48,12 +49,6 @@ def test_build_query_unauthenticated_uses_superuser_only_filter(mock_objects):
4849

4950
# Test application of annotate and order_by
5051

51-
# TODO: Strengthen test_build_query_annotates_and_orders_by_distance to also
52-
# assert the *arguments* to annotate — specifically that it receives
53-
# distance=L2Distance("embedding_sentence_transformers", EMBEDDING_VECTOR).
54-
# Currently only the call count is checked, so a wrong field name or a
55-
# dropped vector would go undetected.
56-
5752
@patch("api.services.embedding_services.Embeddings.objects")
5853
def test_build_query_annotates_and_orders_by_distance(mock_objects):
5954
# Regardless of other arguments, annotate(distance=L2Distance(...)) and
@@ -67,6 +62,12 @@ def test_build_query_annotates_and_orders_by_distance(mock_objects):
6762
filtered_qs.annotate.assert_called_once()
6863
filtered_qs.annotate.return_value.order_by.assert_called_once_with("distance")
6964

65+
# L2Distance is a Django Func subclass, which implements __eq__ by comparing
66+
# class and source expressions — so we can assert the exact field name and
67+
# vector without patching L2Distance itself.
68+
actual_distance_expr = filtered_qs.annotate.call_args.kwargs["distance"]
69+
assert actual_distance_expr == L2Distance("embedding_sentence_transformers", EMBEDDING_VECTOR)
70+
7071
# Test guid-over-document precedence logic
7172

7273
@patch("api.services.embedding_services.Embeddings.objects")
@@ -165,7 +166,10 @@ def test_build_query_returns_unevaluated_queryset(mock_objects):
165166
# evaluate_query tests
166167
# ---------------------------------------------------------------------------
167168

168-
# TODO: Add test for empty queryset — evaluate_query([]) should return [].
169+
def test_evaluate_query_empty_queryset():
170+
# An empty iterable should return an empty list, not raise an exception.
171+
assert evaluate_query([]) == []
172+
169173

170174
def test_evaluate_query_maps_fields():
171175
# Verify that each Embeddings model attribute is mapped to the correct
@@ -193,8 +197,8 @@ def test_evaluate_query_maps_fields():
193197

194198

195199
def test_evaluate_query_none_upload_file():
196-
# When upload_file is None (e.g. the FK was deleted), file_id must be None
197-
# rather than raising an AttributeError on None.guid.
200+
# When upload_file is None, file_id must be None rather than raising
201+
# an AttributeError on None.guid.
198202
obj = MagicMock()
199203
obj.name = "doc.pdf"
200204
obj.text = "some text"
@@ -211,17 +215,71 @@ def test_evaluate_query_none_upload_file():
211215
# log_usage tests
212216
# ---------------------------------------------------------------------------
213217

214-
# TODO: Add test for empty results list — log_usage([]) hits the else branch and
215-
# should call SemanticSearchUsage.objects.create with num_results_returned=0
216-
# and max_distance=None, median_distance=None, min_distance=None.
218+
@patch("api.services.embedding_services.SemanticSearchUsage.objects.create")
219+
def test_log_usage_empty_results(mock_create):
220+
# Empty results hits the else branch. The record should still be created
221+
# with num_results_returned=0 and all distance fields set to None.
222+
user = MagicMock(is_authenticated=True)
223+
224+
log_usage(
225+
[],
226+
message_data="test query",
227+
user=user,
228+
guid=None,
229+
document_name=None,
230+
num_results=10,
231+
encoding_time=0.1,
232+
db_query_time=0.2,
233+
)
234+
235+
mock_create.assert_called_once()
236+
kwargs = mock_create.call_args.kwargs
237+
assert kwargs["num_results_returned"] == 0
238+
assert kwargs["max_distance"] is None
239+
assert kwargs["median_distance"] is None
240+
assert kwargs["min_distance"] is None
241+
242+
243+
@patch("api.services.embedding_services.SemanticSearchUsage.objects.create")
244+
def test_log_usage_unauthenticated_user_stored_as_none(mock_create):
245+
# An unauthenticated user should be stored as None in the DB record, not as
246+
# the user object itself, so the FK constraint is not violated.
247+
user = MagicMock(is_authenticated=False)
248+
249+
log_usage(
250+
[{"distance": 1.0}],
251+
message_data="test query",
252+
user=user,
253+
guid=None,
254+
document_name=None,
255+
num_results=10,
256+
encoding_time=0.1,
257+
db_query_time=0.2,
258+
)
259+
260+
kwargs = mock_create.call_args.kwargs
261+
assert kwargs["user"] is None
262+
263+
264+
@patch("api.services.embedding_services.SemanticSearchUsage.objects.create")
265+
def test_log_usage_none_user_stored_as_none(mock_create):
266+
# Passing user=None directly (e.g. from an anonymous request) should also
267+
# store None — the expression `user if (user and user.is_authenticated)`
268+
# short-circuits on the falsy None before accessing .is_authenticated.
269+
log_usage(
270+
[{"distance": 1.0}],
271+
message_data="test query",
272+
user=None,
273+
guid=None,
274+
document_name=None,
275+
num_results=10,
276+
encoding_time=0.1,
277+
db_query_time=0.2,
278+
)
217279

218-
# TODO: Add test for unauthenticated user — user.is_authenticated=False should
219-
# result in user=None being stored in the SemanticSearchUsage record.
280+
kwargs = mock_create.call_args.kwargs
281+
assert kwargs["user"] is None
220282

221-
# TODO: Add test for user=None — passing None directly as the user argument
222-
# should also store user=None (the expression `user if (user and
223-
# user.is_authenticated) else None` handles both cases, but only the
224-
# authenticated path is currently exercised).
225283

226284
@patch("api.services.embedding_services.SemanticSearchUsage.objects.create")
227285
def test_log_usage_computes_distance_stats(mock_create):
@@ -276,8 +334,37 @@ def test_log_usage_swallows_exceptions(mock_create):
276334
# get_closest_embeddings tests
277335
# ---------------------------------------------------------------------------
278336

279-
# TODO: Add smoke test for get_closest_embeddings verifying the wiring between
280-
# its three steps: encode → build_query → evaluate_query → log_usage.
281-
# Patch TransformerModel.get_instance, build_query, evaluate_query, and
282-
# log_usage. Assert that evaluate_query receives the queryset returned by
283-
# build_query, and that the function returns evaluate_query's result.
337+
@patch("api.services.embedding_services.log_usage")
338+
@patch("api.services.embedding_services.evaluate_query")
339+
@patch("api.services.embedding_services.build_query")
340+
@patch("api.services.embedding_services.TransformerModel")
341+
def test_get_closest_embeddings_wiring(mock_transformer, mock_build, mock_evaluate, mock_log):
342+
# Smoke test verifying that get_closest_embeddings correctly wires together
343+
# encode → build_query → evaluate_query → log_usage and returns the results.
344+
user = MagicMock(is_authenticated=True)
345+
346+
# Simulate the model encoding the message to a vector.
347+
fake_vector = [0.1, 0.2, 0.3]
348+
mock_transformer.get_instance.return_value.model.encode.return_value = fake_vector
349+
350+
# build_query returns a queryset; evaluate_query turns it into a results list.
351+
fake_queryset = MagicMock()
352+
mock_build.return_value = fake_queryset
353+
fake_results = [{"name": "doc.pdf", "distance": 0.5}]
354+
mock_evaluate.return_value = fake_results
355+
356+
result = get_closest_embeddings(user, "some query", document_name="doc.pdf", guid=None, num_results=5)
357+
358+
# The encoded vector must be forwarded to build_query.
359+
mock_build.assert_called_once_with(user, fake_vector, "doc.pdf", None, 5)
360+
361+
# evaluate_query must receive the queryset that build_query returned.
362+
mock_evaluate.assert_called_once_with(fake_queryset)
363+
364+
# log_usage must be called with the results and original parameters.
365+
mock_log.assert_called_once()
366+
log_kwargs = mock_log.call_args.args
367+
assert log_kwargs[0] is fake_results
368+
369+
# The function must return evaluate_query's result unchanged.
370+
assert result is fake_results

0 commit comments

Comments
 (0)