Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ jobs:
cd tests
python -c "import wherobots.db"

- name: Run tests
run: pytest tests/

- name: Check build
run: |
uv build
Expand Down
44 changes: 44 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,50 @@
This project uses `uv`. Run `uv sync` after checking out the repository
to initialize your virtualenv with the project's dependencies.

## Running tests

Unit tests live in `tests/` and run with pytest:

```bash
uv run pytest tests/
```

The `scripts/` directory contains integration scripts that require a live
Wherobots environment and are not part of the automated test suite.

### Smoke test

`scripts/smoke.py` runs queries against a live Wherobots SQL session.
It requires an API key (or token) and supports most `connect()` options
via CLI flags.

```bash
# Basic query with an API key
uv run python scripts/smoke.py \
--api-key-file ~/.wherobots/api-key \
"SELECT 1"

# Specify runtime, region, and version
uv run python scripts/smoke.py \
--api-key-file ~/.wherobots/api-key \
--runtime tiny --region aws-us-west-2 --version latest \
"SELECT ST_AsText(ST_Point(1, 2))"

# Connect directly to an existing session via WebSocket URL
uv run python scripts/smoke.py \
--api-key-file ~/.wherobots/api-key \
--ws-url wss://compute.example.com/sql/org/session-id \
"SHOW TABLES"

# Enable debug logging and execution progress
uv run python scripts/smoke.py \
--api-key-file ~/.wherobots/api-key \
--debug --progress \
"SELECT * FROM wherobots_open_data.overture.places LIMIT 10"
```

Run `uv run python scripts/smoke.py --help` for all available options.

## Publish package to PyPI

When we are ready to release a new version `vx.y.z`, one of the maintainers should:
Expand Down
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -261,3 +261,9 @@ users may find useful:
your expected time between queries and effectively get a continuously
running SQL session runtime without any complex connection management
in your application.
* `cancel_event`: a `threading.Event` that, when set, causes the
connection attempt to abort promptly with an `InterfaceError`. This
is useful when `connect()` is running in a background thread and the
caller needs to interrupt it (e.g. on client disconnect or timeout).
The event is checked before each HTTP request, between retry
attempts, and before the WebSocket handshake.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "wherobots-python-dbapi"
version = "0.26.1"
version = "0.27.0"
description = "Python DB-API driver for Wherobots DB"
authors = [{ name = "Maxime Petazzoni", email = "max@wherobots.com" }]
requires-python = ">=3.10, <4"
Expand Down
File renamed without changes.
146 changes: 146 additions & 0 deletions tests/test_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
"""Tests for the connect() and connect_direct() driver functions."""

import threading
import time
from unittest.mock import MagicMock, patch

import pytest
import requests

from wherobots.db.driver import (
DEFAULT_HTTP_TIMEOUT,
_check_cancelled,
connect,
connect_direct,
)
from wherobots.db.errors import InterfaceError


class TestCheckCancelled:
def test_none_event_is_noop(self):
_check_cancelled(None)

def test_unset_event_is_noop(self):
event = threading.Event()
_check_cancelled(event)

def test_set_event_raises(self):
event = threading.Event()
event.set()
with pytest.raises(InterfaceError, match="cancelled by caller"):
_check_cancelled(event)


class TestConnectCancelEvent:
@patch("wherobots.db.driver.requests.post")
def test_cancel_before_post(self, mock_post):
"""cancel_event set before connect() should raise immediately without making HTTP calls."""
cancel = threading.Event()
cancel.set()

with pytest.raises(InterfaceError, match="cancelled by caller"):
connect(api_key="test-key", cancel_event=cancel)

mock_post.assert_not_called()

@patch("wherobots.db.driver.requests.get")
@patch("wherobots.db.driver.requests.post")
def test_cancel_during_polling(self, mock_post, mock_get):
"""cancel_event set during session polling should abort the retry loop."""
# POST succeeds with redirect
post_resp = MagicMock()
post_resp.status_code = 200
post_resp.url = "https://api.example.com/sql/session/test-id"
post_resp.raise_for_status = MagicMock()
mock_post.return_value = post_resp

# GET returns INITIALIZING (triggers TryAgain)
get_resp = MagicMock()
get_resp.status_code = 200
get_resp.raise_for_status = MagicMock()
get_resp.json.return_value = {"status": "INITIALIZING"}
mock_get.return_value = get_resp

cancel = threading.Event()

# Set cancel after a short delay (during polling)
def set_cancel():
time.sleep(0.1)
cancel.set()

t = threading.Thread(target=set_cancel)
t.start()

with pytest.raises(InterfaceError, match="cancelled by caller"):
connect(api_key="test-key", cancel_event=cancel, wait_timeout=10)

t.join()

@patch("wherobots.db.driver.requests.post")
def test_http_timeout_on_post(self, mock_post):
"""requests.post should be called with a timeout."""
post_resp = MagicMock()
post_resp.status_code = 401
post_resp.raise_for_status.side_effect = requests.HTTPError(response=post_resp)
post_resp.json.side_effect = requests.JSONDecodeError("", "", 0)
mock_post.return_value = post_resp

with pytest.raises(InterfaceError, match="Failed to create SQL session"):
connect(api_key="test-key")

_, kwargs = mock_post.call_args
assert kwargs["timeout"] == DEFAULT_HTTP_TIMEOUT

@patch("wherobots.db.driver.requests.get")
@patch("wherobots.db.driver.requests.post")
def test_http_timeout_on_get(self, mock_post, mock_get):
"""requests.get in the polling loop should be called with a timeout."""
post_resp = MagicMock()
post_resp.status_code = 200
post_resp.url = "https://api.example.com/sql/session/test-id"
post_resp.raise_for_status = MagicMock()
mock_post.return_value = post_resp

