Skip to content

Commit ea74b7a

Browse files
committed
fix tests
1 parent cd29a04 commit ea74b7a

2 files changed

Lines changed: 526 additions & 250 deletions

File tree

src/databricks_ai_bridge/genie.py

Lines changed: 97 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,27 @@ def is_too_big(n):
150150
return truncated_result
151151

152152

153+
def _end_current_span(client, parent_trace_id, current_span, final_state, error=None):
154+
"""Helper function to safely end a span with exception handling."""
155+
if current_span is None:
156+
return None
157+
158+
try:
159+
attributes = {"final_state": final_state}
160+
if error is not None:
161+
attributes["error"] = error
162+
163+
client.end_span(
164+
trace_id=parent_trace_id,
165+
span_id=current_span.span_id,
166+
attributes=attributes,
167+
)
168+
except Exception as e:
169+
logging.warning(f"Failed to end span for {final_state}: {e}")
170+
171+
return None
172+
173+
153174
def _parse_genie_mcp_response(
154175
mcp_result, truncate_results: bool, return_pandas: bool, conversation_id: Optional[str] = None
155176
) -> GenieResponse:
@@ -262,37 +283,84 @@ def poll_for_result(self, conversation_id, message_id):
262283
f"The MCP server must expose a poll tool to use poll_for_result()."
263284
)
264285

265-
iteration_count = 0
266-
while iteration_count < MAX_ITERATIONS:
267-
iteration_count += 1
268-
269-
args = {"conversation_id": conversation_id, "message_id": message_id}
270-
mcp_result = self._mcp_client.call_tool(self._poll_tool_name, args)
271-
272-
try:
273-
if not mcp_result.content or len(mcp_result.content) == 0:
274-
return GenieResponse(
275-
result="No content returned from Genie poll",
276-
conversation_id=conversation_id,
286+
# Use MLflow client for manual span management to track status transitions
287+
client = mlflow.tracking.MlflowClient()
288+
with mlflow.start_span(name="genie_timeline", span_type="CHAIN") as parent:
289+
parent_trace_id = parent.trace_id if parent else None
290+
parent_span_id = parent.span_id if parent else None
291+
292+
# Track last status and current child span
293+
last_status = None
294+
current_span = None
295+
296+
iteration_count = 0
297+
while iteration_count < MAX_ITERATIONS:
298+
iteration_count += 1
299+
300+
args = {"conversation_id": conversation_id, "message_id": message_id}
301+
mcp_result = self._mcp_client.call_tool(self._poll_tool_name, args)
302+
303+
try:
304+
if not mcp_result.content or len(mcp_result.content) == 0:
305+
# End any active span before returning
306+
_end_current_span(client, parent_trace_id, current_span, last_status)
307+
return GenieResponse(
308+
result="No content returned from Genie poll",
309+
conversation_id=conversation_id,
310+
)
311+
312+
content_block = mcp_result.content[0]
313+
content_text = content_block.text if hasattr(content_block, "text") else "{}"
314+
genie_response = json.loads(content_text)
315+
status = genie_response.get("status", "")
316+
except (json.JSONDecodeError, AttributeError, KeyError):
317+
# End any active span before returning
318+
_end_current_span(client, parent_trace_id, current_span, last_status)
319+
return _parse_genie_mcp_response(mcp_result, self.truncate_results, self.return_pandas, conversation_id)
320+
321+
# On status change: end previous span, start new one
322+
if status != last_status:
323+
# END previous span
324+
current_span = _end_current_span(
325+
client, parent_trace_id, current_span, last_status
277326
)
278327

279-
content_block = mcp_result.content[0]
280-
content_text = content_block.text if hasattr(content_block, "text") else "{}"
281-
genie_response = json.loads(content_text)
282-
status = genie_response.get("status", "")
283-
except (json.JSONDecodeError, AttributeError, KeyError):
284-
return _parse_genie_mcp_response(mcp_result, self.truncate_results, self.return_pandas, conversation_id)
285-
286-
if status in ["COMPLETED", "FAILED", "CANCELLED", "QUERY_RESULT_EXPIRED"]:
287-
return _parse_genie_mcp_response(mcp_result, self.truncate_results, self.return_pandas, conversation_id)
288-
289-
logging.debug(f"Polling: status={status}, iteration={iteration_count}")
290-
time.sleep(ITERATION_FREQUENCY)
291-
292-
return GenieResponse(
293-
result=f"Genie query timed out after {MAX_ITERATIONS * ITERATION_FREQUENCY} seconds",
294-
conversation_id=conversation_id,
295-
)
328+
# START new span for non-terminal states
329+
if status not in TERMINAL_STATES:
330+
try:
331+
current_span = client.start_span(
332+
name=status.lower(),
333+
trace_id=parent_trace_id,
334+
parent_id=parent_span_id,
335+
span_type="CHAIN",
336+
attributes={
337+
"state": status,
338+
"conversation_id": conversation_id,
339+
"message_id": message_id,
340+
},
341+
)
342+
except Exception as e:
343+
logging.warning(f"Failed to create span for {status}: {e}")
344+
current_span = None
345+
346+
logging.debug(f"Status: {last_status}{status}")
347+
last_status = status
348+
349+
# Check for terminal states
350+
if status in TERMINAL_STATES:
351+
# End any active span before returning
352+
_end_current_span(client, parent_trace_id, current_span, last_status)
353+
return _parse_genie_mcp_response(mcp_result, self.truncate_results, self.return_pandas, conversation_id)
354+
355+
logging.debug(f"Polling: status={status}, iteration={iteration_count}")
356+
time.sleep(ITERATION_FREQUENCY)
357+
358+
# Timeout - end any active span
359+
_end_current_span(client, parent_trace_id, current_span, last_status)
360+
return GenieResponse(
361+
result=f"Genie query timed out after {MAX_ITERATIONS} iterations of {ITERATION_FREQUENCY} seconds",
362+
conversation_id=conversation_id,
363+
)
296364

297365

298366
@mlflow.trace()

0 commit comments

Comments
 (0)