Skip to content

Commit 03ede4a

Browse files
committed
fix openai_services. add endpoint to use openAI to extract rules
1 parent 9231e22 commit 03ede4a

3 files changed

Lines changed: 134 additions & 62 deletions

File tree

server/api/services/openai_services.py

Lines changed: 60 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -6,49 +6,64 @@
66
class openAIServices:
77
@staticmethod
88
def openAI(userMessage, prompt, model=None, temp=None, stream=False, raw_stream=False):
9-
# Initialize the OpenAI client
10-
try:
11-
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
12-
13-
if model is None:
14-
model = "gpt-4o-mini"
15-
if temp is None:
16-
temp = 0.2
17-
18-
if stream:
19-
20-
request_params = {
21-
"model": model,
22-
"temperature": temp,
23-
"messages": [
24-
{"role": "system", "content": prompt},
25-
{"role": "user", "content": userMessage}
26-
],
27-
"stream": stream
28-
}
29-
response = client.chat.completions.create(**request_params)
30-
31-
for chunk in response:
32-
if raw_stream:
33-
# Return the entire chunk as JSON
34-
yield json.dumps(chunk.model_dump())
35-
else:
36-
# Extract only the content from the delta
37-
if chunk.choices and len(chunk.choices) > 0:
38-
delta = chunk.choices[0].delta
39-
if hasattr(delta, 'content') and delta.content:
40-
yield delta.content
9+
if stream:
10+
return openAIServices._openAI_streaming(userMessage, prompt, model, temp, raw_stream)
11+
else:
12+
return openAIServices._openAI_non_streaming(userMessage, prompt, model, temp)
13+
14+
@staticmethod
15+
def _openAI_non_streaming(userMessage, prompt, model=None, temp=None):
16+
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
17+
18+
if model is None:
19+
model = "gpt-4o-mini"
20+
if temp is None:
21+
temp = 0.2
22+
23+
request_params = {
24+
"model": model,
25+
"temperature": temp,
26+
"messages": [
27+
{"role": "system", "content": prompt},
28+
{"role": "user", "content": userMessage}
29+
],
30+
}
31+
32+
response = client.chat.completions.create(**request_params)
33+
message_content = response.choices[0].message.content
34+
print("OpenAI response content:", repr(message_content))
35+
36+
if not message_content:
37+
raise ValueError("LLM returned empty content")
38+
39+
return message_content
40+
41+
@staticmethod
42+
def _openAI_streaming(userMessage, prompt, model=None, temp=None, raw_stream=False):
43+
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
44+
45+
if model is None:
46+
model = "gpt-4o-mini"
47+
if temp is None:
48+
temp = 0.2
49+
50+
request_params = {
51+
"model": model,
52+
"temperature": temp,
53+
"messages": [
54+
{"role": "system", "content": prompt},
55+
{"role": "user", "content": userMessage}
56+
],
57+
"stream": True
58+
}
59+
60+
response = client.chat.completions.create(**request_params)
61+
62+
for chunk in response:
63+
if raw_stream:
64+
yield json.dumps(chunk.model_dump())
4165
else:
42-
request_params = {
43-
"model": model,
44-
"temperature": temp,
45-
"messages": [
46-
{"role": "system", "content": prompt},
47-
{"role": "user", "content": userMessage}
48-
],
49-
}
50-
response = client.chat.completions.create(**request_params)
51-
return response.choices[0].message.content
52-
except Exception as e:
53-
print(f"Error: {e}")
54-
raise
66+
if chunk.choices and len(chunk.choices) > 0:
67+
delta = chunk.choices[0].delta
68+
if hasattr(delta, 'content') and delta.content:
69+
yield delta.content
Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from django.urls import path
2-
from .views import RuleExtractionAPIView
2+
from .views import RuleExtractionAPIView, RuleExtractionAPIOpenAIView
33

44

55
urlpatterns = [
66

77
path('v1/api/rule_extraction', RuleExtractionAPIView.as_view(),
8-
name='rule_extraction')
8+
name='rule_extraction'),
9+
path('v1/api/rule_extraction_openai', RuleExtractionAPIOpenAIView.as_view(),
10+
name='rule_extraction_openai')
911
]

server/api/views/text_extraction/views.py

Lines changed: 70 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
2+
from ...services.openai_services import openAIServices
33
from rest_framework.views import APIView
44
from rest_framework.permissions import IsAuthenticated
55
from rest_framework.response import Response
@@ -12,7 +12,7 @@
1212

1313

1414
# TODO: Add docstrings and type hints
15-
def anthropic_citations(client, content_chunks, user_prompt):
15+
def anthropic_citations(client, content_chunks, user_prompt):
1616
"""
1717
"""
1818

@@ -31,7 +31,7 @@ def anthropic_citations(client, content_chunks, user_prompt):
3131
},
3232
"citations": {"enabled": True}
3333
},
34-
34+
3535
{
3636
"type": "text",
3737
"text": user_prompt
@@ -41,16 +41,17 @@ def anthropic_citations(client, content_chunks, user_prompt):
4141
],
4242
)
4343

