Skip to content

Commit e5b431c

Browse files
dhruv0811claude
andauthored
Add autoscaling integration tests for lakebase (#374)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent a749d4f commit e5b431c

2 files changed

Lines changed: 377 additions & 15 deletions

File tree

integrations/langchain/tests/integration_tests/test_langchain_lakebase.py

Lines changed: 222 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,20 @@
22
Integration tests for LangChain Lakebase wrappers (DatabricksStore, CheckpointSaver).
33
44
These tests require:
5-
1. A Lakebase instance to be available
5+
1. A Lakebase instance to be available (provisioned or autoscaling)
66
2. Valid Databricks authentication (DATABRICKS_HOST + DATABRICKS_CLIENT_ID/SECRET or profile)
77
8-
Set the environment variable:
9-
LAKEBASE_INSTANCE_NAME: Name of the Lakebase instance
8+
Set at least one of these environment variables:
9+
LAKEBASE_INSTANCE_NAME: Name of the Lakebase provisioned instance
10+
LAKEBASE_PROJECT + LAKEBASE_BRANCH: Autoscaling project and branch names
11+
LAKEBASE_AUTOSCALING_ENDPOINT: Full autoscaling endpoint resource path
1012
11-
Example:
13+
Example (provisioned):
1214
LAKEBASE_INSTANCE_NAME=my-lakebase pytest tests/integration_tests/test_langchain_lakebase.py -v
15+
16+
Example (autoscaling):
17+
LAKEBASE_PROJECT=my-project LAKEBASE_BRANCH=main \
18+
pytest tests/integration_tests/test_langchain_lakebase.py -v
1319
"""
1420

1521
from __future__ import annotations
@@ -32,10 +38,28 @@
3238
DatabricksStore,
3339
)
3440

35-
# Skip all tests if LAKEBASE_INSTANCE_NAME is not set
41+
# Skip all tests if no Lakebase env vars are set
3642
pytestmark = pytest.mark.skipif(
43+
not os.environ.get("LAKEBASE_INSTANCE_NAME")
44+
and not os.environ.get("LAKEBASE_PROJECT")
45+
and not os.environ.get("LAKEBASE_AUTOSCALING_ENDPOINT"),
46+
reason="No Lakebase env vars set "
47+
"(need LAKEBASE_INSTANCE_NAME, LAKEBASE_PROJECT, or LAKEBASE_AUTOSCALING_ENDPOINT)",
48+
)
49+
50+
_skip_no_instance = pytest.mark.skipif(
3751
not os.environ.get("LAKEBASE_INSTANCE_NAME"),
38-
reason="LAKEBASE_INSTANCE_NAME environment variable not set",
52+
reason="LAKEBASE_INSTANCE_NAME not set",
53+
)
54+
55+
_skip_no_project_branch = pytest.mark.skipif(
56+
not os.environ.get("LAKEBASE_PROJECT") or not os.environ.get("LAKEBASE_BRANCH"),
57+
reason="LAKEBASE_PROJECT and LAKEBASE_BRANCH not set",
58+
)
59+
60+
_skip_no_endpoint = pytest.mark.skipif(
61+
not os.environ.get("LAKEBASE_AUTOSCALING_ENDPOINT"),
62+
reason="LAKEBASE_AUTOSCALING_ENDPOINT not set",
3963
)
4064

4165

@@ -44,6 +68,18 @@ def get_instance_name() -> str:
4468
return os.environ["LAKEBASE_INSTANCE_NAME"]
4569

4670

71+
def get_project() -> str:
72+
return os.environ["LAKEBASE_PROJECT"]
73+
74+
75+
def get_branch() -> str:
76+
return os.environ["LAKEBASE_BRANCH"]
77+
78+
79+
def get_autoscaling_endpoint() -> str:
80+
return os.environ["LAKEBASE_AUTOSCALING_ENDPOINT"]
81+
82+
4783
# =============================================================================
4884
# Tables managed by LangGraph that must be cleaned up between test runs.
4985
# Includes both data tables and migration-tracking tables; PostgresStore's
@@ -58,11 +94,14 @@ def get_instance_name() -> str:
5894
"checkpoint_writes",
5995
"checkpoints",
6096
]
97+
ALL_TABLES = STORE_TABLES + CHECKPOINT_TABLES
6198

6299

63-
def _drop_tables(tables: list[str]) -> None:
64-
"""Drop the given tables from the Lakebase instance."""
65-
with LakebaseClient(instance_name=get_instance_name()) as client:
100+
def _drop_tables(tables: list[str], **client_kwargs) -> None:
101+
"""Drop the given tables from a Lakebase instance."""
102+
if not client_kwargs:
103+
client_kwargs = {"instance_name": get_instance_name()}
104+
with LakebaseClient(**client_kwargs) as client:
66105
for table in tables:
67106
client.execute(f"DROP TABLE IF EXISTS {table} CASCADE")
68107

@@ -80,7 +119,7 @@ def unique_namespace() -> tuple[str, str]:
80119

81120
@pytest.fixture(scope="module")
82121
def cleanup_store_tables():
83-
"""Drop store tables before and after all store tests.
122+
"""Drop store tables before and after all provisioned store tests.
84123
85124
scope="module" means tables are dropped once at the start of the module and
86125
once at the end — NOT before/after each individual test. This keeps tests
@@ -95,7 +134,7 @@ def cleanup_store_tables():
95134

96135
@pytest.fixture(scope="module")
97136
def cleanup_checkpoint_tables():
98-
"""Drop checkpoint tables before and after all checkpoint tests.
137+
"""Drop checkpoint tables before and after all provisioned checkpoint tests.
99138
100139
scope="module" means tables are dropped once at the start of the module and
101140
once at the end — NOT before/after each individual test.
@@ -105,6 +144,30 @@ def cleanup_checkpoint_tables():
105144
_drop_tables(CHECKPOINT_TABLES)
106145

107146

147+
@pytest.fixture(scope="module")
148+
def cleanup_all_tables_project_branch():
149+
"""Drop all LangGraph tables on the project/branch autoscaling database."""
150+
if not os.environ.get("LAKEBASE_PROJECT") or not os.environ.get("LAKEBASE_BRANCH"):
151+
yield
152+
return
153+
kwargs = {"project": get_project(), "branch": get_branch()}
154+
_drop_tables(ALL_TABLES, **kwargs)
155+
yield
156+
_drop_tables(ALL_TABLES, **kwargs)
157+
158+
159+
@pytest.fixture(scope="module")
160+
def cleanup_all_tables_endpoint():
161+
"""Drop all LangGraph tables on the endpoint autoscaling database."""
162+
if not os.environ.get("LAKEBASE_AUTOSCALING_ENDPOINT"):
163+
yield
164+
return
165+
kwargs = {"autoscaling_endpoint": get_autoscaling_endpoint()}
166+
_drop_tables(ALL_TABLES, **kwargs)
167+
yield
168+
_drop_tables(ALL_TABLES, **kwargs)
169+
170+
108171
def _make_checkpoint(ts: str = "2025-01-01T00:00:00+00:00") -> Checkpoint:
109172
"""Build a Checkpoint with a random ID and the given timestamp."""
110173
return Checkpoint(
@@ -119,10 +182,11 @@ def _make_checkpoint(ts: str = "2025-01-01T00:00:00+00:00") -> Checkpoint:
119182

120183

121184
# =============================================================================
122-
# DatabricksStore (Sync) Tests
185+
# DatabricksStore (Sync) Tests — Provisioned
123186
# =============================================================================
124187

125188

189+
@_skip_no_instance
126190
class TestDatabricksStore:
127191
"""Test synchronous DatabricksStore against a live Lakebase instance."""
128192

@@ -186,10 +250,11 @@ def test_store_vector_search(self, unique_namespace, cleanup_store_tables):
186250

187251

188252
# =============================================================================
189-
# AsyncDatabricksStore Tests
253+
# AsyncDatabricksStore Tests — Provisioned
190254
# =============================================================================
191255

192256

257+
@_skip_no_instance
193258
class TestAsyncDatabricksStore:
194259
"""Test asynchronous AsyncDatabricksStore against a live Lakebase instance."""
195260

@@ -259,10 +324,11 @@ async def test_async_store_vector_search(self, unique_namespace, cleanup_store_t
259324

260325

261326
# =============================================================================
262-
# CheckpointSaver (Sync) Tests
327+
# CheckpointSaver (Sync) Tests — Provisioned
263328
# =============================================================================
264329

265330

331+
@_skip_no_instance
266332
class TestCheckpointSaver:
267333
"""Test synchronous CheckpointSaver against a live Lakebase instance."""
268334

@@ -304,10 +370,11 @@ def test_checkpoint_list(self, cleanup_checkpoint_tables):
304370

305371

306372
# =============================================================================
307-
# AsyncCheckpointSaver Tests
373+
# AsyncCheckpointSaver Tests — Provisioned
308374
# =============================================================================
309375

310376

377+
@_skip_no_instance
311378
class TestAsyncCheckpointSaver:
312379
"""Test asynchronous AsyncCheckpointSaver against a live Lakebase instance."""
313380

@@ -348,3 +415,143 @@ async def test_async_checkpoint_list(self, cleanup_checkpoint_tables):
348415

349416
checkpoints = [c async for c in saver.alist(config)]
350417
assert len(checkpoints) == 3
418+
419+
420+
# =============================================================================
421+
# Autoscaling — Project/Branch
422+
# =============================================================================
423+
424+
425+
@_skip_no_project_branch
426+
class TestAutoscalingProjectBranch:
427+
"""Test all LangChain Lakebase wrappers with autoscaling project/branch mode."""
428+
429+
def test_store_put_and_get(self, unique_namespace, cleanup_all_tables_project_branch):
430+
"""Test DatabricksStore: autoscaling params forwarded to LakebasePool."""
431+
store = DatabricksStore(project=get_project(), branch=get_branch())
432+
store.setup()
433+
434+
ns = unique_namespace
435+
store.put(ns, "key1", {"data": "autoscaling hello"})
436+
437+
item = store.get(ns, "key1")
438+
assert item is not None
439+
assert item.value == {"data": "autoscaling hello"}
440+
assert item.key == "key1"
441+
442+
@pytest.mark.asyncio
443+
async def test_async_store_put_and_get(
444+
self, unique_namespace, cleanup_all_tables_project_branch
445+
):
446+
"""Test AsyncDatabricksStore: async pool open/close + put + get."""
447+
async with AsyncDatabricksStore(project=get_project(), branch=get_branch()) as store:
448+
await store.setup()
449+
450+
ns = unique_namespace
451+
await store.aput(ns, "async_key", {"data": "async autoscaling"})
452+
453+
item = await store.aget(ns, "async_key")
454+
assert item is not None
455+
assert item.value == {"data": "async autoscaling"}
456+
457+
assert store._lakebase.pool.closed
458+
459+
def test_checkpoint_write_and_read(self, cleanup_all_tables_project_branch):
460+
"""Test CheckpointSaver: context manager auto-setup + put + get_tuple."""
461+
thread_id = uuid.uuid4().hex
462+
463+
with CheckpointSaver(project=get_project(), branch=get_branch()) as saver:
464+
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
465+
checkpoint = _make_checkpoint()
466+
saver.put(config, checkpoint, CheckpointMetadata(), {})
467+
468+
result = saver.get_tuple(config)
469+
assert result is not None
470+
assert result.checkpoint["id"] == checkpoint["id"]
471+
472+
assert saver._lakebase.pool.closed
473+
474+
@pytest.mark.asyncio
475+
async def test_async_checkpoint_write_and_read(self, cleanup_all_tables_project_branch):
476+
"""Test AsyncCheckpointSaver: async context manager auto-setup + put + get_tuple."""
477+
thread_id = uuid.uuid4().hex
478+
479+
async with AsyncCheckpointSaver(project=get_project(), branch=get_branch()) as saver:
480+
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
481+
checkpoint = _make_checkpoint()
482+
await saver.aput(config, checkpoint, CheckpointMetadata(), {})
483+
484+
result = await saver.aget_tuple(config)
485+
assert result is not None
486+
assert result.checkpoint["id"] == checkpoint["id"]
487+
488+
assert saver._lakebase.pool.closed
489+
490+
491+
# =============================================================================
492+
# Autoscaling — Endpoint
493+
# =============================================================================
494+
495+
496+
@_skip_no_endpoint
497+
class TestAutoscalingEndpoint:
498+
"""Test all LangChain Lakebase wrappers with autoscaling endpoint mode."""
499+
500+
def test_store_put_and_get(self, unique_namespace, cleanup_all_tables_endpoint):
501+
"""Test DatabricksStore: endpoint params forwarded to LakebasePool."""
502+
store = DatabricksStore(autoscaling_endpoint=get_autoscaling_endpoint())
503+
store.setup()
504+
505+
ns = unique_namespace
506+
store.put(ns, "key1", {"data": "endpoint hello"})
507+
508+
item = store.get(ns, "key1")
509+
assert item is not None
510+
assert item.value == {"data": "endpoint hello"}
511+
assert item.key == "key1"
512+
513+
@pytest.mark.asyncio
514+
async def test_async_store_put_and_get(self, unique_namespace, cleanup_all_tables_endpoint):
515+
"""Test AsyncDatabricksStore: async endpoint pool open/close + put + get."""
516+
async with AsyncDatabricksStore(autoscaling_endpoint=get_autoscaling_endpoint()) as store:
517+
await store.setup()
518+
519+
ns = unique_namespace
520+
await store.aput(ns, "async_key", {"data": "async endpoint"})
521+
522+
item = await store.aget(ns, "async_key")
523+
assert item is not None
524+
assert item.value == {"data": "async endpoint"}
525+
526+
assert store._lakebase.pool.closed
527+
528+
def test_checkpoint_write_and_read(self, cleanup_all_tables_endpoint):
529+
"""Test CheckpointSaver: endpoint context manager auto-setup + put + get_tuple."""
530+
thread_id = uuid.uuid4().hex
531+
532+
with CheckpointSaver(autoscaling_endpoint=get_autoscaling_endpoint()) as saver:
533+
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
534+
checkpoint = _make_checkpoint()
535+
saver.put(config, checkpoint, CheckpointMetadata(), {})
536+
537+
result = saver.get_tuple(config)
538+
assert result is not None
539+
assert result.checkpoint["id"] == checkpoint["id"]
540+
541+
assert saver._lakebase.pool.closed
542+
543+
@pytest.mark.asyncio
544+
async def test_async_checkpoint_write_and_read(self, cleanup_all_tables_endpoint):
545+
"""Test AsyncCheckpointSaver: async endpoint pool lifecycle."""
546+
thread_id = uuid.uuid4().hex
547+
548+
async with AsyncCheckpointSaver(autoscaling_endpoint=get_autoscaling_endpoint()) as saver:
549+
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
550+
checkpoint = _make_checkpoint()
551+
await saver.aput(config, checkpoint, CheckpointMetadata(), {})
552+
553+
result = await saver.aget_tuple(config)
554+
assert result is not None
555+
assert result.checkpoint["id"] == checkpoint["id"]
556+
557+
assert saver._lakebase.pool.closed

0 commit comments

Comments
 (0)