Skip to content

Commit baa4438

Browse files
authored
Merge pull request #461 from sahilds1/441-embedding-models
[#441] [IMPROVE] Preload embedding model at startup
2 parents 13a0a21 + fe1eeca commit baa4438

6 files changed

Lines changed: 561 additions & 52 deletions

File tree

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,11 @@ df = pd.read_sql(query, engine)
7373

7474
#### Django REST
7575
- The email and password are set in `server/api/management/commands/createsu.py`
76+
- Backend tests can be run using `pytest` by running the below command inside the running backend container:
77+
78+
```
79+
docker compose exec backend pytest api/ -v
80+
```
7681

7782
## API Documentation
7883

server/api/apps.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,38 @@
44
class ApiConfig(AppConfig):
55
default_auto_field = 'django.db.models.BigAutoField'
66
name = 'api'
7+
8+
def ready(self):
9+
10+
try:
11+
import os
12+
import sys
13+
14+
# ready() runs in every Django process: migrate, test, shell, runserver, etc.
15+
# Only preload the model when we're actually going to serve requests.
16+
# Dev (docker-compose.yml) runs `manage.py runserver 0.0.0.0:8000`.
17+
# Prod (Dockerfile.prod CMD) runs `manage.py runserver 0.0.0.0:8000 --noreload`.
18+
# entrypoint.prod.sh also runs migrate, createsu, and populatedb before exec'ing
19+
# runserver — the guard below correctly skips model loading for those commands too.
20+
if sys.argv[1:2] != ['runserver']:
21+
return
22+
23+
# Dev's autoreloader spawns two processes: a parent file-watcher and a child
24+
# server. ready() runs in both, but only the child (RUN_MAIN=true) serves
25+
# requests. Skip the parent to avoid loading the model twice on each file change.
26+
# Prod uses --noreload so RUN_MAIN is never set; 'noreload' in sys.argv handles that case.
27+
if os.environ.get('RUN_MAIN') != 'true' and '--noreload' not in sys.argv:
28+
return
29+
30+
# Note: paraphrase-MiniLM-L6-v2 (~80MB) is downloaded from HuggingFace on first
31+
# use and cached to ~/.cache/torch/sentence_transformers/ inside the container.
32+
# That cache is ephemeral — every container rebuild re-downloads the model unless
33+
# a volume is mounted at that path.
34+
from .services.sentencetTransformer_model import TransformerModel
35+
TransformerModel.get_instance()
36+
except Exception:
37+
# TransformerModel._instance stays None on failure, so the first actual request
38+
# that calls get_instance() will attempt to load the model again.
39+
import logging
40+
logger = logging.getLogger(__name__)
41+
logger.exception("Failed to preload the embedding model at startup")

server/api/services/embedding_services.py

Lines changed: 115 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import logging
33
from statistics import median
44

5+
# Use Q objects to express OR conditions in Django queries
56
from django.db.models import Q
67
from pgvector.django import L2Distance
78

@@ -11,18 +12,17 @@
1112

1213
logger = logging.getLogger(__name__)
1314

14-
def get_closest_embeddings(
15-
user, message_data, document_name=None, guid=None, num_results=10
16-
):
15+
16+
def build_query(user, embedding_vector, document_name=None, guid=None, num_results=10):
1717
"""
18-
Find the closest embeddings to a given message for a specific user.
18+
Build an unevaluated QuerySet for the closest embeddings.
1919
2020
Parameters
2121
----------
2222
user : User
2323
The user whose uploaded documents will be searched
24-
message_data : str
25-
The input message to find similar embeddings for
24+
embedding_vector : array-like
25+
Pre-computed embedding vector to compare against
2626
document_name : str, optional
2727
Filter results to a specific document name
2828
guid : str, optional
@@ -32,59 +32,52 @@ def get_closest_embeddings(
3232
3333
Returns
3434
-------
35-
list[dict]
36-
List of dictionaries containing embedding results with keys:
37-
- name: document name
38-
- text: embedded text content
39-
- page_number: page number in source document
40-
- chunk_number: chunk number within the document
41-
- distance: L2 distance from query embedding
42-
- file_id: GUID of the source file
35+
QuerySet
36+
Unevaluated Django QuerySet ordered by L2 distance, sliced to num_results
4337
"""
44-
45-
encoding_start = time.time()
46-
transformerModel = TransformerModel.get_instance().model
47-
embedding_message = transformerModel.encode(message_data)
48-
encoding_time = time.time() - encoding_start
49-
50-
db_query_start = time.time()
51-
5238
# Django QuerySets are lazily evaluated
5339
if user.is_authenticated:
5440
# User sees their own files + files uploaded by superusers
55-
closest_embeddings_query = (
56-
Embeddings.objects.filter(
57-
Q(upload_file__uploaded_by=user) | Q(upload_file__uploaded_by__is_superuser=True)
58-
)
59-
.annotate(
60-
distance=L2Distance("embedding_sentence_transformers", embedding_message)
61-
)
62-
.order_by("distance")
41+
queryset = Embeddings.objects.filter(
42+
Q(upload_file__uploaded_by=user) | Q(upload_file__uploaded_by__is_superuser=True)
6343
)
6444
else:
6545
# Unauthenticated users only see superuser-uploaded files
66-
closest_embeddings_query = (
67-
Embeddings.objects.filter(upload_file__uploaded_by__is_superuser=True)
68-
.annotate(
69-
distance=L2Distance("embedding_sentence_transformers", embedding_message)
70-
)
71-
.order_by("distance")
72-
)
46+
queryset = Embeddings.objects.filter(upload_file__uploaded_by__is_superuser=True)
47+
48+
queryset = (
49+
queryset
50+
.annotate(distance=L2Distance("embedding_sentence_transformers", embedding_vector))
51+
.order_by("distance")
52+
)
7353

7454
# Filtering to a document GUID takes precedence over a document name
7555
if guid:
76-
closest_embeddings_query = closest_embeddings_query.filter(
77-
upload_file__guid=guid
78-
)
56+
queryset = queryset.filter(upload_file__guid=guid)
7957
elif document_name:
80-
closest_embeddings_query = closest_embeddings_query.filter(name=document_name)
58+
queryset = queryset.filter(name=document_name)
8159

8260
# Slicing is equivalent to SQL's LIMIT clause
83-
closest_embeddings_query = closest_embeddings_query[:num_results]
61+
return queryset[:num_results]
62+
63+
64+
def evaluate_query(queryset):
65+
"""
66+
Evaluate a QuerySet and return a list of result dicts.
67+
68+
Parameters
69+
----------
70+
queryset : iterable
71+
Iterable of Embeddings objects (or any objects with the expected attributes)
8472
73+
Returns
74+
-------
75+
list[dict]
76+
List of dicts with keys: name, text, page_number, chunk_number, distance, file_id
77+
"""
8578
# Iterating evaluates the QuerySet and hits the database
8679
# TODO: Research improving the query evaluation performance
87-
results = [
80+
return [
8881
{
8982
"name": obj.name,
9083
"text": obj.text,
@@ -93,13 +86,36 @@ def get_closest_embeddings(
9386
"distance": obj.distance,
9487
"file_id": obj.upload_file.guid if obj.upload_file else None,
9588
}
96-
for obj in closest_embeddings_query
89+
for obj in queryset
9790
]
9891

99-
db_query_time = time.time() - db_query_start
10092

93+
def log_usage(
94+
results, message_data, user, guid, document_name, num_results, encoding_time, db_query_time
95+
):
96+
"""
97+
Create a SemanticSearchUsage record. Swallows exceptions so search isn't interrupted.
98+
99+
Parameters
100+
----------
101+
results : list[dict]
102+
The search results, each containing a "distance" key
103+
message_data : str
104+
The original search query text
105+
user : User
106+
The user who performed the search
107+
guid : str or None
108+
Document GUID filter used in the search
109+
document_name : str or None
110+
Document name filter used in the search
111+
num_results : int
112+
Number of results requested
113+
encoding_time : float
114+
Time in seconds to encode the query
115+
db_query_time : float
116+
Time in seconds for the database query
117+
"""
101118
try:
102-
# Handle user having no uploaded docs or doc filtering returning no matches
103119
if results:
104120
distances = [r["distance"] for r in results]
105121
SemanticSearchUsage.objects.create(
@@ -113,11 +129,10 @@ def get_closest_embeddings(
113129
num_results_returned=len(results),
114130
max_distance=max(distances),
115131
median_distance=median(distances),
116-
min_distance=min(distances)
132+
min_distance=min(distances),
117133
)
118134
else:
119135
logger.warning("Semantic search returned no results")
120-
121136
SemanticSearchUsage.objects.create(
122137
query_text=message_data,
123138
user=user if (user and user.is_authenticated) else None,
@@ -129,9 +144,58 @@ def get_closest_embeddings(
129144
num_results_returned=0,
130145
max_distance=None,
131146
median_distance=None,
132-
min_distance=None
147+
min_distance=None,
133148
)
134-
except Exception as e:
135-
logger.error(f"Failed to create semantic search usage database record: {e}")
149+
except Exception:
150+
logger.exception("Failed to create semantic search usage database record")
151+
152+
153+
def get_closest_embeddings(
154+
user, message_data, document_name=None, guid=None, num_results=10
155+
):
156+
"""
157+
Find the closest embeddings to a given message for a specific user.
158+
159+
Parameters
160+
----------
161+
user : User
162+
The user whose uploaded documents will be searched
163+
message_data : str
164+
The input message to find similar embeddings for
165+
document_name : str, optional
166+
Filter results to a specific document name
167+
guid : str, optional
168+
Filter results to a specific document GUID (takes precedence over document_name)
169+
num_results : int, default 10
170+
Maximum number of results to return
171+
172+
Returns
173+
-------
174+
list[dict]
175+
List of dictionaries containing embedding results with keys:
176+
- name: document name
177+
- text: embedded text content
178+
- page_number: page number in source document
179+
- chunk_number: chunk number within the document
180+
- distance: L2 distance from query embedding
181+
- file_id: GUID of the source file
182+
183+
Notes
184+
-----
185+
Creates a SemanticSearchUsage record. Swallows exceptions so search isn't interrupted.
186+
"""
187+
encoding_start = time.time()
188+
model = TransformerModel.get_instance().model
189+
embedding_vector = model.encode(message_data)
190+
encoding_time = time.time() - encoding_start
191+
192+
db_query_start = time.time()
193+
queryset = build_query(user, embedding_vector, document_name, guid, num_results)
194+
results = evaluate_query(queryset)
195+
db_query_time = time.time() - db_query_start
196+
197+
log_usage(
198+
results, message_data, user, guid, document_name, num_results, encoding_time, db_query_time
199+
)
136200

137201
return results

0 commit comments

Comments
 (0)