Skip to content

Commit cd29a04

Browse files
committed
poll_for_result
1 parent f17ac98 commit cd29a04

2 files changed

Lines changed: 213 additions & 388 deletions

File tree

src/databricks_ai_bridge/genie.py

Lines changed: 94 additions & 245 deletions
Original file line numberDiff line numberDiff line change
@@ -150,26 +150,55 @@ def is_too_big(n):
150150
return truncated_result
151151

152152

153-
def _end_current_span(client, parent_trace_id, current_span, final_state, error=None):
154-
"""helper function to safely end a span with exception handling."""
153+
def _parse_genie_mcp_response(
154+
mcp_result, truncate_results: bool, return_pandas: bool, conversation_id: Optional[str] = None
155+
) -> GenieResponse:
156+
if not mcp_result.content or len(mcp_result.content) == 0:
157+
return GenieResponse(
158+
result="No content returned from Genie",
159+
conversation_id=conversation_id,
160+
)
155161

156-
if current_span is None:
157-
return None
162+
# Genie backend always returns 1 content block with JSON
163+
content_block = mcp_result.content[0]
164+
content_text = content_block.text if hasattr(content_block, "text") else "{}"
158165

159166
try:
160-
attributes = {"final_state": final_state}
161-
if error is not None:
162-
attributes["error"] = error
163-
164-
client.end_span(
165-
trace_id=parent_trace_id,
166-
span_id=current_span.span_id,
167-
attributes=attributes,
167+
genie_response = json.loads(content_text)
168+
except json.JSONDecodeError:
169+
return GenieResponse(
170+
result=f"Failed to parse response: {content_text}",
171+
conversation_id=conversation_id,
168172
)
169-
except mlflow.exceptions.MlflowTracingException as e:
170-
logging.warning(f"Failed to end span for {final_state}: {e}")
171173

172-
return None
174+
content = genie_response.get("content", "")
175+
conv_id = genie_response.get("conversationId", conversation_id)
176+
query_str = ""
177+
description = ""
178+
179+
try:
180+
content_data = json.loads(content)
181+
query_str = content_data.get("query", "")
182+
description = content_data.get("description", "")
183+
statement_response = content_data.get("statement_response")
184+
185+
if (
186+
statement_response
187+
and statement_response.get("status", {}).get("state") == "SUCCEEDED"
188+
):
189+
result = _parse_query_result(statement_response, truncate_results, return_pandas)
190+
else:
191+
result = content
192+
193+
except (json.JSONDecodeError, KeyError, TypeError, AttributeError):
194+
result = content
195+
196+
return GenieResponse(
197+
result=result,
198+
query=query_str,
199+
description=description,
200+
conversation_id=conv_id,
201+
)
173202

174203

175204
class Genie:
@@ -187,7 +216,16 @@ def __init__(
187216

188217
server_url = f"{workspace_client.config.host}/api/2.0/mcp/genie/{space_id}"
189218
self._mcp_client = DatabricksMCPClient(server_url, workspace_client)
190-
self._tool_name = f"query_space_{space_id}"
219+
220+
tools = self._mcp_client.list_tools()
221+
if not tools:
222+
raise ValueError(f"No tools found in Genie MCP server for space {space_id}")
223+
224+
query_tools = [tool for tool in tools if "query" in tool.name.lower()]
225+
poll_tools = [tool for tool in tools if "poll" in tool.name.lower()]
226+
227+
self._query_tool_name = query_tools[0].name if query_tools else None
228+
self._poll_tool_name = poll_tools[0].name if poll_tools else None
191229

192230
self.headers = {
193231
"Accept": "application/json",
@@ -198,18 +236,6 @@ def __init__(
198236

199237
@mlflow.trace()
200238
def start_conversation(self, content):
201-
"""Start a conversation with the Genie space.
202-
203-
.. deprecated::
204-
This method is deprecated and will be removed in a future release.
205-
Use :meth:`ask_question` instead, which uses the MCP protocol.
206-
"""
207-
warnings.warn(
208-
"start_conversation() is deprecated and will be removed in a future release. "
209-
"Use ask_question() instead.",
210-
DeprecationWarning,
211-
stacklevel=2,
212-
)
213239
resp = self.genie._api.do(
214240
"POST",
215241
f"/api/2.0/genie/spaces/{self.space_id}/start-conversation",
@@ -220,18 +246,6 @@ def start_conversation(self, content):
220246

221247
@mlflow.trace()
222248
def create_message(self, conversation_id, content):
223-
"""Create a message in an existing conversation.
224-
225-
.. deprecated::
226-
This method is deprecated and will be removed in a future release.
227-
Use :meth:`ask_question` instead, which uses MCP protocol.
228-
"""
229-
warnings.warn(
230-
"create_message() is deprecated and will be removed in a future release. "
231-
"Use ask_question(question, conversation_id=...) instead, which uses MCP protocol.",
232-
DeprecationWarning,
233-
stacklevel=2,
234-
)
235249
resp = self.genie._api.do(
236250
"POST",
237251
f"/api/2.0/genie/spaces/{self.space_id}/conversations/{conversation_id}/messages",
@@ -242,221 +256,56 @@ def create_message(self, conversation_id, content):
242256

243257
@mlflow.trace()
244258
def poll_for_result(self, conversation_id, message_id):
245-
"""Poll for the result of a Genie query.
246-
247-
.. deprecated::
248-
This method is deprecated and will be removed in a future release.
249-
Use :meth:`ask_question` instead, which uses MCP protocol.
250-
"""
251-
warnings.warn(
252-
"poll_for_result() is deprecated and will be removed in a future release. "
253-
"Use ask_question() instead, which uses MCP protocol.",
254-
DeprecationWarning,
255-
stacklevel=2,
256-
)
257-
258-
@mlflow.trace()
259-
def poll_query_results(
260-
attachment_id, query_str, description, conversation_id=conversation_id
261-
):
262-
iteration_count = 0
263-
while iteration_count < MAX_ITERATIONS:
264-
iteration_count += 1
265-
resp = self.genie._api.do(
266-
"GET",
267-
f"/api/2.0/genie/spaces/{self.space_id}/conversations/{conversation_id}/messages/{message_id}/attachments/{attachment_id}/query-result",
268-
headers=self.headers,
269-
)["statement_response"]
270-
state = resp["status"]["state"]
271-
returned_conversation_id = resp.get("conversation_id", None)
272-
if state == "SUCCEEDED":
273-
result = _parse_query_result(resp, self.truncate_results, self.return_pandas)
274-
return GenieResponse(result, query_str, description, returned_conversation_id)
275-
elif state in ["RUNNING", "PENDING"]:
276-
logging.debug("Waiting for query result...")
277-
time.sleep(ITERATION_FREQUENCY)
278-
else:
279-
return GenieResponse(
280-
f"No query result: {resp['state']}",
281-
query_str,
282-
description,
283-
returned_conversation_id,
284-
)
285-
return GenieResponse(
286-
f"Genie query for result timed out after {MAX_ITERATIONS} iterations of 5 seconds",
287-
query_str,
288-
description,
289-
conversation_id,
259+
if not self._poll_tool_name:
260+
raise ValueError(
261+
f"Poll tool not available for Genie space {self.space_id}. "
262+
f"The MCP server must expose a poll tool to use poll_for_result()."
290263
)
291264

292-
@mlflow.trace()
293-
def poll_result():
294-
iteration_count = 0
295-
296-
# use MLflow client to get parent of any new spans we create from the current active span
297-
# (parenting keeps spans in the same trace)
298-
client = mlflow.tracking.MlflowClient()
299-
with mlflow.start_span(name="genie_timeline", span_type="CHAIN") as parent:
300-
parent_trace_id = parent.trace_id if parent else None
301-
parent_span_id = parent.span_id if parent else None
302-
303-
# Track last status from API and the current child span
304-
last_status = None
305-
current_span = None
306-
307-
while iteration_count < MAX_ITERATIONS:
308-
iteration_count += 1
309-
resp = self.genie._api.do(
310-
"GET",
311-
f"/api/2.0/genie/spaces/{self.space_id}/conversations/{conversation_id}/messages/{message_id}",
312-
headers=self.headers,
313-
)
314-
returned_conversation_id = resp.get("conversation_id", None)
315-
316-
# get current status from API response
317-
current_status = resp["status"]
318-
319-
# On status change: end previous span, start a new one (excluding terminal states)
320-
if current_status != last_status:
321-
# END previous span
322-
current_span = _end_current_span(
323-
client, parent_trace_id, current_span, last_status
324-
)
325-
326-
# START new span for non-terminal states
327-
if current_status not in TERMINAL_STATES:
328-
# START new span
329-
try:
330-
current_span = client.start_span(
331-
name=current_status.lower(),
332-
trace_id=parent_trace_id,
333-
parent_id=parent_span_id,
334-
span_type="CHAIN",
335-
attributes={
336-
"state": current_status,
337-
"conversation_id": conversation_id,
338-
"message_id": message_id,
339-
},
340-
)
341-
except mlflow.exceptions.MlflowTracingException as e:
342-
logging.warning(f"Failed to create span for {current_status}: {e}")
343-
current_span = None
344-
345-
logging.debug(f"Status: {last_status}{current_status}")
346-
last_status = current_status
347-
348-
if current_status == "COMPLETED":
349-
attachment = next((r for r in resp["attachments"] if "query" in r), None)
350-
if attachment:
351-
query_obj = attachment["query"]
352-
description = query_obj.get("description", "")
353-
query_str = query_obj.get("query", "")
354-
attachment_id = attachment["attachment_id"]
355-
return poll_query_results(
356-
attachment_id,
357-
query_str,
358-
description,
359-
returned_conversation_id,
360-
)
361-
if current_status == "COMPLETED":
362-
text_content = next(r for r in resp["attachments"] if "text" in r)[
363-
"text"
364-
]["content"]
365-
return GenieResponse(
366-
result=text_content,
367-
conversation_id=returned_conversation_id,
368-
)
369-
370-
elif current_status in {"CANCELLED", "QUERY_RESULT_EXPIRED"}:
371-
return GenieResponse(result=f"Genie query {current_status.lower()}.")
372-
373-
elif current_status == "FAILED":
374-
return GenieResponse(
375-
result=f"Genie query failed with error: {resp.get('error', 'Unknown error')}"
376-
)
377-
# includes EXECUTING_QUERY, Genie can retry after this status
378-
else:
379-
logging.debug(f"Status: {current_status}")
380-
time.sleep(ITERATION_FREQUENCY) # faster poll rate
381-
382-
# timeout path / end of while loop — close any open spans
383-
current_span = _end_current_span(
384-
client,
385-
parent_trace_id,
386-
current_span,
387-
last_status,
388-
)
389-
return GenieResponse(
390-
f"Genie query timed out after {MAX_ITERATIONS} iterations of {ITERATION_FREQUENCY} seconds (total {MAX_ITERATIONS * ITERATION_FREQUENCY} seconds)",
391-
conversation_id=conversation_id,
392-
)
393-
394-
return poll_result()
265+
iteration_count = 0
266+
while iteration_count < MAX_ITERATIONS:
267+
iteration_count += 1
395268

396-
@mlflow.trace()
397-
def ask_question(self, question, conversation_id: Optional[str] = None):
398-
"""Ask a question to the Genie space using MCP protocol.
269+
args = {"conversation_id": conversation_id, "message_id": message_id}
270+
mcp_result = self._mcp_client.call_tool(self._poll_tool_name, args)
399271

400-
Args:
401-
question: The question to ask the Genie space
402-
conversation_id: Optional conversation ID to continue an existing conversation
272+
try:
273+
if not mcp_result.content or len(mcp_result.content) == 0:
274+
return GenieResponse(
275+
result="No content returned from Genie poll",
276+
conversation_id=conversation_id,
277+
)
403278

404-
Returns:
405-
GenieResponse with result, query, description, and conversation_id
406-
"""
279+
content_block = mcp_result.content[0]
280+
content_text = content_block.text if hasattr(content_block, "text") else "{}"
281+
genie_response = json.loads(content_text)
282+
status = genie_response.get("status", "")
283+
except (json.JSONDecodeError, AttributeError, KeyError):
284+
return _parse_genie_mcp_response(mcp_result, self.truncate_results, self.return_pandas, conversation_id)
407285

408-
args = {"query": question}
409-
if conversation_id:
410-
args["conversation_id"] = conversation_id
286+
if status in ["COMPLETED", "FAILED", "CANCELLED", "QUERY_RESULT_EXPIRED"]:
287+
return _parse_genie_mcp_response(mcp_result, self.truncate_results, self.return_pandas, conversation_id)
411288

412-
mcp_result = self._mcp_client.call_tool(self._tool_name, args)
289+
logging.debug(f"Polling: status={status}, iteration={iteration_count}")
290+
time.sleep(ITERATION_FREQUENCY)
413291

414-
if not mcp_result.content or len(mcp_result.content) == 0:
415-
return GenieResponse(
416-
result="No content returned from Genie",
417-
conversation_id=conversation_id,
418-
)
292+
return GenieResponse(
293+
result=f"Genie query timed out after {MAX_ITERATIONS * ITERATION_FREQUENCY} seconds",
294+
conversation_id=conversation_id,
295+
)
419296

420-
# Genie backend always returns 1 content block with JSON
421-
content_block = mcp_result.content[0]
422-
content_text = content_block.text if hasattr(content_block, "text") else "{}"
423297

424-
try:
425-
genie_response = json.loads(content_text)
426-
except json.JSONDecodeError:
427-
return GenieResponse(
428-
result=f"Failed to parse response: {content_text}",
429-
conversation_id=conversation_id,
298+
@mlflow.trace()
299+
def ask_question(self, question, conversation_id: Optional[str] = None):
300+
if not self._query_tool_name:
301+
raise ValueError(
302+
f"Query tool not available for Genie space {self.space_id}. "
303+
f"The MCP server must expose a query tool to use ask_question()."
430304
)
431305

432-
content = genie_response.get("content", "")
433-
conv_id = genie_response.get("conversationId", conversation_id)
434-
status = genie_response.get("status", "")
435-
query_str = ""
436-
description = ""
437-
438-
try:
439-
content_data = json.loads(content)
440-
query_str = content_data.get("query", "")
441-
description = content_data.get("description", "")
442-
statement_response = content_data.get("statement_response")
443-
444-
if (
445-
statement_response
446-
and statement_response.get("status", {}).get("state") == "SUCCEEDED"
447-
):
448-
result = _parse_query_result(
449-
statement_response, self.truncate_results, self.return_pandas
450-
)
451-
else:
452-
result = content
453-
454-
except (json.JSONDecodeError, KeyError, TypeError, AttributeError):
455-
result = content
306+
args = {"query": question}
307+
if conversation_id:
308+
args["conversation_id"] = conversation_id
456309

457-
return GenieResponse(
458-
result=result,
459-
query=query_str,
460-
description=description,
461-
conversation_id=conv_id,
462-
)
310+
mcp_result = self._mcp_client.call_tool(self._query_tool_name, args)
311+
return _parse_genie_mcp_response(mcp_result, self.truncate_results, self.return_pandas, conversation_id)

0 commit comments

Comments
 (0)