44-
4544
# Response Structure: https://docs.anthropic.com/en/docs/build-with-claude/citations#response-structure
46-
45+
4746
text = []
4847
cited_text = []
4948
for content in message.to_dict()['content']:
5049
text.append(content['text'])
5150
if 'citations' in content.keys():
52-
text.append(" ".join([f"<{citation['start_block_index']} - {citation['end_block_index']}>" for citation in content['citations']]))
53-
cited_text.append(" ".join([f"<{citation['start_block_index']} - {citation['end_block_index']}> {citation['cited_text']}" for citation in content['citations']]))
51+
text.append(" ".join(
52+
[f"<{citation['start_block_index']} - {citation['end_block_index']}>" for citation in content['citations']]))
53+
cited_text.append(" ".join(
54+
[f"<{citation['start_block_index']} - {citation['end_block_index']}> {citation['cited_text']}" for citation in content['citations']]))
5455

5556
texts = " ".join(text)
5657
cited_texts = " ".join(cited_text)
@@ -66,22 +67,23 @@ class RuleExtractionAPIView(APIView):
6667
def get(self, request):
6768
try:
6869

69-
client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY"))
70-
70+
client = anthropic.Anthropic(
71+
api_key=os.getenv("ANTHROPIC_API_KEY"))
72+
7173
user_prompt = """
7274
I'm creating a system to analyze medical research. It processes peer-reviewed papers to extract key details
7375
7476
Act as a seasoned physician or medical professional who treat patients with bipolar disorder
7577
76-
Identify rules for medication inclusion or exclusion based on medical history or concerns
78+
Identify rules for medication inclusion or exclusion based on medical history or concerns
7779
7880
Return an output with the same structure as these examples:
7981
80-
The rule is history of suicide attempts. The type of rule is "INCLUDE". The reason is lithium is the
82+
The rule is history of suicide attempts. The type of rule is "INCLUDE". The reason is lithium is the
8183
only medication on the market that has been proven to reduce suicidality in patients with bipolar disorder.
8284
The medications for this rule are lithium.
8385
84-
The rule is weight gain concerns. The type of rule is "EXCLUDE". The reason is Seroquel, Risperdal, Abilify, and
86+
The rule is weight gain concerns. The type of rule is "EXCLUDE". The reason is Seroquel, Risperdal, Abilify, and
8587
Zyprexa are known for causing weight gain. The medications for this rule are Quetiapine, Aripiprazole, Olanzapine, Risperidone
8688
}
8789
"""
@@ -92,10 +94,63 @@ def get(self, request):
9294

9395
chunks = [{"type": "text", "text": chunk.text} for chunk in query]
9496

95-
texts, cited_texts = anthropic_citations(client, chunks, user_prompt)
96-
97+
texts, cited_texts = anthropic_citations(
98+
client, chunks, user_prompt)
9799

98100
return Response({"texts": texts, "cited_texts": cited_texts}, status=status.HTTP_200_OK)
99101

100102
except Exception as e:
101-
return Response({"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
103+
return Response({"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
104+
105+
106+
# This is to use openai to extract the rules to save cost
107+
108+
def openai_extraction(content_chunks, user_prompt):
109+
"""
110+
Prepares the OpenAI input and returns the extracted text.
111+
"""
112+
113+
combined_text = "\n\n".join(chunk['text'] for chunk in content_chunks)
114+
115+
result = openAIServices.openAI(
116+
userMessage=combined_text,
117+
prompt=user_prompt,
118+
model="gpt-4o-mini",
119+
temp=0.0,
120+
stream=False
121+
)
122+
return result
123+
124+
125+
@method_decorator(csrf_exempt, name='dispatch')
126+
class RuleExtractionAPIOpenAIView(APIView):
127+
permission_classes = [IsAuthenticated]
128+
129+
def get(self, request):
130+
try:
131+
user_prompt = """
132+
You're analyzing medical text from multiple sources. Each chunk is labeled [chunk-X].
133+
134+
Act as a seasoned physician or medical professional who treats patients with bipolar disorder.
135+
136+
Identify rules for medication inclusion or exclusion based on medical history or concerns.
137+
138+
Return each rule with this exact structure:
139+
The rule is __. The type of rule is "__". The reason is __. The medications for this rule are __. Source: [chunk-X]
140+
141+
Only use chunks provided. If no rule is found in a chunk, skip it.
142+
"""
143+
144+
guid = request.query_params.get('guid')
145+
query = Embeddings.objects.filter(upload_file__guid=guid)
146+
chunks = [
147+
{"type": "text", "text": f"[chunk-{i}] {chunk.text}"}
148+
for i, chunk in enumerate(query)
149+
]
150+
151+
output_text = openai_extraction(chunks, user_prompt)
152+
153+
return Response({"text": output_text}, status=status.HTTP_200_OK)
154+
155+
except Exception as e:
156+
return Response({"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)

0 commit comments

Comments
 (0)