From 8d3e6ab7e123bd0d91b92a2c6cd61c17f267238b Mon Sep 17 00:00:00 2001 From: Peter Foldes Date: Tue, 28 Apr 2026 20:01:18 -0700 Subject: [PATCH] feat: add cancel_event param and HTTP timeouts to connect() Make connect() and connect_direct() interruptible from another thread via a threading.Event. Add 30s timeouts to all HTTP requests and the WebSocket handshake to prevent indefinite hangs. This enables callers (e.g. MCP server) to cleanly abort connection attempts when clients disconnect, preventing zombie thread accumulation and thread pool exhaustion. --- .github/workflows/test.yaml | 3 + CONTRIBUTING.md | 44 +++++++++++ README.md | 6 ++ pyproject.toml | 2 +- {tests => scripts}/smoke.py | 0 tests/test_driver.py | 146 ++++++++++++++++++++++++++++++++++++ wherobots/db/driver.py | 27 ++++++- 7 files changed, 226 insertions(+), 2 deletions(-) rename {tests => scripts}/smoke.py (100%) create mode 100644 tests/test_driver.py diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 1e395ab..80b7342 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -38,6 +38,9 @@ jobs: cd tests python -c "import wherobots.db" + - name: Run tests + run: pytest tests/ + - name: Check build run: | uv build diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index bfb983d..3ea2088 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -3,6 +3,50 @@ This project uses `uv`. Run `uv sync` after checking out the repository to initialize your virtualenv with the project's dependencies. +## Running tests + +Unit tests live in `tests/` and run with pytest: + +```bash +uv run pytest tests/ +``` + +The `scripts/` directory contains integration scripts that require a live +Wherobots environment and are not part of the automated test suite. + +### Smoke test + +`scripts/smoke.py` runs queries against a live Wherobots SQL session. +It requires an API key (or token) and supports most `connect()` options +via CLI flags. + +```bash +# Basic query with an API key +uv run python scripts/smoke.py \ + --api-key-file ~/.wherobots/api-key \ + "SELECT 1" + +# Specify runtime, region, and version +uv run python scripts/smoke.py \ + --api-key-file ~/.wherobots/api-key \ + --runtime tiny --region aws-us-west-2 --version latest \ + "SELECT ST_AsText(ST_Point(1, 2))" + +# Connect directly to an existing session via WebSocket URL +uv run python scripts/smoke.py \ + --api-key-file ~/.wherobots/api-key \ + --ws-url wss://compute.example.com/sql/org/session-id \ + "SHOW TABLES" + +# Enable debug logging and execution progress +uv run python scripts/smoke.py \ + --api-key-file ~/.wherobots/api-key \ + --debug --progress \ + "SELECT * FROM wherobots_open_data.overture.places LIMIT 10" +``` + +Run `uv run python scripts/smoke.py --help` for all available options. + ## Publish package to PyPI When we are ready to release a new version `vx.y.z`, one of the maintainers should: diff --git a/README.md b/README.md index 88ffaef..ce968f6 100644 --- a/README.md +++ b/README.md @@ -261,3 +261,9 @@ users may find useful: your expected time between queries and effectively get a continuously running SQL session runtime without any complex connection management in your application. +* `cancel_event`: a `threading.Event` that, when set, causes the + connection attempt to abort promptly with an `InterfaceError`. This + is useful when `connect()` is running in a background thread and the + caller needs to interrupt it (e.g. on client disconnect or timeout). + The event is checked before each HTTP request, between retry + attempts, and before the WebSocket handshake. diff --git a/pyproject.toml b/pyproject.toml index 161e431..538d50a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "wherobots-python-dbapi" -version = "0.26.1" +version = "0.27.0" description = "Python DB-API driver for Wherobots DB" authors = [{ name = "Maxime Petazzoni", email = "max@wherobots.com" }] requires-python = ">=3.10, <4" diff --git a/tests/smoke.py b/scripts/smoke.py similarity index 100% rename from tests/smoke.py rename to scripts/smoke.py diff --git a/tests/test_driver.py b/tests/test_driver.py new file mode 100644 index 0000000..ef75cf2 --- /dev/null +++ b/tests/test_driver.py @@ -0,0 +1,146 @@ +"""Tests for the connect() and connect_direct() driver functions.""" + +import threading +import time +from unittest.mock import MagicMock, patch + +import pytest +import requests + +from wherobots.db.driver import ( + DEFAULT_HTTP_TIMEOUT, + _check_cancelled, + connect, + connect_direct, +) +from wherobots.db.errors import InterfaceError + + +class TestCheckCancelled: + def test_none_event_is_noop(self): + _check_cancelled(None) + + def test_unset_event_is_noop(self): + event = threading.Event() + _check_cancelled(event) + + def test_set_event_raises(self): + event = threading.Event() + event.set() + with pytest.raises(InterfaceError, match="cancelled by caller"): + _check_cancelled(event) + + +class TestConnectCancelEvent: + @patch("wherobots.db.driver.requests.post") + def test_cancel_before_post(self, mock_post): + """cancel_event set before connect() should raise immediately without making HTTP calls.""" + cancel = threading.Event() + cancel.set() + + with pytest.raises(InterfaceError, match="cancelled by caller"): + connect(api_key="test-key", cancel_event=cancel) + + mock_post.assert_not_called() + + @patch("wherobots.db.driver.requests.get") + @patch("wherobots.db.driver.requests.post") + def test_cancel_during_polling(self, mock_post, mock_get): + """cancel_event set during session polling should abort the retry loop.""" + # POST succeeds with redirect + post_resp = MagicMock() + post_resp.status_code = 200 + post_resp.url = "https://api.example.com/sql/session/test-id" + post_resp.raise_for_status = MagicMock() + mock_post.return_value = post_resp + + # GET returns INITIALIZING (triggers TryAgain) + get_resp = MagicMock() + get_resp.status_code = 200 + get_resp.raise_for_status = MagicMock() + get_resp.json.return_value = {"status": "INITIALIZING"} + mock_get.return_value = get_resp + + cancel = threading.Event() + + # Set cancel after a short delay (during polling) + def set_cancel(): + time.sleep(0.1) + cancel.set() + + t = threading.Thread(target=set_cancel) + t.start() + + with pytest.raises(InterfaceError, match="cancelled by caller"): + connect(api_key="test-key", cancel_event=cancel, wait_timeout=10) + + t.join() + + @patch("wherobots.db.driver.requests.post") + def test_http_timeout_on_post(self, mock_post): + """requests.post should be called with a timeout.""" + post_resp = MagicMock() + post_resp.status_code = 401 + post_resp.raise_for_status.side_effect = requests.HTTPError(response=post_resp) + post_resp.json.side_effect = requests.JSONDecodeError("", "", 0) + mock_post.return_value = post_resp + + with pytest.raises(InterfaceError, match="Failed to create SQL session"): + connect(api_key="test-key") + + _, kwargs = mock_post.call_args + assert kwargs["timeout"] == DEFAULT_HTTP_TIMEOUT + + @patch("wherobots.db.driver.requests.get") + @patch("wherobots.db.driver.requests.post") + def test_http_timeout_on_get(self, mock_post, mock_get): + """requests.get in the polling loop should be called with a timeout.""" + post_resp = MagicMock() + post_resp.status_code = 200 + post_resp.url = "https://api.example.com/sql/session/test-id" + post_resp.raise_for_status = MagicMock() + mock_post.return_value = post_resp + + get_resp = MagicMock() + get_resp.status_code = 200 + get_resp.raise_for_status = MagicMock() + get_resp.json.return_value = { + "status": "READY", + "appMeta": {"url": "https://compute.example.com/sql/org/session-id"}, + } + mock_get.return_value = get_resp + + # Patch connect_direct to avoid actual WebSocket connection + with patch("wherobots.db.driver.connect_direct") as mock_cd: + mock_cd.return_value = MagicMock() + connect(api_key="test-key") + + _, kwargs = mock_get.call_args + assert kwargs["timeout"] == DEFAULT_HTTP_TIMEOUT + + @patch("wherobots.db.driver.requests.post") + def test_connect_without_cancel_event(self, mock_post): + """connect() without cancel_event should work as before (backward compat).""" + post_resp = MagicMock() + post_resp.status_code = 401 + post_resp.raise_for_status.side_effect = requests.HTTPError(response=post_resp) + post_resp.json.side_effect = requests.JSONDecodeError("", "", 0) + mock_post.return_value = post_resp + + with pytest.raises(InterfaceError): + connect(api_key="test-key") + + +class TestConnectDirectCancelEvent: + @patch("wherobots.db.driver.websockets.sync.client.connect") + def test_cancel_before_ws_connect(self, mock_ws): + cancel = threading.Event() + cancel.set() + + with pytest.raises(InterfaceError, match="cancelled by caller"): + connect_direct( + uri="wss://compute.example.com/sql/org/session-id", + cancel_event=cancel, + ) + + mock_ws.assert_not_called() diff --git a/wherobots/db/driver.py b/wherobots/db/driver.py index b18ff84..a85c313 100644 --- a/wherobots/db/driver.py +++ b/wherobots/db/driver.py @@ -11,6 +11,7 @@ import platform import requests import tenacity +import threading from typing import Final, Union, Dict import urllib.parse import websockets.exceptions @@ -51,6 +52,9 @@ # This follows the industry-standard set used by urllib3.util.Retry's status_forcelist. TRANSIENT_HTTP_STATUS_CODES = {429, 502, 503, 504} +# Default timeout for individual HTTP requests (connect + read), in seconds. +DEFAULT_HTTP_TIMEOUT = 30 + def gen_user_agent_header(): try: @@ -79,6 +83,7 @@ def connect( results_format: Union[ResultsFormat, None] = None, data_compression: Union[DataCompression, None] = None, geometry_representation: Union[GeometryRepresentation, None] = None, + cancel_event: Union[threading.Event, None] = None, ) -> Connection: if not token and not api_key: raise ValueError("At least one of `token` or `api_key` is required") @@ -109,6 +114,8 @@ def connect( if not host.startswith("http:"): host = f"https://{host}" + _check_cancelled(cancel_event) + try: resp = requests.post( url=f"{host}/sql/session", @@ -120,6 +127,7 @@ def connect( "sessionType": session_type.value, }, headers=headers, + timeout=DEFAULT_HTTP_TIMEOUT, ) resp.raise_for_status() except requests.HTTPError as e: @@ -149,10 +157,12 @@ def connect( ) | tenacity.retry_if_exception_type(tenacity.TryAgain) ), + before_sleep=lambda _: _check_cancelled(cancel_event), reraise=True, ) def get_session_uri() -> str: - r = requests.get(session_id_url, headers=headers) + _check_cancelled(cancel_event) + r = requests.get(session_id_url, headers=headers, timeout=DEFAULT_HTTP_TIMEOUT) r.raise_for_status() payload = r.json() status = AppStatus(payload.get("status")) @@ -169,6 +179,8 @@ def get_session_uri() -> str: logging.info("Getting SQL session status from %s ...", session_id_url) session_uri = get_session_uri() logging.debug("SQL session URI from app status: %s", session_uri) + except InterfaceError: + raise except Exception as e: raise InterfaceError("Could not acquire SQL session!", e) @@ -179,9 +191,16 @@ def get_session_uri() -> str: results_format=results_format, data_compression=data_compression, geometry_representation=geometry_representation, + cancel_event=cancel_event, ) +def _check_cancelled(cancel_event: Union[threading.Event, None]) -> None: + """Raise InterfaceError if the cancel event is set.""" + if cancel_event is not None and cancel_event.is_set(): + raise InterfaceError("Connection cancelled by caller") + + def http_to_ws(uri: str) -> str: """Converts an HTTP URI to a WebSocket URI.""" parsed = urllib.parse.urlparse(uri) @@ -199,6 +218,7 @@ def connect_direct( results_format: Union[ResultsFormat, None] = None, data_compression: Union[DataCompression, None] = None, geometry_representation: Union[GeometryRepresentation, None] = None, + cancel_event: Union[threading.Event, None] = None, ) -> Connection: uri_with_protocol = f"{uri}/{protocol}" ssl_context = ssl.create_default_context() @@ -215,19 +235,24 @@ def connect_direct( websockets.exceptions.InvalidHandshake, ) ), + before_sleep=lambda _: _check_cancelled(cancel_event), reraise=True, ) def ws_connect() -> websockets.sync.client.ClientConnection: + _check_cancelled(cancel_event) logging.info("Connecting to SQL session at %s ...", uri_with_protocol) return websockets.sync.client.connect( uri=uri_with_protocol, additional_headers=headers, max_size=MAX_MESSAGE_SIZE, + open_timeout=DEFAULT_HTTP_TIMEOUT, ssl=ssl_context, ) try: ws = ws_connect() + except InterfaceError: + raise except Exception as e: raise InterfaceError("Failed to connect to SQL session!") from e