Skip to content

Commit 9aabb44

Browse files
ppradosbaskaryanpiotrm0zc277584121eyurtsev
authored
community[minor]: Add SQL storage implementation (#22207)
Hello @eyurtsev - package: langchain-comminity - **Description**: Add SQL implementation for docstore. A new implementation, in line with my other PR ([async PGVector](langchain-ai/langchain-postgres#32), [SQLChatMessageMemory](#22065)) - Twitter handler: pprados --------- Signed-off-by: ChengZi <[email protected]> Co-authored-by: Bagatur <[email protected]> Co-authored-by: Piotr Mardziel <[email protected]> Co-authored-by: ChengZi <[email protected]> Co-authored-by: Eugene Yurtsev <[email protected]>
1 parent f2f0e0e commit 9aabb44

File tree

6 files changed

+548
-1
lines changed

6 files changed

+548
-1
lines changed

docs/docs/integrations/vectorstores/milvus.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,4 +390,4 @@
390390
},
391391
"nbformat": 4,
392392
"nbformat_minor": 5
393-
}
393+
}

libs/community/langchain_community/storage/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@
3131
from langchain_community.storage.redis import (
3232
RedisStore,
3333
)
34+
from langchain_community.storage.sql import (
35+
SQLStore,
36+
)
3437
from langchain_community.storage.upstash_redis import (
3538
UpstashRedisByteStore,
3639
UpstashRedisStore,
@@ -42,6 +45,7 @@
4245
"CassandraByteStore",
4346
"MongoDBStore",
4447
"RedisStore",
48+
"SQLStore",
4549
"UpstashRedisByteStore",
4650
"UpstashRedisStore",
4751
]
@@ -52,6 +56,7 @@
5256
"CassandraByteStore": "langchain_community.storage.cassandra",
5357
"MongoDBStore": "langchain_community.storage.mongodb",
5458
"RedisStore": "langchain_community.storage.redis",
59+
"SQLStore": "langchain_community.storage.sql",
5560
"UpstashRedisByteStore": "langchain_community.storage.upstash_redis",
5661
"UpstashRedisStore": "langchain_community.storage.upstash_redis",
5762
}
Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
import contextlib
2+
from pathlib import Path
3+
from typing import (
4+
Any,
5+
AsyncGenerator,
6+
AsyncIterator,
7+
Dict,
8+
Generator,
9+
Iterator,
10+
List,
11+
Optional,
12+
Sequence,
13+
Tuple,
14+
Union,
15+
cast,
16+
)
17+
18+
from langchain_core.stores import BaseStore
19+
from sqlalchemy import (
20+
Engine,
21+
LargeBinary,
22+
and_,
23+
create_engine,
24+
delete,
25+
select,
26+
)
27+
from sqlalchemy.ext.asyncio import (
28+
AsyncEngine,
29+
AsyncSession,
30+
async_sessionmaker,
31+
create_async_engine,
32+
)
33+
from sqlalchemy.orm import (
34+
Mapped,
35+
Session,
36+
declarative_base,
37+
mapped_column,
38+
sessionmaker,
39+
)
40+
41+
Base = declarative_base()
42+
43+
44+
def items_equal(x: Any, y: Any) -> bool:
45+
return x == y
46+
47+
48+
class LangchainKeyValueStores(Base): # type: ignore[valid-type,misc]
49+
"""Table used to save values."""
50+
51+
# ATTENTION:
52+
# Prior to modifying this table, please determine whether
53+
# we should create migrations for this table to make sure
54+
# users do not experience data loss.
55+
__tablename__ = "langchain_key_value_stores"
56+
57+
namespace: Mapped[str] = mapped_column(primary_key=True, index=True, nullable=False)
58+
key: Mapped[str] = mapped_column(primary_key=True, index=True, nullable=False)
59+
value = mapped_column(LargeBinary, index=False, nullable=False)
60+
61+
62+
# This is a fix of original SQLStore.
63+
# This can will be removed when a PR will be merged.
64+
class SQLStore(BaseStore[str, bytes]):
65+
"""BaseStore interface that works on an SQL database.
66+
67+
Examples:
68+
Create a SQLStore instance and perform operations on it:
69+
70+
.. code-block:: python
71+
72+
from langchain_rag.storage import SQLStore
73+
74+
# Instantiate the SQLStore with the root path
75+
sql_store = SQLStore(namespace="test", db_url="sqllite://:memory:")
76+
77+
# Set values for keys
78+
sql_store.mset([("key1", b"value1"), ("key2", b"value2")])
79+
80+
# Get values for keys
81+
values = sql_store.mget(["key1", "key2"]) # Returns [b"value1", b"value2"]
82+
83+
# Delete keys
84+
sql_store.mdelete(["key1"])
85+
86+
# Iterate over keys
87+
for key in sql_store.yield_keys():
88+
print(key)
89+
90+
"""
91+
92+
def __init__(
93+
self,
94+
*,
95+
namespace: str,
96+
db_url: Optional[Union[str, Path]] = None,
97+
engine: Optional[Union[Engine, AsyncEngine]] = None,
98+
engine_kwargs: Optional[Dict[str, Any]] = None,
99+
async_mode: Optional[bool] = None,
100+
):
101+
if db_url is None and engine is None:
102+
raise ValueError("Must specify either db_url or engine")
103+
104+
if db_url is not None and engine is not None:
105+
raise ValueError("Must specify either db_url or engine, not both")
106+
107+
_engine: Union[Engine, AsyncEngine]
108+
if db_url:
109+
if async_mode is None:
110+
async_mode = False
111+
if async_mode:
112+
_engine = create_async_engine(
113+
url=str(db_url),
114+
**(engine_kwargs or {}),
115+
)
116+
else:
117+
_engine = create_engine(url=str(db_url), **(engine_kwargs or {}))
118+
elif engine:
119+
_engine = engine
120+
121+
else:
122+
raise AssertionError("Something went wrong with configuration of engine.")
123+
124+
_session_maker: Union[sessionmaker[Session], async_sessionmaker[AsyncSession]]
125+
if isinstance(_engine, AsyncEngine):
126+
self.async_mode = True
127+
_session_maker = async_sessionmaker(bind=_engine)
128+
else:
129+
self.async_mode = False
130+
_session_maker = sessionmaker(bind=_engine)
131+
132+
self.engine = _engine
133+
self.dialect = _engine.dialect.name
134+
self.session_maker = _session_maker
135+
self.namespace = namespace
136+
137+
def create_schema(self) -> None:
138+
Base.metadata.create_all(self.engine)
139+
140+
async def acreate_schema(self) -> None:
141+
assert isinstance(self.engine, AsyncEngine)
142+
async with self.engine.begin() as session:
143+
await session.run_sync(Base.metadata.create_all)
144+
145+
def drop(self) -> None:
146+
Base.metadata.drop_all(bind=self.engine.connect())
147+
148+
async def amget(self, keys: Sequence[str]) -> List[Optional[bytes]]:
149+
assert isinstance(self.engine, AsyncEngine)
150+
result: Dict[str, bytes] = {}
151+
async with self._make_async_session() as session:
152+
stmt = select(LangchainKeyValueStores).filter(
153+
and_(
154+
LangchainKeyValueStores.key.in_(keys),
155+
LangchainKeyValueStores.namespace == self.namespace,
156+
)
157+
)
158+
for v in await session.scalars(stmt):
159+
result[v.key] = v.value
160+
return [result.get(key) for key in keys]
161+
162+
def mget(self, keys: Sequence[str]) -> List[Optional[bytes]]:
163+
result = {}
164+
165+
with self._make_sync_session() as session:
166+
stmt = select(LangchainKeyValueStores).filter(
167+
and_(
168+
LangchainKeyValueStores.key.in_(keys),
169+
LangchainKeyValueStores.namespace == self.namespace,
170+
)
171+
)
172+
for v in session.scalars(stmt):
173+
result[v.key] = v.value
174+
return [result.get(key) for key in keys]
175+
176+
async def amset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None:
177+
async with self._make_async_session() as session:
178+
await self._amdelete([key for key, _ in key_value_pairs], session)
179+
session.add_all(
180+
[
181+
LangchainKeyValueStores(namespace=self.namespace, key=k, value=v)
182+
for k, v in key_value_pairs
183+
]
184+
)
185+
await session.commit()
186+
187+
def mset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None:
188+
values: Dict[str, bytes] = dict(key_value_pairs)
189+
with self._make_sync_session() as session:
190+
self._mdelete(list(values.keys()), session)
191+
session.add_all(
192+
[
193+
LangchainKeyValueStores(namespace=self.namespace, key=k, value=v)
194+
for k, v in values.items()
195+
]
196+
)
197+
session.commit()
198+
199+
def _mdelete(self, keys: Sequence[str], session: Session) -> None:
200+
stmt = delete(LangchainKeyValueStores).filter(
201+
and_(
202+
LangchainKeyValueStores.key.in_(keys),
203+
LangchainKeyValueStores.namespace == self.namespace,
204+
)
205+
)
206+
session.execute(stmt)
207+
208+
async def _amdelete(self, keys: Sequence[str], session: AsyncSession) -> None:
209+
stmt = delete(LangchainKeyValueStores).filter(
210+
and_(
211+
LangchainKeyValueStores.key.in_(keys),
212+
LangchainKeyValueStores.namespace == self.namespace,
213+
)
214+
)
215+
await session.execute(stmt)
216+
217+
def mdelete(self, keys: Sequence[str]) -> None:
218+
with self._make_sync_session() as session:
219+
self._mdelete(keys, session)
220+
session.commit()
221+
222+
async def amdelete(self, keys: Sequence[str]) -> None:
223+
async with self._make_async_session() as session:
224+
await self._amdelete(keys, session)
225+
await session.commit()
226+
227+
def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
228+
with self._make_sync_session() as session:
229+
for v in session.query(LangchainKeyValueStores).filter( # type: ignore
230+
LangchainKeyValueStores.namespace == self.namespace
231+
):
232+
if str(v.key).startswith(prefix or ""):
233+
yield str(v.key)
234+
session.close()
235+
236+
async def ayield_keys(self, *, prefix: Optional[str] = None) -> AsyncIterator[str]:
237+
async with self._make_async_session() as session:
238+
stmt = select(LangchainKeyValueStores).filter(
239+
LangchainKeyValueStores.namespace == self.namespace
240+
)
241+
for v in await session.scalars(stmt):
242+
if str(v.key).startswith(prefix or ""):
243+
yield str(v.key)
244+
await session.close()
245+
246+
@contextlib.contextmanager
247+
def _make_sync_session(self) -> Generator[Session, None, None]:
248+
"""Make an async session."""
249+
if self.async_mode:
250+
raise ValueError(
251+
"Attempting to use a sync method in when async mode is turned on. "
252+
"Please use the corresponding async method instead."
253+
)
254+
with cast(Session, self.session_maker()) as session:
255+
yield cast(Session, session)
256+
257+
@contextlib.asynccontextmanager
258+
async def _make_async_session(self) -> AsyncGenerator[AsyncSession, None]:
259+
"""Make an async session."""
260+
if not self.async_mode:
261+
raise ValueError(
262+
"Attempting to use an async method in when sync mode is turned on. "
263+
"Please use the corresponding async method instead."
264+
)
265+
async with cast(AsyncSession, self.session_maker()) as session:
266+
yield cast(AsyncSession, session)

0 commit comments

Comments
 (0)