Skip to content

Commit 05035c7

Browse files
authored
Consolidate OpenAI autoscaling tests into test_memory_session.py (#389)
1 parent e5b431c commit 05035c7

2 files changed

Lines changed: 147 additions & 175 deletions

File tree

integrations/openai/tests/integration_tests/test_memory_session.py

Lines changed: 147 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,24 @@
22
Integration tests for AsyncDatabricksSession.
33
44
These tests require:
5-
1. A Lakebase instance to be available
5+
1. A Lakebase instance to be available (provisioned or autoscaling)
66
2. Valid Databricks authentication (DATABRICKS_HOST + DATABRICKS_TOKEN or profile)
77
8-
Set the environment variable:
9-
LAKEBASE_INSTANCE_NAME: Name of the Lakebase instance
8+
Set at least one of these environment variables:
9+
LAKEBASE_INSTANCE_NAME: Name of the Lakebase provisioned instance
10+
LAKEBASE_PROJECT + LAKEBASE_BRANCH: Autoscaling project and branch names
11+
LAKEBASE_AUTOSCALING_ENDPOINT: Full autoscaling endpoint resource path
1012
11-
Example:
13+
Example (provisioned):
1214
LAKEBASE_INSTANCE_NAME=lakebase pytest tests/integration_tests/test_memory_session.py -v
15+
16+
Example (autoscaling — project/branch):
17+
LAKEBASE_PROJECT=my-project LAKEBASE_BRANCH=main \
18+
pytest tests/integration_tests/test_memory_session.py -v
19+
20+
Example (autoscaling — endpoint):
21+
LAKEBASE_AUTOSCALING_ENDPOINT=projects/my-project/branches/main/endpoints/primary \
22+
pytest tests/integration_tests/test_memory_session.py -v
1323
"""
1424

1525
from __future__ import annotations
@@ -20,10 +30,28 @@
2030

2131
import pytest
2232

23-
# Skip all tests if LAKEBASE_INSTANCE_NAME is not set
33+
# Skip all tests if no Lakebase env vars are set
2434
pytestmark = pytest.mark.skipif(
35+
not os.environ.get("LAKEBASE_INSTANCE_NAME")
36+
and not os.environ.get("LAKEBASE_PROJECT")
37+
and not os.environ.get("LAKEBASE_AUTOSCALING_ENDPOINT"),
38+
reason="No Lakebase env vars set "
39+
"(need LAKEBASE_INSTANCE_NAME, LAKEBASE_PROJECT, or LAKEBASE_AUTOSCALING_ENDPOINT)",
40+
)
41+
42+
_skip_no_instance = pytest.mark.skipif(
2543
not os.environ.get("LAKEBASE_INSTANCE_NAME"),
26-
reason="LAKEBASE_INSTANCE_NAME environment variable not set",
44+
reason="LAKEBASE_INSTANCE_NAME not set",
45+
)
46+
47+
_skip_no_project_branch = pytest.mark.skipif(
48+
not os.environ.get("LAKEBASE_PROJECT") or not os.environ.get("LAKEBASE_BRANCH"),
49+
reason="LAKEBASE_PROJECT and LAKEBASE_BRANCH not set",
50+
)
51+
52+
_skip_no_endpoint = pytest.mark.skipif(
53+
not os.environ.get("LAKEBASE_AUTOSCALING_ENDPOINT"),
54+
reason="LAKEBASE_AUTOSCALING_ENDPOINT not set",
2755
)
2856

2957

@@ -38,31 +66,32 @@ def get_unique_table_names() -> tuple[str, str]:
3866
return f"test_sessions_{suffix}", f"test_messages_{suffix}"
3967

4068

69+
def _drop_tables(tables_to_cleanup: list[tuple[str, str]], **client_kwargs) -> None:
70+
"""Drop test tables using LakebaseClient."""
71+
from databricks_ai_bridge.lakebase import LakebaseClient
72+
73+
with LakebaseClient(**client_kwargs) as client:
74+
for sessions_table, messages_table in tables_to_cleanup:
75+
# Drop messages first (foreign key constraint)
76+
client.execute(f"DROP TABLE IF EXISTS {messages_table}")
77+
client.execute(f"DROP TABLE IF EXISTS {sessions_table}")
78+
79+
4180
@pytest.fixture
4281
def cleanup_tables():
43-
"""Fixture to track and clean up test tables after tests."""
82+
"""Fixture to track and clean up test tables after provisioned tests."""
4483
tables_to_cleanup: list[tuple[str, str]] = []
45-
4684
yield tables_to_cleanup
47-
48-
# Cleanup after test
4985
if tables_to_cleanup:
50-
from databricks_ai_bridge.lakebase import LakebasePool
51-
52-
pool = LakebasePool(instance_name=get_instance_name())
53-
with pool.connection() as conn:
54-
for sessions_table, messages_table in tables_to_cleanup:
55-
# Drop messages first (foreign key constraint)
56-
conn.execute(f"DROP TABLE IF EXISTS {messages_table}")
57-
conn.execute(f"DROP TABLE IF EXISTS {sessions_table}")
58-
pool.close()
86+
_drop_tables(tables_to_cleanup, instance_name=get_instance_name())
5987

6088

6189
# =============================================================================
62-
# AsyncDatabricksSession Tests
90+
# AsyncDatabricksSession Tests — Provisioned
6391
# =============================================================================
6492

6593

94+
@_skip_no_instance
6695
@pytest.mark.asyncio
6796
async def test_memory_session_crud_operations(cleanup_tables):
6897
"""
@@ -133,6 +162,7 @@ async def test_memory_session_crud_operations(cleanup_tables):
133162
assert items == [], f"Expected empty after clear, got {items}"
134163

135164

165+
@_skip_no_instance
136166
@pytest.mark.asyncio
137167
async def test_memory_session_multiple_sessions_isolated(cleanup_tables):
138168
"""Test that different session_ids have isolated data."""
@@ -185,6 +215,7 @@ async def test_memory_session_multiple_sessions_isolated(cleanup_tables):
185215
await session_2.clear_session()
186216

187217

218+
@_skip_no_instance
188219
@pytest.mark.asyncio
189220
async def test_memory_session_pop_empty_returns_none(cleanup_tables):
190221
"""Test that pop_item returns None on empty session."""
@@ -205,6 +236,7 @@ async def test_memory_session_pop_empty_returns_none(cleanup_tables):
205236
assert popped is None
206237

207238

239+
@_skip_no_instance
208240
@pytest.mark.asyncio
209241
async def test_memory_session_add_empty_items_noop(cleanup_tables):
210242
"""Test that add_items with empty list is a no-op."""
@@ -228,6 +260,7 @@ async def test_memory_session_add_empty_items_noop(cleanup_tables):
228260
assert items == []
229261

230262

263+
@_skip_no_instance
231264
@pytest.mark.asyncio
232265
async def test_memory_session_complex_message_data(cleanup_tables):
233266
"""Test storing complex message data with nested structures."""
@@ -275,6 +308,7 @@ async def test_memory_session_complex_message_data(cleanup_tables):
275308
await session.clear_session()
276309

277310

311+
@_skip_no_instance
278312
@pytest.mark.asyncio
279313
async def test_memory_session_get_items_ordering(cleanup_tables):
280314
"""Test that get_items returns items in correct chronological order."""
@@ -310,3 +344,96 @@ async def test_memory_session_get_items_ordering(cleanup_tables):
310344

311345
# Cleanup
312346
await session.clear_session()
347+
348+
349+
# =============================================================================
350+
# AsyncDatabricksSession Tests — Autoscaling
351+
# =============================================================================
352+
353+
354+
async def _run_autoscaling_crud_test(conn_kwargs: dict, cleanup_tables: list):
355+
"""Test CRUD lifecycle for autoscaling: empty -> add -> get -> pop -> clear."""
356+
from databricks_openai.agents import AsyncDatabricksSession
357+
358+
sessions_table, messages_table = get_unique_table_names()
359+
cleanup_tables.append((sessions_table, messages_table))
360+
361+
session = AsyncDatabricksSession(
362+
session_id=str(uuid.uuid4()),
363+
sessions_table=sessions_table,
364+
messages_table=messages_table,
365+
**conn_kwargs,
366+
)
367+
368+
# Empty session
369+
items = cast(list[Any], await session.get_items())
370+
assert items == []
371+
372+
# Add and retrieve
373+
test_items: list[Any] = [
374+
{"role": "user", "content": "Hello from autoscaling"},
375+
{"role": "assistant", "content": "Autoscaling response"},
376+
]
377+
await session.add_items(test_items)
378+
379+
items = cast(list[Any], await session.get_items())
380+
assert len(items) == 2
381+
assert items[0]["content"] == "Hello from autoscaling"
382+
assert items[1]["content"] == "Autoscaling response"
383+
384+
# Pop last item
385+
popped = cast(Any, await session.pop_item())
386+
assert popped is not None
387+
assert popped["role"] == "assistant"
388+
389+
# Clear
390+
await session.clear_session()
391+
items = cast(list[Any], await session.get_items())
392+
assert items == []
393+
394+
395+
@pytest.fixture
396+
def cleanup_tables_project_branch():
397+
"""Track and clean up test tables on the project/branch autoscaling database."""
398+
tables_to_cleanup: list[tuple[str, str]] = []
399+
yield tables_to_cleanup
400+
if tables_to_cleanup:
401+
_drop_tables(
402+
tables_to_cleanup,
403+
project=os.environ["LAKEBASE_PROJECT"],
404+
branch=os.environ["LAKEBASE_BRANCH"],
405+
)
406+
407+
408+
@pytest.fixture
409+
def cleanup_tables_endpoint():
410+
"""Track and clean up test tables on the endpoint autoscaling database."""
411+
tables_to_cleanup: list[tuple[str, str]] = []
412+
yield tables_to_cleanup
413+
if tables_to_cleanup:
414+
_drop_tables(
415+
tables_to_cleanup,
416+
autoscaling_endpoint=os.environ["LAKEBASE_AUTOSCALING_ENDPOINT"],
417+
)
418+
419+
420+
class TestSessionAutoscaling:
421+
"""Test AsyncDatabricksSession with autoscaling modes (project/branch and endpoint)."""
422+
423+
@_skip_no_project_branch
424+
@pytest.mark.asyncio
425+
async def test_crud_project_branch(self, cleanup_tables_project_branch):
426+
"""Test autoscaling project/branch params forwarded to AsyncLakebaseSQLAlchemy."""
427+
await _run_autoscaling_crud_test(
428+
{"project": os.environ["LAKEBASE_PROJECT"], "branch": os.environ["LAKEBASE_BRANCH"]},
429+
cleanup_tables_project_branch,
430+
)
431+
432+
@_skip_no_endpoint
433+
@pytest.mark.asyncio
434+
async def test_crud_endpoint(self, cleanup_tables_endpoint):
435+
"""Test endpoint autoscaling params forwarded to AsyncLakebaseSQLAlchemy."""
436+
await _run_autoscaling_crud_test(
437+
{"autoscaling_endpoint": os.environ["LAKEBASE_AUTOSCALING_ENDPOINT"]},
438+
cleanup_tables_endpoint,
439+
)

integrations/openai/tests/integration_tests/test_memory_session_autoscaling.py

Lines changed: 0 additions & 155 deletions
This file was deleted.

0 commit comments

Comments
 (0)