Skip to content

Commit 9ddaef8

Browse files
authored
Add Async Stateful Classes for Langgraph memory: AsyncLakebasePool, AsyncCheckpointSaver, AsyncDatabricksStore (#255)
This PR adds async versions of the Lakebase connection pool (AsyncLakebasePool), LangGraph checkpoint saver (AsyncCheckpointSaver) and LangGraph PostgresStore(AsyncDatabricksStore) so they can be used in async workflows such as async agent servers and streaming applications Changes **Core Library (databricks-ai-bridge)** src/databricks_ai_bridge/lakebase.py Added AsyncLakebasePool class that wraps [psycopg_pool.AsyncConnectionPool](https://www.psycopg.org/psycopg3/docs/api/pool.html#psycopg_pool.AsyncConnectionPool) **LangChain (databricks-langchain)** **integrations/langchain/src/databricks_langchain/checkpoint.py** Added AsyncCheckpointSaver class extending [langgraph.checkpoint.postgres.aio.AsyncPostgresSaver](https://pypi.org/project/langgraph-checkpoint-postgres/) **integrations/langchain/src/databricks_langchain/store.py** Added AsyncDatabricksStore class extending [AsyncBatchedBaseStore](https://reference.langchain.com/python/langgraph/store/#langgraph.store.postgres.AsyncPostgresStore) Uses AsyncLakebasePool for connection pooling Creates short-lived AsyncPostgresStore instances for each operation Added unit tests for new classes Used in agent on apps langgraph examples: short-term: bbqiu/agent-on-app-proto#39 long-term: bbqiu/agent-on-app-proto#42
1 parent e373644 commit 9ddaef8

7 files changed

Lines changed: 960 additions & 108 deletions

File tree

integrations/langchain/src/databricks_langchain/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,22 @@
1818
from unitycatalog.ai.langchain.toolkit import UCFunctionToolkit, UnityCatalogTool
1919

2020
from databricks_langchain.chat_models import ChatDatabricks
21-
from databricks_langchain.checkpoint import CheckpointSaver
21+
from databricks_langchain.checkpoint import AsyncCheckpointSaver, CheckpointSaver
2222
from databricks_langchain.embeddings import DatabricksEmbeddings
2323
from databricks_langchain.genie import GenieAgent
2424
from databricks_langchain.multi_server_mcp_client import (
2525
DatabricksMCPServer,
2626
DatabricksMultiServerMCPClient,
2727
MCPServer,
2828
)
29-
from databricks_langchain.store import DatabricksStore
29+
from databricks_langchain.store import AsyncDatabricksStore, DatabricksStore
3030
from databricks_langchain.vector_search_retriever_tool import VectorSearchRetrieverTool
3131
from databricks_langchain.vectorstores import DatabricksVectorSearch
3232

3333
# Expose all integrations to users under databricks-langchain
3434
__all__ = [
35+
"AsyncCheckpointSaver",
36+
"AsyncDatabricksStore",
3537
"ChatDatabricks",
3638
"CheckpointSaver",
3739
"DatabricksEmbeddings",

integrations/langchain/src/databricks_langchain/checkpoint.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
from databricks.sdk import WorkspaceClient
44

55
try:
6-
from databricks_ai_bridge.lakebase import LakebasePool
6+
from databricks_ai_bridge.lakebase import AsyncLakebasePool, LakebasePool
77
from langgraph.checkpoint.postgres import PostgresSaver
8+
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
89

910
_checkpoint_imports_available = True
1011
except ImportError:
1112
PostgresSaver = object
13+
AsyncPostgresSaver = object
1214
_checkpoint_imports_available = False
1315

1416

@@ -48,3 +50,42 @@ def __exit__(self, exc_type, exc_val, exc_tb):
4850
"""Exit context manager and close the connection pool."""
4951
self._lakebase.close()
5052
return False
53+
54+
55+
class AsyncCheckpointSaver(AsyncPostgresSaver):
56+
"""
57+
Async LangGraph PostgresSaver using a Lakebase connection pool.
58+
59+
instance_name: Name of Lakebase Instance
60+
"""
61+
62+
def __init__(
63+
self,
64+
*,
65+
instance_name: str,
66+
workspace_client: WorkspaceClient | None = None,
67+
**pool_kwargs: object,
68+
) -> None:
69+
# Lazy imports
70+
if not _checkpoint_imports_available:
71+
raise ImportError(
72+
"AsyncCheckpointSaver requires databricks-langchain[memory]. "
73+
"Please install with: pip install databricks-langchain[memory]"
74+
)
75+
76+
self._lakebase: AsyncLakebasePool = AsyncLakebasePool(
77+
instance_name=instance_name,
78+
workspace_client=workspace_client,
79+
**dict(pool_kwargs),
80+
)
81+
super().__init__(self._lakebase.pool)
82+
83+
async def __aenter__(self):
84+
"""Enter async context manager and open the connection pool."""
85+
await self._lakebase.open()
86+
return self
87+
88+
async def __aexit__(self, exc_type, exc_val, exc_tb):
89+
"""Exit async context manager and close the connection pool."""
90+
await self._lakebase.close()
91+
return False

integrations/langchain/src/databricks_langchain/store.py

Lines changed: 127 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,27 @@
66
from databricks.sdk import WorkspaceClient
77

88
try:
9-
from databricks_ai_bridge.lakebase import LakebasePool
9+
from databricks_ai_bridge.lakebase import AsyncLakebasePool, LakebasePool
1010
from langgraph.store.base import BaseStore, Op, Result
11-
from langgraph.store.postgres import PostgresStore
11+
from langgraph.store.base.batch import AsyncBatchedBaseStore
12+
from langgraph.store.postgres import AsyncPostgresStore, PostgresStore
13+
14+
from databricks_langchain.embeddings import DatabricksEmbeddings
1215

1316
_store_imports_available = True
1417
except ImportError:
1518
LakebasePool = object
19+
AsyncLakebasePool = object
1620
PostgresStore = object
21+
AsyncPostgresStore = object
1722
BaseStore = object
23+
AsyncBatchedBaseStore = object
1824
Item = object
1925
Op = object
2026
Result = object
27+
DatabricksEmbeddings = object
2128
_store_imports_available = False
2229

23-
from databricks_langchain.embeddings import DatabricksEmbeddings
24-
2530

2631
class DatabricksStore(BaseStore):
2732
"""Provides APIs for working with long-term memory on Databricks using Lakebase.
@@ -135,3 +140,121 @@ async def abatch(self, ops: Iterable[Op]) -> list[Result]:
135140
would need async-compatible connection pooling.
136141
"""
137142
return self.batch(ops)
143+
144+
145+
class AsyncDatabricksStore(AsyncBatchedBaseStore):
146+
"""Async version of DatabricksStore for working with long-term memory on Databricks.
147+
148+
Extends LangGraph AsyncBatchedBaseStore interface using Databricks Lakebase
149+
for async connection pooling, with semantic search capabilities via DatabricksEmbeddings.
150+
151+
Operations borrow a connection from the async pool, create a short-lived AsyncPostgresStore,
152+
execute the operation, and return the connection to the pool.
153+
"""
154+
155+
def __init__(
156+
self,
157+
*,
158+
instance_name: str,
159+
workspace_client: Optional[WorkspaceClient] = None,
160+
embedding_endpoint: Optional[str] = None,
161+
embedding_dims: Optional[int] = None,
162+
embedding_fields: Optional[List[str]] = None,
163+
embeddings: Optional[DatabricksEmbeddings] = None,
164+
**pool_kwargs: Any,
165+
) -> None:
166+
"""Initialize AsyncDatabricksStore with embedding support.
167+
168+
Args:
169+
instance_name: The name of the Lakebase instance to connect to.
170+
workspace_client: Optional Databricks WorkspaceClient for authentication.
171+
embedding_endpoint: Name of the Databricks Model Serving endpoint for embeddings
172+
(e.g., "databricks-gte-large-en"). If provided, enables semantic search.
173+
embedding_dims: Dimension of the embedding vectors (e.g., 1024 for gte-large-en,
174+
1536 for OpenAI-compatible models). Required if embedding_endpoint is set.
175+
embedding_fields: List of JSON paths to vectorize. Defaults to ["$"] which
176+
vectorizes the entire JSON value.
177+
embeddings: Optional pre-configured DatabricksEmbeddings instance. If provided,
178+
takes precedence over embedding_endpoint.
179+
**pool_kwargs: Additional keyword arguments passed to AsyncLakebasePool.
180+
"""
181+
if not _store_imports_available:
182+
raise ImportError(
183+
"AsyncDatabricksStore requires databricks-langchain[memory]. "
184+
"Install with: pip install 'databricks-langchain[memory]'"
185+
)
186+
187+
super().__init__()
188+
189+
self._lakebase: AsyncLakebasePool = AsyncLakebasePool(
190+
instance_name=instance_name,
191+
workspace_client=workspace_client,
192+
**pool_kwargs,
193+
)
194+
195+
# Initialize embeddings and index configuration for semantic search
196+
self.embeddings: Optional[DatabricksEmbeddings] = None
197+
self.index_config: Optional[dict] = None
198+
199+
if embeddings is not None:
200+
# Use pre-configured embeddings instance
201+
if embedding_endpoint is not None:
202+
warnings.warn(
203+
"Both 'embeddings' and 'embedding_endpoint' were specified. "
204+
"Using the provided 'embeddings' instance.",
205+
UserWarning,
206+
stacklevel=2,
207+
)
208+
self.embeddings = embeddings
209+
if embedding_dims is None:
210+
raise ValueError("embedding_dims is required when providing an embeddings instance")
211+
self.index_config = {
212+
"dims": embedding_dims,
213+
"embed": self.embeddings,
214+
"fields": embedding_fields or ["$"],
215+
}
216+
elif embedding_endpoint is not None:
217+
# Create embeddings from endpoint configuration
218+
if embedding_dims is None:
219+
raise ValueError("embedding_dims is required when embedding_endpoint is specified")
220+
self.embeddings = DatabricksEmbeddings(endpoint=embedding_endpoint)
221+
self.index_config = {
222+
"dims": embedding_dims,
223+
"embed": self.embeddings,
224+
"fields": embedding_fields or ["$"],
225+
}
226+
227+
async def _with_store(self, fn, *args, **kwargs):
228+
"""
229+
Borrow an async connection, create a short-lived AsyncPostgresStore with index config,
230+
call fn(store), then return the connection to the pool.
231+
"""
232+
async with self._lakebase.connection() as conn:
233+
if self.index_config is not None:
234+
store = AsyncPostgresStore(conn=conn, index=self.index_config)
235+
else:
236+
store = AsyncPostgresStore(conn=conn)
237+
return await fn(store, *args, **kwargs)
238+
239+
async def setup(self) -> None:
240+
"""Instantiate the store, setting up necessary persistent storage."""
241+
return await self._with_store(lambda s: s.setup())
242+
243+
async def abatch(self, ops: Iterable[Op]) -> list[Result]:
244+
"""Execute a batch of operations asynchronously.
245+
246+
This is the core method required by AsyncBatchedBaseStore. All other async operations
247+
(aget, aput, asearch, adelete, alist_namespaces) are inherited from AsyncBatchedBaseStore
248+
and internally call this abatch() method.
249+
"""
250+
return await self._with_store(lambda s: s.abatch(ops))
251+
252+
async def __aenter__(self):
253+
"""Enter async context manager and open the connection pool."""
254+
await self._lakebase.open()
255+
return self
256+
257+
async def __aexit__(self, exc_type, exc_val, exc_tb):
258+
"""Exit async context manager and close the connection pool."""
259+
await self._lakebase.close()
260+
return False

integrations/langchain/tests/unit_tests/test_checkpoint.py

Lines changed: 104 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from databricks_ai_bridge import lakebase
1212

13-
from databricks_langchain import CheckpointSaver
13+
from databricks_langchain import AsyncCheckpointSaver, CheckpointSaver
1414

1515

1616
class TestConnectionPool:
@@ -66,3 +66,106 @@ def test_checkpoint_saver_configures_lakebase(monkeypatch):
6666

6767
with saver._lakebase.connection() as conn:
6868
assert conn == "lake-conn"
69+
70+
71+
class TestAsyncConnectionPool:
72+
"""Mock async connection pool for testing."""
73+
74+
def __init__(self, connection_value="async-conn"):
75+
self.connection_value = connection_value
76+
self.conninfo = ""
77+
self._opened = False
78+
self._closed = False
79+
80+
def __call__(
81+
self,
82+
*,
83+
conninfo,
84+
connection_class=None,
85+
**kwargs,
86+
):
87+
self.conninfo = conninfo
88+
return self
89+
90+
def connection(self):
91+
class _AsyncCtx:
92+
def __init__(self, outer):
93+
self.outer = outer
94+
95+
async def __aenter__(self):
96+
return self.outer.connection_value
97+
98+
async def __aexit__(self, exc_type, exc, tb):
99+
pass
100+
101+
return _AsyncCtx(self)
102+
103+
async def open(self):
104+
self._opened = True
105+
106+
async def close(self):
107+
self._closed = True
108+
109+
110+
@pytest.mark.asyncio
111+
async def test_async_checkpoint_saver_configures_lakebase(monkeypatch):
112+
test_pool = TestAsyncConnectionPool(connection_value="async-lake-conn")
113+
monkeypatch.setattr(lakebase, "AsyncConnectionPool", test_pool)
114+
115+
workspace = MagicMock()
116+
workspace.database.generate_database_credential.return_value = MagicMock(token="stub-token")
117+
workspace.database.get_database_instance.return_value.read_write_dns = "db-host"
118+
workspace.current_service_principal.me.side_effect = RuntimeError("no sp")
119+
workspace.current_user.me.return_value = MagicMock(user_name="test@databricks.com")
120+
121+
saver = AsyncCheckpointSaver(
122+
instance_name="lakebase-instance",
123+
workspace_client=workspace,
124+
)
125+
126+
assert (
127+
test_pool.conninfo
128+
== "dbname=databricks_postgres user=test@databricks.com host=db-host port=5432 sslmode=require"
129+
)
130+
131+
assert saver._lakebase.pool == test_pool
132+
133+
134+
@pytest.mark.asyncio
135+
async def test_async_checkpoint_saver_context_manager(monkeypatch):
136+
test_pool = TestAsyncConnectionPool(connection_value="async-lake-conn")
137+
monkeypatch.setattr(lakebase, "AsyncConnectionPool", test_pool)
138+
139+
workspace = MagicMock()
140+
workspace.database.generate_database_credential.return_value = MagicMock(token="stub-token")
141+
workspace.database.get_database_instance.return_value.read_write_dns = "db-host"
142+
workspace.current_service_principal.me.side_effect = RuntimeError("no sp")
143+
workspace.current_user.me.return_value = MagicMock(user_name="test@databricks.com")
144+
145+
async with AsyncCheckpointSaver(
146+
instance_name="lakebase-instance",
147+
workspace_client=workspace,
148+
) as saver:
149+
assert test_pool._opened
150+
assert saver._lakebase.pool == test_pool
151+
152+
assert test_pool._closed
153+
154+
155+
@pytest.mark.asyncio
156+
async def test_async_checkpoint_saver_connection(monkeypatch):
157+
test_pool = TestAsyncConnectionPool(connection_value="async-lake-conn")
158+
monkeypatch.setattr(lakebase, "AsyncConnectionPool", test_pool)
159+
160+
workspace = MagicMock()
161+
workspace.database.generate_database_credential.return_value = MagicMock(token="stub-token")
162+
workspace.database.get_database_instance.return_value.read_write_dns = "db-host"
163+
workspace.current_service_principal.me.side_effect = RuntimeError("no sp")
164+
workspace.current_user.me.return_value = MagicMock(user_name="test@databricks.com")
165+
166+
async with AsyncCheckpointSaver(
167+
instance_name="lakebase-instance",
168+
workspace_client=workspace,
169+
) as saver:
170+
async with saver._lakebase.connection() as conn:
171+
assert conn == "async-lake-conn"

0 commit comments

Comments
 (0)