|
1 | | -import os |
| 1 | +from langgraph.graph import StateGraph, START, END |
| 2 | +from langchain_core.messages import AIMessage, HumanMessage |
2 | 3 |
|
| 4 | +from agent.state import OverallState |
3 | 5 | from agent.tools_and_schemas import SearchQueryList, Reflection |
4 | | -from dotenv import load_dotenv |
5 | | -from langchain_core.messages import AIMessage |
6 | | -from langgraph.types import Send |
7 | | -from langgraph.graph import StateGraph |
8 | | -from langgraph.graph import START, END |
9 | | -from langchain_core.runnables import RunnableConfig |
10 | | -from google.genai import Client |
| 6 | +from agent.llm.groq import GroqLLM |
11 | 7 |
|
12 | | -from agent.state import ( |
13 | | - OverallState, |
14 | | - QueryGenerationState, |
15 | | - ReflectionState, |
16 | | - WebSearchState, |
17 | | -) |
18 | | -from agent.configuration import Configuration |
19 | | -from agent.prompts import ( |
20 | | - get_current_date, |
21 | | - query_writer_instructions, |
22 | | - web_searcher_instructions, |
23 | | - reflection_instructions, |
24 | | - answer_instructions, |
25 | | -) |
26 | | -from langchain_google_genai import ChatGoogleGenerativeAI |
27 | | -from agent.utils import ( |
28 | | - get_citations, |
29 | | - get_research_topic, |
30 | | - insert_citation_markers, |
31 | | - resolve_urls, |
32 | | -) |
| 8 | +llm = GroqLLM() |
33 | 9 |
|
34 | | -load_dotenv() |
35 | 10 |
|
36 | | -if os.getenv("GEMINI_API_KEY") is None: |
37 | | - raise ValueError("GEMINI_API_KEY is not set") |
| 11 | +# ---------- NODES ---------- |
38 | 12 |
|
39 | | -# Used for Google Search API |
40 | | -genai_client = Client(api_key=os.getenv("GEMINI_API_KEY")) |
| 13 | +def generate_search_queries(state: OverallState): |
| 14 | + question = state["messages"][-1].content |
41 | 15 |
|
| 16 | + prompt = f"Generate search queries for: {question}" |
| 17 | + queries = llm.invoke_structured(prompt, SearchQueryList) |
42 | 18 |
|
43 | | -# Nodes |
44 | | -def generate_query(state: OverallState, config: RunnableConfig) -> QueryGenerationState: |
45 | | - """LangGraph node that generates search queries based on the User's question. |
46 | | -
|
47 | | - Uses Gemini 2.0 Flash to create an optimized search queries for web research based on |
48 | | - the User's question. |
49 | | -
|
50 | | - Args: |
51 | | - state: Current graph state containing the User's question |
52 | | - config: Configuration for the runnable, including LLM provider settings |
53 | | -
|
54 | | - Returns: |
55 | | - Dictionary with state update, including search_query key containing the generated queries |
56 | | - """ |
57 | | - configurable = Configuration.from_runnable_config(config) |
58 | | - |
59 | | - # check for custom initial search query count |
60 | | - if state.get("initial_search_query_count") is None: |
61 | | - state["initial_search_query_count"] = configurable.number_of_initial_queries |
62 | | - |
63 | | - # init Gemini 2.0 Flash |
64 | | - llm = ChatGoogleGenerativeAI( |
65 | | - model=configurable.query_generator_model, |
66 | | - temperature=1.0, |
67 | | - max_retries=2, |
68 | | - api_key=os.getenv("GEMINI_API_KEY"), |
69 | | - ) |
70 | | - structured_llm = llm.with_structured_output(SearchQueryList) |
71 | | - |
72 | | - # Format the prompt |
73 | | - current_date = get_current_date() |
74 | | - formatted_prompt = query_writer_instructions.format( |
75 | | - current_date=current_date, |
76 | | - research_topic=get_research_topic(state["messages"]), |
77 | | - number_queries=state["initial_search_query_count"], |
78 | | - ) |
79 | | - # Generate the search queries |
80 | | - result = structured_llm.invoke(formatted_prompt) |
81 | | - return {"search_query": result.query} |
82 | | - |
83 | | - |
84 | | -def continue_to_web_research(state: QueryGenerationState): |
85 | | - """LangGraph node that sends the search queries to the web research node. |
86 | | -
|
87 | | - This is used to spawn n number of web research nodes, one for each search query. |
88 | | - """ |
89 | | - return [ |
90 | | - Send("web_research", {"search_query": search_query, "id": int(idx)}) |
91 | | - for idx, search_query in enumerate(state["search_query"]) |
92 | | - ] |
| 19 | + return { |
| 20 | + "search_query": queries.query |
| 21 | + } |
93 | 22 |
|
94 | 23 |
|
95 | | -def web_research(state: WebSearchState, config: RunnableConfig) -> OverallState: |
96 | | - """LangGraph node that performs web research using the native Google Search API tool. |
| 24 | +def web_research(state: OverallState): |
| 25 | + # ❗ тимчасовий fake research |
| 26 | + results = [f"Result for: {q}" for q in state["search_query"]] |
97 | 27 |
|
98 | | - Executes a web search using the native Google Search API tool in combination with Gemini 2.0 Flash. |
| 28 | + return { |
| 29 | + "web_research_result": results |
| 30 | + } |
99 | 31 |
|
100 | | - Args: |
101 | | - state: Current graph state containing the search query and research loop count |
102 | | - config: Configuration for the runnable, including search API settings |
103 | 32 |
|
104 | | - Returns: |
105 | | - Dictionary with state update, including sources_gathered, research_loop_count, and web_research_results |
106 | | - """ |
107 | | - # Configure |
108 | | - configurable = Configuration.from_runnable_config(config) |
109 | | - formatted_prompt = web_searcher_instructions.format( |
110 | | - current_date=get_current_date(), |
111 | | - research_topic=state["search_query"], |
| 33 | +def reflect(state: OverallState): |
| 34 | + prompt = ( |
| 35 | + f"Question: {state['messages'][-1].content}\n\n" |
| 36 | + f"Research results:\n" + "\n".join(state["web_research_result"]) |
112 | 37 | ) |
113 | 38 |
|
114 | | - # Uses the google genai client as the langchain client doesn't return grounding metadata |
115 | | - response = genai_client.models.generate_content( |
116 | | - model=configurable.query_generator_model, |
117 | | - contents=formatted_prompt, |
118 | | - config={ |
119 | | - "tools": [{"google_search": {}}], |
120 | | - "temperature": 0, |
121 | | - }, |
122 | | - ) |
123 | | - # resolve the urls to short urls for saving tokens and time |
124 | | - resolved_urls = resolve_urls( |
125 | | - response.candidates[0].grounding_metadata.grounding_chunks, state["id"] |
126 | | - ) |
127 | | - # Gets the citations and adds them to the generated text |
128 | | - citations = get_citations(response, resolved_urls) |
129 | | - modified_text = insert_citation_markers(response.text, citations) |
130 | | - sources_gathered = [item for citation in citations for item in citation["segments"]] |
| 39 | + reflection = llm.invoke_structured(prompt, Reflection) |
131 | 40 |
|
132 | 41 | return { |
133 | | - "sources_gathered": sources_gathered, |
134 | | - "search_query": [state["search_query"]], |
135 | | - "web_research_result": [modified_text], |
| 42 | + "is_sufficient": reflection.is_sufficient, |
| 43 | + "knowledge_gap": reflection.knowledge_gap, |
| 44 | + "follow_up_queries": reflection.follow_up_queries, |
136 | 45 | } |
137 | 46 |
|
138 | 47 |
|
139 | | -def reflection(state: OverallState, config: RunnableConfig) -> ReflectionState: |
140 | | - """LangGraph node that identifies knowledge gaps and generates potential follow-up queries. |
141 | | -
|
142 | | - Analyzes the current summary to identify areas for further research and generates |
143 | | - potential follow-up queries. Uses structured output to extract |
144 | | - the follow-up query in JSON format. |
145 | | -
|
146 | | - Args: |
147 | | - state: Current graph state containing the running summary and research topic |
148 | | - config: Configuration for the runnable, including LLM provider settings |
149 | | -
|
150 | | - Returns: |
151 | | - Dictionary with state update, including search_query key containing the generated follow-up query |
152 | | - """ |
153 | | - configurable = Configuration.from_runnable_config(config) |
154 | | - # Increment the research loop count and get the reasoning model |
155 | | - state["research_loop_count"] = state.get("research_loop_count", 0) + 1 |
156 | | - reasoning_model = state.get("reasoning_model", configurable.reflection_model) |
157 | | - |
158 | | - # Format the prompt |
159 | | - current_date = get_current_date() |
160 | | - formatted_prompt = reflection_instructions.format( |
161 | | - current_date=current_date, |
162 | | - research_topic=get_research_topic(state["messages"]), |
163 | | - summaries="\n\n---\n\n".join(state["web_research_result"]), |
| 48 | +def finalize(state: OverallState): |
| 49 | + prompt = ( |
| 50 | + f"Answer the question:\n" |
| 51 | + f"{state['messages'][-1].content}\n\n" |
| 52 | + f"Using:\n" + "\n".join(state["web_research_result"]) |
164 | 53 | ) |
165 | | - # init Reasoning Model |
166 | | - llm = ChatGoogleGenerativeAI( |
167 | | - model=reasoning_model, |
168 | | - temperature=1.0, |
169 | | - max_retries=2, |
170 | | - api_key=os.getenv("GEMINI_API_KEY"), |
171 | | - ) |
172 | | - result = llm.with_structured_output(Reflection).invoke(formatted_prompt) |
| 54 | + |
| 55 | + answer = llm.invoke(prompt) |
173 | 56 |
|
174 | 57 | return { |
175 | | - "is_sufficient": result.is_sufficient, |
176 | | - "knowledge_gap": result.knowledge_gap, |
177 | | - "follow_up_queries": result.follow_up_queries, |
178 | | - "research_loop_count": state["research_loop_count"], |
179 | | - "number_of_ran_queries": len(state["search_query"]), |
| 58 | + "messages": [AIMessage(content=answer)] |
180 | 59 | } |
181 | 60 |
|
182 | 61 |
|
183 | | -def evaluate_research( |
184 | | - state: ReflectionState, |
185 | | - config: RunnableConfig, |
186 | | -) -> OverallState: |
187 | | - """LangGraph routing function that determines the next step in the research flow. |
188 | | -
|
189 | | - Controls the research loop by deciding whether to continue gathering information |
190 | | - or to finalize the summary based on the configured maximum number of research loops. |
191 | | -
|
192 | | - Args: |
193 | | - state: Current graph state containing the research loop count |
194 | | - config: Configuration for the runnable, including max_research_loops setting |
195 | | -
|
196 | | - Returns: |
197 | | - String literal indicating the next node to visit ("web_research" or "finalize_summary") |
198 | | - """ |
199 | | - configurable = Configuration.from_runnable_config(config) |
200 | | - max_research_loops = ( |
201 | | - state.get("max_research_loops") |
202 | | - if state.get("max_research_loops") is not None |
203 | | - else configurable.max_research_loops |
204 | | - ) |
205 | | - if state["is_sufficient"] or state["research_loop_count"] >= max_research_loops: |
206 | | - return "finalize_answer" |
207 | | - else: |
208 | | - return [ |
209 | | - Send( |
210 | | - "web_research", |
211 | | - { |
212 | | - "search_query": follow_up_query, |
213 | | - "id": state["number_of_ran_queries"] + int(idx), |
214 | | - }, |
215 | | - ) |
216 | | - for idx, follow_up_query in enumerate(state["follow_up_queries"]) |
217 | | - ] |
218 | | - |
219 | | - |
220 | | -def finalize_answer(state: OverallState, config: RunnableConfig): |
221 | | - """LangGraph node that finalizes the research summary. |
| 62 | +# ---------- GRAPH ---------- |
222 | 63 |
|
223 | | - Prepares the final output by deduplicating and formatting sources, then |
224 | | - combining them with the running summary to create a well-structured |
225 | | - research report with proper citations. |
| 64 | +def build_graph(): |
| 65 | + graph = StateGraph(OverallState) |
226 | 66 |
|
227 | | - Args: |
228 | | - state: Current graph state containing the running summary and sources gathered |
| 67 | + graph.add_node("generate_queries", generate_search_queries) |
| 68 | + graph.add_node("research", web_research) |
| 69 | + graph.add_node("reflect", reflect) |
| 70 | + graph.add_node("finalize", finalize) |
229 | 71 |
|
230 | | - Returns: |
231 | | - Dictionary with state update, including running_summary key containing the formatted final summary with sources |
232 | | - """ |
233 | | - configurable = Configuration.from_runnable_config(config) |
234 | | - reasoning_model = state.get("reasoning_model") or configurable.answer_model |
| 72 | + graph.add_edge(START, "generate_queries") |
| 73 | + graph.add_edge("generate_queries", "research") |
| 74 | + graph.add_edge("research", "reflect") |
235 | 75 |
|
236 | | - # Format the prompt |
237 | | - current_date = get_current_date() |
238 | | - formatted_prompt = answer_instructions.format( |
239 | | - current_date=current_date, |
240 | | - research_topic=get_research_topic(state["messages"]), |
241 | | - summaries="\n---\n\n".join(state["web_research_result"]), |
| 76 | + graph.add_conditional_edges( |
| 77 | + "reflect", |
| 78 | + lambda state: "finalize" |
242 | 79 | ) |
243 | 80 |
|
244 | | - # init Reasoning Model, default to Gemini 2.5 Flash |
245 | | - llm = ChatGoogleGenerativeAI( |
246 | | - model=reasoning_model, |
247 | | - temperature=0, |
248 | | - max_retries=2, |
249 | | - api_key=os.getenv("GEMINI_API_KEY"), |
250 | | - ) |
251 | | - result = llm.invoke(formatted_prompt) |
252 | | - |
253 | | - # Replace the short urls with the original urls and add all used urls to the sources_gathered |
254 | | - unique_sources = [] |
255 | | - for source in state["sources_gathered"]: |
256 | | - if source["short_url"] in result.content: |
257 | | - result.content = result.content.replace( |
258 | | - source["short_url"], source["value"] |
259 | | - ) |
260 | | - unique_sources.append(source) |
261 | | - |
262 | | - return { |
263 | | - "messages": [AIMessage(content=result.content)], |
264 | | - "sources_gathered": unique_sources, |
265 | | - } |
266 | | - |
267 | | - |
268 | | -# Create our Agent Graph |
269 | | -builder = StateGraph(OverallState, config_schema=Configuration) |
| 81 | + graph.add_edge("finalize", END) |
270 | 82 |
|
271 | | -# Define the nodes we will cycle between |
272 | | -builder.add_node("generate_query", generate_query) |
273 | | -builder.add_node("web_research", web_research) |
274 | | -builder.add_node("reflection", reflection) |
275 | | -builder.add_node("finalize_answer", finalize_answer) |
| 83 | + return graph.compile() |
276 | 84 |
|
277 | | -# Set the entrypoint as `generate_query` |
278 | | -# This means that this node is the first one called |
279 | | -builder.add_edge(START, "generate_query") |
280 | | -# Add conditional edge to continue with search queries in a parallel branch |
281 | | -builder.add_conditional_edges( |
282 | | - "generate_query", continue_to_web_research, ["web_research"] |
283 | | -) |
284 | | -# Reflect on the web research |
285 | | -builder.add_edge("web_research", "reflection") |
286 | | -# Evaluate the research |
287 | | -builder.add_conditional_edges( |
288 | | - "reflection", evaluate_research, ["web_research", "finalize_answer"] |
289 | | -) |
290 | | -# Finalize the answer |
291 | | -builder.add_edge("finalize_answer", END) |
292 | 85 |
|
293 | | -graph = builder.compile(name="pro-search-agent") |
| 86 | +graph = build_graph() |
0 commit comments