Skip to content

Commit a5d3b25

Browse files
authored
Implementation of upsert, aupsert, get_by_ids, aget_by_ids (#83)
* Add support for standard get_by_ids, aget_by_ids * Add support for upsert and aupsert * Remove add_texts and aadd_texts
1 parent ee42d24 commit a5d3b25

File tree

5 files changed

+555
-104
lines changed

5 files changed

+555
-104
lines changed

langchain_postgres/vectorstores.py

Lines changed: 113 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
Callable,
1212
Dict,
1313
Generator,
14-
Iterable,
1514
List,
1615
Optional,
1716
Sequence,
@@ -27,6 +26,7 @@
2726
import sqlalchemy
2827
from langchain_core.documents import Document
2928
from langchain_core.embeddings import Embeddings
29+
from langchain_core.indexing import UpsertResponse
3030
from langchain_core.utils import get_from_dict_or_env
3131
from langchain_core.vectorstores import VectorStore
3232
from sqlalchemy import SQLColumnExpression, cast, create_engine, delete, func, select
@@ -714,7 +714,7 @@ async def __afrom(
714714

715715
def add_embeddings(
716716
self,
717-
texts: Iterable[str],
717+
texts: Sequence[str],
718718
embeddings: List[List[float]],
719719
metadatas: Optional[List[dict]] = None,
720720
ids: Optional[List[str]] = None,
@@ -770,7 +770,7 @@ def add_embeddings(
770770

771771
async def aadd_embeddings(
772772
self,
773-
texts: Iterable[str],
773+
texts: Sequence[str],
774774
embeddings: List[List[float]],
775775
metadatas: Optional[List[dict]] = None,
776776
ids: Optional[List[str]] = None,
@@ -824,56 +824,6 @@ async def aadd_embeddings(
824824

825825
return ids
826826

827-
def add_texts(
828-
self,
829-
texts: Iterable[str],
830-
metadatas: Optional[List[dict]] = None,
831-
ids: Optional[List[str]] = None,
832-
**kwargs: Any,
833-
) -> List[str]:
834-
"""Run more texts through the embeddings and add to the vectorstore.
835-
836-
Args:
837-
texts: Iterable of strings to add to the vectorstore.
838-
metadatas: Optional list of metadatas associated with the texts.
839-
ids: Optional list of ids for the texts.
840-
If not provided, will generate a new id for each text.
841-
kwargs: vectorstore specific parameters
842-
843-
Returns:
844-
List of ids from adding the texts into the vectorstore.
845-
"""
846-
assert not self._async_engine, "This method must be called without async_mode"
847-
embeddings = self.embedding_function.embed_documents(list(texts))
848-
return self.add_embeddings(
849-
texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs
850-
)
851-
852-
async def aadd_texts(
853-
self,
854-
texts: Iterable[str],
855-
metadatas: Optional[List[dict]] = None,
856-
ids: Optional[List[str]] = None,
857-
**kwargs: Any,
858-
) -> List[str]:
859-
"""Run more texts through the embeddings and add to the vectorstore.
860-
861-
Args:
862-
texts: Iterable of strings to add to the vectorstore.
863-
metadatas: Optional list of metadatas associated with the texts.
864-
ids: Optional list of ids for the texts.
865-
If not provided, will generate a new id for each text.
866-
kwargs: vectorstore specific parameters
867-
868-
Returns:
869-
List of ids from adding the texts into the vectorstore.
870-
"""
871-
await self.__apost_init__() # Lazy async init
872-
embeddings = await self.embedding_function.aembed_documents(list(texts))
873-
return await self.aadd_embeddings(
874-
texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs
875-
)
876-
877827
def similarity_search(
878828
self,
879829
query: str,
@@ -1014,6 +964,7 @@ def _results_to_docs_and_scores(self, results: Any) -> List[Tuple[Document, floa
1014964
docs = [
1015965
(
1016966
Document(
967+
id=str(result.EmbeddingStore.id),
1017968
page_content=result.EmbeddingStore.document,
1018969
metadata=result.EmbeddingStore.cmetadata,
1019970
),
@@ -2178,3 +2129,112 @@ async def _make_async_session(self) -> AsyncGenerator[AsyncSession, None]:
21782129
)
21792130
async with self.session_maker() as session:
21802131
yield typing_cast(AsyncSession, session)
2132+
2133+
def upsert(self, items: Sequence[Document], /, **kwargs: Any) -> UpsertResponse:
2134+
"""Upsert documents into the vectorstore.
2135+
2136+
Args:
2137+
items: Sequence of documents to upsert.
2138+
kwargs: vectorstore specific parameters
2139+
2140+
Returns:
2141+
UpsertResponse
2142+
"""
2143+
if self._async_engine:
2144+
raise AssertionError("This method must be called in sync mode.")
2145+
texts = [item.page_content for item in items]
2146+
metadatas = [item.metadata for item in items]
2147+
ids = [item.id if item.id is not None else str(uuid.uuid4()) for item in items]
2148+
embeddings = self.embedding_function.embed_documents(list(texts))
2149+
added_ids = self.add_embeddings(
2150+
texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs
2151+
)
2152+
return {
2153+
"succeeded": added_ids,
2154+
"failed": [
2155+
item.id
2156+
for item in items
2157+
if item.id is not None and item.id not in added_ids
2158+
],
2159+
}
2160+
2161+
async def aupsert(
2162+
self, items: Sequence[Document], /, **kwargs: Any
2163+
) -> UpsertResponse:
2164+
"""Upsert documents into the vectorstore.
2165+
2166+
Args:
2167+
items: Sequence of documents to upsert.
2168+
kwargs: vectorstore specific parameters
2169+
2170+
Returns:
2171+
UpsertResponse
2172+
"""
2173+
if not self._async_engine:
2174+
raise AssertionError("This method must be called with async_mode")
2175+
texts = [item.page_content for item in items]
2176+
metadatas = [item.metadata for item in items]
2177+
ids = [item.id if item.id is not None else str(uuid.uuid4()) for item in items]
2178+
embeddings = await self.embedding_function.aembed_documents(list(texts))
2179+
added_ids = await self.aadd_embeddings(
2180+
texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs
2181+
)
2182+
return {
2183+
"succeeded": added_ids,
2184+
"failed": [
2185+
item.id
2186+
for item in items
2187+
if item.id is not None and item.id not in added_ids
2188+
],
2189+
}
2190+
2191+
def get_by_ids(self, ids: Sequence[str], /) -> List[Document]:
2192+
"""Get documents by ids."""
2193+
documents = []
2194+
with self._make_sync_session() as session:
2195+
collection = self.get_collection(session)
2196+
filter_by = [self.EmbeddingStore.collection_id == collection.uuid]
2197+
stmt = (
2198+
select(
2199+
self.EmbeddingStore,
2200+
)
2201+
.where(self.EmbeddingStore.id.in_(ids))
2202+
.filter(*filter_by)
2203+
)
2204+
2205+
for result in session.execute(stmt).scalars().all():
2206+
documents.append(
2207+
Document(
2208+
id=result.id,
2209+
page_content=result.document,
2210+
metadata=result.cmetadata,
2211+
)
2212+
)
2213+
return documents
2214+
2215+
async def aget_by_ids(self, ids: Sequence[str], /) -> List[Document]:
2216+
"""Get documents by ids."""
2217+
documents = []
2218+
async with self._make_async_session() as session:
2219+
collection = await self.aget_collection(session)
2220+
filter_by = [self.EmbeddingStore.collection_id == collection.uuid]
2221+
2222+
stmt = (
2223+
select(
2224+
self.EmbeddingStore,
2225+
)
2226+
.where(self.EmbeddingStore.id.in_(ids))
2227+
.filter(*filter_by)
2228+
)
2229+
2230+
results: Sequence[Any] = (await session.execute(stmt)).scalars().all()
2231+
2232+
for result in results:
2233+
documents.append(
2234+
Document(
2235+
id=str(result.id),
2236+
page_content=result.document,
2237+
metadata=result.cmetadata,
2238+
)
2239+
)
2240+
return documents

0 commit comments

Comments
 (0)