diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 1e395ab..80b7342 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -38,6 +38,9 @@ jobs: cd tests python -c "import wherobots.db" + - name: Run tests + run: pytest tests/ + - name: Check build run: | uv build diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index bfb983d..3ea2088 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -3,6 +3,50 @@ This project uses `uv`. Run `uv sync` after checking out the repository to initialize your virtualenv with the project's dependencies. +## Running tests + +Unit tests live in `tests/` and run with pytest: + +```bash +uv run pytest tests/ +``` + +The `scripts/` directory contains integration scripts that require a live +Wherobots environment and are not part of the automated test suite. + +### Smoke test + +`scripts/smoke.py` runs queries against a live Wherobots SQL session. +It requires an API key (or token) and supports most `connect()` options +via CLI flags. + +```bash +# Basic query with an API key +uv run python scripts/smoke.py \ + --api-key-file ~/.wherobots/api-key \ + "SELECT 1" + +# Specify runtime, region, and version +uv run python scripts/smoke.py \ + --api-key-file ~/.wherobots/api-key \ + --runtime tiny --region aws-us-west-2 --version latest \ + "SELECT ST_AsText(ST_Point(1, 2))" + +# Connect directly to an existing session via WebSocket URL +uv run python scripts/smoke.py \ + --api-key-file ~/.wherobots/api-key \ + --ws-url wss://compute.example.com/sql/org/session-id \ + "SHOW TABLES" + +# Enable debug logging and execution progress +uv run python scripts/smoke.py \ + --api-key-file ~/.wherobots/api-key \ + --debug --progress \ + "SELECT * FROM wherobots_open_data.overture.places LIMIT 10" +``` + +Run `uv run python scripts/smoke.py --help` for all available options. + ## Publish package to PyPI When we are ready to release a new version `vx.y.z`, one of the maintainers should: diff --git a/README.md b/README.md index 88ffaef..ce968f6 100644 --- a/README.md +++ b/README.md @@ -261,3 +261,9 @@ users may find useful: your expected time between queries and effectively get a continuously running SQL session runtime without any complex connection management in your application. +* `cancel_event`: a `threading.Event` that, when set, causes the + connection attempt to abort promptly with an `InterfaceError`. This + is useful when `connect()` is running in a background thread and the + caller needs to interrupt it (e.g. on client disconnect or timeout). + The event is checked before each HTTP request, between retry + attempts, and before the WebSocket handshake. diff --git a/pyproject.toml b/pyproject.toml index 161e431..538d50a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "wherobots-python-dbapi" -version = "0.26.1" +version = "0.27.0" description = "Python DB-API driver for Wherobots DB" authors = [{ name = "Maxime Petazzoni", email = "max@wherobots.com" }] requires-python = ">=3.10, <4" diff --git a/tests/smoke.py b/scripts/smoke.py similarity index 100% rename from tests/smoke.py rename to scripts/smoke.py diff --git a/tests/test_driver.py b/tests/test_driver.py new file mode 100644 index 0000000..ef75cf2 --- /dev/null +++ b/tests/test_driver.py @@ -0,0 +1,146 @@ +"""Tests for the connect() and connect_direct() driver functions.""" + +import threading +import time +from unittest.mock import MagicMock, patch + +import pytest +import requests + +from wherobots.db.driver import ( + DEFAULT_HTTP_TIMEOUT, + _check_cancelled, + connect, + connect_direct, +) +from wherobots.db.errors import InterfaceError + + +class TestCheckCancelled: + def test_none_event_is_noop(self): + _check_cancelled(None) + + def test_unset_event_is_noop(self): + event = threading.Event() + _check_cancelled(event) + + def test_set_event_raises(self): + event = threading.Event() + event.set() + with pytest.raises(InterfaceError, match="cancelled by caller"): + _check_cancelled(event) + + +class TestConnectCancelEvent: + @patch("wherobots.db.driver.requests.post") + def test_cancel_before_post(self, mock_post): + """cancel_event set before connect() should raise immediately without making HTTP calls.""" + cancel = threading.Event() + cancel.set() + + with pytest.raises(InterfaceError, match="cancelled by caller"): + connect(api_key="test-key", cancel_event=cancel) + + mock_post.assert_not_called() + + @patch("wherobots.db.driver.requests.get") + @patch("wherobots.db.driver.requests.post") + def test_cancel_during_polling(self, mock_post, mock_get): + """cancel_event set during session polling should abort the retry loop.""" + # POST succeeds with redirect + post_resp = MagicMock() + post_resp.status_code = 200 + post_resp.url = "https://api.example.com/sql/session/test-id" + post_resp.raise_for_status = MagicMock() + mock_post.return_value = post_resp + + # GET returns INITIALIZING (triggers TryAgain) + get_resp = MagicMock() + get_resp.status_code = 200 + get_resp.raise_for_status = MagicMock() + get_resp.json.return_value = {"status": "INITIALIZING"} + mock_get.return_value = get_resp + + cancel = threading.Event() + + # Set cancel after a short delay (during polling) + def set_cancel(): + time.sleep(0.1) + cancel.set() + + t = threading.Thread(target=set_cancel) + t.start() + + with pytest.raises(InterfaceError, match="cancelled by caller"): + connect(api_key="test-key", cancel_event=cancel, wait_timeout=10) + + t.join() + + @patch("wherobots.db.driver.requests.post") + def test_http_timeout_on_post(self, mock_post): + """requests.post should be called with a timeout.""" + post_resp = MagicMock() + post_resp.status_code = 401 + post_resp.raise_for_status.side_effect = requests.HTTPError(response=post_resp) + post_resp.json.side_effect = requests.JSONDecodeError("", "", 0) + mock_post.return_value = post_resp + + with pytest.raises(InterfaceError, match="Failed to create SQL session"): + connect(api_key="test-key") + + _, kwargs = mock_post.call_args + assert kwargs["timeout"] == DEFAULT_HTTP_TIMEOUT + + @patch("wherobots.db.driver.requests.get") + @patch("wherobots.db.driver.requests.post") + def test_http_timeout_on_get(self, mock_post, mock_get): + """requests.get in the polling loop should be called with a timeout.""" + post_resp = MagicMock() + post_resp.status_code = 200 + post_resp.url = "https://api.example.com/sql/session/test-id" + post_resp.raise_for_status = MagicMock() + mock_post.return_value = post_resp + + get_resp = MagicMock() + get_resp.status_code = 200 + get_resp.raise_for_status = MagicMock() + get_resp.json.return_value = { + "status": "READY", + "appMeta": {"url": "https://compute.example.com/sql/org/session-id"}, + } + mock_get.return_value = get_resp + + # Patch connect_direct to avoid actual WebSocket connection + with patch("wherobots.db.driver.connect_direct") as mock_cd: + mock_cd.return_value = MagicMock() + connect(api_key="test-key") + + _, kwargs = mock_get.call_args + assert kwargs["timeout"] == DEFAULT_HTTP_TIMEOUT + + @patch("wherobots.db.driver.requests.post") + def test_connect_without_cancel_event(self, mock_post): + """connect() without cancel_event should work as before (backward compat).""" + post_resp = MagicMock() + post_resp.status_code = 401 + post_resp.raise_for_status.side_effect = requests.HTTPError(response=post_resp) + post_resp.json.side_effect = requests.JSONDecodeError("", "", 0) + mock_post.return_value = post_resp + + with pytest.raises(InterfaceError): + connect(api_key="test-key") + + +class TestConnectDirectCancelEvent: + @patch("wherobots.db.driver.websockets.sync.client.connect") + def test_cancel_before_ws_connect(self, mock_ws): + cancel = threading.Event() + cancel.set() + + with pytest.raises(InterfaceError, match="cancelled by caller"): + connect_direct( + uri="wss://compute.example.com/sql/org/session-id", + cancel_event=cancel, + ) + + mock_ws.assert_not_called() diff --git a/wherobots/db/driver.py b/wherobots/db/driver.py index b18ff84..a85c313 100644 --- a/wherobots/db/driver.py +++ b/wherobots/db/driver.py @@ -11,6 +11,7 @@ import platform import requests import tenacity +import threading from typing import Final, Union, Dict import urllib.parse import websockets.exceptions @@ -51,6 +52,9 @@ # This follows the industry-standard set used by urllib3.util.Retry's status_forcelist. TRANSIENT_HTTP_STATUS_CODES = {429, 502, 503, 504} +# Default timeout for individual HTTP requests (connect + read), in seconds. +DEFAULT_HTTP_TIMEOUT = 30 + def gen_user_agent_header(): try: @@ -79,6 +83,7 @@ def connect( results_format: Union[ResultsFormat, None] = None, data_compression: Union[DataCompression, None] = None, geometry_representation: Union[GeometryRepresentation, None] = None, + cancel_event: Union[threading.Event, None] = None, ) -> Connection: if not token and not api_key: raise ValueError("At least one of `token` or `api_key` is required") @@ -109,6 +114,8 @@ def connect( if not host.startswith("http:"): host = f"https://{host}" + _check_cancelled(cancel_event) + try: resp = requests.post( url=f"{host}/sql/session", @@ -120,6 +127,7 @@ def connect( "sessionType": session_type.value, }, headers=headers, + timeout=DEFAULT_HTTP_TIMEOUT, ) resp.raise_for_status() except requests.HTTPError as e: @@ -149,10 +157,12 @@ def connect( ) | tenacity.retry_if_exception_type(tenacity.TryAgain) ), + before_sleep=lambda _: _check_cancelled(cancel_event), reraise=True, ) def get_session_uri() -> str: - r = requests.get(session_id_url, headers=headers) + _check_cancelled(cancel_event) + r = requests.get(session_id_url, headers=headers, timeout=DEFAULT_HTTP_TIMEOUT) r.raise_for_status() payload = r.json() status = AppStatus(payload.get("status")) @@ -169,6 +179,8 @@ def get_session_uri() -> str: logging.info("Getting SQL session status from %s ...", session_id_url) session_uri = get_session_uri() logging.debug("SQL session URI from app status: %s", session_uri) + except InterfaceError: + raise except Exception as e: raise InterfaceError("Could not acquire SQL session!", e) @@ -179,9 +191,16 @@ def get_session_uri() -> str: results_format=results_format, data_compression=data_compression, geometry_representation=geometry_representation, + cancel_event=cancel_event, ) +def _check_cancelled(cancel_event: Union[threading.Event, None]) -> None: + """Raise InterfaceError if the cancel event is set.""" + if cancel_event is not None and cancel_event.is_set(): + raise InterfaceError("Connection cancelled by caller") + + def http_to_ws(uri: str) -> str: """Converts an HTTP URI to a WebSocket URI.""" parsed = urllib.parse.urlparse(uri) @@ -199,6 +218,7 @@ def connect_direct( results_format: Union[ResultsFormat, None] = None, data_compression: Union[DataCompression, None] = None, geometry_representation: Union[GeometryRepresentation, None] = None, + cancel_event: Union[threading.Event, None] = None, ) -> Connection: uri_with_protocol = f"{uri}/{protocol}" ssl_context = ssl.create_default_context() @@ -215,19 +235,24 @@ def connect_direct( websockets.exceptions.InvalidHandshake, ) ), + before_sleep=lambda _: _check_cancelled(cancel_event), reraise=True, ) def ws_connect() -> websockets.sync.client.ClientConnection: + _check_cancelled(cancel_event) logging.info("Connecting to SQL session at %s ...", uri_with_protocol) return websockets.sync.client.connect( uri=uri_with_protocol, additional_headers=headers, max_size=MAX_MESSAGE_SIZE, + open_timeout=DEFAULT_HTTP_TIMEOUT, ssl=ssl_context, ) try: ws = ws_connect() + except InterfaceError: + raise except Exception as e: raise InterfaceError("Failed to connect to SQL session!") from e