get_resp = MagicMock()
get_resp.status_code = 200
get_resp.raise_for_status = MagicMock()
get_resp.json.return_value = {
"status": "READY",
"appMeta": {"url": "https://compute.example.com/sql/org/session-id"},
}
mock_get.return_value = get_resp

# Patch connect_direct to avoid actual WebSocket connection
with patch("wherobots.db.driver.connect_direct") as mock_cd:
mock_cd.return_value = MagicMock()
connect(api_key="test-key")

_, kwargs = mock_get.call_args
assert kwargs["timeout"] == DEFAULT_HTTP_TIMEOUT

@patch("wherobots.db.driver.requests.post")
def test_connect_without_cancel_event(self, mock_post):
"""connect() without cancel_event should work as before (backward compat)."""
post_resp = MagicMock()
post_resp.status_code = 401
post_resp.raise_for_status.side_effect = requests.HTTPError(response=post_resp)
post_resp.json.side_effect = requests.JSONDecodeError("", "", 0)
mock_post.return_value = post_resp

with pytest.raises(InterfaceError):
connect(api_key="test-key")


class TestConnectDirectCancelEvent:
@patch("wherobots.db.driver.websockets.sync.client.connect")
def test_cancel_before_ws_connect(self, mock_ws):
cancel = threading.Event()
cancel.set()

with pytest.raises(InterfaceError, match="cancelled by caller"):
connect_direct(
uri="wss://compute.example.com/sql/org/session-id",
cancel_event=cancel,
)

mock_ws.assert_not_called()
27 changes: 26 additions & 1 deletion wherobots/db/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import platform
import requests
import tenacity
import threading
from typing import Final, Union, Dict
import urllib.parse
import websockets.exceptions
Expand Down Expand Up @@ -51,6 +52,9 @@
# This follows the industry-standard set used by urllib3.util.Retry's status_forcelist.
TRANSIENT_HTTP_STATUS_CODES = {429, 502, 503, 504}

# Default timeout for individual HTTP requests (connect + read), in seconds.
DEFAULT_HTTP_TIMEOUT = 30


def gen_user_agent_header():
try:
Expand Down Expand Up @@ -79,6 +83,7 @@ def connect(
results_format: Union[ResultsFormat, None] = None,
data_compression: Union[DataCompression, None] = None,
geometry_representation: Union[GeometryRepresentation, None] = None,
cancel_event: Union[threading.Event, None] = None,
) -> Connection:
if not token and not api_key:
raise ValueError("At least one of `token` or `api_key` is required")
Expand Down Expand Up @@ -109,6 +114,8 @@ def connect(
if not host.startswith("http:"):
host = f"https://{host}"

_check_cancelled(cancel_event)

try:
resp = requests.post(
url=f"{host}/sql/session",
Expand All @@ -120,6 +127,7 @@ def connect(
"sessionType": session_type.value,
},
headers=headers,
timeout=DEFAULT_HTTP_TIMEOUT,
)
resp.raise_for_status()
except requests.HTTPError as e:
Expand Down Expand Up @@ -149,10 +157,12 @@ def connect(
)
| tenacity.retry_if_exception_type(tenacity.TryAgain)
),
before_sleep=lambda _: _check_cancelled(cancel_event),
reraise=True,
)
def get_session_uri() -> str:
r = requests.get(session_id_url, headers=headers)
_check_cancelled(cancel_event)
r = requests.get(session_id_url, headers=headers, timeout=DEFAULT_HTTP_TIMEOUT)
r.raise_for_status()
payload = r.json()
status = AppStatus(payload.get("status"))
Expand All @@ -169,6 +179,8 @@ def get_session_uri() -> str:
logging.info("Getting SQL session status from %s ...", session_id_url)
session_uri = get_session_uri()
logging.debug("SQL session URI from app status: %s", session_uri)
except InterfaceError:
raise
except Exception as e:
raise InterfaceError("Could not acquire SQL session!", e)

Expand All @@ -179,9 +191,16 @@ def get_session_uri() -> str:
results_format=results_format,
data_compression=data_compression,
geometry_representation=geometry_representation,
cancel_event=cancel_event,
)


def _check_cancelled(cancel_event: Union[threading.Event, None]) -> None:
"""Raise InterfaceError if the cancel event is set."""
if cancel_event is not None and cancel_event.is_set():
raise InterfaceError("Connection cancelled by caller")


def http_to_ws(uri: str) -> str:
"""Converts an HTTP URI to a WebSocket URI."""
parsed = urllib.parse.urlparse(uri)
Expand All @@ -199,6 +218,7 @@ def connect_direct(
results_format: Union[ResultsFormat, None] = None,
data_compression: Union[DataCompression, None] = None,
geometry_representation: Union[GeometryRepresentation, None] = None,
cancel_event: Union[threading.Event, None] = None,
) -> Connection:
uri_with_protocol = f"{uri}/{protocol}"
ssl_context = ssl.create_default_context()
Expand All @@ -215,19 +235,24 @@ def connect_direct(
websockets.exceptions.InvalidHandshake,
)
),
before_sleep=lambda _: _check_cancelled(cancel_event),
reraise=True,
)
def ws_connect() -> websockets.sync.client.ClientConnection:
_check_cancelled(cancel_event)
logging.info("Connecting to SQL session at %s ...", uri_with_protocol)
return websockets.sync.client.connect(
uri=uri_with_protocol,
additional_headers=headers,
max_size=MAX_MESSAGE_SIZE,
open_timeout=DEFAULT_HTTP_TIMEOUT,
ssl=ssl_context,
)

try:
ws = ws_connect()
except InterfaceError:
raise
except Exception as e:
raise InterfaceError("Failed to connect to SQL session!") from e

Expand Down
Loading