From 188c83dedd6ac605c6f9f2f4e1441a664563b0b9 Mon Sep 17 00:00:00 2001 From: Philippe Prados Date: Fri, 12 Apr 2024 13:29:59 +0200 Subject: [PATCH 01/18] Pre integration PPR --- langchain_postgres/vectorstores_ppr.py | 1996 ++++++++++++++++++++++++ tests/unit_tests/test_vectorstore.py | 2 +- 2 files changed, 1997 insertions(+), 1 deletion(-) create mode 100644 langchain_postgres/vectorstores_ppr.py diff --git a/langchain_postgres/vectorstores_ppr.py b/langchain_postgres/vectorstores_ppr.py new file mode 100644 index 00000000..1989a83a --- /dev/null +++ b/langchain_postgres/vectorstores_ppr.py @@ -0,0 +1,1996 @@ +from __future__ import annotations + +import asyncio +import contextlib +import enum +import json +import logging +import uuid +from typing import ( + Any, + Callable, + Dict, + Generator, + Iterable, + List, + Optional, + Tuple, + Type, Union, +) + +import numpy as np +import sqlalchemy +from langchain_core._api import warn_deprecated +from sqlalchemy import SQLColumnExpression, cast, delete, func, select, Engine +from sqlalchemy.dialects.postgresql import JSON, JSONB, JSONPATH, UUID, insert +from sqlalchemy.orm import Session, relationship, sessionmaker +# TODO: accepter l'absence de l'option async lors des imports +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker + +try: + from sqlalchemy.orm import declarative_base +except ImportError: + from sqlalchemy.ext.declarative import declarative_base + +from langchain_core.documents import Document +from langchain_core.embeddings import Embeddings +from langchain_core.runnables.config import run_in_executor +from langchain_core.utils import get_from_dict_or_env +from langchain_core.vectorstores import VectorStore + +from langchain_postgres._utils import maximal_marginal_relevance + + +class DistanceStrategy(str, enum.Enum): + """Enumerator of the Distance strategies.""" + + EUCLIDEAN = "l2" + COSINE = "cosine" + MAX_INNER_PRODUCT = "inner" + + +DEFAULT_DISTANCE_STRATEGY = DistanceStrategy.COSINE + +Base = declarative_base() # type: Any + +_LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain" + + +# class BaseModel(Base): +# """Base model for the SQL stores.""" +# +# __abstract__ = True +# uuid = sqlalchemy.Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) +# + +_classes: Any = None + +COMPARISONS_TO_NATIVE = { + "$eq": "==", + "$ne": "!=", + "$lt": "<", + "$lte": "<=", + "$gt": ">", + "$gte": ">=", +} + +SPECIAL_CASED_OPERATORS = { + "$in", + "$nin", + "$between", +} + +TEXT_OPERATORS = { + "$like", + "$ilike", +} + +LOGICAL_OPERATORS = {"$and", "$or"} + +SUPPORTED_OPERATORS = ( + set(COMPARISONS_TO_NATIVE) + .union(TEXT_OPERATORS) + .union(LOGICAL_OPERATORS) + .union(SPECIAL_CASED_OPERATORS) +) + + +def _get_embedding_collection_store(vector_dimension: Optional[int] = None) -> Any: + global _classes + if _classes is not None: + return _classes + + from pgvector.sqlalchemy import Vector # type: ignore + + class CollectionStore(Base): + """Collection store.""" + + __tablename__ = "langchain_pg_collection" + + uuid = sqlalchemy.Column( + UUID(as_uuid=True), primary_key=True, default=uuid.uuid4 + ) + name = sqlalchemy.Column(sqlalchemy.String, nullable=False, unique=True) + cmetadata = sqlalchemy.Column(JSON) + + embeddings = relationship( + "EmbeddingStore", + back_populates="collection", + passive_deletes=True, + ) + + @classmethod + def get_by_name( + cls, session: Session, name: str + ) -> Optional["CollectionStore"]: + # FIXME return session.query(cls).filter(cls.name == name).first() # type: ignore + return session.execute( + select(cls).filter(cls.name == name)).scalars().first() + + # @classmethod + # async def aget_by_name( + # cls, session: AsyncSession, name: str + # ) -> Optional["CollectionStore"]: + # stmt = select(cls).filter(cls.name == name) + # # return await session.execute(stmt) # FIXME + # return (await session.execute(stmt)).scalars().first() # FIXME + # + # # stmt = select(cls).filter(cls.name == name) + # # result = await session.execute(stmt) + # # x = result.scalars() + # # return session.query(cls).filter(cls.name == name).first() + + @classmethod + def get_or_create( + cls, + session: Session, + name: str, + cmetadata: Optional[dict] = None, + ) -> Tuple["CollectionStore", bool]: + """Get or create a collection. + Returns: + Where the bool is True if the collection was created. + """ # noqa: E501 + created = False + collection = cls.get_by_name(session, name) + if collection: + return collection, created + + collection = cls(name=name, cmetadata=cmetadata) + session.add(collection) + session.commit() # FIXME PPR semble utile + created = True + return collection, created + + # @classmethod + # async def aget_or_create( + # cls, + # session: AsyncSession, + # name: str, + # cmetadata: Optional[dict] = None, + # ) -> Tuple["CollectionStore", bool]: + # """ + # Get or create a collection. + # Returns [Collection, bool] where the bool is True if the collection was created. + # """ # noqa: E501 + # created = False + # collection = await cls.aget_by_name(session, name) + # if collection: + # return collection, created + # + # collection = cls(name=name, cmetadata=cmetadata) + # session.add(collection) + # await session.commit() + # created = True + # return collection, created + + class EmbeddingStore(Base): + """Embedding store.""" + + __tablename__ = "langchain_pg_embedding" + + id = sqlalchemy.Column( + sqlalchemy.String, nullable=True, primary_key=True, index=True, unique=True + ) + + collection_id = sqlalchemy.Column( + UUID(as_uuid=True), + sqlalchemy.ForeignKey( + f"{CollectionStore.__tablename__}.uuid", + ondelete="CASCADE", + ), + ) + collection = relationship(CollectionStore, back_populates="embeddings") + + embedding: Vector = sqlalchemy.Column(Vector(vector_dimension)) + document = sqlalchemy.Column(sqlalchemy.String, nullable=True) + cmetadata = sqlalchemy.Column(JSONB, nullable=True) + + __table_args__ = ( + sqlalchemy.Index( + "ix_cmetadata_gin", + "cmetadata", + postgresql_using="gin", + postgresql_ops={"cmetadata": "jsonb_path_ops"}, + ), + ) + + _classes = (EmbeddingStore, CollectionStore) + + return _classes + + +def _results_to_docs(docs_and_scores: Any) -> List[Document]: + """Return docs from docs and scores.""" + return [doc for doc, _ in docs_and_scores] + + +Connection = Union[sqlalchemy.engine.Engine, str] + + +class PGVector(VectorStore): + """Vectorstore implementation using Postgres as the backend. + + Currently, there is no mechanism for supporting data migration. + + So breaking changes in the vectorstore schema will require the user to recreate + the tables and re-add the documents. + + If this is a concern, please use a different vectorstore. If + not, this implementation should be fine for your use case. + + To use this vectorstore you need to have the `vector` extension installed. + The `vector` extension is a Postgres extension that provides vector + similarity search capabilities. + + ```sh + docker run --name pgvector-container -e POSTGRES_PASSWORD=... + -d pgvector/pgvector:pg16 + ``` + + Example: + .. code-block:: python + + from langchain_postgres.vectorstores import PGVector + from langchain_openai.embeddings import OpenAIEmbeddings + + connection_string = "postgresql+psycopg://..." + collection_name = "state_of_the_union_test" + embeddings = OpenAIEmbeddings() + vectorstore = PGVector.from_documents( + embedding=embeddings, + documents=docs, + connection=connection_string, + collection_name=collection_name, + use_jsonb=True, + ) + + + This code has been ported over from langchain_community with minimal changes + to allow users to easily transition from langchain_community to langchain_postgres. + + Some changes had to be made to address issues with the community implementation: + * langchain_postgres now works with psycopg3. Please update your + connection strings from `postgresql+psycopg2://...` to + `postgresql+psycopg://langchain:langchain@...` + (yes, the driver name is `psycopg` not `psycopg3`) + * The schema of the embedding store and collection have been changed to make + add_documents work correctly with user specified ids, specifically + when overwriting existing documents. + You will need to recreate the tables if you are using an existing database. + * A Connection object has to be provided explicitly. Connections will not be + picked up automatically based on env variables. + """ + + def __init__( + self, + embeddings: Embeddings, + *, + connection: Optional[Connection] = None, + embedding_length: Optional[int] = None, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + collection_metadata: Optional[dict] = None, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + pre_delete_collection: bool = False, + logger: Optional[logging.Logger] = None, + relevance_score_fn: Optional[Callable[[float], float]] = None, + engine_args: Optional[dict[str, Any]] = None, + use_jsonb: bool = True, + create_extension: bool = True, + ) -> None: + """Initialize the PGVector store. + + Args: + connection: Postgres connection string. + embeddings: Any embedding function implementing + `langchain.embeddings.base.Embeddings` interface. + embedding_length: The length of the embedding vector. (default: None) + NOTE: This is not mandatory. Defining it will prevent vectors of + any other size to be added to the embeddings table but, without it, + the embeddings can't be indexed. + collection_name: The name of the collection to use. (default: langchain) + NOTE: This is not the name of the table, but the name of the collection. + The tables will be created when initializing the store (if not exists) + So, make sure the user has the right permissions to create tables. + distance_strategy: The distance strategy to use. (default: COSINE) + pre_delete_collection: If True, will delete the collection if it exists. + (default: False). Useful for testing. + engine_args: SQLAlchemy's create engine arguments. + use_jsonb: Use JSONB instead of JSON for metadata. (default: True) + Strongly discouraged from using JSON as it's not as efficient + for querying. + It's provided here for backwards compatibility with older versions, + and will be removed in the future. + create_extension: If True, will create the vector extension if it + doesn't exist. disabling creation is useful when using ReadOnly + Databases. + """ + self.embedding_function = embeddings + self._embedding_length = embedding_length + self.collection_name = collection_name + self.collection_metadata = collection_metadata + self._distance_strategy = distance_strategy + self.pre_delete_collection = pre_delete_collection + self.logger = logger or logging.getLogger(__name__) + self.override_relevance_score_fn = relevance_score_fn + + if isinstance(connection, str): + self._engine = sqlalchemy.create_engine( + url=connection, **(engine_args or {}) + ) + elif isinstance(connection, sqlalchemy.engine.Engine): + self._engine = connection + else: + raise ValueError( + "connection should be a connection string or an instance of " + "sqlalchemy.engine.Engine" + ) + + self._session_maker = sessionmaker(bind=self._engine) + + self.use_jsonb = use_jsonb + self.create_extension = create_extension + + if not use_jsonb: + # Replace with a deprecation warning. + raise NotImplementedError("use_jsonb=False is no longer supported.") + self.__post_init__() + + def __post_init__( + self, + ) -> None: + """Initialize the store.""" + if self.create_extension: + self.create_vector_extension() + + EmbeddingStore, CollectionStore = _get_embedding_collection_store( + self._embedding_length + ) + self.CollectionStore = CollectionStore + self.EmbeddingStore = EmbeddingStore + self.create_tables_if_not_exists() + self.create_collection() + + # async def __apost_init__( + # self, + # ) -> None: + # + # if self.async_mode: + # self._session_maker = self._build_async_sessionmaker() + # """Initialize the store.""" + # EmbeddingStore, CollectionStore = _get_embedding_collection_store( + # self._embedding_length, use_jsonb=self.use_jsonb + # ) + # self.CollectionStore = CollectionStore + # self.EmbeddingStore = EmbeddingStore + # + # if self.create_extension: + # await self.acreate_vector_extension() + # + # await self.acreate_tables_if_not_exists() + # await self.acreate_collection() + + def __del__(self) -> None: + if isinstance(self._bind, sqlalchemy.engine.Connection): + if self.async_mode: + asyncio.run(self._bind.close()) + else: + self._bind.close() + + @property + def embeddings(self) -> Embeddings: + return self.embedding_function + + def _create_engine(self, async_mode: bool = False) -> sqlalchemy.engine.Engine: + if async_mode: + from sqlalchemy.ext.asyncio import create_async_engine + # FIXME: gérer appel async sur un sync + from sqlalchemy.exc import InvalidRequestError + try: + return create_async_engine( + url=self.connection_string, + isolation_level="REPEATABLE READ", # FIXME: merge avec la suite ? + echo=True, # FIXME: a virer + **self.engine_args + ) + except InvalidRequestError: + pass # Ignore and return the synchrone version + logging.warning("Use a synchrone SQL engine !") + return sqlalchemy.create_engine(url=self.connection_string, + echo=True, # FIXME: a virer + **self.engine_args) + + def create_vector_extension(self) -> None: + try: + with self._session_maker() as session: # type: ignore[arg-type] + # The advisor lock fixes issue arising from concurrent + # creation of the vector extension. + # https://github.com/langchain-ai/langchain/issues/12933 + # For more information see: + # https://www.postgresql.org/docs/16/explicit-locking.html#ADVISORY-LOCKS + statement = sqlalchemy.text( + "BEGIN;" + "SELECT pg_advisory_xact_lock(1573678846307946496);" + "CREATE EXTENSION IF NOT EXISTS vector;" + "COMMIT;" + ) + session.execute(statement) + session.commit() + except Exception as e: + raise Exception(f"Failed to create vector extension: {e}") from e + + # async def acreate_vector_extension(self) -> None: + # try: + # async with self._amake_session() as session: + # # The advisor lock fixes issue arising from concurrent + # # creation of the vector extension. + # # https://github.com/langchain-ai/langchain/issues/12933 + # # For more information see: + # # https://www.postgresql.org/docs/16/explicit-locking.html#ADVISORY-LOCKS + # await session.execute( + # sqlalchemy.text( + # "SELECT pg_advisory_xact_lock(1573678846307946496)")) + # await session.execute( + # sqlalchemy.text("CREATE EXTENSION IF NOT EXISTS vector")) + # except Exception as e: + raise Exception(f"Failed to create vector extension: {e}") from e + + def create_tables_if_not_exists(self) -> None: + with self._session_maker() as session: + Base.metadata.create_all(session.get_bind()) + + async def acreate_tables_if_not_exists(self) -> None: + if isinstance(self._bind, sqlalchemy.ext.asyncio.engine.AsyncConnection): + await self._bind.run_sync(Base.metadata.create_all) + else: + async with self._bind.begin() as conn: # FIXME: session.run_sync existe + await conn.run_sync(Base.metadata.create_all) + # async with self._amake_session() as session: + # await session.run_sync(Base.metadata.create_all) + + def drop_tables(self) -> None: + with self._session_maker() as session: + Base.metadata.drop_all(session.get_bind()) + + # async def adrop_tables(self) -> None: + # async with self._amake_session() as session: + # await session.run_sync(Base.metadata.drop_all) + + def create_collection(self) -> None: + if self.pre_delete_collection: + self.delete_collection() + with self._session_maker() as session: + self.CollectionStore.get_or_create( + session, self.collection_name, cmetadata=self.collection_metadata + ) + + # async def acreate_collection(self) -> None: + # async with self._amake_session() as session: + # if self.pre_delete_collection: + # await self._adelete_collection(session) + # await self.CollectionStore.aget_or_create( + # session, self.collection_name, cmetadata=self.collection_metadata + # ) + + def _delete_collection(self,session: Session) -> None: + self.logger.debug("Trying to delete collection") + collection = self.get_collection(session) + if not collection: + self.logger.warning("Collection not found") + return + session.delete(collection) + + # async def _adelete_collection(self, session: AsyncSession) -> None: + # self.logger.debug("Trying to delete collection") + # collection = await self.aget_collection(session) + # if not collection: + # self.logger.warning("Collection not found") + # return + # await session.delete(collection) + + # def delete_collection(self) -> None: + # with self._session_maker() as session: + # self._delete_collection(session) + def delete_collection(self) -> None: + self.logger.debug("Trying to delete collection") + with self._session_maker() as session: # type: ignore[arg-type] + collection = self.get_collection(session) + if not collection: + self.logger.warning("Collection not found") + return + session.delete(collection) + session.commit() + + def delete( + self, + ids: Optional[List[str]] = None, + collection_only: bool = False, + **kwargs: Any, + ) -> None: + """Delete vectors by ids or uuids. + + Args: + ids: List of ids to delete. + collection_only: Only delete ids in the collection. + """ + with self._session_maker() as session: + if ids is not None: + self.logger.debug( + "Trying to delete vectors by ids (represented by the model " + "using the custom ids field)" + ) + + stmt = delete(self.EmbeddingStore) + + if collection_only: + collection = self.get_collection(session) + if not collection: + self.logger.warning("Collection not found") + return + + stmt = stmt.where( + self.EmbeddingStore.collection_id == collection.uuid + ) + + stmt = stmt.where(self.EmbeddingStore.id.in_(ids)) + session.execute(stmt) + session.commit() + + # async def adelete( + # self, + # ids: Optional[List[str]] = None, + # collection_only: bool = False, + # **kwargs: Any, + # ) -> None: + # """Delete vectors by ids or uuids. + # + # Args: + # ids: List of ids to delete. + # collection_only: Only delete ids in the collection. + # """ + # async with self._amake_session() as session: + # if ids is not None: + # self.logger.debug( + # "Trying to delete vectors by ids (represented by the model " + # "using the custom ids field)" + # ) + # + # stmt = delete(self.EmbeddingStore) + # + # if collection_only: + # collection = await self.aget_collection(session) + # if not collection: + # self.logger.warning("Collection not found") + # return + # + # stmt = stmt.where( + # self.EmbeddingStore.collection_id == collection.uuid + # ) + # + # stmt = stmt.where(self.EmbeddingStore.custom_id.in_(ids)) + # await session.execute(stmt) + + def get_collection(self, session: Session) -> Any: + return self.CollectionStore.get_by_name(session, self.collection_name) + + # async def aget_collection(self, session: AsyncSession) -> Any: + # return await self.CollectionStore.aget_by_name(session, self.collection_name) + + @classmethod + def __from( + cls, + texts: List[str], + embeddings: List[List[float]], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + connection: Optional[str] = None, + pre_delete_collection: bool = False, + *, + use_jsonb: bool = True, + **kwargs: Any, + ) -> PGVector: + if ids is None: + ids = [str(uuid.uuid1()) for _ in texts] + + if not metadatas: + metadatas = [{} for _ in texts] + + store = cls( + connection=connection, + collection_name=collection_name, + embeddings=embedding, + distance_strategy=distance_strategy, + pre_delete_collection=pre_delete_collection, + use_jsonb=use_jsonb, + **kwargs, + ) + + store.add_embeddings( + texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs + ) + + return store + + # @classmethod + # async def __afrom( + # cls, + # texts: List[str], + # embeddings: List[List[float]], + # embedding: Embeddings, + # metadatas: Optional[List[dict]] = None, + # ids: Optional[List[str]] = None, + # collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + # distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + # connection_string: Optional[str] = None, + # pre_delete_collection: bool = False, + # *, + # use_jsonb: bool = False, + # **kwargs: Any, + # ) -> PGVector: + # if ids is None: + # ids = [str(uuid.uuid1()) for _ in texts] + # + # if not metadatas: + # metadatas = [{} for _ in texts] + # if connection_string is None: + # connection_string = cls.get_connection_string(kwargs) + # + # store = cls( + # connection_string=connection_string, + # collection_name=collection_name, + # embedding_function=embedding, + # distance_strategy=distance_strategy, + # pre_delete_collection=pre_delete_collection, + # use_jsonb=use_jsonb, + # async_mode=True, # FIXME + # **kwargs, + # ) + # # Second phase to create + # await store.__apost_init__() + # + # await store.aadd_embeddings( + # texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs + # ) + # + # return store + + def add_embeddings( + self, + texts: Iterable[str], + embeddings: List[List[float]], + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + **kwargs: Any, + ) -> List[str]: + """Add embeddings to the vectorstore. + + Args: + texts: Iterable of strings to add to the vectorstore. + embeddings: List of list of embedding vectors. + metadatas: List of metadatas associated with the texts. + kwargs: vectorstore specific parameters + """ + if ids is None: + ids = [str(uuid.uuid1()) for _ in texts] + + if not metadatas: + metadatas = [{} for _ in texts] + + with self._session_maker() as session: # type: ignore[arg-type] + collection = self.get_collection(session) + if not collection: + raise ValueError("Collection not found") + data = [ + { + "id": id, + "collection_id": collection.uuid, + "embedding": embedding, + "document": text, + "cmetadata": metadata or {}, + } + for text, metadata, embedding, id in zip( + texts, metadatas, embeddings, ids + ) + ] + stmt = insert(self.EmbeddingStore).values(data) + on_conflict_stmt = stmt.on_conflict_do_update( + index_elements=["id"], + # Conflict detection based on these columns + set_={ + "embedding": stmt.excluded.embedding, + "document": stmt.excluded.document, + "cmetadata": stmt.excluded.cmetadata, + }, + ) + session.execute(on_conflict_stmt) + session.commit() + + return ids + + # async def aadd_embeddings( REFAIRE + # self, + # texts: Iterable[str], + # embeddings: List[List[float]], + # metadatas: Optional[List[dict]] = None, + # ids: Optional[List[str]] = None, + # **kwargs: Any, + # ) -> List[str]: + # """Add embeddings to the vectorstore. + # + # Args: + # texts: Iterable of strings to add to the vectorstore. + # embeddings: List of list of embedding vectors. + # metadatas: List of metadatas associated with the texts. + # kwargs: vectorstore specific parameters + # """ + # if ids is None: + # ids = [str(uuid.uuid1()) for _ in texts] + # + # if not metadatas: + # metadatas = [{} for _ in texts] + # + # async with self._amake_session() as session: + # collection = await self.aget_collection(session) + # if not collection: + # raise ValueError("Collection not found") + # documents = [] + # for text, metadata, embedding, id in zip(texts, metadatas, embeddings, ids): + # embedding_store = self.EmbeddingStore( + # embedding=embedding, + # document=text, + # cmetadata=metadata, + # custom_id=id, + # collection_id=collection.uuid, + # ) + # documents.append(embedding_store) + # await session.run_sync( + # lambda sync_session: sync_session.bulk_save_objects(documents)) + # + # return ids + + def add_texts( + self, + texts: Iterable[str], + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + **kwargs: Any, + ) -> List[str]: + """Run more texts through the embeddings and add to the vectorstore. + + Args: + texts: Iterable of strings to add to the vectorstore. + metadatas: Optional list of metadatas associated with the texts. + kwargs: vectorstore specific parameters + + Returns: + List of ids from adding the texts into the vectorstore. + """ + embeddings = self.embedding_function.embed_documents(list(texts)) + return self.add_embeddings( + texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs + ) + + # async def aadd_texts( + # self, + # texts: Iterable[str], + # metadatas: Optional[List[dict]] = None, + # ids: Optional[List[str]] = None, + # **kwargs: Any, + # ) -> List[str]: + # """Run more texts through the embeddings and add to the vectorstore. + # + # Args: + # texts: Iterable of strings to add to the vectorstore. + # metadatas: Optional list of metadatas associated with the texts. + # kwargs: vectorstore specific parameters + # + # Returns: + # List of ids from adding the texts into the vectorstore. + # """ + # embeddings = await self.embedding_function.aembed_documents(list(texts)) + # return await self.aadd_embeddings( + # texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs + # ) + + def similarity_search( + self, + query: str, + k: int = 4, + filter: Optional[dict] = None, + **kwargs: Any, + ) -> List[Document]: + """Run similarity search with PGVector with distance. + + Args: + query (str): Query text to search for. + k (int): Number of results to return. Defaults to 4. + filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + + Returns: + List of Documents most similar to the query. + """ + embedding = self.embedding_function.embed_query(text=query) + return self.similarity_search_by_vector( + embedding=embedding, + k=k, + filter=filter, + ) + + # async def asimilarity_search( + # self, + # query: str, + # k: int = 4, + # filter: Optional[dict] = None, + # **kwargs: Any, + # ) -> List[Document]: + # """Run similarity search with PGVector with distance. + # + # Args: + # query (str): Query text to search for. + # k (int): Number of results to return. Defaults to 4. + # filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + # + # Returns: + # List of Documents most similar to the query. + # """ + # embedding = self.embedding_function.embed_query(text=query) + # return await self.asimilarity_search_by_vector( + # embedding=embedding, + # k=k, + # filter=filter, + # ) + + def similarity_search_with_score( + self, + query: str, + k: int = 4, + filter: Optional[dict] = None, + ) -> List[Tuple[Document, float]]: + """Return docs most similar to query. + + Args: + query: Text to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + + Returns: + List of Documents most similar to the query and score for each. + """ + embedding = self.embedding_function.embed_query(query) + docs = self.similarity_search_with_score_by_vector( + embedding=embedding, k=k, filter=filter + ) + return docs + + # async def asimilarity_search_with_score( + # self, + # query: str, + # k: int = 4, + # filter: Optional[dict] = None, + # ) -> List[Tuple[Document, float]]: + # """Return docs most similar to query. + # + # Args: + # query: Text to look up documents similar to. + # k: Number of Documents to return. Defaults to 4. + # filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + # + # Returns: + # List of Documents most similar to the query and score for each. + # """ + # embedding = self.embedding_function.embed_query(query) + # docs = await self.asimilarity_search_with_score_by_vector( + # embedding=embedding, k=k, filter=filter + # ) + # return docs + + @property + def distance_strategy(self) -> Any: + if self._distance_strategy == DistanceStrategy.EUCLIDEAN: + return self.EmbeddingStore.embedding.l2_distance + elif self._distance_strategy == DistanceStrategy.COSINE: + return self.EmbeddingStore.embedding.cosine_distance + elif self._distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT: + return self.EmbeddingStore.embedding.max_inner_product + else: + raise ValueError( + f"Got unexpected value for distance: {self._distance_strategy}. " + f"Should be one of {', '.join([ds.value for ds in DistanceStrategy])}." + ) + + def similarity_search_with_score_by_vector( + self, + embedding: List[float], + k: int = 4, + filter: Optional[dict] = None, + ) -> List[Tuple[Document, float]]: + results = self.__query_collection(embedding=embedding, k=k, filter=filter) + + return self._results_to_docs_and_scores(results) + + def _results_to_docs_and_scores(self, results: Any) -> List[Tuple[Document, float]]: + """Return docs and scores from results.""" + docs = [ + ( + Document( + page_content=result.EmbeddingStore.document, + metadata=result.EmbeddingStore.cmetadata, + ), + result.distance if self.embedding_function is not None else None, + ) + for result in results + ] + return docs + + def _handle_field_filter( + self, + field: str, + value: Any, + ) -> SQLColumnExpression: + """Create a filter for a specific field. + + Args: + field: name of field + value: value to filter + If provided as is then this will be an equality filter + If provided as a dictionary then this will be a filter, the key + will be the operator and the value will be the value to filter by + + Returns: + sqlalchemy expression + """ + if not isinstance(field, str): + raise ValueError( + f"field should be a string but got: {type(field)} with value: {field}" + ) + + if field.startswith("$"): + raise ValueError( + f"Invalid filter condition. Expected a field but got an operator: " + f"{field}" + ) + + # Allow [a-zA-Z0-9_], disallow $ for now until we support escape characters + if not field.isidentifier(): + raise ValueError( + f"Invalid field name: {field}. Expected a valid identifier." + ) + + if isinstance(value, dict): + # This is a filter specification + if len(value) != 1: + raise ValueError( + "Invalid filter condition. Expected a value which " + "is a dictionary with a single key that corresponds to an operator " + f"but got a dictionary with {len(value)} keys. The first few " + f"keys are: {list(value.keys())[:3]}" + ) + operator, filter_value = list(value.items())[0] + # Verify that that operator is an operator + if operator not in SUPPORTED_OPERATORS: + raise ValueError( + f"Invalid operator: {operator}. " + f"Expected one of {SUPPORTED_OPERATORS}" + ) + else: # Then we assume an equality operator + operator = "$eq" + filter_value = value + + if operator in COMPARISONS_TO_NATIVE: + # Then we implement an equality filter + # native is trusted input + native = COMPARISONS_TO_NATIVE[operator] + return func.jsonb_path_match( + self.EmbeddingStore.cmetadata, + cast(f"$.{field} {native} $value", JSONPATH), + cast({"value": filter_value}, JSONB), + ) + elif operator == "$between": + # Use AND with two comparisons + low, high = filter_value + + lower_bound = func.jsonb_path_match( + self.EmbeddingStore.cmetadata, + cast(f"$.{field} >= $value", JSONPATH), + cast({"value": low}, JSONB), + ) + upper_bound = func.jsonb_path_match( + self.EmbeddingStore.cmetadata, + cast(f"$.{field} <= $value", JSONPATH), + cast({"value": high}, JSONB), + ) + return sqlalchemy.and_(lower_bound, upper_bound) + elif operator in {"$in", "$nin", "$like", "$ilike"}: + # We'll do force coercion to text + if operator in {"$in", "$nin"}: + for val in filter_value: + if not isinstance(val, (str, int, float)): + raise NotImplementedError( + f"Unsupported type: {type(val)} for value: {val}" + ) + + queried_field = self.EmbeddingStore.cmetadata[field].astext + + if operator in {"$in"}: + return queried_field.in_([str(val) for val in filter_value]) + elif operator in {"$nin"}: + return queried_field.nin_([str(val) for val in filter_value]) + elif operator in {"$like"}: + return queried_field.like(filter_value) + elif operator in {"$ilike"}: + return queried_field.ilike(filter_value) + else: + raise NotImplementedError() + else: + raise NotImplementedError() + + def _create_filter_clause_deprecated(self, key, value): # type: ignore[no-untyped-def] + """Deprecated functionality. + + This is for backwards compatibility with the JSON based schema for metadata. + It uses incorrect operator syntax (operators are not prefixed with $). + + This implementation is not efficient, and has bugs associated with + the way that it handles numeric filter clauses. + """ + IN, NIN, BETWEEN, GT, LT, NE = "in", "nin", "between", "gt", "lt", "ne" + EQ, LIKE, CONTAINS, OR, AND = "eq", "like", "contains", "or", "and" + + value_case_insensitive = {k.lower(): v for k, v in value.items()} + if IN in map(str.lower, value): + filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext.in_( + value_case_insensitive[IN] + ) + elif NIN in map(str.lower, value): + filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext.not_in( + value_case_insensitive[NIN] + ) + elif BETWEEN in map(str.lower, value): + filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext.between( + str(value_case_insensitive[BETWEEN][0]), + str(value_case_insensitive[BETWEEN][1]), + ) + elif GT in map(str.lower, value): + filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext > str( + value_case_insensitive[GT] + ) + elif LT in map(str.lower, value): + filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext < str( + value_case_insensitive[LT] + ) + elif NE in map(str.lower, value): + filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext != str( + value_case_insensitive[NE] + ) + elif EQ in map(str.lower, value): + filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext == str( + value_case_insensitive[EQ] + ) + elif LIKE in map(str.lower, value): + filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext.like( + value_case_insensitive[LIKE] + ) + elif CONTAINS in map(str.lower, value): + filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext.contains( + value_case_insensitive[CONTAINS] + ) + elif OR in map(str.lower, value): + or_clauses = [ + self._create_filter_clause(key, sub_value) + for sub_value in value_case_insensitive[OR] + ] + filter_by_metadata = sqlalchemy.or_(*or_clauses) + elif AND in map(str.lower, value): + and_clauses = [ + self._create_filter_clause(key, sub_value) + for sub_value in value_case_insensitive[AND] + ] + filter_by_metadata = sqlalchemy.and_(*and_clauses) + + else: + filter_by_metadata = None + + return filter_by_metadata + + def _create_filter_clause_json_deprecated( + self, filter: Any + ) -> List[SQLColumnExpression]: + """Convert filters from IR to SQL clauses. + + **DEPRECATED** This functionality will be deprecated in the future. + + It implements translation of filters for a schema that uses JSON + for metadata rather than the JSONB field which is more efficient + for querying. + """ + filter_clauses = [] + for key, value in filter.items(): + if isinstance(value, dict): + filter_by_metadata = self._create_filter_clause_deprecated(key, value) + + if filter_by_metadata is not None: + filter_clauses.append(filter_by_metadata) + else: + filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext == str( + value + ) + filter_clauses.append(filter_by_metadata) + return filter_clauses + + def _create_filter_clause(self, filters: Any) -> Any: + """Convert LangChain IR filter representation to matching SQLAlchemy clauses. + + At the top level, we still don't know if we're working with a field + or an operator for the keys. After we've determined that we can + call the appropriate logic to handle filter creation. + + Args: + filters: Dictionary of filters to apply to the query. + + Returns: + SQLAlchemy clause to apply to the query. + """ + if isinstance(filters, dict): + if len(filters) == 1: + # The only operators allowed at the top level are $AND and $OR + # First check if an operator or a field + key, value = list(filters.items())[0] + if key.startswith("$"): + # Then it's an operator + if key.lower() not in ["$and", "$or"]: + raise ValueError( + f"Invalid filter condition. Expected $and or $or " + f"but got: {key}" + ) + else: + # Then it's a field + return self._handle_field_filter(key, filters[key]) + + # Here we handle the $and and $or operators + if not isinstance(value, list): + raise ValueError( + f"Expected a list, but got {type(value)} for value: {value}" + ) + if key.lower() == "$and": + and_ = [self._create_filter_clause(el) for el in value] + if len(and_) > 1: + return sqlalchemy.and_(*and_) + elif len(and_) == 1: + return and_[0] + else: + raise ValueError( + "Invalid filter condition. Expected a dictionary " + "but got an empty dictionary" + ) + elif key.lower() == "$or": + or_ = [self._create_filter_clause(el) for el in value] + if len(or_) > 1: + return sqlalchemy.or_(*or_) + elif len(or_) == 1: + return or_[0] + else: + raise ValueError( + "Invalid filter condition. Expected a dictionary " + "but got an empty dictionary" + ) + else: + raise ValueError( + f"Invalid filter condition. Expected $and or $or " + f"but got: {key}" + ) + elif len(filters) > 1: + # Then all keys have to be fields (they cannot be operators) + for key in filters.keys(): + if key.startswith("$"): + raise ValueError( + f"Invalid filter condition. Expected a field but got: {key}" + ) + # These should all be fields and combined using an $and operator + and_ = [self._handle_field_filter(k, v) for k, v in filters.items()] + if len(and_) > 1: + return sqlalchemy.and_(*and_) + elif len(and_) == 1: + return and_[0] + else: + raise ValueError( + "Invalid filter condition. Expected a dictionary " + "but got an empty dictionary" + ) + else: + raise ValueError("Got an empty dictionary for filters.") + else: + raise ValueError( + f"Invalid type: Expected a dictionary but got type: {type(filters)}" + ) + + def __query_collection( + self, + embedding: List[float], + k: int = 4, + filter: Optional[Dict[str, str]] = None, + ) -> List[Any]: + """Query the collection.""" + with self._session_maker() as session: # type: ignore[arg-type] + collection = self.get_collection(session) + if not collection: + raise ValueError("Collection not found") + + filter_by = [self.EmbeddingStore.collection_id == collection.uuid] + if filter: + if self.use_jsonb: + filter_clauses = self._create_filter_clause(filter) + if filter_clauses is not None: + filter_by.append(filter_clauses) + else: + # Old way of doing things + filter_clauses = self._create_filter_clause_json_deprecated(filter) + filter_by.extend(filter_clauses) + + _type = self.EmbeddingStore + + results: List[Any] = ( + session.query( + self.EmbeddingStore, + self.distance_strategy(embedding).label("distance"), # type: ignore + ) + .filter(*filter_by) + .order_by(sqlalchemy.asc("distance")) + .join( + self.CollectionStore, + self.EmbeddingStore.collection_id == self.CollectionStore.uuid, + ) + .limit(k) + .all() + ) + + return results + + # async def __aquery_collection( # FIXME + # self, + # session: AsyncSession, + # embedding: List[float], + # k: int = 4, + # filter: Optional[Dict[str, str]] = None, + # ) -> List[Any]: + # """Query the collection.""" + # collection = await self.aget_collection(session) + # if not collection: + # raise ValueError("Collection not found") + # + # filter_by = [self.EmbeddingStore.collection_id == collection.uuid] + # if filter: + # if self.use_jsonb: + # filter_clauses = self._create_filter_clause(filter) + # if filter_clauses is not None: + # filter_by.append(filter_clauses) + # else: + # # Old way of doing things + # filter_clauses = self._create_filter_clause_json_deprecated(filter) + # filter_by.extend(filter_clauses) + # + # _type = self.EmbeddingStore + # stmt = (select(self.EmbeddingStore, + # self.distance_strategy(embedding).label("distance"), + # ) + # .filter(*filter_by) + # .order_by(sqlalchemy.asc("distance")) + # .join( + # self.CollectionStore, + # self.EmbeddingStore.collection_id == self.CollectionStore.uuid, + # ) + # .limit(k)) + # results: List[Any] = ( + # (await session.execute(stmt)).all() + # ) + # + # return results + + def similarity_search_by_vector( + self, + embedding: List[float], + k: int = 4, + filter: Optional[dict] = None, + **kwargs: Any, + ) -> List[Document]: + """Return docs most similar to embedding vector. + + Args: + embedding: Embedding to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + + Returns: + List of Documents most similar to the query vector. + """ + docs_and_scores = self.similarity_search_with_score_by_vector( + embedding=embedding, k=k, filter=filter + ) + return _results_to_docs(docs_and_scores) + + # async def asimilarity_search_by_vector( + # self, + # embedding: List[float], + # k: int = 4, + # filter: Optional[dict] = None, + # **kwargs: Any, + # ) -> List[Document]: + # """Return docs most similar to embedding vector. + # + # Args: + # embedding: Embedding to look up documents similar to. + # k: Number of Documents to return. Defaults to 4. + # filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + # + # Returns: + # List of Documents most similar to the query vector. + # """ + # docs_and_scores = await self.asimilarity_search_with_score_by_vector( + # embedding=embedding, k=k, filter=filter + # ) + # return _results_to_docs(docs_and_scores) + + @classmethod + def from_texts( + cls: Type[PGVector], + texts: List[str], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + *, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + ids: Optional[List[str]] = None, + pre_delete_collection: bool = False, + use_jsonb: bool = True, + **kwargs: Any, + ) -> PGVector: + """Return VectorStore initialized from documents and embeddings.""" + embeddings = embedding.embed_documents(list(texts)) + + return cls.__from( + texts, + embeddings, + embedding, + metadatas=metadatas, + ids=ids, + collection_name=collection_name, + distance_strategy=distance_strategy, + pre_delete_collection=pre_delete_collection, + use_jsonb=use_jsonb, + **kwargs, + ) + + # @classmethod + # async def afrom_texts( + # cls: Type[PGVector], + # texts: List[str], + # embedding: Embeddings, + # metadatas: Optional[List[dict]] = None, + # collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + # distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + # ids: Optional[List[str]] = None, + # pre_delete_collection: bool = False, + # *, + # use_jsonb: bool = False, + # **kwargs: Any, + # ) -> PGVector: + # """ + # Return VectorStore initialized from texts and embeddings. + # Postgres connection string is required + # "Either pass it as a parameter + # or set the PGVECTOR_CONNECTION_STRING environment variable. + # """ + # embeddings = embedding.embed_documents(list(texts)) + # + # return await cls.__afrom( + # texts, + # embeddings, + # embedding, + # metadatas=metadatas, + # ids=ids, + # collection_name=collection_name, + # distance_strategy=distance_strategy, + # pre_delete_collection=pre_delete_collection, + # use_jsonb=use_jsonb, + # **kwargs, + # ) + + @classmethod + def from_embeddings( + cls, + text_embeddings: List[Tuple[str, List[float]]], + embedding: Embeddings, + *, + metadatas: Optional[List[dict]] = None, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + ids: Optional[List[str]] = None, + pre_delete_collection: bool = False, + **kwargs: Any, + ) -> PGVector: + """Construct PGVector wrapper from raw documents and embeddings. + + Args: + text_embeddings: List of tuples of text and embeddings. + embedding: Embeddings object. + metadatas: Optional list of metadatas associated with the texts. + collection_name: Name of the collection. + distance_strategy: Distance strategy to use. + ids: Optional list of ids for the documents. + pre_delete_collection: If True, will delete the collection if it exists. + **Attention**: This will delete all the documents in the existing + collection. + kwargs: Additional arguments. + + Returns: + PGVector: PGVector instance. + + Example: + .. code-block:: python + + from langchain_postgres.vectorstores import PGVector + from langchain_openai.embeddings import OpenAIEmbeddings + + embeddings = OpenAIEmbeddings() + text_embeddings = embeddings.embed_documents(texts) + text_embedding_pairs = list(zip(texts, text_embeddings)) + vectorstore = PGVector.from_embeddings(text_embedding_pairs, embeddings) + """ + texts = [t[0] for t in text_embeddings] + embeddings = [t[1] for t in text_embeddings] + + return cls.__from( + texts, + embeddings, + embedding, + metadatas=metadatas, + ids=ids, + collection_name=collection_name, + distance_strategy=distance_strategy, + pre_delete_collection=pre_delete_collection, + **kwargs, + ) + + @classmethod + async def afrom_embeddings( + cls, + text_embeddings: List[Tuple[str, List[float]]], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + ids: Optional[List[str]] = None, + pre_delete_collection: bool = False, + **kwargs: Any, + ) -> PGVector: + """Construct PGVector wrapper from raw documents and pre- + generated embeddings. + + Return VectorStore initialized from documents and embeddings. + Postgres connection string is required + "Either pass it as a parameter + or set the PGVECTOR_CONNECTION_STRING environment variable. + + Example: + .. code-block:: python + + from langchain_community.vectorstores import PGVector + from langchain_community.embeddings import OpenAIEmbeddings + embeddings = OpenAIEmbeddings() + text_embeddings = embeddings.embed_documents(texts) + text_embedding_pairs = list(zip(texts, text_embeddings)) + faiss = PGVector.from_embeddings(text_embedding_pairs, embeddings) + """ + texts = [t[0] for t in text_embeddings] + embeddings = [t[1] for t in text_embeddings] + + return await cls.__afrom( + texts, + embeddings, + embedding, + metadatas=metadatas, + ids=ids, + collection_name=collection_name, + distance_strategy=distance_strategy, + pre_delete_collection=pre_delete_collection, + **kwargs, + ) + + @classmethod + def from_existing_index( + cls: Type[PGVector], + embedding: Embeddings, + *, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + pre_delete_collection: bool = False, + connection: Optional[Connection] = None, + **kwargs: Any, + ) -> PGVector: + """ + Get instance of an existing PGVector store.This method will + return the instance of the store without inserting any new + embeddings + """ + store = cls( + connection=connection, + collection_name=collection_name, + embeddings=embedding, + distance_strategy=distance_strategy, + pre_delete_collection=pre_delete_collection, + **kwargs, + ) + + return store + + @classmethod + def get_connection_string(cls, kwargs: Dict[str, Any]) -> str: + connection_string: str = get_from_dict_or_env( + data=kwargs, + key="connection_string", + env_key="PGVECTOR_CONNECTION_STRING", + ) + + if not connection_string: + raise ValueError( + "Postgres connection string is required" + "Either pass it as a parameter" + "or set the PGVECTOR_CONNECTION_STRING environment variable." + ) + + return connection_string + + @classmethod + def from_documents( + cls: Type[PGVector], + documents: List[Document], + embedding: Embeddings, + *, + connection: Optional[Connection] = None, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + ids: Optional[List[str]] = None, + pre_delete_collection: bool = False, + use_jsonb: bool = True, + **kwargs: Any, + ) -> PGVector: + """Return VectorStore initialized from documents and embeddings.""" + + texts = [d.page_content for d in documents] + metadatas = [d.metadata for d in documents] + + return cls.from_texts( + texts=texts, + pre_delete_collection=pre_delete_collection, + embedding=embedding, + distance_strategy=distance_strategy, + metadatas=metadatas, + connection=connection, + ids=ids, + collection_name=collection_name, + use_jsonb=use_jsonb, + **kwargs, + ) + + @classmethod + def connection_string_from_db_params( + cls, + driver: str, + host: str, + port: int, + database: str, + user: str, + password: str, + ) -> str: + """Return connection string from database parameters.""" + if driver != "psycopg": + raise NotImplementedError("Only psycopg3 driver is supported") + return f"postgresql+{driver}://{user}:{password}@{host}:{port}/{database}" + + def _select_relevance_score_fn(self) -> Callable[[float], float]: + """ + The 'correct' relevance function + may differ depending on a few things, including: + - the distance / similarity metric used by the VectorStore + - the scale of your embeddings (OpenAI's are unit normed. Many others are not!) + - embedding dimensionality + - etc. + """ + if self.override_relevance_score_fn is not None: + return self.override_relevance_score_fn + + # Default strategy is to rely on distance strategy provided + # in vectorstore constructor + if self._distance_strategy == DistanceStrategy.COSINE: + return self._cosine_relevance_score_fn + elif self._distance_strategy == DistanceStrategy.EUCLIDEAN: + return self._euclidean_relevance_score_fn + elif self._distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT: + return self._max_inner_product_relevance_score_fn + else: + raise ValueError( + "No supported normalization function" + f" for distance_strategy of {self._distance_strategy}." + "Consider providing relevance_score_fn to PGVector constructor." + ) + + def max_marginal_relevance_search_with_score_by_vector( + self, + embedding: List[float], + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[Dict[str, str]] = None, + **kwargs: Any, + ) -> List[Tuple[Document, float]]: + """Return docs selected using the maximal marginal relevance with score + to embedding vector. + + Maximal marginal relevance optimizes for similarity to query AND diversity + among selected documents. + + Args: + embedding: Embedding to look up documents similar to. + k (int): Number of Documents to return. Defaults to 4. + fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. + Defaults to 20. + lambda_mult (float): Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Defaults to 0.5. + filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + + Returns: + List[Tuple[Document, float]]: List of Documents selected by maximal marginal + relevance to the query and score for each. + """ + results = self.__query_collection(embedding=embedding, k=fetch_k, filter=filter) + + embedding_list = [result.EmbeddingStore.embedding for result in results] + + mmr_selected = maximal_marginal_relevance( + np.array(embedding, dtype=np.float32), + embedding_list, + k=k, + lambda_mult=lambda_mult, + ) + + candidates = self._results_to_docs_and_scores(results) + + return [r for i, r in enumerate(candidates) if i in mmr_selected] + + # async def amax_marginal_relevance_search_with_score_by_vector( + # self, + # embedding: List[float], + # k: int = 4, + # fetch_k: int = 20, + # lambda_mult: float = 0.5, + # filter: Optional[Dict[str, str]] = None, + # **kwargs: Any, + # ) -> List[Tuple[Document, float]]: + # """Return docs selected using the maximal marginal relevance with score + # to embedding vector. + # + # Maximal marginal relevance optimizes for similarity to query AND diversity + # among selected documents. + # + # Args: + # embedding: Embedding to look up documents similar to. + # k (int): Number of Documents to return. Defaults to 4. + # fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. + # Defaults to 20. + # lambda_mult (float): Number between 0 and 1 that determines the degree + # of diversity among the results with 0 corresponding + # to maximum diversity and 1 to minimum diversity. + # Defaults to 0.5. + # filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + # + # Returns: + # List[Tuple[Document, float]]: List of Documents selected by maximal marginal + # relevance to the query and score for each. + # """ + # with self._session_maker() as session: + # results = await self.__aquery_collection(session=session, + # embedding=embedding, k=fetch_k, + # filter=filter) + # + # embedding_list = [result.EmbeddingStore.embedding for result in results] + # + # mmr_selected = maximal_marginal_relevance( + # np.array(embedding, dtype=np.float32), + # embedding_list, + # k=k, + # lambda_mult=lambda_mult, + # ) + # + # candidates = self._results_to_docs_and_scores(results) + # + # return [r for i, r in enumerate(candidates) if i in mmr_selected] + + def max_marginal_relevance_search( + self, + query: str, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[Dict[str, str]] = None, + **kwargs: Any, + ) -> List[Document]: + """Return docs selected using the maximal marginal relevance. + + Maximal marginal relevance optimizes for similarity to query AND diversity + among selected documents. + + Args: + query (str): Text to look up documents similar to. + k (int): Number of Documents to return. Defaults to 4. + fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. + Defaults to 20. + lambda_mult (float): Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Defaults to 0.5. + filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + + Returns: + List[Document]: List of Documents selected by maximal marginal relevance. + """ + embedding = self.embedding_function.embed_query(query) + return self.max_marginal_relevance_search_by_vector( + embedding, + k=k, + fetch_k=fetch_k, + lambda_mult=lambda_mult, + filter=filter, + **kwargs, + ) + + # async def amax_marginal_relevance_search( + # self, + # query: str, + # k: int = 4, + # fetch_k: int = 20, + # lambda_mult: float = 0.5, + # filter: Optional[Dict[str, str]] = None, + # **kwargs: Any, + # ) -> List[Document]: + # """Return docs selected using the maximal marginal relevance. + # + # Maximal marginal relevance optimizes for similarity to query AND diversity + # among selected documents. + # + # Args: + # query (str): Text to look up documents similar to. + # k (int): Number of Documents to return. Defaults to 4. + # fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. + # Defaults to 20. + # lambda_mult (float): Number between 0 and 1 that determines the degree + # of diversity among the results with 0 corresponding + # to maximum diversity and 1 to minimum diversity. + # Defaults to 0.5. + # filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + # + # Returns: + # List[Document]: List of Documents selected by maximal marginal relevance. + # """ + # embedding = self.embedding_function.embed_query(query) + # return await self.amax_marginal_relevance_search_by_vector( + # embedding, + # k=k, + # fetch_k=fetch_k, + # lambda_mult=lambda_mult, + # filter=filter, + # **kwargs, + # ) + + def max_marginal_relevance_search_with_score( + self, + query: str, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[dict] = None, + **kwargs: Any, + ) -> List[Tuple[Document, float]]: + """Return docs selected using the maximal marginal relevance with score. + + Maximal marginal relevance optimizes for similarity to query AND diversity + among selected documents. + + Args: + query (str): Text to look up documents similar to. + k (int): Number of Documents to return. Defaults to 4. + fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. + Defaults to 20. + lambda_mult (float): Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Defaults to 0.5. + filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + + Returns: + List[Tuple[Document, float]]: List of Documents selected by maximal marginal + relevance to the query and score for each. + """ + embedding = self.embedding_function.embed_query(query) + docs = self.max_marginal_relevance_search_with_score_by_vector( + embedding=embedding, + k=k, + fetch_k=fetch_k, + lambda_mult=lambda_mult, + filter=filter, + **kwargs, + ) + return docs + + # async def amax_marginal_relevance_search_with_score( + # self, + # query: str, + # k: int = 4, + # fetch_k: int = 20, + # lambda_mult: float = 0.5, + # filter: Optional[dict] = None, + # **kwargs: Any, + # ) -> List[Tuple[Document, float]]: + # """Return docs selected using the maximal marginal relevance with score. + # + # Maximal marginal relevance optimizes for similarity to query AND diversity + # among selected documents. + # + # Args: + # query (str): Text to look up documents similar to. + # k (int): Number of Documents to return. Defaults to 4. + # fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. + # Defaults to 20. + # lambda_mult (float): Number between 0 and 1 that determines the degree + # of diversity among the results with 0 corresponding + # to maximum diversity and 1 to minimum diversity. + # Defaults to 0.5. + # filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + # + # Returns: + # List[Tuple[Document, float]]: List of Documents selected by maximal marginal + # relevance to the query and score for each. + # """ + # embedding = self.embedding_function.embed_query(query) + # docs = await self.amax_marginal_relevance_search_with_score_by_vector( + # embedding=embedding, + # k=k, + # fetch_k=fetch_k, + # lambda_mult=lambda_mult, + # filter=filter, + # **kwargs, + # ) + # return docs + + def max_marginal_relevance_search_by_vector( + self, + embedding: List[float], + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[Dict[str, str]] = None, + **kwargs: Any, + ) -> List[Document]: + """Return docs selected using the maximal marginal relevance + to embedding vector. + + Maximal marginal relevance optimizes for similarity to query AND diversity + among selected documents. + + Args: + embedding (str): Text to look up documents similar to. + k (int): Number of Documents to return. Defaults to 4. + fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. + Defaults to 20. + lambda_mult (float): Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Defaults to 0.5. + filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + + Returns: + List[Document]: List of Documents selected by maximal marginal relevance. + """ + docs_and_scores = self.max_marginal_relevance_search_with_score_by_vector( + embedding, + k=k, + fetch_k=fetch_k, + lambda_mult=lambda_mult, + filter=filter, + **kwargs, + ) + + return _results_to_docs(docs_and_scores) + + async def amax_marginal_relevance_search_by_vector( + self, + embedding: List[float], + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[Dict[str, str]] = None, + **kwargs: Any, + ) -> List[Document]: + """Return docs selected using the maximal marginal relevance.""" + + # This is a temporary workaround to make the similarity search + # asynchronous. The proper solution is to make the similarity search + # asynchronous in the vector store implementations. + return await run_in_executor( + None, + self.max_marginal_relevance_search_by_vector, + embedding, + k=k, + fetch_k=fetch_k, + lambda_mult=lambda_mult, + filter=filter, + **kwargs, + ) + + + # async def aadd_documents( # FIXME: remove + # self, documents: List[Document], **kwargs: Any + # ) -> List[str]: + # """Run more documents through the embeddings and add to the vectorstore. + # + # Args: + # documents (List[Document]: Documents to add to the vectorstore. + # + # Returns: + # List[str]: List of IDs of the added texts. + # """ + # texts = [doc.page_content for doc in documents] + # metadatas = [doc.metadata for doc in documents] + # return await self.aadd_texts(texts, metadatas, **kwargs) + + # async def adelete_collection(self) -> None: + # async with self._amake_session() as session: + # await self._adelete_collection(session) + + # def _build_sessionmaker(self) -> sessionmaker: + # return sessionmaker(bind=self._bind) + # + # def _build_async_sessionmaker(self) -> sessionmaker: + # return async_sessionmaker( + # bind=self._bind, + # sync_session_class=self._session_maker) + # + # @contextlib.contextmanager + # def _session_maker(self) -> Generator[Session, None, None]: + # """Create a context manager for the session, bind to _conn string.""" + # session = self._session_maker() + # session.begin() # FIXME: sans ? + # yield session + # session.commit() # FIXME + # + # @contextlib.asynccontextmanager + # async def _amake_session(self) -> Generator[AsyncSession, None, None]: + # """Create a context manager for the session, bind to _conn string.""" + # async_session: AsyncSession = self._session_maker() + # await async_session.begin() + # yield async_session + # await async_session.commit() # FIXME + # diff --git a/tests/unit_tests/test_vectorstore.py b/tests/unit_tests/test_vectorstore.py index 75169688..13b3e735 100644 --- a/tests/unit_tests/test_vectorstore.py +++ b/tests/unit_tests/test_vectorstore.py @@ -5,7 +5,7 @@ import pytest from langchain_core.documents import Document -from langchain_postgres.vectorstores import ( +from langchain_postgres.vectorstores_ppr import ( SUPPORTED_OPERATORS, PGVector, ) From 69c22a3ce926f4b29d360d0b50724a0f89ed198f Mon Sep 17 00:00:00 2001 From: Philippe Prados Date: Mon, 15 Apr 2024 08:40:31 +0200 Subject: [PATCH 02/18] Add async mode --- langchain_postgres/vectorstores.py | 1104 ++++++++++--- langchain_postgres/vectorstores_ppr.py | 1996 ------------------------ tests/unit_tests/test_vectorstore.py | 186 ++- 3 files changed, 1093 insertions(+), 2193 deletions(-) delete mode 100644 langchain_postgres/vectorstores_ppr.py diff --git a/langchain_postgres/vectorstores.py b/langchain_postgres/vectorstores.py index 89b78fba..d5aa6dba 100644 --- a/langchain_postgres/vectorstores.py +++ b/langchain_postgres/vectorstores.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import enum import logging import uuid @@ -11,14 +12,15 @@ List, Optional, Tuple, - Type, - Union, + Type, Union, ) import numpy as np import sqlalchemy -from sqlalchemy import SQLColumnExpression, cast, delete, func +from sqlalchemy import SQLColumnExpression, cast, delete, func, select, Engine from sqlalchemy.dialects.postgresql import JSON, JSONB, JSONPATH, UUID, insert +# TODO: accepter l'absence de l'option async lors des imports +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from sqlalchemy.orm import Session, relationship, sessionmaker try: @@ -28,7 +30,6 @@ from langchain_core.documents import Document from langchain_core.embeddings import Embeddings -from langchain_core.runnables.config import run_in_executor from langchain_core.utils import get_from_dict_or_env from langchain_core.vectorstores import VectorStore @@ -47,10 +48,8 @@ class DistanceStrategy(str, enum.Enum): Base = declarative_base() # type: Any - _LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain" - _classes: Any = None COMPARISONS_TO_NATIVE = { @@ -110,16 +109,28 @@ class CollectionStore(Base): @classmethod def get_by_name( - cls, session: Session, name: str + cls, session: Session, name: str ) -> Optional["CollectionStore"]: return session.query(cls).filter(cls.name == name).first() # type: ignore + @classmethod + async def aget_by_name( + cls, session: AsyncSession, name: str + ) -> Optional["CollectionStore"]: + stmt = select(cls).filter(cls.name == name) + # return await session.execute(stmt) # FIXME + return (await session.execute(stmt)).scalars().first() # FIXME + # stmt = select(cls).filter(cls.name == name) + # result = await session.execute(stmt) + # x = result.scalars() + # return session.query(cls).filter(cls.name == name).first() + @classmethod def get_or_create( - cls, - session: Session, - name: str, - cmetadata: Optional[dict] = None, + cls, + session: Session, + name: str, + cmetadata: Optional[dict] = None, ) -> Tuple["CollectionStore", bool]: """Get or create a collection. Returns: @@ -132,7 +143,29 @@ def get_or_create( collection = cls(name=name, cmetadata=cmetadata) session.add(collection) - session.commit() + session.commit() # FIXME PPR semble utile + created = True + return collection, created + + @classmethod + async def aget_or_create( + cls, + session: AsyncSession, + name: str, + cmetadata: Optional[dict] = None, + ) -> Tuple["CollectionStore", bool]: + """ + Get or create a collection. + Returns [Collection, bool] where the bool is True if the collection was created. + """ # noqa: E501 + created = False + collection = await cls.aget_by_name(session, name) + if collection: + return collection, created + + collection = cls(name=name, cmetadata=cmetadata) + session.add(collection) + await session.commit() created = True return collection, created @@ -235,20 +268,21 @@ class PGVector(VectorStore): """ def __init__( - self, - embeddings: Embeddings, - *, - connection: Optional[Connection] = None, - embedding_length: Optional[int] = None, - collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, - collection_metadata: Optional[dict] = None, - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - pre_delete_collection: bool = False, - logger: Optional[logging.Logger] = None, - relevance_score_fn: Optional[Callable[[float], float]] = None, - engine_args: Optional[dict[str, Any]] = None, - use_jsonb: bool = True, - create_extension: bool = True, + self, + embeddings: Embeddings, + *, + connection: Union[None, Connection, str] = None, + embedding_length: Optional[int] = None, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + collection_metadata: Optional[dict] = None, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + pre_delete_collection: bool = False, + logger: Optional[logging.Logger] = None, + relevance_score_fn: Optional[Callable[[float], float]] = None, + engine_args: Optional[dict[str, Any]] = None, + use_jsonb: bool = True, + create_extension: bool = True, + async_mode: bool = False, # FIXME: a virer. Gaff aux imports sans async ) -> None: """Initialize the PGVector store. @@ -277,6 +311,7 @@ def __init__( doesn't exist. disabling creation is useful when using ReadOnly Databases. """ + self.async_mode = async_mode self.embedding_function = embeddings self._embedding_length = embedding_length self.collection_name = collection_name @@ -287,9 +322,8 @@ def __init__( self.override_relevance_score_fn = relevance_score_fn if isinstance(connection, str): - self._engine = sqlalchemy.create_engine( - url=connection, **(engine_args or {}) - ) + self._engine = self._create_engine( + connection, engine_args, async_mode) elif isinstance(connection, sqlalchemy.engine.Engine): self._engine = connection else: @@ -297,8 +331,13 @@ def __init__( "connection should be a connection string or an instance of " "sqlalchemy.engine.Engine" ) - - self._session_maker = sessionmaker(bind=self._engine) + # If the driver accept only the synchrone calls, update the async_mode + self.async_mode = not isinstance(self._engine, Engine) + self._session_maker: Union[sessionmaker, async_sessionmaker] + if self.async_mode: + self._session_maker = async_sessionmaker(bind=self._engine) + else: + self._session_maker = sessionmaker(bind=self._engine) self.use_jsonb = use_jsonb self.create_extension = create_extension @@ -306,10 +345,11 @@ def __init__( if not use_jsonb: # Replace with a deprecation warning. raise NotImplementedError("use_jsonb=False is no longer supported.") - self.__post_init__() + if not async_mode: + self.__post_init__() def __post_init__( - self, + self, ) -> None: """Initialize the store.""" if self.create_extension: @@ -323,9 +363,85 @@ def __post_init__( self.create_tables_if_not_exists() self.create_collection() + async def __apost_init__( + self, + ) -> None: + + """Initialize the store.""" + if self.create_extension: + await self.acreate_vector_extension() + + EmbeddingStore, CollectionStore = _get_embedding_collection_store( + self._embedding_length + ) + self.CollectionStore = CollectionStore + self.EmbeddingStore = EmbeddingStore + await self.acreate_tables_if_not_exists() + await self.acreate_collection() + + @classmethod + async def create(cls, + embeddings: Embeddings, + *, + connection: Optional[Connection] = None, + embedding_length: Optional[int] = None, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + collection_metadata: Optional[dict] = None, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + pre_delete_collection: bool = False, + logger: Optional[logging.Logger] = None, + relevance_score_fn: Optional[Callable[[float], float]] = None, + engine_args: Optional[dict[str, Any]] = None, + use_jsonb: bool = True, + create_extension: bool = True, + async_mode: bool = True, + ) -> PGVector: + self = cls( + embeddings=embeddings, + connection=connection, + embedding_length=embedding_length, + collection_name=collection_name, + collection_metadata=collection_metadata, + distance_strategy=distance_strategy, + pre_delete_collection=pre_delete_collection, + logger=logger, + relevance_score_fn=relevance_score_fn, + engine_args=engine_args, + use_jsonb=use_jsonb, + create_extension=create_extension, + async_mode=async_mode, + ) + if async_mode: + await self.__apost_init__() + return self + + def _create_engine(self, + connection: str, + engine_args: Optional[dict[str, Any]] = None, + async_mode: bool = False) -> sqlalchemy.engine.Engine | sqlalchemy.ext.asyncio.AsyncEngine: + if async_mode: + from sqlalchemy.ext.asyncio import create_async_engine + # FIXME: gérer appel async sur un sync + from sqlalchemy.exc import InvalidRequestError + try: + return create_async_engine( + url=connection, + isolation_level="REPEATABLE READ", # FIXME: merge avec la suite ? + echo=True, # FIXME: a virer + **(engine_args or {}) + ) + except InvalidRequestError: + pass # Ignore and return the synchrone version + logging.warning("Use a synchrone SQL engine !") + return sqlalchemy.create_engine(url=connection, + **(engine_args or {})) + def __del__(self) -> None: if isinstance(self._engine, sqlalchemy.engine.Connection): - self._engine.close() + if self.async_mode: + asyncio.run(self._engine.close()) + else: + self._engine.close() @property def embeddings(self) -> Embeddings: @@ -350,22 +466,79 @@ def create_vector_extension(self) -> None: except Exception as e: raise Exception(f"Failed to create vector extension: {e}") from e + async def acreate_vector_extension(self) -> None: + try: + async with self._session_maker() as session: + # The advisor lock fixes issue arising from concurrent + # creation of the vector extension. + # https://github.com/langchain-ai/langchain/issues/12933 + # For more information see: + # https://www.postgresql.org/docs/16/explicit-locking.html#ADVISORY-LOCKS + await session.execute( + sqlalchemy.text( + "SELECT pg_advisory_xact_lock(1573678846307946496)")) + await session.execute( + sqlalchemy.text("CREATE EXTENSION IF NOT EXISTS vector")) + except Exception as e: + raise Exception(f"Failed to create vector extension: {e}") from e + def create_tables_if_not_exists(self) -> None: with self._session_maker() as session: Base.metadata.create_all(session.get_bind()) + async def acreate_tables_if_not_exists(self) -> None: + if isinstance(self._engine, sqlalchemy.ext.asyncio.engine.AsyncConnection): + await self._engine.run_sync(Base.metadata.create_all) + else: + async with self._engine.begin() as conn: # FIXME: session.run_sync existe + await conn.run_sync(Base.metadata.create_all) + # async with self._amake_session() as session: + # await session.run_sync(Base.metadata.create_all) + def drop_tables(self) -> None: with self._session_maker() as session: Base.metadata.drop_all(session.get_bind()) + async def adrop_tables(self) -> None: + async with self._engine.begin() as conn: + await conn.run_sync(Base.metadata.drop_all) + def create_collection(self) -> None: if self.pre_delete_collection: self.delete_collection() - with self._session_maker() as session: # type: ignore[arg-type] + with self._session_maker() as session: self.CollectionStore.get_or_create( session, self.collection_name, cmetadata=self.collection_metadata ) + async def acreate_collection(self) -> None: + async with self._session_maker() as session: + if self.pre_delete_collection: + await self._adelete_collection(session) + await self.CollectionStore.aget_or_create( + session, self.collection_name, cmetadata=self.collection_metadata + ) + + def _delete_collection(self, session: Session) -> None: + self.logger.debug("Trying to delete collection") + collection = self.get_collection(session) + if not collection: + self.logger.warning("Collection not found") + return + session.delete(collection) + + # FIXME: necessaire le _adelete ? + async def _adelete_collection(self, session: AsyncSession) -> None: + self.logger.debug("Trying to delete collection") + collection = await self.aget_collection(session) + if not collection: + self.logger.warning("Collection not found") + return + await session.delete(collection) + + # def delete_collection(self) -> None: + # with self._session_maker() as session: + # self._delete_collection(session) def delete_collection(self) -> None: self.logger.debug("Trying to delete collection") with self._session_maker() as session: # type: ignore[arg-type] @@ -376,11 +549,21 @@ def delete_collection(self) -> None: session.delete(collection) session.commit() + async def adelete_collection(self) -> None: + self.logger.debug("Trying to delete collection") + with self._session_maker() as session: # type: ignore[arg-type] + collection = await self.aget_collection(session) + if not collection: + self.logger.warning("Collection not found") + return + await session.adelete(collection) + await session.commit() + def delete( - self, - ids: Optional[List[str]] = None, - collection_only: bool = False, - **kwargs: Any, + self, + ids: Optional[List[str]] = None, + collection_only: bool = False, + **kwargs: Any, ) -> None: """Delete vectors by ids or uuids. @@ -411,24 +594,62 @@ def delete( session.execute(stmt) session.commit() + async def adelete( + self, + ids: Optional[List[str]] = None, + collection_only: bool = False, + **kwargs: Any, + ) -> None: + """Delete vectors by ids or uuids. + + Args: + ids: List of ids to delete. + collection_only: Only delete ids in the collection. + """ + async with self._session_maker() as session: + if ids is not None: + self.logger.debug( + "Trying to delete vectors by ids (represented by the model " + "using the custom ids field)" + ) + + stmt = delete(self.EmbeddingStore) + + if collection_only: + collection = await self.aget_collection(session) + if not collection: + self.logger.warning("Collection not found") + return + + stmt = stmt.where( + self.EmbeddingStore.collection_id == collection.uuid + ) + + stmt = stmt.where(self.EmbeddingStore.id.in_(ids)) + await session.execute(stmt) + await session.commit() + def get_collection(self, session: Session) -> Any: return self.CollectionStore.get_by_name(session, self.collection_name) + async def aget_collection(self, session: AsyncSession) -> Any: + return await self.CollectionStore.aget_by_name(session, self.collection_name) + @classmethod def __from( - cls, - texts: List[str], - embeddings: List[List[float]], - embedding: Embeddings, - metadatas: Optional[List[dict]] = None, - ids: Optional[List[str]] = None, - collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - connection: Optional[str] = None, - pre_delete_collection: bool = False, - *, - use_jsonb: bool = True, - **kwargs: Any, + cls, + texts: List[str], + embeddings: List[List[float]], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + connection: Optional[str] = None, + pre_delete_collection: bool = False, + *, + use_jsonb: bool = True, + **kwargs: Any, ) -> PGVector: if ids is None: ids = [str(uuid.uuid4()) for _ in texts] @@ -452,13 +673,54 @@ def __from( return store + @classmethod + async def __afrom( + cls, + texts: List[str], + embeddings: List[List[float]], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + connection: Optional[str] = None, + pre_delete_collection: bool = False, + *, + use_jsonb: bool = True, + **kwargs: Any, + ) -> PGVector: + if ids is None: + ids = [str(uuid.uuid1()) for _ in texts] + + if not metadatas: + metadatas = [{} for _ in texts] + + store = cls( + connection=connection, + collection_name=collection_name, + embeddings=embedding, + distance_strategy=distance_strategy, + pre_delete_collection=pre_delete_collection, + use_jsonb=use_jsonb, + async_mode=True, # FIXME + **kwargs, + ) + # Second phase to create + await store.__apost_init__() + + await store.aadd_embeddings( + texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs + ) + + return store + def add_embeddings( - self, - texts: Iterable[str], - embeddings: List[List[float]], - metadatas: Optional[List[dict]] = None, - ids: Optional[List[str]] = None, - **kwargs: Any, + self, + texts: Iterable[str], + embeddings: List[List[float]], + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + **kwargs: Any, ) -> List[str]: """Add embeddings to the vectorstore. @@ -505,12 +767,65 @@ def add_embeddings( return ids + async def aadd_embeddings( + self, + texts: Iterable[str], + embeddings: List[List[float]], + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + **kwargs: Any, + ) -> List[str]: + """Add embeddings to the vectorstore. + + Args: + texts: Iterable of strings to add to the vectorstore. + embeddings: List of list of embedding vectors. + metadatas: List of metadatas associated with the texts. + kwargs: vectorstore specific parameters + """ + if ids is None: + ids = [str(uuid.uuid1()) for _ in texts] + + if not metadatas: + metadatas = [{} for _ in texts] + + async with self._session_maker() as session: # type: ignore[arg-type] + collection = await self.aget_collection(session) + if not collection: + raise ValueError("Collection not found") + data = [ + { + "id": id, + "collection_id": collection.uuid, + "embedding": embedding, + "document": text, + "cmetadata": metadata or {}, + } + for text, metadata, embedding, id in zip( + texts, metadatas, embeddings, ids + ) + ] + stmt = insert(self.EmbeddingStore).values(data) + on_conflict_stmt = stmt.on_conflict_do_update( + index_elements=["id"], + # Conflict detection based on these columns + set_={ + "embedding": stmt.excluded.embedding, + "document": stmt.excluded.document, + "cmetadata": stmt.excluded.cmetadata, + }, + ) + await session.execute(on_conflict_stmt) + await session.commit() + + return ids + def add_texts( - self, - texts: Iterable[str], - metadatas: Optional[List[dict]] = None, - ids: Optional[List[str]] = None, - **kwargs: Any, + self, + texts: Iterable[str], + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + **kwargs: Any, ) -> List[str]: """Run more texts through the embeddings and add to the vectorstore. @@ -527,12 +842,34 @@ def add_texts( texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs ) + async def aadd_texts( + self, + texts: Iterable[str], + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + **kwargs: Any, + ) -> List[str]: + """Run more texts through the embeddings and add to the vectorstore. + + Args: + texts: Iterable of strings to add to the vectorstore. + metadatas: Optional list of metadatas associated with the texts. + kwargs: vectorstore specific parameters + + Returns: + List of ids from adding the texts into the vectorstore. + """ + embeddings = await self.embedding_function.aembed_documents(list(texts)) + return await self.aadd_embeddings( + texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs + ) + def similarity_search( - self, - query: str, - k: int = 4, - filter: Optional[dict] = None, - **kwargs: Any, + self, + query: str, + k: int = 4, + filter: Optional[dict] = None, + **kwargs: Any, ) -> List[Document]: """Run similarity search with PGVector with distance. @@ -551,11 +888,35 @@ def similarity_search( filter=filter, ) + async def asimilarity_search( + self, + query: str, + k: int = 4, + filter: Optional[dict] = None, + **kwargs: Any, + ) -> List[Document]: + """Run similarity search with PGVector with distance. + + Args: + query (str): Query text to search for. + k (int): Number of results to return. Defaults to 4. + filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + + Returns: + List of Documents most similar to the query. + """ + embedding = self.embedding_function.embed_query(text=query) + return await self.asimilarity_search_by_vector( + embedding=embedding, + k=k, + filter=filter, + ) + def similarity_search_with_score( - self, - query: str, - k: int = 4, - filter: Optional[dict] = None, + self, + query: str, + k: int = 4, + filter: Optional[dict] = None, ) -> List[Tuple[Document, float]]: """Return docs most similar to query. @@ -573,6 +934,28 @@ def similarity_search_with_score( ) return docs + async def asimilarity_search_with_score( + self, + query: str, + k: int = 4, + filter: Optional[dict] = None, + ) -> List[Tuple[Document, float]]: + """Return docs most similar to query. + + Args: + query: Text to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + + Returns: + List of Documents most similar to the query and score for each. + """ + embedding = self.embedding_function.embed_query(query) + docs = await self.asimilarity_search_with_score_by_vector( + embedding=embedding, k=k, filter=filter + ) + return docs + @property def distance_strategy(self) -> Any: if self._distance_strategy == DistanceStrategy.EUCLIDEAN: @@ -588,15 +971,29 @@ def distance_strategy(self) -> Any: ) def similarity_search_with_score_by_vector( - self, - embedding: List[float], - k: int = 4, - filter: Optional[dict] = None, + self, + embedding: List[float], + k: int = 4, + filter: Optional[dict] = None, ) -> List[Tuple[Document, float]]: results = self.__query_collection(embedding=embedding, k=k, filter=filter) return self._results_to_docs_and_scores(results) + async def asimilarity_search_with_score_by_vector( + self, + embedding: List[float], + k: int = 4, + filter: Optional[dict] = None, + ) -> List[Tuple[Document, float]]: + async with self._session_maker() as session: # type: ignore[arg-type] + results = await self.__aquery_collection( + session=session, + embedding=embedding, k=k, + filter=filter) + + return self._results_to_docs_and_scores(results) + def _results_to_docs_and_scores(self, results: Any) -> List[Tuple[Document, float]]: """Return docs and scores from results.""" docs = [ @@ -612,9 +1009,9 @@ def _results_to_docs_and_scores(self, results: Any) -> List[Tuple[Document, floa return docs def _handle_field_filter( - self, - field: str, - value: Any, + self, + field: str, + value: Any, ) -> SQLColumnExpression: """Create a filter for a specific field. @@ -724,7 +1121,8 @@ def _handle_field_filter( else: raise NotImplementedError() - def _create_filter_clause_deprecated(self, key, value): # type: ignore[no-untyped-def] + def _create_filter_clause_deprecated(self, key, + value): # type: ignore[no-untyped-def] """Deprecated functionality. This is for backwards compatibility with the JSON based schema for metadata. @@ -793,7 +1191,7 @@ def _create_filter_clause_deprecated(self, key, value): # type: ignore[no-untyp return filter_by_metadata def _create_filter_clause_json_deprecated( - self, filter: Any + self, filter: Any ) -> List[SQLColumnExpression]: """Convert filters from IR to SQL clauses. @@ -904,10 +1302,10 @@ def _create_filter_clause(self, filters: Any) -> Any: ) def __query_collection( - self, - embedding: List[float], - k: int = 4, - filter: Optional[Dict[str, str]] = None, + self, + embedding: List[float], + k: int = 4, + filter: Optional[Dict[str, str]] = None, ) -> List[Any]: """Query the collection.""" with self._session_maker() as session: # type: ignore[arg-type] @@ -945,12 +1343,53 @@ def __query_collection( return results + async def __aquery_collection( + self, + session: AsyncSession, + embedding: List[float], + k: int = 4, + filter: Optional[Dict[str, str]] = None, + ) -> List[Any]: + """Query the collection.""" + collection = await self.aget_collection(session) + if not collection: + raise ValueError("Collection not found") + + filter_by = [self.EmbeddingStore.collection_id == collection.uuid] + if filter: + if self.use_jsonb: + filter_clauses = self._create_filter_clause(filter) + if filter_clauses is not None: + filter_by.append(filter_clauses) + else: + # Old way of doing things + filter_clauses = self._create_filter_clause_json_deprecated(filter) + filter_by.extend(filter_clauses) + + _type = self.EmbeddingStore + + stmt = (select( + self.EmbeddingStore, + self.distance_strategy(embedding).label("distance"), # type: ignore + ) + .filter(*filter_by) + .order_by(sqlalchemy.asc("distance")) + .join( + self.CollectionStore, + self.EmbeddingStore.collection_id == self.CollectionStore.uuid, + ) + .limit(k)) + + results: List[Any] = (await session.execute(stmt)).all() + + return results + def similarity_search_by_vector( - self, - embedding: List[float], - k: int = 4, - filter: Optional[dict] = None, - **kwargs: Any, + self, + embedding: List[float], + k: int = 4, + filter: Optional[dict] = None, + **kwargs: Any, ) -> List[Document]: """Return docs most similar to embedding vector. @@ -967,19 +1406,41 @@ def similarity_search_by_vector( ) return _results_to_docs(docs_and_scores) + async def asimilarity_search_by_vector( + self, + embedding: List[float], + k: int = 4, + filter: Optional[dict] = None, + **kwargs: Any, + ) -> List[Document]: + """Return docs most similar to embedding vector. + + Args: + embedding: Embedding to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + + Returns: + List of Documents most similar to the query vector. + """ + docs_and_scores = await self.asimilarity_search_with_score_by_vector( + embedding=embedding, k=k, filter=filter + ) + return _results_to_docs(docs_and_scores) + @classmethod def from_texts( - cls: Type[PGVector], - texts: List[str], - embedding: Embeddings, - metadatas: Optional[List[dict]] = None, - *, - collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - ids: Optional[List[str]] = None, - pre_delete_collection: bool = False, - use_jsonb: bool = True, - **kwargs: Any, + cls: Type[PGVector], + texts: List[str], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + *, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + ids: Optional[List[str]] = None, + pre_delete_collection: bool = False, + use_jsonb: bool = True, + **kwargs: Any, ) -> PGVector: """Return VectorStore initialized from documents and embeddings.""" embeddings = embedding.embed_documents(list(texts)) @@ -997,18 +1458,47 @@ def from_texts( **kwargs, ) + @classmethod + async def afrom_texts( + cls: Type[PGVector], + texts: List[str], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + ids: Optional[List[str]] = None, + pre_delete_collection: bool = False, + *, + use_jsonb: bool = True, + **kwargs: Any, + ) -> PGVector: + """Return VectorStore initialized from documents and embeddings.""" + embeddings = embedding.embed_documents(list(texts)) + return await cls.__afrom( + texts, + embeddings, + embedding, + metadatas=metadatas, + ids=ids, + collection_name=collection_name, + distance_strategy=distance_strategy, + pre_delete_collection=pre_delete_collection, + use_jsonb=use_jsonb, + **kwargs, + ) + @classmethod def from_embeddings( - cls, - text_embeddings: List[Tuple[str, List[float]]], - embedding: Embeddings, - *, - metadatas: Optional[List[dict]] = None, - collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - ids: Optional[List[str]] = None, - pre_delete_collection: bool = False, - **kwargs: Any, + cls, + text_embeddings: List[Tuple[str, List[float]]], + embedding: Embeddings, + *, + metadatas: Optional[List[dict]] = None, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + ids: Optional[List[str]] = None, + pre_delete_collection: bool = False, + **kwargs: Any, ) -> PGVector: """Construct PGVector wrapper from raw documents and embeddings. @@ -1053,16 +1543,61 @@ def from_embeddings( **kwargs, ) + @classmethod + async def afrom_embeddings( + cls, + text_embeddings: List[Tuple[str, List[float]]], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + ids: Optional[List[str]] = None, + pre_delete_collection: bool = False, + **kwargs: Any, + ) -> PGVector: + """Construct PGVector wrapper from raw documents and pre- + generated embeddings. + + Return VectorStore initialized from documents and embeddings. + Postgres connection string is required + "Either pass it as a parameter + or set the PGVECTOR_CONNECTION_STRING environment variable. + + Example: + .. code-block:: python + + from langchain_community.vectorstores import PGVector + from langchain_community.embeddings import OpenAIEmbeddings + embeddings = OpenAIEmbeddings() + text_embeddings = embeddings.embed_documents(texts) + text_embedding_pairs = list(zip(texts, text_embeddings)) + faiss = PGVector.from_embeddings(text_embedding_pairs, embeddings) + """ + texts = [t[0] for t in text_embeddings] + embeddings = [t[1] for t in text_embeddings] + + return await cls.__afrom( + texts, + embeddings, + embedding, + metadatas=metadatas, + ids=ids, + collection_name=collection_name, + distance_strategy=distance_strategy, + pre_delete_collection=pre_delete_collection, + **kwargs, + ) + @classmethod def from_existing_index( - cls: Type[PGVector], - embedding: Embeddings, - *, - collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - pre_delete_collection: bool = False, - connection: Optional[Connection] = None, - **kwargs: Any, + cls: Type[PGVector], + embedding: Embeddings, + *, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + pre_delete_collection: bool = False, + connection: Optional[Connection] = None, + **kwargs: Any, ) -> PGVector: """ Get instance of an existing PGVector store.This method will @@ -1080,11 +1615,39 @@ def from_existing_index( return store + @classmethod + async def afrom_existing_index( + cls: Type[PGVector], + embedding: Embeddings, + *, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + pre_delete_collection: bool = False, + connection: Optional[Connection] = None, + **kwargs: Any, + ) -> PGVector: + """ + Get instance of an existing PGVector store.This method will + return the instance of the store without inserting any new + embeddings + """ + store = cls( # FIXME: créate + connection=connection, + collection_name=collection_name, + embeddings=embedding, + distance_strategy=distance_strategy, + pre_delete_collection=pre_delete_collection, + async_mode=True, + **kwargs, + ) + + return store + @classmethod def get_connection_string(cls, kwargs: Dict[str, Any]) -> str: connection_string: str = get_from_dict_or_env( data=kwargs, - key="connection_string", + key="connection", env_key="PGVECTOR_CONNECTION_STRING", ) @@ -1099,17 +1662,17 @@ def get_connection_string(cls, kwargs: Dict[str, Any]) -> str: @classmethod def from_documents( - cls: Type[PGVector], - documents: List[Document], - embedding: Embeddings, - *, - connection: Optional[Connection] = None, - collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - ids: Optional[List[str]] = None, - pre_delete_collection: bool = False, - use_jsonb: bool = True, - **kwargs: Any, + cls: Type[PGVector], + documents: List[Document], + embedding: Embeddings, + *, + connection: Optional[Connection] = None, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + ids: Optional[List[str]] = None, + pre_delete_collection: bool = False, + use_jsonb: bool = True, + **kwargs: Any, ) -> PGVector: """Return VectorStore initialized from documents and embeddings.""" @@ -1129,15 +1692,53 @@ def from_documents( **kwargs, ) + @classmethod + async def afrom_documents( + cls: Type[PGVector], + documents: List[Document], + embedding: Embeddings, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + ids: Optional[List[str]] = None, + pre_delete_collection: bool = False, + *, + use_jsonb: bool = True, + **kwargs: Any, + ) -> PGVector: + """ + Return VectorStore initialized from documents and embeddings. + Postgres connection string is required + "Either pass it as a parameter + or set the PGVECTOR_CONNECTION_STRING environment variable. + """ + + texts = [d.page_content for d in documents] + metadatas = [d.metadata for d in documents] + connection_string = cls.get_connection_string(kwargs) + + kwargs["connection"] = connection_string + + return await cls.afrom_texts( + texts=texts, + pre_delete_collection=pre_delete_collection, + embedding=embedding, + distance_strategy=distance_strategy, + metadatas=metadatas, + ids=ids, + collection_name=collection_name, + use_jsonb=use_jsonb, + **kwargs, + ) + @classmethod def connection_string_from_db_params( - cls, - driver: str, - host: str, - port: int, - database: str, - user: str, - password: str, + cls, + driver: str, + host: str, + port: int, + database: str, + user: str, + password: str, ) -> str: """Return connection string from database parameters.""" if driver != "psycopg": @@ -1172,13 +1773,13 @@ def _select_relevance_score_fn(self) -> Callable[[float], float]: ) def max_marginal_relevance_search_with_score_by_vector( - self, - embedding: List[float], - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - filter: Optional[Dict[str, str]] = None, - **kwargs: Any, + self, + embedding: List[float], + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[Dict[str, str]] = None, + **kwargs: Any, ) -> List[Tuple[Document, float]]: """Return docs selected using the maximal marginal relevance with score to embedding vector. @@ -1216,14 +1817,62 @@ def max_marginal_relevance_search_with_score_by_vector( return [r for i, r in enumerate(candidates) if i in mmr_selected] + async def amax_marginal_relevance_search_with_score_by_vector( + self, + embedding: List[float], + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[Dict[str, str]] = None, + **kwargs: Any, + ) -> List[Tuple[Document, float]]: + """Return docs selected using the maximal marginal relevance with score + to embedding vector. + + Maximal marginal relevance optimizes for similarity to query AND diversity + among selected documents. + + Args: + embedding: Embedding to look up documents similar to. + k (int): Number of Documents to return. Defaults to 4. + fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. + Defaults to 20. + lambda_mult (float): Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Defaults to 0.5. + filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + + Returns: + List[Tuple[Document, float]]: List of Documents selected by maximal marginal + relevance to the query and score for each. + """ + async with self._session_maker() as session: + results = await self.__aquery_collection(session=session, + embedding=embedding, k=fetch_k, + filter=filter) + + embedding_list = [result.EmbeddingStore.embedding for result in results] + + mmr_selected = maximal_marginal_relevance( + np.array(embedding, dtype=np.float32), + embedding_list, + k=k, + lambda_mult=lambda_mult, + ) + + candidates = self._results_to_docs_and_scores(results) + + return [r for i, r in enumerate(candidates) if i in mmr_selected] + def max_marginal_relevance_search( - self, - query: str, - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - filter: Optional[Dict[str, str]] = None, - **kwargs: Any, + self, + query: str, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[Dict[str, str]] = None, + **kwargs: Any, ) -> List[Document]: """Return docs selected using the maximal marginal relevance. @@ -1254,14 +1903,52 @@ def max_marginal_relevance_search( **kwargs, ) + async def amax_marginal_relevance_search( + self, + query: str, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[Dict[str, str]] = None, + **kwargs: Any, + ) -> List[Document]: + """Return docs selected using the maximal marginal relevance. + + Maximal marginal relevance optimizes for similarity to query AND diversity + among selected documents. + + Args: + query (str): Text to look up documents similar to. + k (int): Number of Documents to return. Defaults to 4. + fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. + Defaults to 20. + lambda_mult (float): Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Defaults to 0.5. + filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + + Returns: + List[Document]: List of Documents selected by maximal marginal relevance. + """ + embedding = self.embedding_function.embed_query(query) + return await self.amax_marginal_relevance_search_by_vector( + embedding, + k=k, + fetch_k=fetch_k, + lambda_mult=lambda_mult, + filter=filter, + **kwargs, + ) + def max_marginal_relevance_search_with_score( - self, - query: str, - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - filter: Optional[dict] = None, - **kwargs: Any, + self, + query: str, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[dict] = None, + **kwargs: Any, ) -> List[Tuple[Document, float]]: """Return docs selected using the maximal marginal relevance with score. @@ -1294,14 +1981,54 @@ def max_marginal_relevance_search_with_score( ) return docs + async def amax_marginal_relevance_search_with_score( + self, + query: str, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[dict] = None, + **kwargs: Any, + ) -> List[Tuple[Document, float]]: + """Return docs selected using the maximal marginal relevance with score. + + Maximal marginal relevance optimizes for similarity to query AND diversity + among selected documents. + + Args: + query (str): Text to look up documents similar to. + k (int): Number of Documents to return. Defaults to 4. + fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. + Defaults to 20. + lambda_mult (float): Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Defaults to 0.5. + filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + + Returns: + List[Tuple[Document, float]]: List of Documents selected by maximal marginal + relevance to the query and score for each. + """ + embedding = self.embedding_function.embed_query(query) + docs = await self.amax_marginal_relevance_search_with_score_by_vector( + embedding=embedding, + k=k, + fetch_k=fetch_k, + lambda_mult=lambda_mult, + filter=filter, + **kwargs, + ) + return docs + def max_marginal_relevance_search_by_vector( - self, - embedding: List[float], - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - filter: Optional[Dict[str, str]] = None, - **kwargs: Any, + self, + embedding: List[float], + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[Dict[str, str]] = None, + **kwargs: Any, ) -> List[Document]: """Return docs selected using the maximal marginal relevance to embedding vector. @@ -1335,22 +2062,35 @@ def max_marginal_relevance_search_by_vector( return _results_to_docs(docs_and_scores) async def amax_marginal_relevance_search_by_vector( - self, - embedding: List[float], - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - filter: Optional[Dict[str, str]] = None, - **kwargs: Any, + self, + embedding: List[float], + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[Dict[str, str]] = None, + **kwargs: Any, ) -> List[Document]: - """Return docs selected using the maximal marginal relevance.""" - - # This is a temporary workaround to make the similarity search - # asynchronous. The proper solution is to make the similarity search - # asynchronous in the vector store implementations. - return await run_in_executor( - None, - self.max_marginal_relevance_search_by_vector, + """Return docs selected using the maximal marginal relevance + to embedding vector. + + Maximal marginal relevance optimizes for similarity to query AND diversity + among selected documents. + + Args: + embedding (str): Text to look up documents similar to. + k (int): Number of Documents to return. Defaults to 4. + fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. + Defaults to 20. + lambda_mult (float): Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Defaults to 0.5. + filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + + Returns: + List[Document]: List of Documents selected by maximal marginal relevance. + """ + docs_and_scores = await self.amax_marginal_relevance_search_with_score_by_vector( embedding, k=k, fetch_k=fetch_k, @@ -1358,3 +2098,5 @@ async def amax_marginal_relevance_search_by_vector( filter=filter, **kwargs, ) + + return _results_to_docs(docs_and_scores) diff --git a/langchain_postgres/vectorstores_ppr.py b/langchain_postgres/vectorstores_ppr.py deleted file mode 100644 index 1989a83a..00000000 --- a/langchain_postgres/vectorstores_ppr.py +++ /dev/null @@ -1,1996 +0,0 @@ -from __future__ import annotations - -import asyncio -import contextlib -import enum -import json -import logging -import uuid -from typing import ( - Any, - Callable, - Dict, - Generator, - Iterable, - List, - Optional, - Tuple, - Type, Union, -) - -import numpy as np -import sqlalchemy -from langchain_core._api import warn_deprecated -from sqlalchemy import SQLColumnExpression, cast, delete, func, select, Engine -from sqlalchemy.dialects.postgresql import JSON, JSONB, JSONPATH, UUID, insert -from sqlalchemy.orm import Session, relationship, sessionmaker -# TODO: accepter l'absence de l'option async lors des imports -from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker - -try: - from sqlalchemy.orm import declarative_base -except ImportError: - from sqlalchemy.ext.declarative import declarative_base - -from langchain_core.documents import Document -from langchain_core.embeddings import Embeddings -from langchain_core.runnables.config import run_in_executor -from langchain_core.utils import get_from_dict_or_env -from langchain_core.vectorstores import VectorStore - -from langchain_postgres._utils import maximal_marginal_relevance - - -class DistanceStrategy(str, enum.Enum): - """Enumerator of the Distance strategies.""" - - EUCLIDEAN = "l2" - COSINE = "cosine" - MAX_INNER_PRODUCT = "inner" - - -DEFAULT_DISTANCE_STRATEGY = DistanceStrategy.COSINE - -Base = declarative_base() # type: Any - -_LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain" - - -# class BaseModel(Base): -# """Base model for the SQL stores.""" -# -# __abstract__ = True -# uuid = sqlalchemy.Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) -# - -_classes: Any = None - -COMPARISONS_TO_NATIVE = { - "$eq": "==", - "$ne": "!=", - "$lt": "<", - "$lte": "<=", - "$gt": ">", - "$gte": ">=", -} - -SPECIAL_CASED_OPERATORS = { - "$in", - "$nin", - "$between", -} - -TEXT_OPERATORS = { - "$like", - "$ilike", -} - -LOGICAL_OPERATORS = {"$and", "$or"} - -SUPPORTED_OPERATORS = ( - set(COMPARISONS_TO_NATIVE) - .union(TEXT_OPERATORS) - .union(LOGICAL_OPERATORS) - .union(SPECIAL_CASED_OPERATORS) -) - - -def _get_embedding_collection_store(vector_dimension: Optional[int] = None) -> Any: - global _classes - if _classes is not None: - return _classes - - from pgvector.sqlalchemy import Vector # type: ignore - - class CollectionStore(Base): - """Collection store.""" - - __tablename__ = "langchain_pg_collection" - - uuid = sqlalchemy.Column( - UUID(as_uuid=True), primary_key=True, default=uuid.uuid4 - ) - name = sqlalchemy.Column(sqlalchemy.String, nullable=False, unique=True) - cmetadata = sqlalchemy.Column(JSON) - - embeddings = relationship( - "EmbeddingStore", - back_populates="collection", - passive_deletes=True, - ) - - @classmethod - def get_by_name( - cls, session: Session, name: str - ) -> Optional["CollectionStore"]: - # FIXME return session.query(cls).filter(cls.name == name).first() # type: ignore - return session.execute( - select(cls).filter(cls.name == name)).scalars().first() - - # @classmethod - # async def aget_by_name( - # cls, session: AsyncSession, name: str - # ) -> Optional["CollectionStore"]: - # stmt = select(cls).filter(cls.name == name) - # # return await session.execute(stmt) # FIXME - # return (await session.execute(stmt)).scalars().first() # FIXME - # - # # stmt = select(cls).filter(cls.name == name) - # # result = await session.execute(stmt) - # # x = result.scalars() - # # return session.query(cls).filter(cls.name == name).first() - - @classmethod - def get_or_create( - cls, - session: Session, - name: str, - cmetadata: Optional[dict] = None, - ) -> Tuple["CollectionStore", bool]: - """Get or create a collection. - Returns: - Where the bool is True if the collection was created. - """ # noqa: E501 - created = False - collection = cls.get_by_name(session, name) - if collection: - return collection, created - - collection = cls(name=name, cmetadata=cmetadata) - session.add(collection) - session.commit() # FIXME PPR semble utile - created = True - return collection, created - - # @classmethod - # async def aget_or_create( - # cls, - # session: AsyncSession, - # name: str, - # cmetadata: Optional[dict] = None, - # ) -> Tuple["CollectionStore", bool]: - # """ - # Get or create a collection. - # Returns [Collection, bool] where the bool is True if the collection was created. - # """ # noqa: E501 - # created = False - # collection = await cls.aget_by_name(session, name) - # if collection: - # return collection, created - # - # collection = cls(name=name, cmetadata=cmetadata) - # session.add(collection) - # await session.commit() - # created = True - # return collection, created - - class EmbeddingStore(Base): - """Embedding store.""" - - __tablename__ = "langchain_pg_embedding" - - id = sqlalchemy.Column( - sqlalchemy.String, nullable=True, primary_key=True, index=True, unique=True - ) - - collection_id = sqlalchemy.Column( - UUID(as_uuid=True), - sqlalchemy.ForeignKey( - f"{CollectionStore.__tablename__}.uuid", - ondelete="CASCADE", - ), - ) - collection = relationship(CollectionStore, back_populates="embeddings") - - embedding: Vector = sqlalchemy.Column(Vector(vector_dimension)) - document = sqlalchemy.Column(sqlalchemy.String, nullable=True) - cmetadata = sqlalchemy.Column(JSONB, nullable=True) - - __table_args__ = ( - sqlalchemy.Index( - "ix_cmetadata_gin", - "cmetadata", - postgresql_using="gin", - postgresql_ops={"cmetadata": "jsonb_path_ops"}, - ), - ) - - _classes = (EmbeddingStore, CollectionStore) - - return _classes - - -def _results_to_docs(docs_and_scores: Any) -> List[Document]: - """Return docs from docs and scores.""" - return [doc for doc, _ in docs_and_scores] - - -Connection = Union[sqlalchemy.engine.Engine, str] - - -class PGVector(VectorStore): - """Vectorstore implementation using Postgres as the backend. - - Currently, there is no mechanism for supporting data migration. - - So breaking changes in the vectorstore schema will require the user to recreate - the tables and re-add the documents. - - If this is a concern, please use a different vectorstore. If - not, this implementation should be fine for your use case. - - To use this vectorstore you need to have the `vector` extension installed. - The `vector` extension is a Postgres extension that provides vector - similarity search capabilities. - - ```sh - docker run --name pgvector-container -e POSTGRES_PASSWORD=... - -d pgvector/pgvector:pg16 - ``` - - Example: - .. code-block:: python - - from langchain_postgres.vectorstores import PGVector - from langchain_openai.embeddings import OpenAIEmbeddings - - connection_string = "postgresql+psycopg://..." - collection_name = "state_of_the_union_test" - embeddings = OpenAIEmbeddings() - vectorstore = PGVector.from_documents( - embedding=embeddings, - documents=docs, - connection=connection_string, - collection_name=collection_name, - use_jsonb=True, - ) - - - This code has been ported over from langchain_community with minimal changes - to allow users to easily transition from langchain_community to langchain_postgres. - - Some changes had to be made to address issues with the community implementation: - * langchain_postgres now works with psycopg3. Please update your - connection strings from `postgresql+psycopg2://...` to - `postgresql+psycopg://langchain:langchain@...` - (yes, the driver name is `psycopg` not `psycopg3`) - * The schema of the embedding store and collection have been changed to make - add_documents work correctly with user specified ids, specifically - when overwriting existing documents. - You will need to recreate the tables if you are using an existing database. - * A Connection object has to be provided explicitly. Connections will not be - picked up automatically based on env variables. - """ - - def __init__( - self, - embeddings: Embeddings, - *, - connection: Optional[Connection] = None, - embedding_length: Optional[int] = None, - collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, - collection_metadata: Optional[dict] = None, - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - pre_delete_collection: bool = False, - logger: Optional[logging.Logger] = None, - relevance_score_fn: Optional[Callable[[float], float]] = None, - engine_args: Optional[dict[str, Any]] = None, - use_jsonb: bool = True, - create_extension: bool = True, - ) -> None: - """Initialize the PGVector store. - - Args: - connection: Postgres connection string. - embeddings: Any embedding function implementing - `langchain.embeddings.base.Embeddings` interface. - embedding_length: The length of the embedding vector. (default: None) - NOTE: This is not mandatory. Defining it will prevent vectors of - any other size to be added to the embeddings table but, without it, - the embeddings can't be indexed. - collection_name: The name of the collection to use. (default: langchain) - NOTE: This is not the name of the table, but the name of the collection. - The tables will be created when initializing the store (if not exists) - So, make sure the user has the right permissions to create tables. - distance_strategy: The distance strategy to use. (default: COSINE) - pre_delete_collection: If True, will delete the collection if it exists. - (default: False). Useful for testing. - engine_args: SQLAlchemy's create engine arguments. - use_jsonb: Use JSONB instead of JSON for metadata. (default: True) - Strongly discouraged from using JSON as it's not as efficient - for querying. - It's provided here for backwards compatibility with older versions, - and will be removed in the future. - create_extension: If True, will create the vector extension if it - doesn't exist. disabling creation is useful when using ReadOnly - Databases. - """ - self.embedding_function = embeddings - self._embedding_length = embedding_length - self.collection_name = collection_name - self.collection_metadata = collection_metadata - self._distance_strategy = distance_strategy - self.pre_delete_collection = pre_delete_collection - self.logger = logger or logging.getLogger(__name__) - self.override_relevance_score_fn = relevance_score_fn - - if isinstance(connection, str): - self._engine = sqlalchemy.create_engine( - url=connection, **(engine_args or {}) - ) - elif isinstance(connection, sqlalchemy.engine.Engine): - self._engine = connection - else: - raise ValueError( - "connection should be a connection string or an instance of " - "sqlalchemy.engine.Engine" - ) - - self._session_maker = sessionmaker(bind=self._engine) - - self.use_jsonb = use_jsonb - self.create_extension = create_extension - - if not use_jsonb: - # Replace with a deprecation warning. - raise NotImplementedError("use_jsonb=False is no longer supported.") - self.__post_init__() - - def __post_init__( - self, - ) -> None: - """Initialize the store.""" - if self.create_extension: - self.create_vector_extension() - - EmbeddingStore, CollectionStore = _get_embedding_collection_store( - self._embedding_length - ) - self.CollectionStore = CollectionStore - self.EmbeddingStore = EmbeddingStore - self.create_tables_if_not_exists() - self.create_collection() - - # async def __apost_init__( - # self, - # ) -> None: - # - # if self.async_mode: - # self._session_maker = self._build_async_sessionmaker() - # """Initialize the store.""" - # EmbeddingStore, CollectionStore = _get_embedding_collection_store( - # self._embedding_length, use_jsonb=self.use_jsonb - # ) - # self.CollectionStore = CollectionStore - # self.EmbeddingStore = EmbeddingStore - # - # if self.create_extension: - # await self.acreate_vector_extension() - # - # await self.acreate_tables_if_not_exists() - # await self.acreate_collection() - - def __del__(self) -> None: - if isinstance(self._bind, sqlalchemy.engine.Connection): - if self.async_mode: - asyncio.run(self._bind.close()) - else: - self._bind.close() - - @property - def embeddings(self) -> Embeddings: - return self.embedding_function - - def _create_engine(self, async_mode: bool = False) -> sqlalchemy.engine.Engine: - if async_mode: - from sqlalchemy.ext.asyncio import create_async_engine - # FIXME: gérer appel async sur un sync - from sqlalchemy.exc import InvalidRequestError - try: - return create_async_engine( - url=self.connection_string, - isolation_level="REPEATABLE READ", # FIXME: merge avec la suite ? - echo=True, # FIXME: a virer - **self.engine_args - ) - except InvalidRequestError: - pass # Ignore and return the synchrone version - logging.warning("Use a synchrone SQL engine !") - return sqlalchemy.create_engine(url=self.connection_string, - echo=True, # FIXME: a virer - **self.engine_args) - - def create_vector_extension(self) -> None: - try: - with self._session_maker() as session: # type: ignore[arg-type] - # The advisor lock fixes issue arising from concurrent - # creation of the vector extension. - # https://github.com/langchain-ai/langchain/issues/12933 - # For more information see: - # https://www.postgresql.org/docs/16/explicit-locking.html#ADVISORY-LOCKS - statement = sqlalchemy.text( - "BEGIN;" - "SELECT pg_advisory_xact_lock(1573678846307946496);" - "CREATE EXTENSION IF NOT EXISTS vector;" - "COMMIT;" - ) - session.execute(statement) - session.commit() - except Exception as e: - raise Exception(f"Failed to create vector extension: {e}") from e - - # async def acreate_vector_extension(self) -> None: - # try: - # async with self._amake_session() as session: - # # The advisor lock fixes issue arising from concurrent - # # creation of the vector extension. - # # https://github.com/langchain-ai/langchain/issues/12933 - # # For more information see: - # # https://www.postgresql.org/docs/16/explicit-locking.html#ADVISORY-LOCKS - # await session.execute( - # sqlalchemy.text( - # "SELECT pg_advisory_xact_lock(1573678846307946496)")) - # await session.execute( - # sqlalchemy.text("CREATE EXTENSION IF NOT EXISTS vector")) - # except Exception as e: - raise Exception(f"Failed to create vector extension: {e}") from e - - def create_tables_if_not_exists(self) -> None: - with self._session_maker() as session: - Base.metadata.create_all(session.get_bind()) - - async def acreate_tables_if_not_exists(self) -> None: - if isinstance(self._bind, sqlalchemy.ext.asyncio.engine.AsyncConnection): - await self._bind.run_sync(Base.metadata.create_all) - else: - async with self._bind.begin() as conn: # FIXME: session.run_sync existe - await conn.run_sync(Base.metadata.create_all) - # async with self._amake_session() as session: - # await session.run_sync(Base.metadata.create_all) - - def drop_tables(self) -> None: - with self._session_maker() as session: - Base.metadata.drop_all(session.get_bind()) - - # async def adrop_tables(self) -> None: - # async with self._amake_session() as session: - # await session.run_sync(Base.metadata.drop_all) - - def create_collection(self) -> None: - if self.pre_delete_collection: - self.delete_collection() - with self._session_maker() as session: - self.CollectionStore.get_or_create( - session, self.collection_name, cmetadata=self.collection_metadata - ) - - # async def acreate_collection(self) -> None: - # async with self._amake_session() as session: - # if self.pre_delete_collection: - # await self._adelete_collection(session) - # await self.CollectionStore.aget_or_create( - # session, self.collection_name, cmetadata=self.collection_metadata - # ) - - def _delete_collection(self,session: Session) -> None: - self.logger.debug("Trying to delete collection") - collection = self.get_collection(session) - if not collection: - self.logger.warning("Collection not found") - return - session.delete(collection) - - # async def _adelete_collection(self, session: AsyncSession) -> None: - # self.logger.debug("Trying to delete collection") - # collection = await self.aget_collection(session) - # if not collection: - # self.logger.warning("Collection not found") - # return - # await session.delete(collection) - - # def delete_collection(self) -> None: - # with self._session_maker() as session: - # self._delete_collection(session) - def delete_collection(self) -> None: - self.logger.debug("Trying to delete collection") - with self._session_maker() as session: # type: ignore[arg-type] - collection = self.get_collection(session) - if not collection: - self.logger.warning("Collection not found") - return - session.delete(collection) - session.commit() - - def delete( - self, - ids: Optional[List[str]] = None, - collection_only: bool = False, - **kwargs: Any, - ) -> None: - """Delete vectors by ids or uuids. - - Args: - ids: List of ids to delete. - collection_only: Only delete ids in the collection. - """ - with self._session_maker() as session: - if ids is not None: - self.logger.debug( - "Trying to delete vectors by ids (represented by the model " - "using the custom ids field)" - ) - - stmt = delete(self.EmbeddingStore) - - if collection_only: - collection = self.get_collection(session) - if not collection: - self.logger.warning("Collection not found") - return - - stmt = stmt.where( - self.EmbeddingStore.collection_id == collection.uuid - ) - - stmt = stmt.where(self.EmbeddingStore.id.in_(ids)) - session.execute(stmt) - session.commit() - - # async def adelete( - # self, - # ids: Optional[List[str]] = None, - # collection_only: bool = False, - # **kwargs: Any, - # ) -> None: - # """Delete vectors by ids or uuids. - # - # Args: - # ids: List of ids to delete. - # collection_only: Only delete ids in the collection. - # """ - # async with self._amake_session() as session: - # if ids is not None: - # self.logger.debug( - # "Trying to delete vectors by ids (represented by the model " - # "using the custom ids field)" - # ) - # - # stmt = delete(self.EmbeddingStore) - # - # if collection_only: - # collection = await self.aget_collection(session) - # if not collection: - # self.logger.warning("Collection not found") - # return - # - # stmt = stmt.where( - # self.EmbeddingStore.collection_id == collection.uuid - # ) - # - # stmt = stmt.where(self.EmbeddingStore.custom_id.in_(ids)) - # await session.execute(stmt) - - def get_collection(self, session: Session) -> Any: - return self.CollectionStore.get_by_name(session, self.collection_name) - - # async def aget_collection(self, session: AsyncSession) -> Any: - # return await self.CollectionStore.aget_by_name(session, self.collection_name) - - @classmethod - def __from( - cls, - texts: List[str], - embeddings: List[List[float]], - embedding: Embeddings, - metadatas: Optional[List[dict]] = None, - ids: Optional[List[str]] = None, - collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - connection: Optional[str] = None, - pre_delete_collection: bool = False, - *, - use_jsonb: bool = True, - **kwargs: Any, - ) -> PGVector: - if ids is None: - ids = [str(uuid.uuid1()) for _ in texts] - - if not metadatas: - metadatas = [{} for _ in texts] - - store = cls( - connection=connection, - collection_name=collection_name, - embeddings=embedding, - distance_strategy=distance_strategy, - pre_delete_collection=pre_delete_collection, - use_jsonb=use_jsonb, - **kwargs, - ) - - store.add_embeddings( - texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs - ) - - return store - - # @classmethod - # async def __afrom( - # cls, - # texts: List[str], - # embeddings: List[List[float]], - # embedding: Embeddings, - # metadatas: Optional[List[dict]] = None, - # ids: Optional[List[str]] = None, - # collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, - # distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - # connection_string: Optional[str] = None, - # pre_delete_collection: bool = False, - # *, - # use_jsonb: bool = False, - # **kwargs: Any, - # ) -> PGVector: - # if ids is None: - # ids = [str(uuid.uuid1()) for _ in texts] - # - # if not metadatas: - # metadatas = [{} for _ in texts] - # if connection_string is None: - # connection_string = cls.get_connection_string(kwargs) - # - # store = cls( - # connection_string=connection_string, - # collection_name=collection_name, - # embedding_function=embedding, - # distance_strategy=distance_strategy, - # pre_delete_collection=pre_delete_collection, - # use_jsonb=use_jsonb, - # async_mode=True, # FIXME - # **kwargs, - # ) - # # Second phase to create - # await store.__apost_init__() - # - # await store.aadd_embeddings( - # texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs - # ) - # - # return store - - def add_embeddings( - self, - texts: Iterable[str], - embeddings: List[List[float]], - metadatas: Optional[List[dict]] = None, - ids: Optional[List[str]] = None, - **kwargs: Any, - ) -> List[str]: - """Add embeddings to the vectorstore. - - Args: - texts: Iterable of strings to add to the vectorstore. - embeddings: List of list of embedding vectors. - metadatas: List of metadatas associated with the texts. - kwargs: vectorstore specific parameters - """ - if ids is None: - ids = [str(uuid.uuid1()) for _ in texts] - - if not metadatas: - metadatas = [{} for _ in texts] - - with self._session_maker() as session: # type: ignore[arg-type] - collection = self.get_collection(session) - if not collection: - raise ValueError("Collection not found") - data = [ - { - "id": id, - "collection_id": collection.uuid, - "embedding": embedding, - "document": text, - "cmetadata": metadata or {}, - } - for text, metadata, embedding, id in zip( - texts, metadatas, embeddings, ids - ) - ] - stmt = insert(self.EmbeddingStore).values(data) - on_conflict_stmt = stmt.on_conflict_do_update( - index_elements=["id"], - # Conflict detection based on these columns - set_={ - "embedding": stmt.excluded.embedding, - "document": stmt.excluded.document, - "cmetadata": stmt.excluded.cmetadata, - }, - ) - session.execute(on_conflict_stmt) - session.commit() - - return ids - - # async def aadd_embeddings( REFAIRE - # self, - # texts: Iterable[str], - # embeddings: List[List[float]], - # metadatas: Optional[List[dict]] = None, - # ids: Optional[List[str]] = None, - # **kwargs: Any, - # ) -> List[str]: - # """Add embeddings to the vectorstore. - # - # Args: - # texts: Iterable of strings to add to the vectorstore. - # embeddings: List of list of embedding vectors. - # metadatas: List of metadatas associated with the texts. - # kwargs: vectorstore specific parameters - # """ - # if ids is None: - # ids = [str(uuid.uuid1()) for _ in texts] - # - # if not metadatas: - # metadatas = [{} for _ in texts] - # - # async with self._amake_session() as session: - # collection = await self.aget_collection(session) - # if not collection: - # raise ValueError("Collection not found") - # documents = [] - # for text, metadata, embedding, id in zip(texts, metadatas, embeddings, ids): - # embedding_store = self.EmbeddingStore( - # embedding=embedding, - # document=text, - # cmetadata=metadata, - # custom_id=id, - # collection_id=collection.uuid, - # ) - # documents.append(embedding_store) - # await session.run_sync( - # lambda sync_session: sync_session.bulk_save_objects(documents)) - # - # return ids - - def add_texts( - self, - texts: Iterable[str], - metadatas: Optional[List[dict]] = None, - ids: Optional[List[str]] = None, - **kwargs: Any, - ) -> List[str]: - """Run more texts through the embeddings and add to the vectorstore. - - Args: - texts: Iterable of strings to add to the vectorstore. - metadatas: Optional list of metadatas associated with the texts. - kwargs: vectorstore specific parameters - - Returns: - List of ids from adding the texts into the vectorstore. - """ - embeddings = self.embedding_function.embed_documents(list(texts)) - return self.add_embeddings( - texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs - ) - - # async def aadd_texts( - # self, - # texts: Iterable[str], - # metadatas: Optional[List[dict]] = None, - # ids: Optional[List[str]] = None, - # **kwargs: Any, - # ) -> List[str]: - # """Run more texts through the embeddings and add to the vectorstore. - # - # Args: - # texts: Iterable of strings to add to the vectorstore. - # metadatas: Optional list of metadatas associated with the texts. - # kwargs: vectorstore specific parameters - # - # Returns: - # List of ids from adding the texts into the vectorstore. - # """ - # embeddings = await self.embedding_function.aembed_documents(list(texts)) - # return await self.aadd_embeddings( - # texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs - # ) - - def similarity_search( - self, - query: str, - k: int = 4, - filter: Optional[dict] = None, - **kwargs: Any, - ) -> List[Document]: - """Run similarity search with PGVector with distance. - - Args: - query (str): Query text to search for. - k (int): Number of results to return. Defaults to 4. - filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. - - Returns: - List of Documents most similar to the query. - """ - embedding = self.embedding_function.embed_query(text=query) - return self.similarity_search_by_vector( - embedding=embedding, - k=k, - filter=filter, - ) - - # async def asimilarity_search( - # self, - # query: str, - # k: int = 4, - # filter: Optional[dict] = None, - # **kwargs: Any, - # ) -> List[Document]: - # """Run similarity search with PGVector with distance. - # - # Args: - # query (str): Query text to search for. - # k (int): Number of results to return. Defaults to 4. - # filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. - # - # Returns: - # List of Documents most similar to the query. - # """ - # embedding = self.embedding_function.embed_query(text=query) - # return await self.asimilarity_search_by_vector( - # embedding=embedding, - # k=k, - # filter=filter, - # ) - - def similarity_search_with_score( - self, - query: str, - k: int = 4, - filter: Optional[dict] = None, - ) -> List[Tuple[Document, float]]: - """Return docs most similar to query. - - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. - - Returns: - List of Documents most similar to the query and score for each. - """ - embedding = self.embedding_function.embed_query(query) - docs = self.similarity_search_with_score_by_vector( - embedding=embedding, k=k, filter=filter - ) - return docs - - # async def asimilarity_search_with_score( - # self, - # query: str, - # k: int = 4, - # filter: Optional[dict] = None, - # ) -> List[Tuple[Document, float]]: - # """Return docs most similar to query. - # - # Args: - # query: Text to look up documents similar to. - # k: Number of Documents to return. Defaults to 4. - # filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. - # - # Returns: - # List of Documents most similar to the query and score for each. - # """ - # embedding = self.embedding_function.embed_query(query) - # docs = await self.asimilarity_search_with_score_by_vector( - # embedding=embedding, k=k, filter=filter - # ) - # return docs - - @property - def distance_strategy(self) -> Any: - if self._distance_strategy == DistanceStrategy.EUCLIDEAN: - return self.EmbeddingStore.embedding.l2_distance - elif self._distance_strategy == DistanceStrategy.COSINE: - return self.EmbeddingStore.embedding.cosine_distance - elif self._distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT: - return self.EmbeddingStore.embedding.max_inner_product - else: - raise ValueError( - f"Got unexpected value for distance: {self._distance_strategy}. " - f"Should be one of {', '.join([ds.value for ds in DistanceStrategy])}." - ) - - def similarity_search_with_score_by_vector( - self, - embedding: List[float], - k: int = 4, - filter: Optional[dict] = None, - ) -> List[Tuple[Document, float]]: - results = self.__query_collection(embedding=embedding, k=k, filter=filter) - - return self._results_to_docs_and_scores(results) - - def _results_to_docs_and_scores(self, results: Any) -> List[Tuple[Document, float]]: - """Return docs and scores from results.""" - docs = [ - ( - Document( - page_content=result.EmbeddingStore.document, - metadata=result.EmbeddingStore.cmetadata, - ), - result.distance if self.embedding_function is not None else None, - ) - for result in results - ] - return docs - - def _handle_field_filter( - self, - field: str, - value: Any, - ) -> SQLColumnExpression: - """Create a filter for a specific field. - - Args: - field: name of field - value: value to filter - If provided as is then this will be an equality filter - If provided as a dictionary then this will be a filter, the key - will be the operator and the value will be the value to filter by - - Returns: - sqlalchemy expression - """ - if not isinstance(field, str): - raise ValueError( - f"field should be a string but got: {type(field)} with value: {field}" - ) - - if field.startswith("$"): - raise ValueError( - f"Invalid filter condition. Expected a field but got an operator: " - f"{field}" - ) - - # Allow [a-zA-Z0-9_], disallow $ for now until we support escape characters - if not field.isidentifier(): - raise ValueError( - f"Invalid field name: {field}. Expected a valid identifier." - ) - - if isinstance(value, dict): - # This is a filter specification - if len(value) != 1: - raise ValueError( - "Invalid filter condition. Expected a value which " - "is a dictionary with a single key that corresponds to an operator " - f"but got a dictionary with {len(value)} keys. The first few " - f"keys are: {list(value.keys())[:3]}" - ) - operator, filter_value = list(value.items())[0] - # Verify that that operator is an operator - if operator not in SUPPORTED_OPERATORS: - raise ValueError( - f"Invalid operator: {operator}. " - f"Expected one of {SUPPORTED_OPERATORS}" - ) - else: # Then we assume an equality operator - operator = "$eq" - filter_value = value - - if operator in COMPARISONS_TO_NATIVE: - # Then we implement an equality filter - # native is trusted input - native = COMPARISONS_TO_NATIVE[operator] - return func.jsonb_path_match( - self.EmbeddingStore.cmetadata, - cast(f"$.{field} {native} $value", JSONPATH), - cast({"value": filter_value}, JSONB), - ) - elif operator == "$between": - # Use AND with two comparisons - low, high = filter_value - - lower_bound = func.jsonb_path_match( - self.EmbeddingStore.cmetadata, - cast(f"$.{field} >= $value", JSONPATH), - cast({"value": low}, JSONB), - ) - upper_bound = func.jsonb_path_match( - self.EmbeddingStore.cmetadata, - cast(f"$.{field} <= $value", JSONPATH), - cast({"value": high}, JSONB), - ) - return sqlalchemy.and_(lower_bound, upper_bound) - elif operator in {"$in", "$nin", "$like", "$ilike"}: - # We'll do force coercion to text - if operator in {"$in", "$nin"}: - for val in filter_value: - if not isinstance(val, (str, int, float)): - raise NotImplementedError( - f"Unsupported type: {type(val)} for value: {val}" - ) - - queried_field = self.EmbeddingStore.cmetadata[field].astext - - if operator in {"$in"}: - return queried_field.in_([str(val) for val in filter_value]) - elif operator in {"$nin"}: - return queried_field.nin_([str(val) for val in filter_value]) - elif operator in {"$like"}: - return queried_field.like(filter_value) - elif operator in {"$ilike"}: - return queried_field.ilike(filter_value) - else: - raise NotImplementedError() - else: - raise NotImplementedError() - - def _create_filter_clause_deprecated(self, key, value): # type: ignore[no-untyped-def] - """Deprecated functionality. - - This is for backwards compatibility with the JSON based schema for metadata. - It uses incorrect operator syntax (operators are not prefixed with $). - - This implementation is not efficient, and has bugs associated with - the way that it handles numeric filter clauses. - """ - IN, NIN, BETWEEN, GT, LT, NE = "in", "nin", "between", "gt", "lt", "ne" - EQ, LIKE, CONTAINS, OR, AND = "eq", "like", "contains", "or", "and" - - value_case_insensitive = {k.lower(): v for k, v in value.items()} - if IN in map(str.lower, value): - filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext.in_( - value_case_insensitive[IN] - ) - elif NIN in map(str.lower, value): - filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext.not_in( - value_case_insensitive[NIN] - ) - elif BETWEEN in map(str.lower, value): - filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext.between( - str(value_case_insensitive[BETWEEN][0]), - str(value_case_insensitive[BETWEEN][1]), - ) - elif GT in map(str.lower, value): - filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext > str( - value_case_insensitive[GT] - ) - elif LT in map(str.lower, value): - filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext < str( - value_case_insensitive[LT] - ) - elif NE in map(str.lower, value): - filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext != str( - value_case_insensitive[NE] - ) - elif EQ in map(str.lower, value): - filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext == str( - value_case_insensitive[EQ] - ) - elif LIKE in map(str.lower, value): - filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext.like( - value_case_insensitive[LIKE] - ) - elif CONTAINS in map(str.lower, value): - filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext.contains( - value_case_insensitive[CONTAINS] - ) - elif OR in map(str.lower, value): - or_clauses = [ - self._create_filter_clause(key, sub_value) - for sub_value in value_case_insensitive[OR] - ] - filter_by_metadata = sqlalchemy.or_(*or_clauses) - elif AND in map(str.lower, value): - and_clauses = [ - self._create_filter_clause(key, sub_value) - for sub_value in value_case_insensitive[AND] - ] - filter_by_metadata = sqlalchemy.and_(*and_clauses) - - else: - filter_by_metadata = None - - return filter_by_metadata - - def _create_filter_clause_json_deprecated( - self, filter: Any - ) -> List[SQLColumnExpression]: - """Convert filters from IR to SQL clauses. - - **DEPRECATED** This functionality will be deprecated in the future. - - It implements translation of filters for a schema that uses JSON - for metadata rather than the JSONB field which is more efficient - for querying. - """ - filter_clauses = [] - for key, value in filter.items(): - if isinstance(value, dict): - filter_by_metadata = self._create_filter_clause_deprecated(key, value) - - if filter_by_metadata is not None: - filter_clauses.append(filter_by_metadata) - else: - filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext == str( - value - ) - filter_clauses.append(filter_by_metadata) - return filter_clauses - - def _create_filter_clause(self, filters: Any) -> Any: - """Convert LangChain IR filter representation to matching SQLAlchemy clauses. - - At the top level, we still don't know if we're working with a field - or an operator for the keys. After we've determined that we can - call the appropriate logic to handle filter creation. - - Args: - filters: Dictionary of filters to apply to the query. - - Returns: - SQLAlchemy clause to apply to the query. - """ - if isinstance(filters, dict): - if len(filters) == 1: - # The only operators allowed at the top level are $AND and $OR - # First check if an operator or a field - key, value = list(filters.items())[0] - if key.startswith("$"): - # Then it's an operator - if key.lower() not in ["$and", "$or"]: - raise ValueError( - f"Invalid filter condition. Expected $and or $or " - f"but got: {key}" - ) - else: - # Then it's a field - return self._handle_field_filter(key, filters[key]) - - # Here we handle the $and and $or operators - if not isinstance(value, list): - raise ValueError( - f"Expected a list, but got {type(value)} for value: {value}" - ) - if key.lower() == "$and": - and_ = [self._create_filter_clause(el) for el in value] - if len(and_) > 1: - return sqlalchemy.and_(*and_) - elif len(and_) == 1: - return and_[0] - else: - raise ValueError( - "Invalid filter condition. Expected a dictionary " - "but got an empty dictionary" - ) - elif key.lower() == "$or": - or_ = [self._create_filter_clause(el) for el in value] - if len(or_) > 1: - return sqlalchemy.or_(*or_) - elif len(or_) == 1: - return or_[0] - else: - raise ValueError( - "Invalid filter condition. Expected a dictionary " - "but got an empty dictionary" - ) - else: - raise ValueError( - f"Invalid filter condition. Expected $and or $or " - f"but got: {key}" - ) - elif len(filters) > 1: - # Then all keys have to be fields (they cannot be operators) - for key in filters.keys(): - if key.startswith("$"): - raise ValueError( - f"Invalid filter condition. Expected a field but got: {key}" - ) - # These should all be fields and combined using an $and operator - and_ = [self._handle_field_filter(k, v) for k, v in filters.items()] - if len(and_) > 1: - return sqlalchemy.and_(*and_) - elif len(and_) == 1: - return and_[0] - else: - raise ValueError( - "Invalid filter condition. Expected a dictionary " - "but got an empty dictionary" - ) - else: - raise ValueError("Got an empty dictionary for filters.") - else: - raise ValueError( - f"Invalid type: Expected a dictionary but got type: {type(filters)}" - ) - - def __query_collection( - self, - embedding: List[float], - k: int = 4, - filter: Optional[Dict[str, str]] = None, - ) -> List[Any]: - """Query the collection.""" - with self._session_maker() as session: # type: ignore[arg-type] - collection = self.get_collection(session) - if not collection: - raise ValueError("Collection not found") - - filter_by = [self.EmbeddingStore.collection_id == collection.uuid] - if filter: - if self.use_jsonb: - filter_clauses = self._create_filter_clause(filter) - if filter_clauses is not None: - filter_by.append(filter_clauses) - else: - # Old way of doing things - filter_clauses = self._create_filter_clause_json_deprecated(filter) - filter_by.extend(filter_clauses) - - _type = self.EmbeddingStore - - results: List[Any] = ( - session.query( - self.EmbeddingStore, - self.distance_strategy(embedding).label("distance"), # type: ignore - ) - .filter(*filter_by) - .order_by(sqlalchemy.asc("distance")) - .join( - self.CollectionStore, - self.EmbeddingStore.collection_id == self.CollectionStore.uuid, - ) - .limit(k) - .all() - ) - - return results - - # async def __aquery_collection( # FIXME - # self, - # session: AsyncSession, - # embedding: List[float], - # k: int = 4, - # filter: Optional[Dict[str, str]] = None, - # ) -> List[Any]: - # """Query the collection.""" - # collection = await self.aget_collection(session) - # if not collection: - # raise ValueError("Collection not found") - # - # filter_by = [self.EmbeddingStore.collection_id == collection.uuid] - # if filter: - # if self.use_jsonb: - # filter_clauses = self._create_filter_clause(filter) - # if filter_clauses is not None: - # filter_by.append(filter_clauses) - # else: - # # Old way of doing things - # filter_clauses = self._create_filter_clause_json_deprecated(filter) - # filter_by.extend(filter_clauses) - # - # _type = self.EmbeddingStore - # stmt = (select(self.EmbeddingStore, - # self.distance_strategy(embedding).label("distance"), - # ) - # .filter(*filter_by) - # .order_by(sqlalchemy.asc("distance")) - # .join( - # self.CollectionStore, - # self.EmbeddingStore.collection_id == self.CollectionStore.uuid, - # ) - # .limit(k)) - # results: List[Any] = ( - # (await session.execute(stmt)).all() - # ) - # - # return results - - def similarity_search_by_vector( - self, - embedding: List[float], - k: int = 4, - filter: Optional[dict] = None, - **kwargs: Any, - ) -> List[Document]: - """Return docs most similar to embedding vector. - - Args: - embedding: Embedding to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. - - Returns: - List of Documents most similar to the query vector. - """ - docs_and_scores = self.similarity_search_with_score_by_vector( - embedding=embedding, k=k, filter=filter - ) - return _results_to_docs(docs_and_scores) - - # async def asimilarity_search_by_vector( - # self, - # embedding: List[float], - # k: int = 4, - # filter: Optional[dict] = None, - # **kwargs: Any, - # ) -> List[Document]: - # """Return docs most similar to embedding vector. - # - # Args: - # embedding: Embedding to look up documents similar to. - # k: Number of Documents to return. Defaults to 4. - # filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. - # - # Returns: - # List of Documents most similar to the query vector. - # """ - # docs_and_scores = await self.asimilarity_search_with_score_by_vector( - # embedding=embedding, k=k, filter=filter - # ) - # return _results_to_docs(docs_and_scores) - - @classmethod - def from_texts( - cls: Type[PGVector], - texts: List[str], - embedding: Embeddings, - metadatas: Optional[List[dict]] = None, - *, - collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - ids: Optional[List[str]] = None, - pre_delete_collection: bool = False, - use_jsonb: bool = True, - **kwargs: Any, - ) -> PGVector: - """Return VectorStore initialized from documents and embeddings.""" - embeddings = embedding.embed_documents(list(texts)) - - return cls.__from( - texts, - embeddings, - embedding, - metadatas=metadatas, - ids=ids, - collection_name=collection_name, - distance_strategy=distance_strategy, - pre_delete_collection=pre_delete_collection, - use_jsonb=use_jsonb, - **kwargs, - ) - - # @classmethod - # async def afrom_texts( - # cls: Type[PGVector], - # texts: List[str], - # embedding: Embeddings, - # metadatas: Optional[List[dict]] = None, - # collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, - # distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - # ids: Optional[List[str]] = None, - # pre_delete_collection: bool = False, - # *, - # use_jsonb: bool = False, - # **kwargs: Any, - # ) -> PGVector: - # """ - # Return VectorStore initialized from texts and embeddings. - # Postgres connection string is required - # "Either pass it as a parameter - # or set the PGVECTOR_CONNECTION_STRING environment variable. - # """ - # embeddings = embedding.embed_documents(list(texts)) - # - # return await cls.__afrom( - # texts, - # embeddings, - # embedding, - # metadatas=metadatas, - # ids=ids, - # collection_name=collection_name, - # distance_strategy=distance_strategy, - # pre_delete_collection=pre_delete_collection, - # use_jsonb=use_jsonb, - # **kwargs, - # ) - - @classmethod - def from_embeddings( - cls, - text_embeddings: List[Tuple[str, List[float]]], - embedding: Embeddings, - *, - metadatas: Optional[List[dict]] = None, - collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - ids: Optional[List[str]] = None, - pre_delete_collection: bool = False, - **kwargs: Any, - ) -> PGVector: - """Construct PGVector wrapper from raw documents and embeddings. - - Args: - text_embeddings: List of tuples of text and embeddings. - embedding: Embeddings object. - metadatas: Optional list of metadatas associated with the texts. - collection_name: Name of the collection. - distance_strategy: Distance strategy to use. - ids: Optional list of ids for the documents. - pre_delete_collection: If True, will delete the collection if it exists. - **Attention**: This will delete all the documents in the existing - collection. - kwargs: Additional arguments. - - Returns: - PGVector: PGVector instance. - - Example: - .. code-block:: python - - from langchain_postgres.vectorstores import PGVector - from langchain_openai.embeddings import OpenAIEmbeddings - - embeddings = OpenAIEmbeddings() - text_embeddings = embeddings.embed_documents(texts) - text_embedding_pairs = list(zip(texts, text_embeddings)) - vectorstore = PGVector.from_embeddings(text_embedding_pairs, embeddings) - """ - texts = [t[0] for t in text_embeddings] - embeddings = [t[1] for t in text_embeddings] - - return cls.__from( - texts, - embeddings, - embedding, - metadatas=metadatas, - ids=ids, - collection_name=collection_name, - distance_strategy=distance_strategy, - pre_delete_collection=pre_delete_collection, - **kwargs, - ) - - @classmethod - async def afrom_embeddings( - cls, - text_embeddings: List[Tuple[str, List[float]]], - embedding: Embeddings, - metadatas: Optional[List[dict]] = None, - collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - ids: Optional[List[str]] = None, - pre_delete_collection: bool = False, - **kwargs: Any, - ) -> PGVector: - """Construct PGVector wrapper from raw documents and pre- - generated embeddings. - - Return VectorStore initialized from documents and embeddings. - Postgres connection string is required - "Either pass it as a parameter - or set the PGVECTOR_CONNECTION_STRING environment variable. - - Example: - .. code-block:: python - - from langchain_community.vectorstores import PGVector - from langchain_community.embeddings import OpenAIEmbeddings - embeddings = OpenAIEmbeddings() - text_embeddings = embeddings.embed_documents(texts) - text_embedding_pairs = list(zip(texts, text_embeddings)) - faiss = PGVector.from_embeddings(text_embedding_pairs, embeddings) - """ - texts = [t[0] for t in text_embeddings] - embeddings = [t[1] for t in text_embeddings] - - return await cls.__afrom( - texts, - embeddings, - embedding, - metadatas=metadatas, - ids=ids, - collection_name=collection_name, - distance_strategy=distance_strategy, - pre_delete_collection=pre_delete_collection, - **kwargs, - ) - - @classmethod - def from_existing_index( - cls: Type[PGVector], - embedding: Embeddings, - *, - collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - pre_delete_collection: bool = False, - connection: Optional[Connection] = None, - **kwargs: Any, - ) -> PGVector: - """ - Get instance of an existing PGVector store.This method will - return the instance of the store without inserting any new - embeddings - """ - store = cls( - connection=connection, - collection_name=collection_name, - embeddings=embedding, - distance_strategy=distance_strategy, - pre_delete_collection=pre_delete_collection, - **kwargs, - ) - - return store - - @classmethod - def get_connection_string(cls, kwargs: Dict[str, Any]) -> str: - connection_string: str = get_from_dict_or_env( - data=kwargs, - key="connection_string", - env_key="PGVECTOR_CONNECTION_STRING", - ) - - if not connection_string: - raise ValueError( - "Postgres connection string is required" - "Either pass it as a parameter" - "or set the PGVECTOR_CONNECTION_STRING environment variable." - ) - - return connection_string - - @classmethod - def from_documents( - cls: Type[PGVector], - documents: List[Document], - embedding: Embeddings, - *, - connection: Optional[Connection] = None, - collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - ids: Optional[List[str]] = None, - pre_delete_collection: bool = False, - use_jsonb: bool = True, - **kwargs: Any, - ) -> PGVector: - """Return VectorStore initialized from documents and embeddings.""" - - texts = [d.page_content for d in documents] - metadatas = [d.metadata for d in documents] - - return cls.from_texts( - texts=texts, - pre_delete_collection=pre_delete_collection, - embedding=embedding, - distance_strategy=distance_strategy, - metadatas=metadatas, - connection=connection, - ids=ids, - collection_name=collection_name, - use_jsonb=use_jsonb, - **kwargs, - ) - - @classmethod - def connection_string_from_db_params( - cls, - driver: str, - host: str, - port: int, - database: str, - user: str, - password: str, - ) -> str: - """Return connection string from database parameters.""" - if driver != "psycopg": - raise NotImplementedError("Only psycopg3 driver is supported") - return f"postgresql+{driver}://{user}:{password}@{host}:{port}/{database}" - - def _select_relevance_score_fn(self) -> Callable[[float], float]: - """ - The 'correct' relevance function - may differ depending on a few things, including: - - the distance / similarity metric used by the VectorStore - - the scale of your embeddings (OpenAI's are unit normed. Many others are not!) - - embedding dimensionality - - etc. - """ - if self.override_relevance_score_fn is not None: - return self.override_relevance_score_fn - - # Default strategy is to rely on distance strategy provided - # in vectorstore constructor - if self._distance_strategy == DistanceStrategy.COSINE: - return self._cosine_relevance_score_fn - elif self._distance_strategy == DistanceStrategy.EUCLIDEAN: - return self._euclidean_relevance_score_fn - elif self._distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT: - return self._max_inner_product_relevance_score_fn - else: - raise ValueError( - "No supported normalization function" - f" for distance_strategy of {self._distance_strategy}." - "Consider providing relevance_score_fn to PGVector constructor." - ) - - def max_marginal_relevance_search_with_score_by_vector( - self, - embedding: List[float], - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - filter: Optional[Dict[str, str]] = None, - **kwargs: Any, - ) -> List[Tuple[Document, float]]: - """Return docs selected using the maximal marginal relevance with score - to embedding vector. - - Maximal marginal relevance optimizes for similarity to query AND diversity - among selected documents. - - Args: - embedding: Embedding to look up documents similar to. - k (int): Number of Documents to return. Defaults to 4. - fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. - Defaults to 20. - lambda_mult (float): Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - Defaults to 0.5. - filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. - - Returns: - List[Tuple[Document, float]]: List of Documents selected by maximal marginal - relevance to the query and score for each. - """ - results = self.__query_collection(embedding=embedding, k=fetch_k, filter=filter) - - embedding_list = [result.EmbeddingStore.embedding for result in results] - - mmr_selected = maximal_marginal_relevance( - np.array(embedding, dtype=np.float32), - embedding_list, - k=k, - lambda_mult=lambda_mult, - ) - - candidates = self._results_to_docs_and_scores(results) - - return [r for i, r in enumerate(candidates) if i in mmr_selected] - - # async def amax_marginal_relevance_search_with_score_by_vector( - # self, - # embedding: List[float], - # k: int = 4, - # fetch_k: int = 20, - # lambda_mult: float = 0.5, - # filter: Optional[Dict[str, str]] = None, - # **kwargs: Any, - # ) -> List[Tuple[Document, float]]: - # """Return docs selected using the maximal marginal relevance with score - # to embedding vector. - # - # Maximal marginal relevance optimizes for similarity to query AND diversity - # among selected documents. - # - # Args: - # embedding: Embedding to look up documents similar to. - # k (int): Number of Documents to return. Defaults to 4. - # fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. - # Defaults to 20. - # lambda_mult (float): Number between 0 and 1 that determines the degree - # of diversity among the results with 0 corresponding - # to maximum diversity and 1 to minimum diversity. - # Defaults to 0.5. - # filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. - # - # Returns: - # List[Tuple[Document, float]]: List of Documents selected by maximal marginal - # relevance to the query and score for each. - # """ - # with self._session_maker() as session: - # results = await self.__aquery_collection(session=session, - # embedding=embedding, k=fetch_k, - # filter=filter) - # - # embedding_list = [result.EmbeddingStore.embedding for result in results] - # - # mmr_selected = maximal_marginal_relevance( - # np.array(embedding, dtype=np.float32), - # embedding_list, - # k=k, - # lambda_mult=lambda_mult, - # ) - # - # candidates = self._results_to_docs_and_scores(results) - # - # return [r for i, r in enumerate(candidates) if i in mmr_selected] - - def max_marginal_relevance_search( - self, - query: str, - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - filter: Optional[Dict[str, str]] = None, - **kwargs: Any, - ) -> List[Document]: - """Return docs selected using the maximal marginal relevance. - - Maximal marginal relevance optimizes for similarity to query AND diversity - among selected documents. - - Args: - query (str): Text to look up documents similar to. - k (int): Number of Documents to return. Defaults to 4. - fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. - Defaults to 20. - lambda_mult (float): Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - Defaults to 0.5. - filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. - - Returns: - List[Document]: List of Documents selected by maximal marginal relevance. - """ - embedding = self.embedding_function.embed_query(query) - return self.max_marginal_relevance_search_by_vector( - embedding, - k=k, - fetch_k=fetch_k, - lambda_mult=lambda_mult, - filter=filter, - **kwargs, - ) - - # async def amax_marginal_relevance_search( - # self, - # query: str, - # k: int = 4, - # fetch_k: int = 20, - # lambda_mult: float = 0.5, - # filter: Optional[Dict[str, str]] = None, - # **kwargs: Any, - # ) -> List[Document]: - # """Return docs selected using the maximal marginal relevance. - # - # Maximal marginal relevance optimizes for similarity to query AND diversity - # among selected documents. - # - # Args: - # query (str): Text to look up documents similar to. - # k (int): Number of Documents to return. Defaults to 4. - # fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. - # Defaults to 20. - # lambda_mult (float): Number between 0 and 1 that determines the degree - # of diversity among the results with 0 corresponding - # to maximum diversity and 1 to minimum diversity. - # Defaults to 0.5. - # filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. - # - # Returns: - # List[Document]: List of Documents selected by maximal marginal relevance. - # """ - # embedding = self.embedding_function.embed_query(query) - # return await self.amax_marginal_relevance_search_by_vector( - # embedding, - # k=k, - # fetch_k=fetch_k, - # lambda_mult=lambda_mult, - # filter=filter, - # **kwargs, - # ) - - def max_marginal_relevance_search_with_score( - self, - query: str, - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - filter: Optional[dict] = None, - **kwargs: Any, - ) -> List[Tuple[Document, float]]: - """Return docs selected using the maximal marginal relevance with score. - - Maximal marginal relevance optimizes for similarity to query AND diversity - among selected documents. - - Args: - query (str): Text to look up documents similar to. - k (int): Number of Documents to return. Defaults to 4. - fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. - Defaults to 20. - lambda_mult (float): Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - Defaults to 0.5. - filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. - - Returns: - List[Tuple[Document, float]]: List of Documents selected by maximal marginal - relevance to the query and score for each. - """ - embedding = self.embedding_function.embed_query(query) - docs = self.max_marginal_relevance_search_with_score_by_vector( - embedding=embedding, - k=k, - fetch_k=fetch_k, - lambda_mult=lambda_mult, - filter=filter, - **kwargs, - ) - return docs - - # async def amax_marginal_relevance_search_with_score( - # self, - # query: str, - # k: int = 4, - # fetch_k: int = 20, - # lambda_mult: float = 0.5, - # filter: Optional[dict] = None, - # **kwargs: Any, - # ) -> List[Tuple[Document, float]]: - # """Return docs selected using the maximal marginal relevance with score. - # - # Maximal marginal relevance optimizes for similarity to query AND diversity - # among selected documents. - # - # Args: - # query (str): Text to look up documents similar to. - # k (int): Number of Documents to return. Defaults to 4. - # fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. - # Defaults to 20. - # lambda_mult (float): Number between 0 and 1 that determines the degree - # of diversity among the results with 0 corresponding - # to maximum diversity and 1 to minimum diversity. - # Defaults to 0.5. - # filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. - # - # Returns: - # List[Tuple[Document, float]]: List of Documents selected by maximal marginal - # relevance to the query and score for each. - # """ - # embedding = self.embedding_function.embed_query(query) - # docs = await self.amax_marginal_relevance_search_with_score_by_vector( - # embedding=embedding, - # k=k, - # fetch_k=fetch_k, - # lambda_mult=lambda_mult, - # filter=filter, - # **kwargs, - # ) - # return docs - - def max_marginal_relevance_search_by_vector( - self, - embedding: List[float], - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - filter: Optional[Dict[str, str]] = None, - **kwargs: Any, - ) -> List[Document]: - """Return docs selected using the maximal marginal relevance - to embedding vector. - - Maximal marginal relevance optimizes for similarity to query AND diversity - among selected documents. - - Args: - embedding (str): Text to look up documents similar to. - k (int): Number of Documents to return. Defaults to 4. - fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. - Defaults to 20. - lambda_mult (float): Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - Defaults to 0.5. - filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. - - Returns: - List[Document]: List of Documents selected by maximal marginal relevance. - """ - docs_and_scores = self.max_marginal_relevance_search_with_score_by_vector( - embedding, - k=k, - fetch_k=fetch_k, - lambda_mult=lambda_mult, - filter=filter, - **kwargs, - ) - - return _results_to_docs(docs_and_scores) - - async def amax_marginal_relevance_search_by_vector( - self, - embedding: List[float], - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - filter: Optional[Dict[str, str]] = None, - **kwargs: Any, - ) -> List[Document]: - """Return docs selected using the maximal marginal relevance.""" - - # This is a temporary workaround to make the similarity search - # asynchronous. The proper solution is to make the similarity search - # asynchronous in the vector store implementations. - return await run_in_executor( - None, - self.max_marginal_relevance_search_by_vector, - embedding, - k=k, - fetch_k=fetch_k, - lambda_mult=lambda_mult, - filter=filter, - **kwargs, - ) - - - # async def aadd_documents( # FIXME: remove - # self, documents: List[Document], **kwargs: Any - # ) -> List[str]: - # """Run more documents through the embeddings and add to the vectorstore. - # - # Args: - # documents (List[Document]: Documents to add to the vectorstore. - # - # Returns: - # List[str]: List of IDs of the added texts. - # """ - # texts = [doc.page_content for doc in documents] - # metadatas = [doc.metadata for doc in documents] - # return await self.aadd_texts(texts, metadatas, **kwargs) - - # async def adelete_collection(self) -> None: - # async with self._amake_session() as session: - # await self._adelete_collection(session) - - # def _build_sessionmaker(self) -> sessionmaker: - # return sessionmaker(bind=self._bind) - # - # def _build_async_sessionmaker(self) -> sessionmaker: - # return async_sessionmaker( - # bind=self._bind, - # sync_session_class=self._session_maker) - # - # @contextlib.contextmanager - # def _session_maker(self) -> Generator[Session, None, None]: - # """Create a context manager for the session, bind to _conn string.""" - # session = self._session_maker() - # session.begin() # FIXME: sans ? - # yield session - # session.commit() # FIXME - # - # @contextlib.asynccontextmanager - # async def _amake_session(self) -> Generator[AsyncSession, None, None]: - # """Create a context manager for the session, bind to _conn string.""" - # async_session: AsyncSession = self._session_maker() - # await async_session.begin() - # yield async_session - # await async_session.commit() # FIXME - # diff --git a/tests/unit_tests/test_vectorstore.py b/tests/unit_tests/test_vectorstore.py index 13b3e735..72280d05 100644 --- a/tests/unit_tests/test_vectorstore.py +++ b/tests/unit_tests/test_vectorstore.py @@ -5,10 +5,12 @@ import pytest from langchain_core.documents import Document -from langchain_postgres.vectorstores_ppr import ( +from langchain_postgres.vectorstores import ( SUPPORTED_OPERATORS, PGVector, ) +from sqlalchemy import select + from tests.unit_tests.fake_embeddings import FakeEmbeddings from tests.unit_tests.fixtures.filtering_test_cases import ( DOCUMENTS, @@ -38,7 +40,7 @@ def embed_query(self, text: str) -> List[float]: return [float(1.0)] * (ADA_TOKEN_COUNT - 1) + [float(0.0)] -def test_pgvector(pgvector: PGVector) -> None: +def test_pgvector() -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] docsearch = PGVector.from_texts( @@ -51,6 +53,21 @@ def test_pgvector(pgvector: PGVector) -> None: output = docsearch.similarity_search("foo", k=1) assert output == [Document(page_content="foo")] +# @pytest.mark.requires("xxx") # FIXME +@pytest.mark.asyncio +async def test_async_pgvector() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + docsearch = await PGVector.afrom_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + connection=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = await docsearch.asimilarity_search("foo", k=1) + assert output == [Document(page_content="foo")] + def test_pgvector_embeddings() -> None: """Test end to end construction with embeddings and search.""" @@ -68,6 +85,23 @@ def test_pgvector_embeddings() -> None: assert output == [Document(page_content="foo")] +@pytest.mark.asyncio +async def test_async_pgvector_embeddings() -> None: + """Test end to end construction with embeddings and search.""" + texts = ["foo", "bar", "baz"] + text_embeddings = FakeEmbeddingsWithAdaDimension().embed_documents(texts) + text_embedding_pairs = list(zip(texts, text_embeddings)) + docsearch = await PGVector.afrom_embeddings( + text_embeddings=text_embedding_pairs, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + connection=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = await docsearch.asimilarity_search("foo", k=1) + assert output == [Document(page_content="foo")] + + def test_pgvector_with_metadatas() -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] @@ -84,6 +118,23 @@ def test_pgvector_with_metadatas() -> None: assert output == [Document(page_content="foo", metadata={"page": "0"})] +@pytest.mark.asyncio +async def test_async_pgvector_with_metadatas() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = await PGVector.afrom_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = await docsearch.asimilarity_search("foo", k=1) + assert output == [Document(page_content="foo", metadata={"page": "0"})] + + def test_pgvector_with_metadatas_with_scores() -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] @@ -100,6 +151,23 @@ def test_pgvector_with_metadatas_with_scores() -> None: assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)] +@pytest.mark.asyncio +async def test_async_pgvector_with_metadatas_with_scores() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = await PGVector.afrom_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = await docsearch.asimilarity_search_with_score("foo", k=1) + assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)] + + def test_pgvector_with_filter_match() -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] @@ -116,6 +184,23 @@ def test_pgvector_with_filter_match() -> None: assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)] +@pytest.mark.asyncio +async def test_async_pgvector_with_filter_match() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = await PGVector.afrom_texts( + texts=texts, + collection_name="test_collection_filter", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = await docsearch.asimilarity_search_with_score("foo", k=1, filter={"page": "0"}) + assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)] + + def test_pgvector_with_filter_distant_match() -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] @@ -133,6 +218,24 @@ def test_pgvector_with_filter_distant_match() -> None: (Document(page_content="baz", metadata={"page": "2"}), 0.0013003906671379406) ] +@pytest.mark.asyncio +async def test_async_pgvector_with_filter_distant_match() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = await PGVector.afrom_texts( + texts=texts, + collection_name="test_collection_filter", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = await docsearch.asimilarity_search_with_score("foo", k=1, filter={"page": "2"}) + assert output == [ + (Document(page_content="baz", metadata={"page": "2"}), 0.0013003906671379406) + ] + def test_pgvector_with_filter_no_match() -> None: """Test end to end construction and search.""" @@ -149,6 +252,22 @@ def test_pgvector_with_filter_no_match() -> None: output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "5"}) assert output == [] +@pytest.mark.asyncio +async def test_async_pgvector_with_filter_no_match() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = await PGVector.afrom_texts( + texts=texts, + collection_name="test_collection_filter", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = await docsearch.asimilarity_search_with_score("foo", k=1, filter={"page": "5"}) + assert output == [] + def test_pgvector_collection_with_metadata() -> None: """Test end to end collection construction""" @@ -167,6 +286,26 @@ def test_pgvector_collection_with_metadata() -> None: assert collection.name == "test_collection" assert collection.cmetadata == {"foo": "bar"} +@pytest.mark.asyncio +async def test_async_pgvector_collection_with_metadata() -> None: + """Test end to end collection construction""" + pgvector = await PGVector.create( + collection_name="test_collection", + collection_metadata={"foo": "bar"}, + embeddings=FakeEmbeddingsWithAdaDimension(), + connection=CONNECTION_STRING, + pre_delete_collection=True, + ) + async with pgvector._session_maker() as session: + collection = await pgvector.aget_collection(session) + if collection is None: + assert False, "Expected a CollectionStore object but received None" + else: + assert collection.name == "test_collection" + assert collection.cmetadata == {"foo": "bar"} + + + def test_pgvector_delete_docs() -> None: """Add and delete documents.""" @@ -196,6 +335,35 @@ def test_pgvector_delete_docs() -> None: assert sorted(record.id for record in records) == [] # type: ignore +@pytest.mark.asyncio +async def test_async_pgvector_delete_docs() -> None: + """Add and delete documents.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + vectorstore = await PGVector.afrom_texts( + texts=texts, + collection_name="test_collection_filter", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + ids=["1", "2", "3"], + connection=CONNECTION_STRING, + pre_delete_collection=True, + ) + await vectorstore.adelete(["1", "2"]) + async with vectorstore._session_maker() as session: + records=(await session.execute(select(vectorstore.EmbeddingStore))).scalars().all() + # ignoring type error since mypy cannot determine whether + # the list is sortable + assert sorted(record.id for record in records) == ["3"] # type: ignore + + await vectorstore.adelete(["2", "3"]) # Should not raise on missing ids + async with vectorstore._session_maker() as session: + records=(await session.execute(select(vectorstore.EmbeddingStore))).scalars().all() + # ignoring type error since mypy cannot determine whether + # the list is sortable + assert sorted(record.id for record in records) == [] # type: ignore + + def test_pgvector_index_documents() -> None: """Test adding duplicate documents results in overwrites.""" documents = [ @@ -485,17 +653,6 @@ def test_pgvector_with_with_metadata_filters_5( assert [doc.metadata["id"] for doc in docs] == expected_ids, test_filter -@pytest.mark.parametrize("test_filter, expected_ids", TYPE_6_FILTERING_TEST_CASES) -def test_pgvector_with_with_metadata_filters_6( - pgvector: PGVector, - test_filter: Dict[str, Any], - expected_ids: List[int], -) -> None: - """Test end to end construction and search.""" - docs = pgvector.similarity_search("meow", k=5, filter=test_filter) - assert [doc.metadata["id"] for doc in docs] == expected_ids, test_filter - - @pytest.mark.parametrize( "invalid_filter", [ @@ -508,8 +665,6 @@ def test_pgvector_with_with_metadata_filters_6( {"$and": {}}, {"$between": {}}, {"$eq": {}}, - {"$exists": {}}, - {"$exists": 1}, ], ) def test_invalid_filters(pgvector: PGVector, invalid_filter: Any) -> None: @@ -524,7 +679,6 @@ def test_validate_operators() -> None: "$and", "$between", "$eq", - "$exists", "$gt", "$gte", "$ilike", From 0abd730671f35dd6ba148e7ee4ba1b2617780da2 Mon Sep 17 00:00:00 2001 From: Philippe Prados Date: Mon, 15 Apr 2024 10:49:54 +0200 Subject: [PATCH 03/18] Fix lint --- langchain_postgres/vectorstores.py | 856 ++++++++++++++------------- tests/unit_tests/test_vectorstore.py | 349 ++++++++++- 2 files changed, 778 insertions(+), 427 deletions(-) diff --git a/langchain_postgres/vectorstores.py b/langchain_postgres/vectorstores.py index d5aa6dba..ffce9ff6 100644 --- a/langchain_postgres/vectorstores.py +++ b/langchain_postgres/vectorstores.py @@ -1,6 +1,5 @@ from __future__ import annotations -import asyncio import enum import logging import uuid @@ -11,27 +10,38 @@ Iterable, List, Optional, + Sequence, Tuple, - Type, Union, + Type, + Union, ) import numpy as np import sqlalchemy -from sqlalchemy import SQLColumnExpression, cast, delete, func, select, Engine -from sqlalchemy.dialects.postgresql import JSON, JSONB, JSONPATH, UUID, insert -# TODO: accepter l'absence de l'option async lors des imports -from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker -from sqlalchemy.orm import Session, relationship, sessionmaker - -try: - from sqlalchemy.orm import declarative_base -except ImportError: - from sqlalchemy.ext.declarative import declarative_base +# try: +# from sqlalchemy.orm import declarative_base +# except ImportError: +# from sqlalchemy.ext.declarative import declarative_base from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.utils import get_from_dict_or_env from langchain_core.vectorstores import VectorStore +from sqlalchemy import SQLColumnExpression, cast, create_engine, delete, func, select +from sqlalchemy.dialects.postgresql import JSON, JSONB, JSONPATH, UUID, insert +from sqlalchemy.engine import Engine +from sqlalchemy.ext.asyncio import ( + AsyncEngine, + AsyncSession, + async_sessionmaker, + create_async_engine, +) +from sqlalchemy.orm import ( + Session, + declarative_base, + relationship, + sessionmaker, +) from langchain_postgres._utils import maximal_marginal_relevance @@ -109,28 +119,26 @@ class CollectionStore(Base): @classmethod def get_by_name( - cls, session: Session, name: str + cls, session: Session, name: str ) -> Optional["CollectionStore"]: return session.query(cls).filter(cls.name == name).first() # type: ignore @classmethod async def aget_by_name( - cls, session: AsyncSession, name: str + cls, session: AsyncSession, name: str ) -> Optional["CollectionStore"]: - stmt = select(cls).filter(cls.name == name) - # return await session.execute(stmt) # FIXME - return (await session.execute(stmt)).scalars().first() # FIXME - # stmt = select(cls).filter(cls.name == name) - # result = await session.execute(stmt) - # x = result.scalars() - # return session.query(cls).filter(cls.name == name).first() + return ( + (await session.execute(select(CollectionStore).where(cls.name == name))) + .scalars() + .first() + ) @classmethod def get_or_create( - cls, - session: Session, - name: str, - cmetadata: Optional[dict] = None, + cls, + session: Session, + name: str, + cmetadata: Optional[dict] = None, ) -> Tuple["CollectionStore", bool]: """Get or create a collection. Returns: @@ -143,16 +151,16 @@ def get_or_create( collection = cls(name=name, cmetadata=cmetadata) session.add(collection) - session.commit() # FIXME PPR semble utile + session.commit() created = True return collection, created @classmethod async def aget_or_create( - cls, - session: AsyncSession, - name: str, - cmetadata: Optional[dict] = None, + cls, + session: AsyncSession, + name: str, + cmetadata: Optional[dict] = None, ) -> Tuple["CollectionStore", bool]: """ Get or create a collection. @@ -268,23 +276,24 @@ class PGVector(VectorStore): """ def __init__( - self, - embeddings: Embeddings, - *, - connection: Union[None, Connection, str] = None, - embedding_length: Optional[int] = None, - collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, - collection_metadata: Optional[dict] = None, - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - pre_delete_collection: bool = False, - logger: Optional[logging.Logger] = None, - relevance_score_fn: Optional[Callable[[float], float]] = None, - engine_args: Optional[dict[str, Any]] = None, - use_jsonb: bool = True, - create_extension: bool = True, - async_mode: bool = False, # FIXME: a virer. Gaff aux imports sans async + self, + embeddings: Embeddings, + *, + connection: Union[None, Connection, str] = None, + embedding_length: Optional[int] = None, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + collection_metadata: Optional[dict] = None, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + pre_delete_collection: bool = False, + logger: Optional[logging.Logger] = None, + relevance_score_fn: Optional[Callable[[float], float]] = None, + engine_args: Optional[dict[str, Any]] = None, + use_jsonb: bool = True, + create_extension: bool = True, + _async_mode: bool = False, # Tag to force the async mode ) -> None: """Initialize the PGVector store. + For an async version, use `PGVector.acreate()` instead. Args: connection: Postgres connection string. @@ -311,7 +320,7 @@ def __init__( doesn't exist. disabling creation is useful when using ReadOnly Databases. """ - self.async_mode = async_mode + self._async_mode = _async_mode self.embedding_function = embeddings self._embedding_length = embedding_length self.collection_name = collection_name @@ -320,22 +329,29 @@ def __init__( self.pre_delete_collection = pre_delete_collection self.logger = logger or logging.getLogger(__name__) self.override_relevance_score_fn = relevance_score_fn + self._engine: Optional[Engine] = None + self._async_engine: Optional[AsyncEngine] = None if isinstance(connection, str): - self._engine = self._create_engine( - connection, engine_args, async_mode) - elif isinstance(connection, sqlalchemy.engine.Engine): + if _async_mode: + self._async_engine = create_async_engine( + connection, **(engine_args or {}) + ) + else: + self._engine = create_engine(url=connection, **(engine_args or {})) + elif isinstance(connection, Engine): self._engine = connection + elif isinstance(connection, AsyncEngine): + self._async_mode = True + self._async_engine = connection else: raise ValueError( "connection should be a connection string or an instance of " - "sqlalchemy.engine.Engine" + "sqlalchemy.engine.Engine or sqlalchemy.ext.asyncio.engine.AsyncEngine" ) - # If the driver accept only the synchrone calls, update the async_mode - self.async_mode = not isinstance(self._engine, Engine) self._session_maker: Union[sessionmaker, async_sessionmaker] - if self.async_mode: - self._session_maker = async_sessionmaker(bind=self._engine) + if self._async_mode: + self._session_maker = async_sessionmaker(bind=self._async_engine) else: self._session_maker = sessionmaker(bind=self._engine) @@ -345,11 +361,18 @@ def __init__( if not use_jsonb: # Replace with a deprecation warning. raise NotImplementedError("use_jsonb=False is no longer supported.") - if not async_mode: + if not _async_mode: self.__post_init__() + else: + import inspect + + assert inspect.stack()[1].function in [ + "acreate", + "__afrom", + ], "Call await PGVector.acreate() instead of PGVector(...))" def __post_init__( - self, + self, ) -> None: """Initialize the store.""" if self.create_extension: @@ -364,10 +387,9 @@ def __post_init__( self.create_collection() async def __apost_init__( - self, + self, ) -> None: - - """Initialize the store.""" + """Async initialize the store.""" if self.create_extension: await self.acreate_vector_extension() @@ -380,22 +402,49 @@ async def __apost_init__( await self.acreate_collection() @classmethod - async def create(cls, - embeddings: Embeddings, - *, - connection: Optional[Connection] = None, - embedding_length: Optional[int] = None, - collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, - collection_metadata: Optional[dict] = None, - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - pre_delete_collection: bool = False, - logger: Optional[logging.Logger] = None, - relevance_score_fn: Optional[Callable[[float], float]] = None, - engine_args: Optional[dict[str, Any]] = None, - use_jsonb: bool = True, - create_extension: bool = True, - async_mode: bool = True, - ) -> PGVector: + async def acreate( + cls, + embeddings: Embeddings, + *, + connection: Optional[Connection] = None, + embedding_length: Optional[int] = None, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + collection_metadata: Optional[dict] = None, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + pre_delete_collection: bool = False, + logger: Optional[logging.Logger] = None, + relevance_score_fn: Optional[Callable[[float], float]] = None, + engine_args: Optional[dict[str, Any]] = None, + use_jsonb: bool = True, + create_extension: bool = True, + ) -> PGVector: + """Async create instance + + Args: + connection: Postgres connection string. + embeddings: Any embedding function implementing + `langchain.embeddings.base.Embeddings` interface. + embedding_length: The length of the embedding vector. (default: None) + NOTE: This is not mandatory. Defining it will prevent vectors of + any other size to be added to the embeddings table but, without it, + the embeddings can't be indexed. + collection_name: The name of the collection to use. (default: langchain) + NOTE: This is not the name of the table, but the name of the collection. + The tables will be created when initializing the store (if not exists) + So, make sure the user has the right permissions to create tables. + distance_strategy: The distance strategy to use. (default: COSINE) + pre_delete_collection: If True, will delete the collection if it exists. + (default: False). Useful for testing. + engine_args: SQLAlchemy's create engine arguments. + use_jsonb: Use JSONB instead of JSON for metadata. (default: True) + Strongly discouraged from using JSON as it's not as efficient + for querying. + It's provided here for backwards compatibility with older versions, + and will be removed in the future. + create_extension: If True, will create the vector extension if it + doesn't exist. disabling creation is useful when using ReadOnly + Databases. + """ self = cls( embeddings=embeddings, connection=connection, @@ -409,40 +458,11 @@ async def create(cls, engine_args=engine_args, use_jsonb=use_jsonb, create_extension=create_extension, - async_mode=async_mode, + _async_mode=True, ) - if async_mode: - await self.__apost_init__() + await self.__apost_init__() return self - def _create_engine(self, - connection: str, - engine_args: Optional[dict[str, Any]] = None, - async_mode: bool = False) -> sqlalchemy.engine.Engine | sqlalchemy.ext.asyncio.AsyncEngine: - if async_mode: - from sqlalchemy.ext.asyncio import create_async_engine - # FIXME: gérer appel async sur un sync - from sqlalchemy.exc import InvalidRequestError - try: - return create_async_engine( - url=connection, - isolation_level="REPEATABLE READ", # FIXME: merge avec la suite ? - echo=True, # FIXME: a virer - **(engine_args or {}) - ) - except InvalidRequestError: - pass # Ignore and return the synchrone version - logging.warning("Use a synchrone SQL engine !") - return sqlalchemy.create_engine(url=connection, - **(engine_args or {})) - - def __del__(self) -> None: - if isinstance(self._engine, sqlalchemy.engine.Connection): - if self.async_mode: - asyncio.run(self._engine.close()) - else: - self._engine.close() - @property def embeddings(self) -> Embeddings: return self.embedding_function @@ -475,10 +495,11 @@ async def acreate_vector_extension(self) -> None: # For more information see: # https://www.postgresql.org/docs/16/explicit-locking.html#ADVISORY-LOCKS await session.execute( - sqlalchemy.text( - "SELECT pg_advisory_xact_lock(1573678846307946496)")) + sqlalchemy.text("SELECT pg_advisory_xact_lock(1573678846307946496)") + ) await session.execute( - sqlalchemy.text("CREATE EXTENSION IF NOT EXISTS vector")) + sqlalchemy.text("CREATE EXTENSION IF NOT EXISTS vector") + ) except Exception as e: raise Exception(f"Failed to create vector extension: {e}") from e @@ -487,20 +508,17 @@ def create_tables_if_not_exists(self) -> None: Base.metadata.create_all(session.get_bind()) async def acreate_tables_if_not_exists(self) -> None: - if isinstance(self._engine, sqlalchemy.ext.asyncio.engine.AsyncConnection): - await self._engine.run_sync(Base.metadata.create_all) - else: - async with self._engine.begin() as conn: # FIXME: session.run_sync existe - await conn.run_sync(Base.metadata.create_all) - # async with self._amake_session() as session: - # await session.run_sync(Base.metadata.create_all) + assert self._async_engine, "Use with async mode" + async with self._async_engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) def drop_tables(self) -> None: with self._session_maker() as session: Base.metadata.drop_all(session.get_bind()) async def adrop_tables(self) -> None: - async with self._engine.begin() as conn: + assert self._async_engine, "Use with async mode" + async with self._async_engine.begin() as conn: await conn.run_sync(Base.metadata.drop_all) def create_collection(self) -> None: @@ -512,6 +530,7 @@ def create_collection(self) -> None: ) async def acreate_collection(self) -> None: + assert self._async_engine, "Use with async mode" async with self._session_maker() as session: if self.pre_delete_collection: await self._adelete_collection(session) @@ -527,7 +546,6 @@ def _delete_collection(self, session: Session) -> None: return session.delete(collection) - # FIXME: necessaire le _adelete ? async def _adelete_collection(self, session: AsyncSession) -> None: self.logger.debug("Trying to delete collection") collection = await self.aget_collection(session) @@ -536,9 +554,6 @@ async def _adelete_collection(self, session: AsyncSession) -> None: return await session.delete(collection) - # def delete_collection(self) -> None: - # with self._session_maker() as session: - # self._delete_collection(session) def delete_collection(self) -> None: self.logger.debug("Trying to delete collection") with self._session_maker() as session: # type: ignore[arg-type] @@ -551,7 +566,7 @@ def delete_collection(self) -> None: async def adelete_collection(self) -> None: self.logger.debug("Trying to delete collection") - with self._session_maker() as session: # type: ignore[arg-type] + async with self._session_maker() as session: # type: ignore[arg-type] collection = await self.aget_collection(session) if not collection: self.logger.warning("Collection not found") @@ -560,10 +575,10 @@ async def adelete_collection(self) -> None: await session.commit() def delete( - self, - ids: Optional[List[str]] = None, - collection_only: bool = False, - **kwargs: Any, + self, + ids: Optional[List[str]] = None, + collection_only: bool = False, + **kwargs: Any, ) -> None: """Delete vectors by ids or uuids. @@ -595,12 +610,12 @@ def delete( session.commit() async def adelete( - self, - ids: Optional[List[str]] = None, - collection_only: bool = False, - **kwargs: Any, + self, + ids: Optional[List[str]] = None, + collection_only: bool = False, + **kwargs: Any, ) -> None: - """Delete vectors by ids or uuids. + """Async delete vectors by ids or uuids. Args: ids: List of ids to delete. @@ -637,19 +652,19 @@ async def aget_collection(self, session: AsyncSession) -> Any: @classmethod def __from( - cls, - texts: List[str], - embeddings: List[List[float]], - embedding: Embeddings, - metadatas: Optional[List[dict]] = None, - ids: Optional[List[str]] = None, - collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - connection: Optional[str] = None, - pre_delete_collection: bool = False, - *, - use_jsonb: bool = True, - **kwargs: Any, + cls, + texts: List[str], + embeddings: List[List[float]], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + connection: Optional[str] = None, + pre_delete_collection: bool = False, + *, + use_jsonb: bool = True, + **kwargs: Any, ) -> PGVector: if ids is None: ids = [str(uuid.uuid4()) for _ in texts] @@ -675,19 +690,19 @@ def __from( @classmethod async def __afrom( - cls, - texts: List[str], - embeddings: List[List[float]], - embedding: Embeddings, - metadatas: Optional[List[dict]] = None, - ids: Optional[List[str]] = None, - collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - connection: Optional[str] = None, - pre_delete_collection: bool = False, - *, - use_jsonb: bool = True, - **kwargs: Any, + cls, + texts: List[str], + embeddings: List[List[float]], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + connection: Optional[str] = None, + pre_delete_collection: bool = False, + *, + use_jsonb: bool = True, + **kwargs: Any, ) -> PGVector: if ids is None: ids = [str(uuid.uuid1()) for _ in texts] @@ -702,7 +717,7 @@ async def __afrom( distance_strategy=distance_strategy, pre_delete_collection=pre_delete_collection, use_jsonb=use_jsonb, - async_mode=True, # FIXME + _async_mode=True, **kwargs, ) # Second phase to create @@ -715,12 +730,12 @@ async def __afrom( return store def add_embeddings( - self, - texts: Iterable[str], - embeddings: List[List[float]], - metadatas: Optional[List[dict]] = None, - ids: Optional[List[str]] = None, - **kwargs: Any, + self, + texts: Iterable[str], + embeddings: List[List[float]], + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + **kwargs: Any, ) -> List[str]: """Add embeddings to the vectorstore. @@ -728,6 +743,8 @@ def add_embeddings( texts: Iterable of strings to add to the vectorstore. embeddings: List of list of embedding vectors. metadatas: List of metadatas associated with the texts. + ids: Optional list of ids for the documents. + If not provided, will generate a new id for each document. kwargs: vectorstore specific parameters """ if ids is None: @@ -768,19 +785,21 @@ def add_embeddings( return ids async def aadd_embeddings( - self, - texts: Iterable[str], - embeddings: List[List[float]], - metadatas: Optional[List[dict]] = None, - ids: Optional[List[str]] = None, - **kwargs: Any, + self, + texts: Iterable[str], + embeddings: List[List[float]], + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + **kwargs: Any, ) -> List[str]: - """Add embeddings to the vectorstore. + """Async add embeddings to the vectorstore. Args: texts: Iterable of strings to add to the vectorstore. embeddings: List of list of embedding vectors. metadatas: List of metadatas associated with the texts. + ids: Optional list of ids for the texts. + If not provided, will generate a new id for each text. kwargs: vectorstore specific parameters """ if ids is None: @@ -821,17 +840,19 @@ async def aadd_embeddings( return ids def add_texts( - self, - texts: Iterable[str], - metadatas: Optional[List[dict]] = None, - ids: Optional[List[str]] = None, - **kwargs: Any, + self, + texts: Iterable[str], + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + **kwargs: Any, ) -> List[str]: """Run more texts through the embeddings and add to the vectorstore. Args: texts: Iterable of strings to add to the vectorstore. metadatas: Optional list of metadatas associated with the texts. + ids: Optional list of ids for the texts. + If not provided, will generate a new id for each text. kwargs: vectorstore specific parameters Returns: @@ -843,17 +864,19 @@ def add_texts( ) async def aadd_texts( - self, - texts: Iterable[str], - metadatas: Optional[List[dict]] = None, - ids: Optional[List[str]] = None, - **kwargs: Any, + self, + texts: Iterable[str], + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + **kwargs: Any, ) -> List[str]: """Run more texts through the embeddings and add to the vectorstore. Args: texts: Iterable of strings to add to the vectorstore. metadatas: Optional list of metadatas associated with the texts. + ids: Optional list of ids for the texts. + If not provided, will generate a new id for each text. kwargs: vectorstore specific parameters Returns: @@ -865,11 +888,11 @@ async def aadd_texts( ) def similarity_search( - self, - query: str, - k: int = 4, - filter: Optional[dict] = None, - **kwargs: Any, + self, + query: str, + k: int = 4, + filter: Optional[dict] = None, + **kwargs: Any, ) -> List[Document]: """Run similarity search with PGVector with distance. @@ -889,11 +912,11 @@ def similarity_search( ) async def asimilarity_search( - self, - query: str, - k: int = 4, - filter: Optional[dict] = None, - **kwargs: Any, + self, + query: str, + k: int = 4, + filter: Optional[dict] = None, + **kwargs: Any, ) -> List[Document]: """Run similarity search with PGVector with distance. @@ -913,10 +936,10 @@ async def asimilarity_search( ) def similarity_search_with_score( - self, - query: str, - k: int = 4, - filter: Optional[dict] = None, + self, + query: str, + k: int = 4, + filter: Optional[dict] = None, ) -> List[Tuple[Document, float]]: """Return docs most similar to query. @@ -935,10 +958,10 @@ def similarity_search_with_score( return docs async def asimilarity_search_with_score( - self, - query: str, - k: int = 4, - filter: Optional[dict] = None, + self, + query: str, + k: int = 4, + filter: Optional[dict] = None, ) -> List[Tuple[Document, float]]: """Return docs most similar to query. @@ -971,26 +994,25 @@ def distance_strategy(self) -> Any: ) def similarity_search_with_score_by_vector( - self, - embedding: List[float], - k: int = 4, - filter: Optional[dict] = None, + self, + embedding: List[float], + k: int = 4, + filter: Optional[dict] = None, ) -> List[Tuple[Document, float]]: results = self.__query_collection(embedding=embedding, k=k, filter=filter) return self._results_to_docs_and_scores(results) async def asimilarity_search_with_score_by_vector( - self, - embedding: List[float], - k: int = 4, - filter: Optional[dict] = None, + self, + embedding: List[float], + k: int = 4, + filter: Optional[dict] = None, ) -> List[Tuple[Document, float]]: async with self._session_maker() as session: # type: ignore[arg-type] results = await self.__aquery_collection( - session=session, - embedding=embedding, k=k, - filter=filter) + session=session, embedding=embedding, k=k, filter=filter + ) return self._results_to_docs_and_scores(results) @@ -1009,9 +1031,9 @@ def _results_to_docs_and_scores(self, results: Any) -> List[Tuple[Document, floa return docs def _handle_field_filter( - self, - field: str, - value: Any, + self, + field: str, + value: Any, ) -> SQLColumnExpression: """Create a filter for a specific field. @@ -1121,8 +1143,7 @@ def _handle_field_filter( else: raise NotImplementedError() - def _create_filter_clause_deprecated(self, key, - value): # type: ignore[no-untyped-def] + def _create_filter_clause_deprecated(self, key, value): # type: ignore[no-untyped-def] """Deprecated functionality. This is for backwards compatibility with the JSON based schema for metadata. @@ -1191,7 +1212,7 @@ def _create_filter_clause_deprecated(self, key, return filter_by_metadata def _create_filter_clause_json_deprecated( - self, filter: Any + self, filter: Any ) -> List[SQLColumnExpression]: """Convert filters from IR to SQL clauses. @@ -1302,11 +1323,11 @@ def _create_filter_clause(self, filters: Any) -> Any: ) def __query_collection( - self, - embedding: List[float], - k: int = 4, - filter: Optional[Dict[str, str]] = None, - ) -> List[Any]: + self, + embedding: List[float], + k: int = 4, + filter: Optional[Dict[str, str]] = None, + ) -> Sequence[Any]: """Query the collection.""" with self._session_maker() as session: # type: ignore[arg-type] collection = self.get_collection(session) @@ -1344,52 +1365,55 @@ def __query_collection( return results async def __aquery_collection( - self, - session: AsyncSession, - embedding: List[float], - k: int = 4, - filter: Optional[Dict[str, str]] = None, - ) -> List[Any]: + self, + session: AsyncSession, + embedding: List[float], + k: int = 4, + filter: Optional[Dict[str, str]] = None, + ) -> Sequence[Any]: """Query the collection.""" - collection = await self.aget_collection(session) - if not collection: - raise ValueError("Collection not found") - - filter_by = [self.EmbeddingStore.collection_id == collection.uuid] - if filter: - if self.use_jsonb: - filter_clauses = self._create_filter_clause(filter) - if filter_clauses is not None: - filter_by.append(filter_clauses) - else: - # Old way of doing things - filter_clauses = self._create_filter_clause_json_deprecated(filter) - filter_by.extend(filter_clauses) + async with self._session_maker() as session: # type: ignore[arg-type] + collection = await self.aget_collection(session) + if not collection: + raise ValueError("Collection not found") - _type = self.EmbeddingStore + filter_by = [self.EmbeddingStore.collection_id == collection.uuid] + if filter: + if self.use_jsonb: + filter_clauses = self._create_filter_clause(filter) + if filter_clauses is not None: + filter_by.append(filter_clauses) + else: + # Old way of doing things + filter_clauses = self._create_filter_clause_json_deprecated(filter) + filter_by.extend(filter_clauses) - stmt = (select( - self.EmbeddingStore, - self.distance_strategy(embedding).label("distance"), # type: ignore - ) + _type = self.EmbeddingStore + + stmt = ( + select( + self.EmbeddingStore, + self.distance_strategy(embedding).label("distance"), # type: ignore + ) .filter(*filter_by) .order_by(sqlalchemy.asc("distance")) .join( - self.CollectionStore, - self.EmbeddingStore.collection_id == self.CollectionStore.uuid, - ) - .limit(k)) + self.CollectionStore, + self.EmbeddingStore.collection_id == self.CollectionStore.uuid, + ) + .limit(k) + ) - results: List[Any] = (await session.execute(stmt)).all() + results: Sequence[Any] = (await session.execute(stmt)).all() - return results + return results def similarity_search_by_vector( - self, - embedding: List[float], - k: int = 4, - filter: Optional[dict] = None, - **kwargs: Any, + self, + embedding: List[float], + k: int = 4, + filter: Optional[dict] = None, + **kwargs: Any, ) -> List[Document]: """Return docs most similar to embedding vector. @@ -1407,11 +1431,11 @@ def similarity_search_by_vector( return _results_to_docs(docs_and_scores) async def asimilarity_search_by_vector( - self, - embedding: List[float], - k: int = 4, - filter: Optional[dict] = None, - **kwargs: Any, + self, + embedding: List[float], + k: int = 4, + filter: Optional[dict] = None, + **kwargs: Any, ) -> List[Document]: """Return docs most similar to embedding vector. @@ -1430,17 +1454,17 @@ async def asimilarity_search_by_vector( @classmethod def from_texts( - cls: Type[PGVector], - texts: List[str], - embedding: Embeddings, - metadatas: Optional[List[dict]] = None, - *, - collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - ids: Optional[List[str]] = None, - pre_delete_collection: bool = False, - use_jsonb: bool = True, - **kwargs: Any, + cls: Type[PGVector], + texts: List[str], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + *, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + ids: Optional[List[str]] = None, + pre_delete_collection: bool = False, + use_jsonb: bool = True, + **kwargs: Any, ) -> PGVector: """Return VectorStore initialized from documents and embeddings.""" embeddings = embedding.embed_documents(list(texts)) @@ -1460,17 +1484,17 @@ def from_texts( @classmethod async def afrom_texts( - cls: Type[PGVector], - texts: List[str], - embedding: Embeddings, - metadatas: Optional[List[dict]] = None, - collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - ids: Optional[List[str]] = None, - pre_delete_collection: bool = False, - *, - use_jsonb: bool = True, - **kwargs: Any, + cls: Type[PGVector], + texts: List[str], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + ids: Optional[List[str]] = None, + pre_delete_collection: bool = False, + *, + use_jsonb: bool = True, + **kwargs: Any, ) -> PGVector: """Return VectorStore initialized from documents and embeddings.""" embeddings = embedding.embed_documents(list(texts)) @@ -1489,16 +1513,16 @@ async def afrom_texts( @classmethod def from_embeddings( - cls, - text_embeddings: List[Tuple[str, List[float]]], - embedding: Embeddings, - *, - metadatas: Optional[List[dict]] = None, - collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - ids: Optional[List[str]] = None, - pre_delete_collection: bool = False, - **kwargs: Any, + cls, + text_embeddings: List[Tuple[str, List[float]]], + embedding: Embeddings, + *, + metadatas: Optional[List[dict]] = None, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + ids: Optional[List[str]] = None, + pre_delete_collection: bool = False, + **kwargs: Any, ) -> PGVector: """Construct PGVector wrapper from raw documents and embeddings. @@ -1509,6 +1533,7 @@ def from_embeddings( collection_name: Name of the collection. distance_strategy: Distance strategy to use. ids: Optional list of ids for the documents. + If not provided, will generate a new id for each document. pre_delete_collection: If True, will delete the collection if it exists. **Attention**: This will delete all the documents in the existing collection. @@ -1545,15 +1570,15 @@ def from_embeddings( @classmethod async def afrom_embeddings( - cls, - text_embeddings: List[Tuple[str, List[float]]], - embedding: Embeddings, - metadatas: Optional[List[dict]] = None, - collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - ids: Optional[List[str]] = None, - pre_delete_collection: bool = False, - **kwargs: Any, + cls, + text_embeddings: List[Tuple[str, List[float]]], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + ids: Optional[List[str]] = None, + pre_delete_collection: bool = False, + **kwargs: Any, ) -> PGVector: """Construct PGVector wrapper from raw documents and pre- generated embeddings. @@ -1590,14 +1615,14 @@ async def afrom_embeddings( @classmethod def from_existing_index( - cls: Type[PGVector], - embedding: Embeddings, - *, - collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - pre_delete_collection: bool = False, - connection: Optional[Connection] = None, - **kwargs: Any, + cls: Type[PGVector], + embedding: Embeddings, + *, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + pre_delete_collection: bool = False, + connection: Optional[Connection] = None, + **kwargs: Any, ) -> PGVector: """ Get instance of an existing PGVector store.This method will @@ -1617,27 +1642,26 @@ def from_existing_index( @classmethod async def afrom_existing_index( - cls: Type[PGVector], - embedding: Embeddings, - *, - collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - pre_delete_collection: bool = False, - connection: Optional[Connection] = None, - **kwargs: Any, + cls: Type[PGVector], + embedding: Embeddings, + *, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + pre_delete_collection: bool = False, + connection: Optional[Connection] = None, + **kwargs: Any, ) -> PGVector: """ Get instance of an existing PGVector store.This method will return the instance of the store without inserting any new embeddings """ - store = cls( # FIXME: créate + store = await PGVector.acreate( connection=connection, collection_name=collection_name, embeddings=embedding, distance_strategy=distance_strategy, pre_delete_collection=pre_delete_collection, - async_mode=True, **kwargs, ) @@ -1662,17 +1686,17 @@ def get_connection_string(cls, kwargs: Dict[str, Any]) -> str: @classmethod def from_documents( - cls: Type[PGVector], - documents: List[Document], - embedding: Embeddings, - *, - connection: Optional[Connection] = None, - collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - ids: Optional[List[str]] = None, - pre_delete_collection: bool = False, - use_jsonb: bool = True, - **kwargs: Any, + cls: Type[PGVector], + documents: List[Document], + embedding: Embeddings, + *, + connection: Optional[Connection] = None, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + ids: Optional[List[str]] = None, + pre_delete_collection: bool = False, + use_jsonb: bool = True, + **kwargs: Any, ) -> PGVector: """Return VectorStore initialized from documents and embeddings.""" @@ -1694,16 +1718,16 @@ def from_documents( @classmethod async def afrom_documents( - cls: Type[PGVector], - documents: List[Document], - embedding: Embeddings, - collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - ids: Optional[List[str]] = None, - pre_delete_collection: bool = False, - *, - use_jsonb: bool = True, - **kwargs: Any, + cls: Type[PGVector], + documents: List[Document], + embedding: Embeddings, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + ids: Optional[List[str]] = None, + pre_delete_collection: bool = False, + *, + use_jsonb: bool = True, + **kwargs: Any, ) -> PGVector: """ Return VectorStore initialized from documents and embeddings. @@ -1732,13 +1756,13 @@ async def afrom_documents( @classmethod def connection_string_from_db_params( - cls, - driver: str, - host: str, - port: int, - database: str, - user: str, - password: str, + cls, + driver: str, + host: str, + port: int, + database: str, + user: str, + password: str, ) -> str: """Return connection string from database parameters.""" if driver != "psycopg": @@ -1773,13 +1797,13 @@ def _select_relevance_score_fn(self) -> Callable[[float], float]: ) def max_marginal_relevance_search_with_score_by_vector( - self, - embedding: List[float], - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - filter: Optional[Dict[str, str]] = None, - **kwargs: Any, + self, + embedding: List[float], + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[Dict[str, str]] = None, + **kwargs: Any, ) -> List[Tuple[Document, float]]: """Return docs selected using the maximal marginal relevance with score to embedding vector. @@ -1818,13 +1842,13 @@ def max_marginal_relevance_search_with_score_by_vector( return [r for i, r in enumerate(candidates) if i in mmr_selected] async def amax_marginal_relevance_search_with_score_by_vector( - self, - embedding: List[float], - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - filter: Optional[Dict[str, str]] = None, - **kwargs: Any, + self, + embedding: List[float], + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[Dict[str, str]] = None, + **kwargs: Any, ) -> List[Tuple[Document, float]]: """Return docs selected using the maximal marginal relevance with score to embedding vector. @@ -1848,9 +1872,9 @@ async def amax_marginal_relevance_search_with_score_by_vector( relevance to the query and score for each. """ async with self._session_maker() as session: - results = await self.__aquery_collection(session=session, - embedding=embedding, k=fetch_k, - filter=filter) + results = await self.__aquery_collection( + session=session, embedding=embedding, k=fetch_k, filter=filter + ) embedding_list = [result.EmbeddingStore.embedding for result in results] @@ -1866,13 +1890,13 @@ async def amax_marginal_relevance_search_with_score_by_vector( return [r for i, r in enumerate(candidates) if i in mmr_selected] def max_marginal_relevance_search( - self, - query: str, - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - filter: Optional[Dict[str, str]] = None, - **kwargs: Any, + self, + query: str, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[Dict[str, str]] = None, + **kwargs: Any, ) -> List[Document]: """Return docs selected using the maximal marginal relevance. @@ -1904,13 +1928,13 @@ def max_marginal_relevance_search( ) async def amax_marginal_relevance_search( - self, - query: str, - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - filter: Optional[Dict[str, str]] = None, - **kwargs: Any, + self, + query: str, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[Dict[str, str]] = None, + **kwargs: Any, ) -> List[Document]: """Return docs selected using the maximal marginal relevance. @@ -1942,13 +1966,13 @@ async def amax_marginal_relevance_search( ) def max_marginal_relevance_search_with_score( - self, - query: str, - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - filter: Optional[dict] = None, - **kwargs: Any, + self, + query: str, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[dict] = None, + **kwargs: Any, ) -> List[Tuple[Document, float]]: """Return docs selected using the maximal marginal relevance with score. @@ -1982,13 +2006,13 @@ def max_marginal_relevance_search_with_score( return docs async def amax_marginal_relevance_search_with_score( - self, - query: str, - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - filter: Optional[dict] = None, - **kwargs: Any, + self, + query: str, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[dict] = None, + **kwargs: Any, ) -> List[Tuple[Document, float]]: """Return docs selected using the maximal marginal relevance with score. @@ -2022,13 +2046,13 @@ async def amax_marginal_relevance_search_with_score( return docs def max_marginal_relevance_search_by_vector( - self, - embedding: List[float], - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - filter: Optional[Dict[str, str]] = None, - **kwargs: Any, + self, + embedding: List[float], + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[Dict[str, str]] = None, + **kwargs: Any, ) -> List[Document]: """Return docs selected using the maximal marginal relevance to embedding vector. @@ -2062,13 +2086,13 @@ def max_marginal_relevance_search_by_vector( return _results_to_docs(docs_and_scores) async def amax_marginal_relevance_search_by_vector( - self, - embedding: List[float], - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - filter: Optional[Dict[str, str]] = None, - **kwargs: Any, + self, + embedding: List[float], + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[Dict[str, str]] = None, + **kwargs: Any, ) -> List[Document]: """Return docs selected using the maximal marginal relevance to embedding vector. @@ -2090,13 +2114,15 @@ async def amax_marginal_relevance_search_by_vector( Returns: List[Document]: List of Documents selected by maximal marginal relevance. """ - docs_and_scores = await self.amax_marginal_relevance_search_with_score_by_vector( - embedding, - k=k, - fetch_k=fetch_k, - lambda_mult=lambda_mult, - filter=filter, - **kwargs, + docs_and_scores = ( + await self.amax_marginal_relevance_search_with_score_by_vector( + embedding, + k=k, + fetch_k=fetch_k, + lambda_mult=lambda_mult, + filter=filter, + **kwargs, + ) ) return _results_to_docs(docs_and_scores) diff --git a/tests/unit_tests/test_vectorstore.py b/tests/unit_tests/test_vectorstore.py index 72280d05..e0f05e18 100644 --- a/tests/unit_tests/test_vectorstore.py +++ b/tests/unit_tests/test_vectorstore.py @@ -1,16 +1,15 @@ """Test PGVector functionality.""" import contextlib -from typing import Any, Dict, Generator, List +from typing import Any, AsyncGenerator, Dict, Generator, List import pytest from langchain_core.documents import Document +from sqlalchemy import select from langchain_postgres.vectorstores import ( SUPPORTED_OPERATORS, PGVector, ) -from sqlalchemy import select - from tests.unit_tests.fake_embeddings import FakeEmbeddings from tests.unit_tests.fixtures.filtering_test_cases import ( DOCUMENTS, @@ -53,7 +52,7 @@ def test_pgvector() -> None: output = docsearch.similarity_search("foo", k=1) assert output == [Document(page_content="foo")] -# @pytest.mark.requires("xxx") # FIXME + @pytest.mark.asyncio async def test_async_pgvector() -> None: """Test end to end construction and search.""" @@ -197,7 +196,9 @@ async def test_async_pgvector_with_filter_match() -> None: connection=CONNECTION_STRING, pre_delete_collection=True, ) - output = await docsearch.asimilarity_search_with_score("foo", k=1, filter={"page": "0"}) + output = await docsearch.asimilarity_search_with_score( + "foo", k=1, filter={"page": "0"} + ) assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)] @@ -218,6 +219,7 @@ def test_pgvector_with_filter_distant_match() -> None: (Document(page_content="baz", metadata={"page": "2"}), 0.0013003906671379406) ] + @pytest.mark.asyncio async def test_async_pgvector_with_filter_distant_match() -> None: """Test end to end construction and search.""" @@ -231,7 +233,9 @@ async def test_async_pgvector_with_filter_distant_match() -> None: connection=CONNECTION_STRING, pre_delete_collection=True, ) - output = await docsearch.asimilarity_search_with_score("foo", k=1, filter={"page": "2"}) + output = await docsearch.asimilarity_search_with_score( + "foo", k=1, filter={"page": "2"} + ) assert output == [ (Document(page_content="baz", metadata={"page": "2"}), 0.0013003906671379406) ] @@ -252,6 +256,7 @@ def test_pgvector_with_filter_no_match() -> None: output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "5"}) assert output == [] + @pytest.mark.asyncio async def test_async_pgvector_with_filter_no_match() -> None: """Test end to end construction and search.""" @@ -265,7 +270,9 @@ async def test_async_pgvector_with_filter_no_match() -> None: connection=CONNECTION_STRING, pre_delete_collection=True, ) - output = await docsearch.asimilarity_search_with_score("foo", k=1, filter={"page": "5"}) + output = await docsearch.asimilarity_search_with_score( + "foo", k=1, filter={"page": "5"} + ) assert output == [] @@ -286,10 +293,11 @@ def test_pgvector_collection_with_metadata() -> None: assert collection.name == "test_collection" assert collection.cmetadata == {"foo": "bar"} + @pytest.mark.asyncio async def test_async_pgvector_collection_with_metadata() -> None: """Test end to end collection construction""" - pgvector = await PGVector.create( + pgvector = await PGVector.acreate( collection_name="test_collection", collection_metadata={"foo": "bar"}, embeddings=FakeEmbeddingsWithAdaDimension(), @@ -305,8 +313,6 @@ async def test_async_pgvector_collection_with_metadata() -> None: assert collection.cmetadata == {"foo": "bar"} - - def test_pgvector_delete_docs() -> None: """Add and delete documents.""" texts = ["foo", "bar", "baz"] @@ -335,6 +341,22 @@ def test_pgvector_delete_docs() -> None: assert sorted(record.id for record in records) == [] # type: ignore +def test_pgvector_delete_collection() -> None: + """Add and delete documents.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + vectorstore = PGVector.from_texts( + texts=texts, + collection_name="test_collection_filter", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + ids=["1", "2", "3"], + connection=CONNECTION_STRING, + pre_delete_collection=True, + ) + vectorstore.delete(collection_only=True) + + @pytest.mark.asyncio async def test_async_pgvector_delete_docs() -> None: """Add and delete documents.""" @@ -351,14 +373,18 @@ async def test_async_pgvector_delete_docs() -> None: ) await vectorstore.adelete(["1", "2"]) async with vectorstore._session_maker() as session: - records=(await session.execute(select(vectorstore.EmbeddingStore))).scalars().all() + records = ( + (await session.execute(select(vectorstore.EmbeddingStore))).scalars().all() + ) # ignoring type error since mypy cannot determine whether # the list is sortable assert sorted(record.id for record in records) == ["3"] # type: ignore await vectorstore.adelete(["2", "3"]) # Should not raise on missing ids async with vectorstore._session_maker() as session: - records=(await session.execute(select(vectorstore.EmbeddingStore))).scalars().all() + records = ( + (await session.execute(select(vectorstore.EmbeddingStore))).scalars().all() + ) # ignoring type error since mypy cannot determine whether # the list is sortable assert sorted(record.id for record in records) == [] # type: ignore @@ -439,6 +465,88 @@ def test_pgvector_index_documents() -> None: } +@pytest.mark.asyncio +async def test_async_pgvector_index_documents() -> None: + """Test adding duplicate documents results in overwrites.""" + documents = [ + Document( + page_content="there are cats in the pond", + metadata={"id": 1, "location": "pond", "topic": "animals"}, + ), + Document( + page_content="ducks are also found in the pond", + metadata={"id": 2, "location": "pond", "topic": "animals"}, + ), + Document( + page_content="fresh apples are available at the market", + metadata={"id": 3, "location": "market", "topic": "food"}, + ), + Document( + page_content="the market also sells fresh oranges", + metadata={"id": 4, "location": "market", "topic": "food"}, + ), + Document( + page_content="the new art exhibit is fascinating", + metadata={"id": 5, "location": "museum", "topic": "art"}, + ), + ] + + vectorstore = await PGVector.afrom_documents( + documents=documents, + collection_name="test_collection_filter", + embedding=FakeEmbeddingsWithAdaDimension(), + ids=[doc.metadata["id"] for doc in documents], + connection=CONNECTION_STRING, + pre_delete_collection=True, + ) + async with vectorstore._session_maker() as session: + records = ( + (await session.execute(select(vectorstore.EmbeddingStore))).scalars().all() + ) + # ignoring type error since mypy cannot determine whether + # the list is sortable + assert sorted(record.id for record in records) == [ + "1", + "2", + "3", + "4", + "5", + ] + + # Try to overwrite the first document + documents = [ + Document( + page_content="new content in the zoo", + metadata={"id": 1, "location": "zoo", "topic": "zoo"}, + ), + ] + + await vectorstore.aadd_documents( + documents, ids=[doc.metadata["id"] for doc in documents] + ) + + async with vectorstore._session_maker() as session: + records = ( + (await session.execute(select(vectorstore.EmbeddingStore))).scalars().all() + ) + ordered_records = sorted(records, key=lambda x: x.id) + # ignoring type error since mypy cannot determine whether + # the list is sortable + assert [record.id for record in ordered_records] == [ + "1", + "2", + "3", + "4", + "5", + ] + + assert ordered_records[0].cmetadata == { + "id": 1, + "location": "zoo", + "topic": "zoo", + } + + def test_pgvector_relevance_score() -> None: """Test to make sure the relevance score is scaled to 0-1.""" texts = ["foo", "bar", "baz"] @@ -460,6 +568,28 @@ def test_pgvector_relevance_score() -> None: ] +@pytest.mark.asyncio +async def test_async_pgvector_relevance_score() -> None: + """Test to make sure the relevance score is scaled to 0-1.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = await PGVector.afrom_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection=CONNECTION_STRING, + pre_delete_collection=True, + ) + + output = await docsearch.asimilarity_search_with_relevance_scores("foo", k=3) + assert output == [ + (Document(page_content="foo", metadata={"page": "0"}), 1.0), + (Document(page_content="bar", metadata={"page": "1"}), 0.9996744261675065), + (Document(page_content="baz", metadata={"page": "2"}), 0.9986996093328621), + ] + + def test_pgvector_retriever_search_threshold() -> None: """Test using retriever for searching with threshold.""" texts = ["foo", "bar", "baz"] @@ -484,6 +614,31 @@ def test_pgvector_retriever_search_threshold() -> None: ] +@pytest.mark.asyncio +async def test_async_pgvector_retriever_search_threshold() -> None: + """Test using retriever for searching with threshold.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = await PGVector.afrom_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection=CONNECTION_STRING, + pre_delete_collection=True, + ) + + retriever = docsearch.as_retriever( + search_type="similarity_score_threshold", + search_kwargs={"k": 3, "score_threshold": 0.999}, + ) + output = await retriever.aget_relevant_documents("summer") + assert output == [ + Document(page_content="foo", metadata={"page": "0"}), + Document(page_content="bar", metadata={"page": "1"}), + ] + + def test_pgvector_retriever_search_threshold_custom_normalization_fn() -> None: """Test searching with threshold and custom normalization function""" texts = ["foo", "bar", "baz"] @@ -506,6 +661,31 @@ def test_pgvector_retriever_search_threshold_custom_normalization_fn() -> None: assert output == [] +@pytest.mark.asyncio +async def test_async_pgvector_retriever_search_threshold_custom_normalization_fn() -> ( + None +): + """Test searching with threshold and custom normalization function""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = await PGVector.afrom_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection=CONNECTION_STRING, + pre_delete_collection=True, + relevance_score_fn=lambda d: d * 0, + ) + + retriever = docsearch.as_retriever( + search_type="similarity_score_threshold", + search_kwargs={"k": 3, "score_threshold": 0.5}, + ) + output = await retriever.aget_relevant_documents("foo") + assert output == [] + + def test_pgvector_max_marginal_relevance_search() -> None: """Test max marginal relevance search.""" texts = ["foo", "bar", "baz"] @@ -520,6 +700,21 @@ def test_pgvector_max_marginal_relevance_search() -> None: assert output == [Document(page_content="foo")] +@pytest.mark.asyncio +async def test_async_pgvector_max_marginal_relevance_search() -> None: + """Test max marginal relevance search.""" + texts = ["foo", "bar", "baz"] + docsearch = await PGVector.afrom_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + connection=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = await docsearch.amax_marginal_relevance_search("foo", k=1, fetch_k=3) + assert output == [Document(page_content="foo")] + + def test_pgvector_max_marginal_relevance_search_with_score() -> None: """Test max marginal relevance search with relevance scores.""" texts = ["foo", "bar", "baz"] @@ -534,6 +729,23 @@ def test_pgvector_max_marginal_relevance_search_with_score() -> None: assert output == [(Document(page_content="foo"), 0.0)] +@pytest.mark.asyncio +async def test_async_pgvector_max_marginal_relevance_search_with_score() -> None: + """Test max marginal relevance search with relevance scores.""" + texts = ["foo", "bar", "baz"] + docsearch = await PGVector.afrom_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + connection=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = await docsearch.amax_marginal_relevance_search_with_score( + "foo", k=1, fetch_k=3 + ) + assert output == [(Document(page_content="foo"), 0.0)] + + def test_pgvector_with_custom_connection() -> None: """Test construction using a custom connection.""" texts = ["foo", "bar", "baz"] @@ -548,6 +760,21 @@ def test_pgvector_with_custom_connection() -> None: assert output == [Document(page_content="foo")] +@pytest.mark.asyncio +async def test_async_pgvector_with_custom_connection() -> None: + """Test construction using a custom connection.""" + texts = ["foo", "bar", "baz"] + docsearch = await PGVector.afrom_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + connection=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = await docsearch.asimilarity_search("foo", k=1) + assert output == [Document(page_content="foo")] + + def test_pgvector_with_custom_engine_args() -> None: """Test construction using custom engine arguments.""" texts = ["foo", "bar", "baz"] @@ -580,6 +807,26 @@ def pgvector() -> Generator[PGVector, None, None]: yield vector_store +@pytest.mark.asyncio +@pytest.fixture +async def async_pgvector() -> AsyncGenerator[PGVector, None]: + """Create an async PGVector instance.""" + store = await PGVector.afrom_documents( + documents=DOCUMENTS, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + connection=CONNECTION_STRING, + pre_delete_collection=True, + relevance_score_fn=lambda d: d * 0, + use_jsonb=True, + ) + try: + yield store + # Do clean up + finally: + await store.adrop_tables() + + @contextlib.contextmanager def get_vectorstore() -> Generator[PGVector, None, None]: """Get a pre-populated-vectorstore""" @@ -598,6 +845,24 @@ def get_vectorstore() -> Generator[PGVector, None, None]: store.drop_tables() +@contextlib.asynccontextmanager +async def aget_vectorstore() -> AsyncGenerator[PGVector, None]: + """Get a pre-populated-vectorstore""" + store = await PGVector.afrom_documents( + documents=DOCUMENTS, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + connection=CONNECTION_STRING, + pre_delete_collection=True, + relevance_score_fn=lambda d: d * 0, + use_jsonb=True, + ) + try: + yield store + finally: + await store.adrop_tables() + + @pytest.mark.parametrize("test_filter, expected_ids", TYPE_1_FILTERING_TEST_CASES) def test_pgvector_with_with_metadata_filters_1( test_filter: Dict[str, Any], @@ -609,6 +874,18 @@ def test_pgvector_with_with_metadata_filters_1( assert [doc.metadata["id"] for doc in docs] == expected_ids, test_filter +@pytest.mark.asyncio +@pytest.mark.parametrize("test_filter, expected_ids", TYPE_1_FILTERING_TEST_CASES) +async def test_async_pgvector_with_with_metadata_filters_1( + test_filter: Dict[str, Any], + expected_ids: List[int], +) -> None: + """Test end to end construction and search.""" + async with aget_vectorstore() as pgvector: + docs = await pgvector.asimilarity_search("meow", k=5, filter=test_filter) + assert [doc.metadata["id"] for doc in docs] == expected_ids, test_filter + + @pytest.mark.parametrize("test_filter, expected_ids", TYPE_2_FILTERING_TEST_CASES) def test_pgvector_with_with_metadata_filters_2( pgvector: PGVector, @@ -620,6 +897,18 @@ def test_pgvector_with_with_metadata_filters_2( assert [doc.metadata["id"] for doc in docs] == expected_ids, test_filter +@pytest.mark.asyncio +@pytest.mark.parametrize("test_filter, expected_ids", TYPE_2_FILTERING_TEST_CASES) +async def test_async_pgvector_with_with_metadata_filters_2( + async_pgvector: PGVector, + test_filter: Dict[str, Any], + expected_ids: List[int], +) -> None: + """Test end to end construction and search.""" + docs = await async_pgvector.asimilarity_search("meow", k=5, filter=test_filter) + assert [doc.metadata["id"] for doc in docs] == expected_ids, test_filter + + @pytest.mark.parametrize("test_filter, expected_ids", TYPE_3_FILTERING_TEST_CASES) def test_pgvector_with_with_metadata_filters_3( pgvector: PGVector, @@ -631,6 +920,18 @@ def test_pgvector_with_with_metadata_filters_3( assert [doc.metadata["id"] for doc in docs] == expected_ids, test_filter +@pytest.mark.asyncio +@pytest.mark.parametrize("test_filter, expected_ids", TYPE_3_FILTERING_TEST_CASES) +async def test_async_pgvector_with_with_metadata_filters_3( + async_pgvector: PGVector, + test_filter: Dict[str, Any], + expected_ids: List[int], +) -> None: + """Test end to end construction and search.""" + docs = await async_pgvector.asimilarity_search("meow", k=5, filter=test_filter) + assert [doc.metadata["id"] for doc in docs] == expected_ids, test_filter + + @pytest.mark.parametrize("test_filter, expected_ids", TYPE_4_FILTERING_TEST_CASES) def test_pgvector_with_with_metadata_filters_4( pgvector: PGVector, @@ -642,6 +943,18 @@ def test_pgvector_with_with_metadata_filters_4( assert [doc.metadata["id"] for doc in docs] == expected_ids, test_filter +@pytest.mark.asyncio +@pytest.mark.parametrize("test_filter, expected_ids", TYPE_4_FILTERING_TEST_CASES) +async def test_async_pgvector_with_with_metadata_filters_4( + async_pgvector: PGVector, + test_filter: Dict[str, Any], + expected_ids: List[int], +) -> None: + """Test end to end construction and search.""" + docs = await async_pgvector.asimilarity_search("meow", k=5, filter=test_filter) + assert [doc.metadata["id"] for doc in docs] == expected_ids, test_filter + + @pytest.mark.parametrize("test_filter, expected_ids", TYPE_5_FILTERING_TEST_CASES) def test_pgvector_with_with_metadata_filters_5( pgvector: PGVector, @@ -653,6 +966,18 @@ def test_pgvector_with_with_metadata_filters_5( assert [doc.metadata["id"] for doc in docs] == expected_ids, test_filter +@pytest.mark.asyncio +@pytest.mark.parametrize("test_filter, expected_ids", TYPE_5_FILTERING_TEST_CASES) +async def test_async_pgvector_with_with_metadata_filters_5( + async_pgvector: PGVector, + test_filter: Dict[str, Any], + expected_ids: List[int], +) -> None: + """Test end to end construction and search.""" + docs = await async_pgvector.asimilarity_search("meow", k=5, filter=test_filter) + assert [doc.metadata["id"] for doc in docs] == expected_ids, test_filter + + @pytest.mark.parametrize( "invalid_filter", [ From e1ad8c824283c83379d41b6d766a464524e4abd0 Mon Sep 17 00:00:00 2001 From: Philippe Prados Date: Tue, 16 Apr 2024 17:19:36 +0200 Subject: [PATCH 04/18] It's possible to share the session_maker with sync mode. --- langchain_postgres/vectorstores.py | 187 ++++++++++++--------------- tests/unit_tests/test_vectorstore.py | 23 ++-- 2 files changed, 97 insertions(+), 113 deletions(-) diff --git a/langchain_postgres/vectorstores.py b/langchain_postgres/vectorstores.py index ffce9ff6..c32739b8 100644 --- a/langchain_postgres/vectorstores.py +++ b/langchain_postgres/vectorstores.py @@ -29,6 +29,7 @@ from langchain_core.vectorstores import VectorStore from sqlalchemy import SQLColumnExpression, cast, create_engine, delete, func, select from sqlalchemy.dialects.postgresql import JSON, JSONB, JSONPATH, UUID, insert +from sqlalchemy.orm import scoped_session from sqlalchemy.engine import Engine from sqlalchemy.ext.asyncio import ( AsyncEngine, @@ -279,7 +280,7 @@ def __init__( self, embeddings: Embeddings, *, - connection: Union[None, Connection, str] = None, + connection: Union[None, Connection, Engine, AsyncEngine, str] = None, embedding_length: Optional[int] = None, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, collection_metadata: Optional[dict] = None, @@ -290,7 +291,7 @@ def __init__( engine_args: Optional[dict[str, Any]] = None, use_jsonb: bool = True, create_extension: bool = True, - _async_mode: bool = False, # Tag to force the async mode + async_mode: bool = False, ) -> None: """Initialize the PGVector store. For an async version, use `PGVector.acreate()` instead. @@ -320,7 +321,7 @@ def __init__( doesn't exist. disabling creation is useful when using ReadOnly Databases. """ - self._async_mode = _async_mode + self.async_mode = async_mode self.embedding_function = embeddings self._embedding_length = embedding_length self.collection_name = collection_name @@ -331,9 +332,10 @@ def __init__( self.override_relevance_score_fn = relevance_score_fn self._engine: Optional[Engine] = None self._async_engine: Optional[AsyncEngine] = None + self._async_init=False if isinstance(connection, str): - if _async_mode: + if async_mode: self._async_engine = create_async_engine( connection, **(engine_args or {}) ) @@ -342,7 +344,7 @@ def __init__( elif isinstance(connection, Engine): self._engine = connection elif isinstance(connection, AsyncEngine): - self._async_mode = True + self.async_mode = True self._async_engine = connection else: raise ValueError( @@ -350,10 +352,10 @@ def __init__( "sqlalchemy.engine.Engine or sqlalchemy.ext.asyncio.engine.AsyncEngine" ) self._session_maker: Union[sessionmaker, async_sessionmaker] - if self._async_mode: - self._session_maker = async_sessionmaker(bind=self._async_engine) + if self.async_mode: + self.session_maker = async_sessionmaker(bind=self._async_engine) else: - self._session_maker = sessionmaker(bind=self._engine) + self.session_maker = scoped_session(sessionmaker(bind=self._engine)) self.use_jsonb = use_jsonb self.create_extension = create_extension @@ -361,15 +363,8 @@ def __init__( if not use_jsonb: # Replace with a deprecation warning. raise NotImplementedError("use_jsonb=False is no longer supported.") - if not _async_mode: + if not async_mode: self.__post_init__() - else: - import inspect - - assert inspect.stack()[1].function in [ - "acreate", - "__afrom", - ], "Call await PGVector.acreate() instead of PGVector(...))" def __post_init__( self, @@ -390,6 +385,10 @@ async def __apost_init__( self, ) -> None: """Async initialize the store.""" + if self._async_init: + return + self._async_init = True + if self.create_extension: await self.acreate_vector_extension() @@ -401,75 +400,14 @@ async def __apost_init__( await self.acreate_tables_if_not_exists() await self.acreate_collection() - @classmethod - async def acreate( - cls, - embeddings: Embeddings, - *, - connection: Optional[Connection] = None, - embedding_length: Optional[int] = None, - collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, - collection_metadata: Optional[dict] = None, - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - pre_delete_collection: bool = False, - logger: Optional[logging.Logger] = None, - relevance_score_fn: Optional[Callable[[float], float]] = None, - engine_args: Optional[dict[str, Any]] = None, - use_jsonb: bool = True, - create_extension: bool = True, - ) -> PGVector: - """Async create instance - - Args: - connection: Postgres connection string. - embeddings: Any embedding function implementing - `langchain.embeddings.base.Embeddings` interface. - embedding_length: The length of the embedding vector. (default: None) - NOTE: This is not mandatory. Defining it will prevent vectors of - any other size to be added to the embeddings table but, without it, - the embeddings can't be indexed. - collection_name: The name of the collection to use. (default: langchain) - NOTE: This is not the name of the table, but the name of the collection. - The tables will be created when initializing the store (if not exists) - So, make sure the user has the right permissions to create tables. - distance_strategy: The distance strategy to use. (default: COSINE) - pre_delete_collection: If True, will delete the collection if it exists. - (default: False). Useful for testing. - engine_args: SQLAlchemy's create engine arguments. - use_jsonb: Use JSONB instead of JSON for metadata. (default: True) - Strongly discouraged from using JSON as it's not as efficient - for querying. - It's provided here for backwards compatibility with older versions, - and will be removed in the future. - create_extension: If True, will create the vector extension if it - doesn't exist. disabling creation is useful when using ReadOnly - Databases. - """ - self = cls( - embeddings=embeddings, - connection=connection, - embedding_length=embedding_length, - collection_name=collection_name, - collection_metadata=collection_metadata, - distance_strategy=distance_strategy, - pre_delete_collection=pre_delete_collection, - logger=logger, - relevance_score_fn=relevance_score_fn, - engine_args=engine_args, - use_jsonb=use_jsonb, - create_extension=create_extension, - _async_mode=True, - ) - await self.__apost_init__() - return self - @property def embeddings(self) -> Embeddings: return self.embedding_function def create_vector_extension(self) -> None: + assert not self._async_engine, "This method must be called without async_mode" try: - with self._session_maker() as session: # type: ignore[arg-type] + with self.session_maker() as session: # type: ignore[arg-type] # The advisor lock fixes issue arising from concurrent # creation of the vector extension. # https://github.com/langchain-ai/langchain/issues/12933 @@ -487,8 +425,11 @@ def create_vector_extension(self) -> None: raise Exception(f"Failed to create vector extension: {e}") from e async def acreate_vector_extension(self) -> None: + assert self.async_mode,"This method must be called with async_mode" + await self.__apost_init__() # Lazy async init + try: - async with self._session_maker() as session: + async with self.session_maker() as session: # The advisor lock fixes issue arising from concurrent # creation of the vector extension. # https://github.com/langchain-ai/langchain/issues/12933 @@ -504,34 +445,40 @@ async def acreate_vector_extension(self) -> None: raise Exception(f"Failed to create vector extension: {e}") from e def create_tables_if_not_exists(self) -> None: - with self._session_maker() as session: + assert not self._async_engine, "This method must be called without async_mode" + with self.session_maker() as session: Base.metadata.create_all(session.get_bind()) async def acreate_tables_if_not_exists(self) -> None: - assert self._async_engine, "Use with async mode" + assert self._async_engine, "This method must be called with async_mode" + await self.__apost_init__() # Lazy async init async with self._async_engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) def drop_tables(self) -> None: - with self._session_maker() as session: + assert not self._async_engine, "This method must be called without async_mode" + with self.session_maker() as session: Base.metadata.drop_all(session.get_bind()) async def adrop_tables(self) -> None: - assert self._async_engine, "Use with async mode" + assert self._async_engine, "This method must be called with async_mode" + await self.__apost_init__() # Lazy async init async with self._async_engine.begin() as conn: await conn.run_sync(Base.metadata.drop_all) def create_collection(self) -> None: + assert not self._async_engine, "This method must be called without async_mode" if self.pre_delete_collection: self.delete_collection() - with self._session_maker() as session: + with self.session_maker() as session: self.CollectionStore.get_or_create( session, self.collection_name, cmetadata=self.collection_metadata ) async def acreate_collection(self) -> None: - assert self._async_engine, "Use with async mode" - async with self._session_maker() as session: + assert self._async_engine, "This method must be called with async_mode" + await self.__apost_init__() # Lazy async init + async with self.session_maker() as session: if self.pre_delete_collection: await self._adelete_collection(session) await self.CollectionStore.aget_or_create( @@ -555,8 +502,9 @@ async def _adelete_collection(self, session: AsyncSession) -> None: await session.delete(collection) def delete_collection(self) -> None: + assert not self._async_engine, "This method must be called without async_mode" self.logger.debug("Trying to delete collection") - with self._session_maker() as session: # type: ignore[arg-type] + with self.session_maker() as session: # type: ignore[arg-type] collection = self.get_collection(session) if not collection: self.logger.warning("Collection not found") @@ -565,8 +513,9 @@ def delete_collection(self) -> None: session.commit() async def adelete_collection(self) -> None: - self.logger.debug("Trying to delete collection") - async with self._session_maker() as session: # type: ignore[arg-type] + assert self._async_engine, "This method must be called with async_mode" + await self.__apost_init__() # Lazy async init + async with self.session_maker() as session: # type: ignore[arg-type] collection = await self.aget_collection(session) if not collection: self.logger.warning("Collection not found") @@ -586,7 +535,8 @@ def delete( ids: List of ids to delete. collection_only: Only delete ids in the collection. """ - with self._session_maker() as session: + assert not self._async_engine, "This method must be called without async_mode" + with self.session_maker() as session: if ids is not None: self.logger.debug( "Trying to delete vectors by ids (represented by the model " @@ -621,7 +571,9 @@ async def adelete( ids: List of ids to delete. collection_only: Only delete ids in the collection. """ - async with self._session_maker() as session: + assert self._async_engine, "This method must be called with async_mode" + await self.__apost_init__() # Lazy async init + async with self.session_maker() as session: if ids is not None: self.logger.debug( "Trying to delete vectors by ids (represented by the model " @@ -645,9 +597,12 @@ async def adelete( await session.commit() def get_collection(self, session: Session) -> Any: + assert not self._async_engine, "This method must be called without async_mode" return self.CollectionStore.get_by_name(session, self.collection_name) async def aget_collection(self, session: AsyncSession) -> Any: + assert self._async_engine, "This method must be called with async_mode" + await self.__apost_init__() # Lazy async init return await self.CollectionStore.aget_by_name(session, self.collection_name) @classmethod @@ -717,11 +672,9 @@ async def __afrom( distance_strategy=distance_strategy, pre_delete_collection=pre_delete_collection, use_jsonb=use_jsonb, - _async_mode=True, + async_mode=True, **kwargs, ) - # Second phase to create - await store.__apost_init__() await store.aadd_embeddings( texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs @@ -747,13 +700,14 @@ def add_embeddings( If not provided, will generate a new id for each document. kwargs: vectorstore specific parameters """ + assert not self._async_engine, "This method must be called without async_mode" if ids is None: ids = [str(uuid.uuid4()) for _ in texts] if not metadatas: metadatas = [{} for _ in texts] - with self._session_maker() as session: # type: ignore[arg-type] + with self.session_maker() as session: # type: ignore[arg-type] collection = self.get_collection(session) if not collection: raise ValueError("Collection not found") @@ -802,13 +756,15 @@ async def aadd_embeddings( If not provided, will generate a new id for each text. kwargs: vectorstore specific parameters """ + assert self._async_engine, "This method must be called with async_mode" + await self.__apost_init__() # Lazy async init if ids is None: ids = [str(uuid.uuid1()) for _ in texts] if not metadatas: metadatas = [{} for _ in texts] - async with self._session_maker() as session: # type: ignore[arg-type] + async with self.session_maker() as session: # type: ignore[arg-type] collection = await self.aget_collection(session) if not collection: raise ValueError("Collection not found") @@ -858,6 +814,7 @@ def add_texts( Returns: List of ids from adding the texts into the vectorstore. """ + assert not self._async_engine, "This method must be called without async_mode" embeddings = self.embedding_function.embed_documents(list(texts)) return self.add_embeddings( texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs @@ -882,6 +839,8 @@ async def aadd_texts( Returns: List of ids from adding the texts into the vectorstore. """ + assert self._async_engine, "This method must be called with async_mode" + await self.__apost_init__() # Lazy async init embeddings = await self.embedding_function.aembed_documents(list(texts)) return await self.aadd_embeddings( texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs @@ -904,6 +863,7 @@ def similarity_search( Returns: List of Documents most similar to the query. """ + assert not self._async_engine, "This method must be called without async_mode" embedding = self.embedding_function.embed_query(text=query) return self.similarity_search_by_vector( embedding=embedding, @@ -928,6 +888,8 @@ async def asimilarity_search( Returns: List of Documents most similar to the query. """ + assert self._async_engine, "This method must be called with async_mode" + await self.__apost_init__() # Lazy async init embedding = self.embedding_function.embed_query(text=query) return await self.asimilarity_search_by_vector( embedding=embedding, @@ -951,6 +913,7 @@ def similarity_search_with_score( Returns: List of Documents most similar to the query and score for each. """ + assert not self._async_engine, "This method must be called without async_mode" embedding = self.embedding_function.embed_query(query) docs = self.similarity_search_with_score_by_vector( embedding=embedding, k=k, filter=filter @@ -973,6 +936,8 @@ async def asimilarity_search_with_score( Returns: List of Documents most similar to the query and score for each. """ + assert self._async_engine, "This method must be called with async_mode" + await self.__apost_init__() # Lazy async init embedding = self.embedding_function.embed_query(query) docs = await self.asimilarity_search_with_score_by_vector( embedding=embedding, k=k, filter=filter @@ -999,6 +964,7 @@ def similarity_search_with_score_by_vector( k: int = 4, filter: Optional[dict] = None, ) -> List[Tuple[Document, float]]: + assert not self._async_engine, "This method must be called without async_mode" results = self.__query_collection(embedding=embedding, k=k, filter=filter) return self._results_to_docs_and_scores(results) @@ -1009,7 +975,9 @@ async def asimilarity_search_with_score_by_vector( k: int = 4, filter: Optional[dict] = None, ) -> List[Tuple[Document, float]]: - async with self._session_maker() as session: # type: ignore[arg-type] + assert self._async_engine, "This method must be called with async_mode" + await self.__apost_init__() # Lazy async init + async with self.session_maker() as session: # type: ignore[arg-type] results = await self.__aquery_collection( session=session, embedding=embedding, k=k, filter=filter ) @@ -1329,7 +1297,7 @@ def __query_collection( filter: Optional[Dict[str, str]] = None, ) -> Sequence[Any]: """Query the collection.""" - with self._session_maker() as session: # type: ignore[arg-type] + with self.session_maker() as session: # type: ignore[arg-type] collection = self.get_collection(session) if not collection: raise ValueError("Collection not found") @@ -1372,7 +1340,7 @@ async def __aquery_collection( filter: Optional[Dict[str, str]] = None, ) -> Sequence[Any]: """Query the collection.""" - async with self._session_maker() as session: # type: ignore[arg-type] + async with self.session_maker() as session: # type: ignore[arg-type] collection = await self.aget_collection(session) if not collection: raise ValueError("Collection not found") @@ -1425,6 +1393,7 @@ def similarity_search_by_vector( Returns: List of Documents most similar to the query vector. """ + assert not self._async_engine, "This method must be called without async_mode" docs_and_scores = self.similarity_search_with_score_by_vector( embedding=embedding, k=k, filter=filter ) @@ -1447,6 +1416,8 @@ async def asimilarity_search_by_vector( Returns: List of Documents most similar to the query vector. """ + assert self._async_engine, "This method must be called with async_mode" + await self.__apost_init__() # Lazy async init docs_and_scores = await self.asimilarity_search_with_score_by_vector( embedding=embedding, k=k, filter=filter ) @@ -1656,7 +1627,7 @@ async def afrom_existing_index( return the instance of the store without inserting any new embeddings """ - store = await PGVector.acreate( + store = await PGVector( connection=connection, collection_name=collection_name, embeddings=embedding, @@ -1826,6 +1797,7 @@ def max_marginal_relevance_search_with_score_by_vector( List[Tuple[Document, float]]: List of Documents selected by maximal marginal relevance to the query and score for each. """ + assert not self._async_engine, "This method must be called without async_mode" results = self.__query_collection(embedding=embedding, k=fetch_k, filter=filter) embedding_list = [result.EmbeddingStore.embedding for result in results] @@ -1871,7 +1843,9 @@ async def amax_marginal_relevance_search_with_score_by_vector( List[Tuple[Document, float]]: List of Documents selected by maximal marginal relevance to the query and score for each. """ - async with self._session_maker() as session: + assert self._async_engine, "This method must be called with async_mode" + await self.__apost_init__() # Lazy async init + async with self.session_maker() as session: results = await self.__aquery_collection( session=session, embedding=embedding, k=fetch_k, filter=filter ) @@ -1917,6 +1891,7 @@ def max_marginal_relevance_search( Returns: List[Document]: List of Documents selected by maximal marginal relevance. """ + assert not self._async_engine, "This method must be called without async_mode" embedding = self.embedding_function.embed_query(query) return self.max_marginal_relevance_search_by_vector( embedding, @@ -1955,6 +1930,8 @@ async def amax_marginal_relevance_search( Returns: List[Document]: List of Documents selected by maximal marginal relevance. """ + assert self._async_engine, "This method must be called with async_mode" + await self.__apost_init__() # Lazy async init embedding = self.embedding_function.embed_query(query) return await self.amax_marginal_relevance_search_by_vector( embedding, @@ -1994,6 +1971,7 @@ def max_marginal_relevance_search_with_score( List[Tuple[Document, float]]: List of Documents selected by maximal marginal relevance to the query and score for each. """ + assert not self._async_engine, "This method must be called without async_mode" embedding = self.embedding_function.embed_query(query) docs = self.max_marginal_relevance_search_with_score_by_vector( embedding=embedding, @@ -2034,6 +2012,8 @@ async def amax_marginal_relevance_search_with_score( List[Tuple[Document, float]]: List of Documents selected by maximal marginal relevance to the query and score for each. """ + assert self._async_engine, "This method must be called with async_mode" + await self.__apost_init__() # Lazy async init embedding = self.embedding_function.embed_query(query) docs = await self.amax_marginal_relevance_search_with_score_by_vector( embedding=embedding, @@ -2074,6 +2054,7 @@ def max_marginal_relevance_search_by_vector( Returns: List[Document]: List of Documents selected by maximal marginal relevance. """ + assert not self._async_engine, "This method must be called without async_mode" docs_and_scores = self.max_marginal_relevance_search_with_score_by_vector( embedding, k=k, @@ -2114,6 +2095,8 @@ async def amax_marginal_relevance_search_by_vector( Returns: List[Document]: List of Documents selected by maximal marginal relevance. """ + assert self._async_engine, "This method must be called with async_mode" + await self.__apost_init__() # Lazy async init docs_and_scores = ( await self.amax_marginal_relevance_search_with_score_by_vector( embedding, diff --git a/tests/unit_tests/test_vectorstore.py b/tests/unit_tests/test_vectorstore.py index e0f05e18..3b959710 100644 --- a/tests/unit_tests/test_vectorstore.py +++ b/tests/unit_tests/test_vectorstore.py @@ -285,7 +285,7 @@ def test_pgvector_collection_with_metadata() -> None: connection=CONNECTION_STRING, pre_delete_collection=True, ) - with pgvector._session_maker() as session: + with pgvector.session_maker() as session: collection = pgvector.get_collection(session) if collection is None: assert False, "Expected a CollectionStore object but received None" @@ -297,14 +297,15 @@ def test_pgvector_collection_with_metadata() -> None: @pytest.mark.asyncio async def test_async_pgvector_collection_with_metadata() -> None: """Test end to end collection construction""" - pgvector = await PGVector.acreate( + pgvector = PGVector( collection_name="test_collection", collection_metadata={"foo": "bar"}, embeddings=FakeEmbeddingsWithAdaDimension(), connection=CONNECTION_STRING, pre_delete_collection=True, + async_mode=True, ) - async with pgvector._session_maker() as session: + async with pgvector.session_maker() as session: collection = await pgvector.aget_collection(session) if collection is None: assert False, "Expected a CollectionStore object but received None" @@ -327,14 +328,14 @@ def test_pgvector_delete_docs() -> None: pre_delete_collection=True, ) vectorstore.delete(["1", "2"]) - with vectorstore._session_maker() as session: + with vectorstore.session_maker() as session: records = list(session.query(vectorstore.EmbeddingStore).all()) # ignoring type error since mypy cannot determine whether # the list is sortable assert sorted(record.id for record in records) == ["3"] # type: ignore vectorstore.delete(["2", "3"]) # Should not raise on missing ids - with vectorstore._session_maker() as session: + with vectorstore.session_maker() as session: records = list(session.query(vectorstore.EmbeddingStore).all()) # ignoring type error since mypy cannot determine whether # the list is sortable @@ -372,7 +373,7 @@ async def test_async_pgvector_delete_docs() -> None: pre_delete_collection=True, ) await vectorstore.adelete(["1", "2"]) - async with vectorstore._session_maker() as session: + async with vectorstore.session_maker() as session: records = ( (await session.execute(select(vectorstore.EmbeddingStore))).scalars().all() ) @@ -381,7 +382,7 @@ async def test_async_pgvector_delete_docs() -> None: assert sorted(record.id for record in records) == ["3"] # type: ignore await vectorstore.adelete(["2", "3"]) # Should not raise on missing ids - async with vectorstore._session_maker() as session: + async with vectorstore.session_maker() as session: records = ( (await session.execute(select(vectorstore.EmbeddingStore))).scalars().all() ) @@ -423,7 +424,7 @@ def test_pgvector_index_documents() -> None: connection=CONNECTION_STRING, pre_delete_collection=True, ) - with vectorstore._session_maker() as session: + with vectorstore.session_maker() as session: records = list(session.query(vectorstore.EmbeddingStore).all()) # ignoring type error since mypy cannot determine whether # the list is sortable @@ -445,7 +446,7 @@ def test_pgvector_index_documents() -> None: vectorstore.add_documents(documents, ids=[doc.metadata["id"] for doc in documents]) - with vectorstore._session_maker() as session: + with vectorstore.session_maker() as session: records = list(session.query(vectorstore.EmbeddingStore).all()) ordered_records = sorted(records, key=lambda x: x.id) # ignoring type error since mypy cannot determine whether @@ -499,7 +500,7 @@ async def test_async_pgvector_index_documents() -> None: connection=CONNECTION_STRING, pre_delete_collection=True, ) - async with vectorstore._session_maker() as session: + async with vectorstore.session_maker() as session: records = ( (await session.execute(select(vectorstore.EmbeddingStore))).scalars().all() ) @@ -525,7 +526,7 @@ async def test_async_pgvector_index_documents() -> None: documents, ids=[doc.metadata["id"] for doc in documents] ) - async with vectorstore._session_maker() as session: + async with vectorstore.session_maker() as session: records = ( (await session.execute(select(vectorstore.EmbeddingStore))).scalars().all() ) From dc6187a905cef3ce5218945ecae3595f7d23951e Mon Sep 17 00:00:00 2001 From: Philippe Prados Date: Wed, 17 Apr 2024 17:48:47 +0200 Subject: [PATCH 05/18] Fix __apost_init__ --- langchain_postgres/vectorstores.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/langchain_postgres/vectorstores.py b/langchain_postgres/vectorstores.py index c32739b8..3e2311d9 100644 --- a/langchain_postgres/vectorstores.py +++ b/langchain_postgres/vectorstores.py @@ -1,7 +1,9 @@ from __future__ import annotations +import asyncio import enum import logging +import threading import uuid from typing import ( Any, @@ -426,7 +428,6 @@ def create_vector_extension(self) -> None: async def acreate_vector_extension(self) -> None: assert self.async_mode,"This method must be called with async_mode" - await self.__apost_init__() # Lazy async init try: async with self.session_maker() as session: @@ -937,7 +938,7 @@ async def asimilarity_search_with_score( List of Documents most similar to the query and score for each. """ assert self._async_engine, "This method must be called with async_mode" - await self.__apost_init__() # Lazy async init + await self.__apost_init__() # Lazy async init embedding = self.embedding_function.embed_query(query) docs = await self.asimilarity_search_with_score_by_vector( embedding=embedding, k=k, filter=filter From 14d6da818ca5fc50398057e93ee65e933627fac3 Mon Sep 17 00:00:00 2001 From: Philippe Prados Date: Mon, 22 Apr 2024 12:48:18 +0200 Subject: [PATCH 06/18] Add async API --- langchain_postgres/vectorstores.py | 77 +++++++++++++++++------------- pyproject.toml | 9 ++++ 2 files changed, 54 insertions(+), 32 deletions(-) diff --git a/langchain_postgres/vectorstores.py b/langchain_postgres/vectorstores.py index 3e2311d9..4186a82e 100644 --- a/langchain_postgres/vectorstores.py +++ b/langchain_postgres/vectorstores.py @@ -1,9 +1,8 @@ +# pylint: disable=too-many-lines from __future__ import annotations -import asyncio import enum import logging -import threading import uuid from typing import ( Any, @@ -17,21 +16,18 @@ Type, Union, ) +from typing import ( + cast as typing_cast, +) import numpy as np import sqlalchemy - -# try: -# from sqlalchemy.orm import declarative_base -# except ImportError: -# from sqlalchemy.ext.declarative import declarative_base from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.utils import get_from_dict_or_env from langchain_core.vectorstores import VectorStore from sqlalchemy import SQLColumnExpression, cast, create_engine, delete, func, select from sqlalchemy.dialects.postgresql import JSON, JSONB, JSONPATH, UUID, insert -from sqlalchemy.orm import scoped_session from sqlalchemy.engine import Engine from sqlalchemy.ext.asyncio import ( AsyncEngine, @@ -43,6 +39,7 @@ Session, declarative_base, relationship, + scoped_session, sessionmaker, ) @@ -124,14 +121,24 @@ class CollectionStore(Base): def get_by_name( cls, session: Session, name: str ) -> Optional["CollectionStore"]: - return session.query(cls).filter(cls.name == name).first() # type: ignore + return ( + session.query(cls) + .filter(typing_cast(sqlalchemy.Column, cls.name) == name) + .first() + ) @classmethod async def aget_by_name( cls, session: AsyncSession, name: str ) -> Optional["CollectionStore"]: return ( - (await session.execute(select(CollectionStore).where(cls.name == name))) + ( + await session.execute( + select(CollectionStore).where( + typing_cast(sqlalchemy.Column, cls.name) == name + ) + ) + ) .scalars() .first() ) @@ -259,6 +266,7 @@ class PGVector(VectorStore): connection=connection_string, collection_name=collection_name, use_jsonb=True, + async_mode=False, ) @@ -276,6 +284,9 @@ class PGVector(VectorStore): You will need to recreate the tables if you are using an existing database. * A Connection object has to be provided explicitly. Connections will not be picked up automatically based on env variables. + * langchain_postgres now accept async connections. If you want to use the async + version, you need to set `async_mode=True` when initializing the store or + use an async engine. """ def __init__( @@ -299,7 +310,7 @@ def __init__( For an async version, use `PGVector.acreate()` instead. Args: - connection: Postgres connection string. + connection: Postgres connection string or (async)engine. embeddings: Any embedding function implementing `langchain.embeddings.base.Embeddings` interface. embedding_length: The length of the embedding vector. (default: None) @@ -334,7 +345,7 @@ def __init__( self.override_relevance_score_fn = relevance_score_fn self._engine: Optional[Engine] = None self._async_engine: Optional[AsyncEngine] = None - self._async_init=False + self._async_init = False if isinstance(connection, str): if async_mode: @@ -344,6 +355,7 @@ def __init__( else: self._engine = create_engine(url=connection, **(engine_args or {})) elif isinstance(connection, Engine): + self.async_mode = False self._engine = connection elif isinstance(connection, AsyncEngine): self.async_mode = True @@ -353,7 +365,7 @@ def __init__( "connection should be a connection string or an instance of " "sqlalchemy.engine.Engine or sqlalchemy.ext.asyncio.engine.AsyncEngine" ) - self._session_maker: Union[sessionmaker, async_sessionmaker] + self.session_maker: Union[scoped_session, async_sessionmaker] if self.async_mode: self.session_maker = async_sessionmaker(bind=self._async_engine) else: @@ -365,7 +377,7 @@ def __init__( if not use_jsonb: # Replace with a deprecation warning. raise NotImplementedError("use_jsonb=False is no longer supported.") - if not async_mode: + if not self.async_mode: self.__post_init__() def __post_init__( @@ -386,7 +398,7 @@ def __post_init__( async def __apost_init__( self, ) -> None: - """Async initialize the store.""" + """Async initialize the store (use lazy approach).""" if self._async_init: return self._async_init = True @@ -427,7 +439,7 @@ def create_vector_extension(self) -> None: raise Exception(f"Failed to create vector extension: {e}") from e async def acreate_vector_extension(self) -> None: - assert self.async_mode,"This method must be called with async_mode" + assert self.async_mode, "This method must be called with async_mode" try: async with self.session_maker() as session: @@ -452,7 +464,7 @@ def create_tables_if_not_exists(self) -> None: async def acreate_tables_if_not_exists(self) -> None: assert self._async_engine, "This method must be called with async_mode" - await self.__apost_init__() # Lazy async init + await self.__apost_init__() # Lazy async init async with self._async_engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) @@ -463,7 +475,7 @@ def drop_tables(self) -> None: async def adrop_tables(self) -> None: assert self._async_engine, "This method must be called with async_mode" - await self.__apost_init__() # Lazy async init + await self.__apost_init__() # Lazy async init async with self._async_engine.begin() as conn: await conn.run_sync(Base.metadata.drop_all) @@ -478,7 +490,7 @@ def create_collection(self) -> None: async def acreate_collection(self) -> None: assert self._async_engine, "This method must be called with async_mode" - await self.__apost_init__() # Lazy async init + await self.__apost_init__() # Lazy async init async with self.session_maker() as session: if self.pre_delete_collection: await self._adelete_collection(session) @@ -515,7 +527,7 @@ def delete_collection(self) -> None: async def adelete_collection(self) -> None: assert self._async_engine, "This method must be called with async_mode" - await self.__apost_init__() # Lazy async init + await self.__apost_init__() # Lazy async init async with self.session_maker() as session: # type: ignore[arg-type] collection = await self.aget_collection(session) if not collection: @@ -573,7 +585,7 @@ async def adelete( collection_only: Only delete ids in the collection. """ assert self._async_engine, "This method must be called with async_mode" - await self.__apost_init__() # Lazy async init + await self.__apost_init__() # Lazy async init async with self.session_maker() as session: if ids is not None: self.logger.debug( @@ -603,7 +615,7 @@ def get_collection(self, session: Session) -> Any: async def aget_collection(self, session: AsyncSession) -> Any: assert self._async_engine, "This method must be called with async_mode" - await self.__apost_init__() # Lazy async init + await self.__apost_init__() # Lazy async init return await self.CollectionStore.aget_by_name(session, self.collection_name) @classmethod @@ -758,7 +770,7 @@ async def aadd_embeddings( kwargs: vectorstore specific parameters """ assert self._async_engine, "This method must be called with async_mode" - await self.__apost_init__() # Lazy async init + await self.__apost_init__() # Lazy async init if ids is None: ids = [str(uuid.uuid1()) for _ in texts] @@ -841,7 +853,7 @@ async def aadd_texts( List of ids from adding the texts into the vectorstore. """ assert self._async_engine, "This method must be called with async_mode" - await self.__apost_init__() # Lazy async init + await self.__apost_init__() # Lazy async init embeddings = await self.embedding_function.aembed_documents(list(texts)) return await self.aadd_embeddings( texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs @@ -890,7 +902,7 @@ async def asimilarity_search( List of Documents most similar to the query. """ assert self._async_engine, "This method must be called with async_mode" - await self.__apost_init__() # Lazy async init + await self.__apost_init__() # Lazy async init embedding = self.embedding_function.embed_query(text=query) return await self.asimilarity_search_by_vector( embedding=embedding, @@ -977,7 +989,7 @@ async def asimilarity_search_with_score_by_vector( filter: Optional[dict] = None, ) -> List[Tuple[Document, float]]: assert self._async_engine, "This method must be called with async_mode" - await self.__apost_init__() # Lazy async init + await self.__apost_init__() # Lazy async init async with self.session_maker() as session: # type: ignore[arg-type] results = await self.__aquery_collection( session=session, embedding=embedding, k=k, filter=filter @@ -1418,7 +1430,7 @@ async def asimilarity_search_by_vector( List of Documents most similar to the query vector. """ assert self._async_engine, "This method must be called with async_mode" - await self.__apost_init__() # Lazy async init + await self.__apost_init__() # Lazy async init docs_and_scores = await self.asimilarity_search_with_score_by_vector( embedding=embedding, k=k, filter=filter ) @@ -1628,12 +1640,13 @@ async def afrom_existing_index( return the instance of the store without inserting any new embeddings """ - store = await PGVector( + store = PGVector( connection=connection, collection_name=collection_name, embeddings=embedding, distance_strategy=distance_strategy, pre_delete_collection=pre_delete_collection, + async_mode=True, **kwargs, ) @@ -1845,7 +1858,7 @@ async def amax_marginal_relevance_search_with_score_by_vector( relevance to the query and score for each. """ assert self._async_engine, "This method must be called with async_mode" - await self.__apost_init__() # Lazy async init + await self.__apost_init__() # Lazy async init async with self.session_maker() as session: results = await self.__aquery_collection( session=session, embedding=embedding, k=fetch_k, filter=filter @@ -1932,7 +1945,7 @@ async def amax_marginal_relevance_search( List[Document]: List of Documents selected by maximal marginal relevance. """ assert self._async_engine, "This method must be called with async_mode" - await self.__apost_init__() # Lazy async init + await self.__apost_init__() # Lazy async init embedding = self.embedding_function.embed_query(query) return await self.amax_marginal_relevance_search_by_vector( embedding, @@ -2014,7 +2027,7 @@ async def amax_marginal_relevance_search_with_score( relevance to the query and score for each. """ assert self._async_engine, "This method must be called with async_mode" - await self.__apost_init__() # Lazy async init + await self.__apost_init__() # Lazy async init embedding = self.embedding_function.embed_query(query) docs = await self.amax_marginal_relevance_search_with_score_by_vector( embedding=embedding, @@ -2097,7 +2110,7 @@ async def amax_marginal_relevance_search_by_vector( List[Document]: List of Documents selected by maximal marginal relevance. """ assert self._async_engine, "This method must be called with async_mode" - await self.__apost_init__() # Lazy async init + await self.__apost_init__() # Lazy async init docs_and_scores = ( await self.amax_marginal_relevance_search_with_score_by_vector( embedding, diff --git a/pyproject.toml b/pyproject.toml index 8e349f7f..ac8e6bb6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -85,3 +85,12 @@ timeout = 30 markers = [] asyncio_mode = "auto" +[tool.codespell] +skip = '*.md,.git,*.pdf,*.svg,*.pdf,*.yaml,*.ipynb,poetry.lock,*.min.js,*.css,package-lock.json,example_data,_dist,examples,templates,*.trig' +# Ignore latin etc +ignore-regex = '.*(Stati Uniti|Tense=Pres).*' +# whats is a typo but used frequently in queries so kept as is +# aapply - async apply +# unsecure - typo but part of API, decided to not bother for now +#ignore-words-list = 'momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogyny,unsecure,damon,crate,aadd,symbl,precesses,accademia,nin' +ignore-words-list = 'momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogyny,unsecure,damon,crate,aadd,symbl,precesses,accademia,nin' From e95df2fe5e075008a7ef0ffe943d788b00eeb7e0 Mon Sep 17 00:00:00 2001 From: Philippe Prados Date: Mon, 22 Apr 2024 12:56:12 +0200 Subject: [PATCH 07/18] Rebase --- langchain_postgres/vectorstores.py | 2 ++ pyproject.toml | 7 +------ 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/langchain_postgres/vectorstores.py b/langchain_postgres/vectorstores.py index 4186a82e..b3541c3b 100644 --- a/langchain_postgres/vectorstores.py +++ b/langchain_postgres/vectorstores.py @@ -58,8 +58,10 @@ class DistanceStrategy(str, enum.Enum): Base = declarative_base() # type: Any + _LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain" + _classes: Any = None COMPARISONS_TO_NATIVE = { diff --git a/pyproject.toml b/pyproject.toml index ac8e6bb6..d7b47b34 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -86,11 +86,6 @@ markers = [] asyncio_mode = "auto" [tool.codespell] -skip = '*.md,.git,*.pdf,*.svg,*.pdf,*.yaml,*.ipynb,poetry.lock,*.min.js,*.css,package-lock.json,example_data,_dist,examples,templates,*.trig' -# Ignore latin etc +skip = '.git,*.pdf,*.svg,*.pdf,*.yaml,*.ipynb,poetry.lock,*.min.js,*.css,package-lock.json,example_data,_dist,examples,templates,*.trig' ignore-regex = '.*(Stati Uniti|Tense=Pres).*' -# whats is a typo but used frequently in queries so kept as is -# aapply - async apply -# unsecure - typo but part of API, decided to not bother for now -#ignore-words-list = 'momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogyny,unsecure,damon,crate,aadd,symbl,precesses,accademia,nin' ignore-words-list = 'momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogyny,unsecure,damon,crate,aadd,symbl,precesses,accademia,nin' From 9c2035c3335e6c404902431c6505b18277179510 Mon Sep 17 00:00:00 2001 From: Philippe Prados Date: Mon, 22 Apr 2024 16:36:39 +0200 Subject: [PATCH 08/18] Fix create_vector_extension with async --- langchain_postgres/vectorstores.py | 63 ++++++++++++------------------ 1 file changed, 26 insertions(+), 37 deletions(-) diff --git a/langchain_postgres/vectorstores.py b/langchain_postgres/vectorstores.py index b3541c3b..8841d22a 100644 --- a/langchain_postgres/vectorstores.py +++ b/langchain_postgres/vectorstores.py @@ -28,7 +28,7 @@ from langchain_core.vectorstores import VectorStore from sqlalchemy import SQLColumnExpression, cast, create_engine, delete, func, select from sqlalchemy.dialects.postgresql import JSON, JSONB, JSONPATH, UUID, insert -from sqlalchemy.engine import Engine +from sqlalchemy.engine import Engine, Connection from sqlalchemy.ext.asyncio import ( AsyncEngine, AsyncSession, @@ -230,7 +230,18 @@ def _results_to_docs(docs_and_scores: Any) -> List[Document]: return [doc for doc, _ in docs_and_scores] -Connection = Union[sqlalchemy.engine.Engine, str] +def _create_vector_extension(conn: Connection) -> None: + statement = sqlalchemy.text( + "BEGIN;" + "SELECT pg_advisory_xact_lock(1573678846307946496);" + "CREATE EXTENSION IF NOT EXISTS vector;" + "COMMIT;" + ) + conn.execute(statement) + conn.commit() + + +DBConnection = Union[sqlalchemy.engine.Engine, str] class PGVector(VectorStore): @@ -295,7 +306,7 @@ def __init__( self, embeddings: Embeddings, *, - connection: Union[None, Connection, Engine, AsyncEngine, str] = None, + connection: Union[None, DBConnection, Engine, AsyncEngine, str] = None, embedding_length: Optional[int] = None, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, collection_metadata: Optional[dict] = None, @@ -401,7 +412,7 @@ async def __apost_init__( self, ) -> None: """Async initialize the store (use lazy approach).""" - if self._async_init: + if self._async_init: # Warning: possible race condition return self._async_init = True @@ -423,50 +434,25 @@ def embeddings(self) -> Embeddings: def create_vector_extension(self) -> None: assert not self._async_engine, "This method must be called without async_mode" try: - with self.session_maker() as session: # type: ignore[arg-type] - # The advisor lock fixes issue arising from concurrent - # creation of the vector extension. - # https://github.com/langchain-ai/langchain/issues/12933 - # For more information see: - # https://www.postgresql.org/docs/16/explicit-locking.html#ADVISORY-LOCKS - statement = sqlalchemy.text( - "BEGIN;" - "SELECT pg_advisory_xact_lock(1573678846307946496);" - "CREATE EXTENSION IF NOT EXISTS vector;" - "COMMIT;" - ) - session.execute(statement) - session.commit() + with self._engine.connect() as conn: + _create_vector_extension(conn) except Exception as e: raise Exception(f"Failed to create vector extension: {e}") from e async def acreate_vector_extension(self) -> None: assert self.async_mode, "This method must be called with async_mode" - try: - async with self.session_maker() as session: - # The advisor lock fixes issue arising from concurrent - # creation of the vector extension. - # https://github.com/langchain-ai/langchain/issues/12933 - # For more information see: - # https://www.postgresql.org/docs/16/explicit-locking.html#ADVISORY-LOCKS - await session.execute( - sqlalchemy.text("SELECT pg_advisory_xact_lock(1573678846307946496)") - ) - await session.execute( - sqlalchemy.text("CREATE EXTENSION IF NOT EXISTS vector") - ) - except Exception as e: - raise Exception(f"Failed to create vector extension: {e}") from e + async with self._async_engine.begin() as conn: + await conn.run_sync(_create_vector_extension) def create_tables_if_not_exists(self) -> None: assert not self._async_engine, "This method must be called without async_mode" with self.session_maker() as session: Base.metadata.create_all(session.get_bind()) + session.commit() async def acreate_tables_if_not_exists(self) -> None: assert self._async_engine, "This method must be called with async_mode" - await self.__apost_init__() # Lazy async init async with self._async_engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) @@ -474,6 +460,7 @@ def drop_tables(self) -> None: assert not self._async_engine, "This method must be called without async_mode" with self.session_maker() as session: Base.metadata.drop_all(session.get_bind()) + session.commit() async def adrop_tables(self) -> None: assert self._async_engine, "This method must be called with async_mode" @@ -489,6 +476,7 @@ def create_collection(self) -> None: self.CollectionStore.get_or_create( session, self.collection_name, cmetadata=self.collection_metadata ) + session.commit() async def acreate_collection(self) -> None: assert self._async_engine, "This method must be called with async_mode" @@ -499,6 +487,7 @@ async def acreate_collection(self) -> None: await self.CollectionStore.aget_or_create( session, self.collection_name, cmetadata=self.collection_metadata ) + await session.commit() def _delete_collection(self, session: Session) -> None: self.logger.debug("Trying to delete collection") @@ -1607,7 +1596,7 @@ def from_existing_index( collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, pre_delete_collection: bool = False, - connection: Optional[Connection] = None, + connection: Optional[DBConnection] = None, **kwargs: Any, ) -> PGVector: """ @@ -1634,7 +1623,7 @@ async def afrom_existing_index( collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, pre_delete_collection: bool = False, - connection: Optional[Connection] = None, + connection: Optional[DBConnection] = None, **kwargs: Any, ) -> PGVector: """ @@ -1677,7 +1666,7 @@ def from_documents( documents: List[Document], embedding: Embeddings, *, - connection: Optional[Connection] = None, + connection: Optional[DBConnection] = None, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, ids: Optional[List[str]] = None, From 7cc0bb1f8d1dfdfa331548ca0f2866fb727cbb16 Mon Sep 17 00:00:00 2001 From: Philippe Prados Date: Mon, 22 Apr 2024 17:00:23 +0200 Subject: [PATCH 09/18] Fix create_vector_extension with async --- langchain_postgres/vectorstores.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/langchain_postgres/vectorstores.py b/langchain_postgres/vectorstores.py index 8841d22a..592d07d9 100644 --- a/langchain_postgres/vectorstores.py +++ b/langchain_postgres/vectorstores.py @@ -232,15 +232,12 @@ def _results_to_docs(docs_and_scores: Any) -> List[Document]: def _create_vector_extension(conn: Connection) -> None: statement = sqlalchemy.text( - "BEGIN;" "SELECT pg_advisory_xact_lock(1573678846307946496);" "CREATE EXTENSION IF NOT EXISTS vector;" - "COMMIT;" ) conn.execute(statement) conn.commit() - DBConnection = Union[sqlalchemy.engine.Engine, str] From 4cbeeb2772713a8301386752a0c366a9c6af91bc Mon Sep 17 00:00:00 2001 From: Philippe Prados Date: Thu, 25 Apr 2024 10:41:01 +0200 Subject: [PATCH 10/18] =?UTF-8?q?Re-ajuste=20les=20UI=20vis=20=C3=A0=20vis?= =?UTF-8?q?=20de=20l'async=20(partiellement)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- langchain_postgres/vectorstores.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/langchain_postgres/vectorstores.py b/langchain_postgres/vectorstores.py index 592d07d9..07a9147c 100644 --- a/langchain_postgres/vectorstores.py +++ b/langchain_postgres/vectorstores.py @@ -413,14 +413,14 @@ async def __apost_init__( return self._async_init = True - if self.create_extension: - await self.acreate_vector_extension() - EmbeddingStore, CollectionStore = _get_embedding_collection_store( self._embedding_length ) self.CollectionStore = CollectionStore self.EmbeddingStore = EmbeddingStore + if self.create_extension: + await self.acreate_vector_extension() + await self.acreate_tables_if_not_exists() await self.acreate_collection() From 19d3f683cbb9a05e1e48e5f618a259f9c0a66a48 Mon Sep 17 00:00:00 2001 From: Philippe Prados Date: Mon, 15 Apr 2024 08:40:31 +0200 Subject: [PATCH 11/18] Add async mode --- langchain_postgres/vectorstores.py | 965 +++++++++++++-------------- tests/unit_tests/test_vectorstore.py | 100 +-- 2 files changed, 507 insertions(+), 558 deletions(-) diff --git a/langchain_postgres/vectorstores.py b/langchain_postgres/vectorstores.py index 07a9147c..d5aa6dba 100644 --- a/langchain_postgres/vectorstores.py +++ b/langchain_postgres/vectorstores.py @@ -1,6 +1,6 @@ -# pylint: disable=too-many-lines from __future__ import annotations +import asyncio import enum import logging import uuid @@ -11,37 +11,27 @@ Iterable, List, Optional, - Sequence, Tuple, - Type, - Union, -) -from typing import ( - cast as typing_cast, + Type, Union, ) import numpy as np import sqlalchemy +from sqlalchemy import SQLColumnExpression, cast, delete, func, select, Engine +from sqlalchemy.dialects.postgresql import JSON, JSONB, JSONPATH, UUID, insert +# TODO: accepter l'absence de l'option async lors des imports +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker +from sqlalchemy.orm import Session, relationship, sessionmaker + +try: + from sqlalchemy.orm import declarative_base +except ImportError: + from sqlalchemy.ext.declarative import declarative_base + from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.utils import get_from_dict_or_env from langchain_core.vectorstores import VectorStore -from sqlalchemy import SQLColumnExpression, cast, create_engine, delete, func, select -from sqlalchemy.dialects.postgresql import JSON, JSONB, JSONPATH, UUID, insert -from sqlalchemy.engine import Engine, Connection -from sqlalchemy.ext.asyncio import ( - AsyncEngine, - AsyncSession, - async_sessionmaker, - create_async_engine, -) -from sqlalchemy.orm import ( - Session, - declarative_base, - relationship, - scoped_session, - sessionmaker, -) from langchain_postgres._utils import maximal_marginal_relevance @@ -58,10 +48,8 @@ class DistanceStrategy(str, enum.Enum): Base = declarative_base() # type: Any - _LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain" - _classes: Any = None COMPARISONS_TO_NATIVE = { @@ -121,36 +109,28 @@ class CollectionStore(Base): @classmethod def get_by_name( - cls, session: Session, name: str + cls, session: Session, name: str ) -> Optional["CollectionStore"]: - return ( - session.query(cls) - .filter(typing_cast(sqlalchemy.Column, cls.name) == name) - .first() - ) + return session.query(cls).filter(cls.name == name).first() # type: ignore @classmethod async def aget_by_name( - cls, session: AsyncSession, name: str + cls, session: AsyncSession, name: str ) -> Optional["CollectionStore"]: - return ( - ( - await session.execute( - select(CollectionStore).where( - typing_cast(sqlalchemy.Column, cls.name) == name - ) - ) - ) - .scalars() - .first() - ) + stmt = select(cls).filter(cls.name == name) + # return await session.execute(stmt) # FIXME + return (await session.execute(stmt)).scalars().first() # FIXME + # stmt = select(cls).filter(cls.name == name) + # result = await session.execute(stmt) + # x = result.scalars() + # return session.query(cls).filter(cls.name == name).first() @classmethod def get_or_create( - cls, - session: Session, - name: str, - cmetadata: Optional[dict] = None, + cls, + session: Session, + name: str, + cmetadata: Optional[dict] = None, ) -> Tuple["CollectionStore", bool]: """Get or create a collection. Returns: @@ -163,16 +143,16 @@ def get_or_create( collection = cls(name=name, cmetadata=cmetadata) session.add(collection) - session.commit() + session.commit() # FIXME PPR semble utile created = True return collection, created @classmethod async def aget_or_create( - cls, - session: AsyncSession, - name: str, - cmetadata: Optional[dict] = None, + cls, + session: AsyncSession, + name: str, + cmetadata: Optional[dict] = None, ) -> Tuple["CollectionStore", bool]: """ Get or create a collection. @@ -230,15 +210,7 @@ def _results_to_docs(docs_and_scores: Any) -> List[Document]: return [doc for doc, _ in docs_and_scores] -def _create_vector_extension(conn: Connection) -> None: - statement = sqlalchemy.text( - "SELECT pg_advisory_xact_lock(1573678846307946496);" - "CREATE EXTENSION IF NOT EXISTS vector;" - ) - conn.execute(statement) - conn.commit() - -DBConnection = Union[sqlalchemy.engine.Engine, str] +Connection = Union[sqlalchemy.engine.Engine, str] class PGVector(VectorStore): @@ -276,7 +248,6 @@ class PGVector(VectorStore): connection=connection_string, collection_name=collection_name, use_jsonb=True, - async_mode=False, ) @@ -294,33 +265,29 @@ class PGVector(VectorStore): You will need to recreate the tables if you are using an existing database. * A Connection object has to be provided explicitly. Connections will not be picked up automatically based on env variables. - * langchain_postgres now accept async connections. If you want to use the async - version, you need to set `async_mode=True` when initializing the store or - use an async engine. """ def __init__( - self, - embeddings: Embeddings, - *, - connection: Union[None, DBConnection, Engine, AsyncEngine, str] = None, - embedding_length: Optional[int] = None, - collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, - collection_metadata: Optional[dict] = None, - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - pre_delete_collection: bool = False, - logger: Optional[logging.Logger] = None, - relevance_score_fn: Optional[Callable[[float], float]] = None, - engine_args: Optional[dict[str, Any]] = None, - use_jsonb: bool = True, - create_extension: bool = True, - async_mode: bool = False, + self, + embeddings: Embeddings, + *, + connection: Union[None, Connection, str] = None, + embedding_length: Optional[int] = None, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + collection_metadata: Optional[dict] = None, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + pre_delete_collection: bool = False, + logger: Optional[logging.Logger] = None, + relevance_score_fn: Optional[Callable[[float], float]] = None, + engine_args: Optional[dict[str, Any]] = None, + use_jsonb: bool = True, + create_extension: bool = True, + async_mode: bool = False, # FIXME: a virer. Gaff aux imports sans async ) -> None: """Initialize the PGVector store. - For an async version, use `PGVector.acreate()` instead. Args: - connection: Postgres connection string or (async)engine. + connection: Postgres connection string. embeddings: Any embedding function implementing `langchain.embeddings.base.Embeddings` interface. embedding_length: The length of the embedding vector. (default: None) @@ -353,33 +320,24 @@ def __init__( self.pre_delete_collection = pre_delete_collection self.logger = logger or logging.getLogger(__name__) self.override_relevance_score_fn = relevance_score_fn - self._engine: Optional[Engine] = None - self._async_engine: Optional[AsyncEngine] = None - self._async_init = False if isinstance(connection, str): - if async_mode: - self._async_engine = create_async_engine( - connection, **(engine_args or {}) - ) - else: - self._engine = create_engine(url=connection, **(engine_args or {})) - elif isinstance(connection, Engine): - self.async_mode = False + self._engine = self._create_engine( + connection, engine_args, async_mode) + elif isinstance(connection, sqlalchemy.engine.Engine): self._engine = connection - elif isinstance(connection, AsyncEngine): - self.async_mode = True - self._async_engine = connection else: raise ValueError( "connection should be a connection string or an instance of " - "sqlalchemy.engine.Engine or sqlalchemy.ext.asyncio.engine.AsyncEngine" + "sqlalchemy.engine.Engine" ) - self.session_maker: Union[scoped_session, async_sessionmaker] + # If the driver accept only the synchrone calls, update the async_mode + self.async_mode = not isinstance(self._engine, Engine) + self._session_maker: Union[sessionmaker, async_sessionmaker] if self.async_mode: - self.session_maker = async_sessionmaker(bind=self._async_engine) + self._session_maker = async_sessionmaker(bind=self._engine) else: - self.session_maker = scoped_session(sessionmaker(bind=self._engine)) + self._session_maker = sessionmaker(bind=self._engine) self.use_jsonb = use_jsonb self.create_extension = create_extension @@ -387,11 +345,11 @@ def __init__( if not use_jsonb: # Replace with a deprecation warning. raise NotImplementedError("use_jsonb=False is no longer supported.") - if not self.async_mode: + if not async_mode: self.__post_init__() def __post_init__( - self, + self, ) -> None: """Initialize the store.""" if self.create_extension: @@ -406,85 +364,160 @@ def __post_init__( self.create_collection() async def __apost_init__( - self, + self, ) -> None: - """Async initialize the store (use lazy approach).""" - if self._async_init: # Warning: possible race condition - return - self._async_init = True + + """Initialize the store.""" + if self.create_extension: + await self.acreate_vector_extension() EmbeddingStore, CollectionStore = _get_embedding_collection_store( self._embedding_length ) self.CollectionStore = CollectionStore self.EmbeddingStore = EmbeddingStore - if self.create_extension: - await self.acreate_vector_extension() - await self.acreate_tables_if_not_exists() await self.acreate_collection() + @classmethod + async def create(cls, + embeddings: Embeddings, + *, + connection: Optional[Connection] = None, + embedding_length: Optional[int] = None, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + collection_metadata: Optional[dict] = None, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + pre_delete_collection: bool = False, + logger: Optional[logging.Logger] = None, + relevance_score_fn: Optional[Callable[[float], float]] = None, + engine_args: Optional[dict[str, Any]] = None, + use_jsonb: bool = True, + create_extension: bool = True, + async_mode: bool = True, + ) -> PGVector: + self = cls( + embeddings=embeddings, + connection=connection, + embedding_length=embedding_length, + collection_name=collection_name, + collection_metadata=collection_metadata, + distance_strategy=distance_strategy, + pre_delete_collection=pre_delete_collection, + logger=logger, + relevance_score_fn=relevance_score_fn, + engine_args=engine_args, + use_jsonb=use_jsonb, + create_extension=create_extension, + async_mode=async_mode, + ) + if async_mode: + await self.__apost_init__() + return self + + def _create_engine(self, + connection: str, + engine_args: Optional[dict[str, Any]] = None, + async_mode: bool = False) -> sqlalchemy.engine.Engine | sqlalchemy.ext.asyncio.AsyncEngine: + if async_mode: + from sqlalchemy.ext.asyncio import create_async_engine + # FIXME: gérer appel async sur un sync + from sqlalchemy.exc import InvalidRequestError + try: + return create_async_engine( + url=connection, + isolation_level="REPEATABLE READ", # FIXME: merge avec la suite ? + echo=True, # FIXME: a virer + **(engine_args or {}) + ) + except InvalidRequestError: + pass # Ignore and return the synchrone version + logging.warning("Use a synchrone SQL engine !") + return sqlalchemy.create_engine(url=connection, + **(engine_args or {})) + + def __del__(self) -> None: + if isinstance(self._engine, sqlalchemy.engine.Connection): + if self.async_mode: + asyncio.run(self._engine.close()) + else: + self._engine.close() + @property def embeddings(self) -> Embeddings: return self.embedding_function def create_vector_extension(self) -> None: - assert not self._async_engine, "This method must be called without async_mode" try: - with self._engine.connect() as conn: - _create_vector_extension(conn) + with self._session_maker() as session: # type: ignore[arg-type] + # The advisor lock fixes issue arising from concurrent + # creation of the vector extension. + # https://github.com/langchain-ai/langchain/issues/12933 + # For more information see: + # https://www.postgresql.org/docs/16/explicit-locking.html#ADVISORY-LOCKS + statement = sqlalchemy.text( + "BEGIN;" + "SELECT pg_advisory_xact_lock(1573678846307946496);" + "CREATE EXTENSION IF NOT EXISTS vector;" + "COMMIT;" + ) + session.execute(statement) + session.commit() except Exception as e: raise Exception(f"Failed to create vector extension: {e}") from e async def acreate_vector_extension(self) -> None: - assert self.async_mode, "This method must be called with async_mode" - - async with self._async_engine.begin() as conn: - await conn.run_sync(_create_vector_extension) + try: + async with self._session_maker() as session: + # The advisor lock fixes issue arising from concurrent + # creation of the vector extension. + # https://github.com/langchain-ai/langchain/issues/12933 + # For more information see: + # https://www.postgresql.org/docs/16/explicit-locking.html#ADVISORY-LOCKS + await session.execute( + sqlalchemy.text( + "SELECT pg_advisory_xact_lock(1573678846307946496)")) + await session.execute( + sqlalchemy.text("CREATE EXTENSION IF NOT EXISTS vector")) + except Exception as e: + raise Exception(f"Failed to create vector extension: {e}") from e def create_tables_if_not_exists(self) -> None: - assert not self._async_engine, "This method must be called without async_mode" - with self.session_maker() as session: + with self._session_maker() as session: Base.metadata.create_all(session.get_bind()) - session.commit() async def acreate_tables_if_not_exists(self) -> None: - assert self._async_engine, "This method must be called with async_mode" - async with self._async_engine.begin() as conn: - await conn.run_sync(Base.metadata.create_all) + if isinstance(self._engine, sqlalchemy.ext.asyncio.engine.AsyncConnection): + await self._engine.run_sync(Base.metadata.create_all) + else: + async with self._engine.begin() as conn: # FIXME: session.run_sync existe + await conn.run_sync(Base.metadata.create_all) + # async with self._amake_session() as session: + # await session.run_sync(Base.metadata.create_all) def drop_tables(self) -> None: - assert not self._async_engine, "This method must be called without async_mode" - with self.session_maker() as session: + with self._session_maker() as session: Base.metadata.drop_all(session.get_bind()) - session.commit() async def adrop_tables(self) -> None: - assert self._async_engine, "This method must be called with async_mode" - await self.__apost_init__() # Lazy async init - async with self._async_engine.begin() as conn: + async with self._engine.begin() as conn: await conn.run_sync(Base.metadata.drop_all) def create_collection(self) -> None: - assert not self._async_engine, "This method must be called without async_mode" if self.pre_delete_collection: self.delete_collection() - with self.session_maker() as session: + with self._session_maker() as session: self.CollectionStore.get_or_create( session, self.collection_name, cmetadata=self.collection_metadata ) - session.commit() async def acreate_collection(self) -> None: - assert self._async_engine, "This method must be called with async_mode" - await self.__apost_init__() # Lazy async init - async with self.session_maker() as session: + async with self._session_maker() as session: if self.pre_delete_collection: await self._adelete_collection(session) await self.CollectionStore.aget_or_create( session, self.collection_name, cmetadata=self.collection_metadata ) - await session.commit() def _delete_collection(self, session: Session) -> None: self.logger.debug("Trying to delete collection") @@ -494,6 +527,7 @@ def _delete_collection(self, session: Session) -> None: return session.delete(collection) + # FIXME: necessaire le _adelete ? async def _adelete_collection(self, session: AsyncSession) -> None: self.logger.debug("Trying to delete collection") collection = await self.aget_collection(session) @@ -502,10 +536,12 @@ async def _adelete_collection(self, session: AsyncSession) -> None: return await session.delete(collection) + # def delete_collection(self) -> None: + # with self._session_maker() as session: + # self._delete_collection(session) def delete_collection(self) -> None: - assert not self._async_engine, "This method must be called without async_mode" self.logger.debug("Trying to delete collection") - with self.session_maker() as session: # type: ignore[arg-type] + with self._session_maker() as session: # type: ignore[arg-type] collection = self.get_collection(session) if not collection: self.logger.warning("Collection not found") @@ -514,9 +550,8 @@ def delete_collection(self) -> None: session.commit() async def adelete_collection(self) -> None: - assert self._async_engine, "This method must be called with async_mode" - await self.__apost_init__() # Lazy async init - async with self.session_maker() as session: # type: ignore[arg-type] + self.logger.debug("Trying to delete collection") + with self._session_maker() as session: # type: ignore[arg-type] collection = await self.aget_collection(session) if not collection: self.logger.warning("Collection not found") @@ -525,10 +560,10 @@ async def adelete_collection(self) -> None: await session.commit() def delete( - self, - ids: Optional[List[str]] = None, - collection_only: bool = False, - **kwargs: Any, + self, + ids: Optional[List[str]] = None, + collection_only: bool = False, + **kwargs: Any, ) -> None: """Delete vectors by ids or uuids. @@ -536,8 +571,7 @@ def delete( ids: List of ids to delete. collection_only: Only delete ids in the collection. """ - assert not self._async_engine, "This method must be called without async_mode" - with self.session_maker() as session: + with self._session_maker() as session: if ids is not None: self.logger.debug( "Trying to delete vectors by ids (represented by the model " @@ -561,20 +595,18 @@ def delete( session.commit() async def adelete( - self, - ids: Optional[List[str]] = None, - collection_only: bool = False, - **kwargs: Any, + self, + ids: Optional[List[str]] = None, + collection_only: bool = False, + **kwargs: Any, ) -> None: - """Async delete vectors by ids or uuids. + """Delete vectors by ids or uuids. Args: ids: List of ids to delete. collection_only: Only delete ids in the collection. """ - assert self._async_engine, "This method must be called with async_mode" - await self.__apost_init__() # Lazy async init - async with self.session_maker() as session: + async with self._session_maker() as session: if ids is not None: self.logger.debug( "Trying to delete vectors by ids (represented by the model " @@ -598,29 +630,26 @@ async def adelete( await session.commit() def get_collection(self, session: Session) -> Any: - assert not self._async_engine, "This method must be called without async_mode" return self.CollectionStore.get_by_name(session, self.collection_name) async def aget_collection(self, session: AsyncSession) -> Any: - assert self._async_engine, "This method must be called with async_mode" - await self.__apost_init__() # Lazy async init return await self.CollectionStore.aget_by_name(session, self.collection_name) @classmethod def __from( - cls, - texts: List[str], - embeddings: List[List[float]], - embedding: Embeddings, - metadatas: Optional[List[dict]] = None, - ids: Optional[List[str]] = None, - collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - connection: Optional[str] = None, - pre_delete_collection: bool = False, - *, - use_jsonb: bool = True, - **kwargs: Any, + cls, + texts: List[str], + embeddings: List[List[float]], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + connection: Optional[str] = None, + pre_delete_collection: bool = False, + *, + use_jsonb: bool = True, + **kwargs: Any, ) -> PGVector: if ids is None: ids = [str(uuid.uuid4()) for _ in texts] @@ -646,19 +675,19 @@ def __from( @classmethod async def __afrom( - cls, - texts: List[str], - embeddings: List[List[float]], - embedding: Embeddings, - metadatas: Optional[List[dict]] = None, - ids: Optional[List[str]] = None, - collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - connection: Optional[str] = None, - pre_delete_collection: bool = False, - *, - use_jsonb: bool = True, - **kwargs: Any, + cls, + texts: List[str], + embeddings: List[List[float]], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + connection: Optional[str] = None, + pre_delete_collection: bool = False, + *, + use_jsonb: bool = True, + **kwargs: Any, ) -> PGVector: if ids is None: ids = [str(uuid.uuid1()) for _ in texts] @@ -673,9 +702,11 @@ async def __afrom( distance_strategy=distance_strategy, pre_delete_collection=pre_delete_collection, use_jsonb=use_jsonb, - async_mode=True, + async_mode=True, # FIXME **kwargs, ) + # Second phase to create + await store.__apost_init__() await store.aadd_embeddings( texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs @@ -684,12 +715,12 @@ async def __afrom( return store def add_embeddings( - self, - texts: Iterable[str], - embeddings: List[List[float]], - metadatas: Optional[List[dict]] = None, - ids: Optional[List[str]] = None, - **kwargs: Any, + self, + texts: Iterable[str], + embeddings: List[List[float]], + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + **kwargs: Any, ) -> List[str]: """Add embeddings to the vectorstore. @@ -697,18 +728,15 @@ def add_embeddings( texts: Iterable of strings to add to the vectorstore. embeddings: List of list of embedding vectors. metadatas: List of metadatas associated with the texts. - ids: Optional list of ids for the documents. - If not provided, will generate a new id for each document. kwargs: vectorstore specific parameters """ - assert not self._async_engine, "This method must be called without async_mode" if ids is None: ids = [str(uuid.uuid4()) for _ in texts] if not metadatas: metadatas = [{} for _ in texts] - with self.session_maker() as session: # type: ignore[arg-type] + with self._session_maker() as session: # type: ignore[arg-type] collection = self.get_collection(session) if not collection: raise ValueError("Collection not found") @@ -740,32 +768,28 @@ def add_embeddings( return ids async def aadd_embeddings( - self, - texts: Iterable[str], - embeddings: List[List[float]], - metadatas: Optional[List[dict]] = None, - ids: Optional[List[str]] = None, - **kwargs: Any, + self, + texts: Iterable[str], + embeddings: List[List[float]], + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + **kwargs: Any, ) -> List[str]: - """Async add embeddings to the vectorstore. + """Add embeddings to the vectorstore. Args: texts: Iterable of strings to add to the vectorstore. embeddings: List of list of embedding vectors. metadatas: List of metadatas associated with the texts. - ids: Optional list of ids for the texts. - If not provided, will generate a new id for each text. kwargs: vectorstore specific parameters """ - assert self._async_engine, "This method must be called with async_mode" - await self.__apost_init__() # Lazy async init if ids is None: ids = [str(uuid.uuid1()) for _ in texts] if not metadatas: metadatas = [{} for _ in texts] - async with self.session_maker() as session: # type: ignore[arg-type] + async with self._session_maker() as session: # type: ignore[arg-type] collection = await self.aget_collection(session) if not collection: raise ValueError("Collection not found") @@ -797,62 +821,55 @@ async def aadd_embeddings( return ids def add_texts( - self, - texts: Iterable[str], - metadatas: Optional[List[dict]] = None, - ids: Optional[List[str]] = None, - **kwargs: Any, + self, + texts: Iterable[str], + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + **kwargs: Any, ) -> List[str]: """Run more texts through the embeddings and add to the vectorstore. Args: texts: Iterable of strings to add to the vectorstore. metadatas: Optional list of metadatas associated with the texts. - ids: Optional list of ids for the texts. - If not provided, will generate a new id for each text. kwargs: vectorstore specific parameters Returns: List of ids from adding the texts into the vectorstore. """ - assert not self._async_engine, "This method must be called without async_mode" embeddings = self.embedding_function.embed_documents(list(texts)) return self.add_embeddings( texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs ) async def aadd_texts( - self, - texts: Iterable[str], - metadatas: Optional[List[dict]] = None, - ids: Optional[List[str]] = None, - **kwargs: Any, + self, + texts: Iterable[str], + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + **kwargs: Any, ) -> List[str]: """Run more texts through the embeddings and add to the vectorstore. Args: texts: Iterable of strings to add to the vectorstore. metadatas: Optional list of metadatas associated with the texts. - ids: Optional list of ids for the texts. - If not provided, will generate a new id for each text. kwargs: vectorstore specific parameters Returns: List of ids from adding the texts into the vectorstore. """ - assert self._async_engine, "This method must be called with async_mode" - await self.__apost_init__() # Lazy async init embeddings = await self.embedding_function.aembed_documents(list(texts)) return await self.aadd_embeddings( texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs ) def similarity_search( - self, - query: str, - k: int = 4, - filter: Optional[dict] = None, - **kwargs: Any, + self, + query: str, + k: int = 4, + filter: Optional[dict] = None, + **kwargs: Any, ) -> List[Document]: """Run similarity search with PGVector with distance. @@ -864,7 +881,6 @@ def similarity_search( Returns: List of Documents most similar to the query. """ - assert not self._async_engine, "This method must be called without async_mode" embedding = self.embedding_function.embed_query(text=query) return self.similarity_search_by_vector( embedding=embedding, @@ -873,11 +889,11 @@ def similarity_search( ) async def asimilarity_search( - self, - query: str, - k: int = 4, - filter: Optional[dict] = None, - **kwargs: Any, + self, + query: str, + k: int = 4, + filter: Optional[dict] = None, + **kwargs: Any, ) -> List[Document]: """Run similarity search with PGVector with distance. @@ -889,8 +905,6 @@ async def asimilarity_search( Returns: List of Documents most similar to the query. """ - assert self._async_engine, "This method must be called with async_mode" - await self.__apost_init__() # Lazy async init embedding = self.embedding_function.embed_query(text=query) return await self.asimilarity_search_by_vector( embedding=embedding, @@ -899,10 +913,10 @@ async def asimilarity_search( ) def similarity_search_with_score( - self, - query: str, - k: int = 4, - filter: Optional[dict] = None, + self, + query: str, + k: int = 4, + filter: Optional[dict] = None, ) -> List[Tuple[Document, float]]: """Return docs most similar to query. @@ -914,7 +928,6 @@ def similarity_search_with_score( Returns: List of Documents most similar to the query and score for each. """ - assert not self._async_engine, "This method must be called without async_mode" embedding = self.embedding_function.embed_query(query) docs = self.similarity_search_with_score_by_vector( embedding=embedding, k=k, filter=filter @@ -922,10 +935,10 @@ def similarity_search_with_score( return docs async def asimilarity_search_with_score( - self, - query: str, - k: int = 4, - filter: Optional[dict] = None, + self, + query: str, + k: int = 4, + filter: Optional[dict] = None, ) -> List[Tuple[Document, float]]: """Return docs most similar to query. @@ -937,8 +950,6 @@ async def asimilarity_search_with_score( Returns: List of Documents most similar to the query and score for each. """ - assert self._async_engine, "This method must be called with async_mode" - await self.__apost_init__() # Lazy async init embedding = self.embedding_function.embed_query(query) docs = await self.asimilarity_search_with_score_by_vector( embedding=embedding, k=k, filter=filter @@ -960,28 +971,26 @@ def distance_strategy(self) -> Any: ) def similarity_search_with_score_by_vector( - self, - embedding: List[float], - k: int = 4, - filter: Optional[dict] = None, + self, + embedding: List[float], + k: int = 4, + filter: Optional[dict] = None, ) -> List[Tuple[Document, float]]: - assert not self._async_engine, "This method must be called without async_mode" results = self.__query_collection(embedding=embedding, k=k, filter=filter) return self._results_to_docs_and_scores(results) async def asimilarity_search_with_score_by_vector( - self, - embedding: List[float], - k: int = 4, - filter: Optional[dict] = None, + self, + embedding: List[float], + k: int = 4, + filter: Optional[dict] = None, ) -> List[Tuple[Document, float]]: - assert self._async_engine, "This method must be called with async_mode" - await self.__apost_init__() # Lazy async init - async with self.session_maker() as session: # type: ignore[arg-type] + async with self._session_maker() as session: # type: ignore[arg-type] results = await self.__aquery_collection( - session=session, embedding=embedding, k=k, filter=filter - ) + session=session, + embedding=embedding, k=k, + filter=filter) return self._results_to_docs_and_scores(results) @@ -1000,9 +1009,9 @@ def _results_to_docs_and_scores(self, results: Any) -> List[Tuple[Document, floa return docs def _handle_field_filter( - self, - field: str, - value: Any, + self, + field: str, + value: Any, ) -> SQLColumnExpression: """Create a filter for a specific field. @@ -1112,7 +1121,8 @@ def _handle_field_filter( else: raise NotImplementedError() - def _create_filter_clause_deprecated(self, key, value): # type: ignore[no-untyped-def] + def _create_filter_clause_deprecated(self, key, + value): # type: ignore[no-untyped-def] """Deprecated functionality. This is for backwards compatibility with the JSON based schema for metadata. @@ -1181,7 +1191,7 @@ def _create_filter_clause_deprecated(self, key, value): # type: ignore[no-untyp return filter_by_metadata def _create_filter_clause_json_deprecated( - self, filter: Any + self, filter: Any ) -> List[SQLColumnExpression]: """Convert filters from IR to SQL clauses. @@ -1292,13 +1302,13 @@ def _create_filter_clause(self, filters: Any) -> Any: ) def __query_collection( - self, - embedding: List[float], - k: int = 4, - filter: Optional[Dict[str, str]] = None, - ) -> Sequence[Any]: + self, + embedding: List[float], + k: int = 4, + filter: Optional[Dict[str, str]] = None, + ) -> List[Any]: """Query the collection.""" - with self.session_maker() as session: # type: ignore[arg-type] + with self._session_maker() as session: # type: ignore[arg-type] collection = self.get_collection(session) if not collection: raise ValueError("Collection not found") @@ -1334,55 +1344,52 @@ def __query_collection( return results async def __aquery_collection( - self, - session: AsyncSession, - embedding: List[float], - k: int = 4, - filter: Optional[Dict[str, str]] = None, - ) -> Sequence[Any]: + self, + session: AsyncSession, + embedding: List[float], + k: int = 4, + filter: Optional[Dict[str, str]] = None, + ) -> List[Any]: """Query the collection.""" - async with self.session_maker() as session: # type: ignore[arg-type] - collection = await self.aget_collection(session) - if not collection: - raise ValueError("Collection not found") - - filter_by = [self.EmbeddingStore.collection_id == collection.uuid] - if filter: - if self.use_jsonb: - filter_clauses = self._create_filter_clause(filter) - if filter_clauses is not None: - filter_by.append(filter_clauses) - else: - # Old way of doing things - filter_clauses = self._create_filter_clause_json_deprecated(filter) - filter_by.extend(filter_clauses) + collection = await self.aget_collection(session) + if not collection: + raise ValueError("Collection not found") + + filter_by = [self.EmbeddingStore.collection_id == collection.uuid] + if filter: + if self.use_jsonb: + filter_clauses = self._create_filter_clause(filter) + if filter_clauses is not None: + filter_by.append(filter_clauses) + else: + # Old way of doing things + filter_clauses = self._create_filter_clause_json_deprecated(filter) + filter_by.extend(filter_clauses) - _type = self.EmbeddingStore + _type = self.EmbeddingStore - stmt = ( - select( - self.EmbeddingStore, - self.distance_strategy(embedding).label("distance"), # type: ignore - ) + stmt = (select( + self.EmbeddingStore, + self.distance_strategy(embedding).label("distance"), # type: ignore + ) .filter(*filter_by) .order_by(sqlalchemy.asc("distance")) .join( - self.CollectionStore, - self.EmbeddingStore.collection_id == self.CollectionStore.uuid, - ) - .limit(k) - ) + self.CollectionStore, + self.EmbeddingStore.collection_id == self.CollectionStore.uuid, + ) + .limit(k)) - results: Sequence[Any] = (await session.execute(stmt)).all() + results: List[Any] = (await session.execute(stmt)).all() - return results + return results def similarity_search_by_vector( - self, - embedding: List[float], - k: int = 4, - filter: Optional[dict] = None, - **kwargs: Any, + self, + embedding: List[float], + k: int = 4, + filter: Optional[dict] = None, + **kwargs: Any, ) -> List[Document]: """Return docs most similar to embedding vector. @@ -1394,18 +1401,17 @@ def similarity_search_by_vector( Returns: List of Documents most similar to the query vector. """ - assert not self._async_engine, "This method must be called without async_mode" docs_and_scores = self.similarity_search_with_score_by_vector( embedding=embedding, k=k, filter=filter ) return _results_to_docs(docs_and_scores) async def asimilarity_search_by_vector( - self, - embedding: List[float], - k: int = 4, - filter: Optional[dict] = None, - **kwargs: Any, + self, + embedding: List[float], + k: int = 4, + filter: Optional[dict] = None, + **kwargs: Any, ) -> List[Document]: """Return docs most similar to embedding vector. @@ -1417,8 +1423,6 @@ async def asimilarity_search_by_vector( Returns: List of Documents most similar to the query vector. """ - assert self._async_engine, "This method must be called with async_mode" - await self.__apost_init__() # Lazy async init docs_and_scores = await self.asimilarity_search_with_score_by_vector( embedding=embedding, k=k, filter=filter ) @@ -1426,17 +1430,17 @@ async def asimilarity_search_by_vector( @classmethod def from_texts( - cls: Type[PGVector], - texts: List[str], - embedding: Embeddings, - metadatas: Optional[List[dict]] = None, - *, - collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - ids: Optional[List[str]] = None, - pre_delete_collection: bool = False, - use_jsonb: bool = True, - **kwargs: Any, + cls: Type[PGVector], + texts: List[str], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + *, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + ids: Optional[List[str]] = None, + pre_delete_collection: bool = False, + use_jsonb: bool = True, + **kwargs: Any, ) -> PGVector: """Return VectorStore initialized from documents and embeddings.""" embeddings = embedding.embed_documents(list(texts)) @@ -1456,17 +1460,17 @@ def from_texts( @classmethod async def afrom_texts( - cls: Type[PGVector], - texts: List[str], - embedding: Embeddings, - metadatas: Optional[List[dict]] = None, - collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - ids: Optional[List[str]] = None, - pre_delete_collection: bool = False, - *, - use_jsonb: bool = True, - **kwargs: Any, + cls: Type[PGVector], + texts: List[str], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + ids: Optional[List[str]] = None, + pre_delete_collection: bool = False, + *, + use_jsonb: bool = True, + **kwargs: Any, ) -> PGVector: """Return VectorStore initialized from documents and embeddings.""" embeddings = embedding.embed_documents(list(texts)) @@ -1485,16 +1489,16 @@ async def afrom_texts( @classmethod def from_embeddings( - cls, - text_embeddings: List[Tuple[str, List[float]]], - embedding: Embeddings, - *, - metadatas: Optional[List[dict]] = None, - collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - ids: Optional[List[str]] = None, - pre_delete_collection: bool = False, - **kwargs: Any, + cls, + text_embeddings: List[Tuple[str, List[float]]], + embedding: Embeddings, + *, + metadatas: Optional[List[dict]] = None, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + ids: Optional[List[str]] = None, + pre_delete_collection: bool = False, + **kwargs: Any, ) -> PGVector: """Construct PGVector wrapper from raw documents and embeddings. @@ -1505,7 +1509,6 @@ def from_embeddings( collection_name: Name of the collection. distance_strategy: Distance strategy to use. ids: Optional list of ids for the documents. - If not provided, will generate a new id for each document. pre_delete_collection: If True, will delete the collection if it exists. **Attention**: This will delete all the documents in the existing collection. @@ -1542,15 +1545,15 @@ def from_embeddings( @classmethod async def afrom_embeddings( - cls, - text_embeddings: List[Tuple[str, List[float]]], - embedding: Embeddings, - metadatas: Optional[List[dict]] = None, - collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - ids: Optional[List[str]] = None, - pre_delete_collection: bool = False, - **kwargs: Any, + cls, + text_embeddings: List[Tuple[str, List[float]]], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + ids: Optional[List[str]] = None, + pre_delete_collection: bool = False, + **kwargs: Any, ) -> PGVector: """Construct PGVector wrapper from raw documents and pre- generated embeddings. @@ -1587,14 +1590,14 @@ async def afrom_embeddings( @classmethod def from_existing_index( - cls: Type[PGVector], - embedding: Embeddings, - *, - collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - pre_delete_collection: bool = False, - connection: Optional[DBConnection] = None, - **kwargs: Any, + cls: Type[PGVector], + embedding: Embeddings, + *, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + pre_delete_collection: bool = False, + connection: Optional[Connection] = None, + **kwargs: Any, ) -> PGVector: """ Get instance of an existing PGVector store.This method will @@ -1614,21 +1617,21 @@ def from_existing_index( @classmethod async def afrom_existing_index( - cls: Type[PGVector], - embedding: Embeddings, - *, - collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - pre_delete_collection: bool = False, - connection: Optional[DBConnection] = None, - **kwargs: Any, + cls: Type[PGVector], + embedding: Embeddings, + *, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + pre_delete_collection: bool = False, + connection: Optional[Connection] = None, + **kwargs: Any, ) -> PGVector: """ Get instance of an existing PGVector store.This method will return the instance of the store without inserting any new embeddings """ - store = PGVector( + store = cls( # FIXME: créate connection=connection, collection_name=collection_name, embeddings=embedding, @@ -1659,17 +1662,17 @@ def get_connection_string(cls, kwargs: Dict[str, Any]) -> str: @classmethod def from_documents( - cls: Type[PGVector], - documents: List[Document], - embedding: Embeddings, - *, - connection: Optional[DBConnection] = None, - collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - ids: Optional[List[str]] = None, - pre_delete_collection: bool = False, - use_jsonb: bool = True, - **kwargs: Any, + cls: Type[PGVector], + documents: List[Document], + embedding: Embeddings, + *, + connection: Optional[Connection] = None, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + ids: Optional[List[str]] = None, + pre_delete_collection: bool = False, + use_jsonb: bool = True, + **kwargs: Any, ) -> PGVector: """Return VectorStore initialized from documents and embeddings.""" @@ -1691,16 +1694,16 @@ def from_documents( @classmethod async def afrom_documents( - cls: Type[PGVector], - documents: List[Document], - embedding: Embeddings, - collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - ids: Optional[List[str]] = None, - pre_delete_collection: bool = False, - *, - use_jsonb: bool = True, - **kwargs: Any, + cls: Type[PGVector], + documents: List[Document], + embedding: Embeddings, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + ids: Optional[List[str]] = None, + pre_delete_collection: bool = False, + *, + use_jsonb: bool = True, + **kwargs: Any, ) -> PGVector: """ Return VectorStore initialized from documents and embeddings. @@ -1729,13 +1732,13 @@ async def afrom_documents( @classmethod def connection_string_from_db_params( - cls, - driver: str, - host: str, - port: int, - database: str, - user: str, - password: str, + cls, + driver: str, + host: str, + port: int, + database: str, + user: str, + password: str, ) -> str: """Return connection string from database parameters.""" if driver != "psycopg": @@ -1770,13 +1773,13 @@ def _select_relevance_score_fn(self) -> Callable[[float], float]: ) def max_marginal_relevance_search_with_score_by_vector( - self, - embedding: List[float], - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - filter: Optional[Dict[str, str]] = None, - **kwargs: Any, + self, + embedding: List[float], + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[Dict[str, str]] = None, + **kwargs: Any, ) -> List[Tuple[Document, float]]: """Return docs selected using the maximal marginal relevance with score to embedding vector. @@ -1799,7 +1802,6 @@ def max_marginal_relevance_search_with_score_by_vector( List[Tuple[Document, float]]: List of Documents selected by maximal marginal relevance to the query and score for each. """ - assert not self._async_engine, "This method must be called without async_mode" results = self.__query_collection(embedding=embedding, k=fetch_k, filter=filter) embedding_list = [result.EmbeddingStore.embedding for result in results] @@ -1816,13 +1818,13 @@ def max_marginal_relevance_search_with_score_by_vector( return [r for i, r in enumerate(candidates) if i in mmr_selected] async def amax_marginal_relevance_search_with_score_by_vector( - self, - embedding: List[float], - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - filter: Optional[Dict[str, str]] = None, - **kwargs: Any, + self, + embedding: List[float], + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[Dict[str, str]] = None, + **kwargs: Any, ) -> List[Tuple[Document, float]]: """Return docs selected using the maximal marginal relevance with score to embedding vector. @@ -1845,12 +1847,10 @@ async def amax_marginal_relevance_search_with_score_by_vector( List[Tuple[Document, float]]: List of Documents selected by maximal marginal relevance to the query and score for each. """ - assert self._async_engine, "This method must be called with async_mode" - await self.__apost_init__() # Lazy async init - async with self.session_maker() as session: - results = await self.__aquery_collection( - session=session, embedding=embedding, k=fetch_k, filter=filter - ) + async with self._session_maker() as session: + results = await self.__aquery_collection(session=session, + embedding=embedding, k=fetch_k, + filter=filter) embedding_list = [result.EmbeddingStore.embedding for result in results] @@ -1866,13 +1866,13 @@ async def amax_marginal_relevance_search_with_score_by_vector( return [r for i, r in enumerate(candidates) if i in mmr_selected] def max_marginal_relevance_search( - self, - query: str, - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - filter: Optional[Dict[str, str]] = None, - **kwargs: Any, + self, + query: str, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[Dict[str, str]] = None, + **kwargs: Any, ) -> List[Document]: """Return docs selected using the maximal marginal relevance. @@ -1893,7 +1893,6 @@ def max_marginal_relevance_search( Returns: List[Document]: List of Documents selected by maximal marginal relevance. """ - assert not self._async_engine, "This method must be called without async_mode" embedding = self.embedding_function.embed_query(query) return self.max_marginal_relevance_search_by_vector( embedding, @@ -1905,13 +1904,13 @@ def max_marginal_relevance_search( ) async def amax_marginal_relevance_search( - self, - query: str, - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - filter: Optional[Dict[str, str]] = None, - **kwargs: Any, + self, + query: str, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[Dict[str, str]] = None, + **kwargs: Any, ) -> List[Document]: """Return docs selected using the maximal marginal relevance. @@ -1932,8 +1931,6 @@ async def amax_marginal_relevance_search( Returns: List[Document]: List of Documents selected by maximal marginal relevance. """ - assert self._async_engine, "This method must be called with async_mode" - await self.__apost_init__() # Lazy async init embedding = self.embedding_function.embed_query(query) return await self.amax_marginal_relevance_search_by_vector( embedding, @@ -1945,13 +1942,13 @@ async def amax_marginal_relevance_search( ) def max_marginal_relevance_search_with_score( - self, - query: str, - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - filter: Optional[dict] = None, - **kwargs: Any, + self, + query: str, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[dict] = None, + **kwargs: Any, ) -> List[Tuple[Document, float]]: """Return docs selected using the maximal marginal relevance with score. @@ -1973,7 +1970,6 @@ def max_marginal_relevance_search_with_score( List[Tuple[Document, float]]: List of Documents selected by maximal marginal relevance to the query and score for each. """ - assert not self._async_engine, "This method must be called without async_mode" embedding = self.embedding_function.embed_query(query) docs = self.max_marginal_relevance_search_with_score_by_vector( embedding=embedding, @@ -1986,13 +1982,13 @@ def max_marginal_relevance_search_with_score( return docs async def amax_marginal_relevance_search_with_score( - self, - query: str, - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - filter: Optional[dict] = None, - **kwargs: Any, + self, + query: str, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[dict] = None, + **kwargs: Any, ) -> List[Tuple[Document, float]]: """Return docs selected using the maximal marginal relevance with score. @@ -2014,8 +2010,6 @@ async def amax_marginal_relevance_search_with_score( List[Tuple[Document, float]]: List of Documents selected by maximal marginal relevance to the query and score for each. """ - assert self._async_engine, "This method must be called with async_mode" - await self.__apost_init__() # Lazy async init embedding = self.embedding_function.embed_query(query) docs = await self.amax_marginal_relevance_search_with_score_by_vector( embedding=embedding, @@ -2028,13 +2022,13 @@ async def amax_marginal_relevance_search_with_score( return docs def max_marginal_relevance_search_by_vector( - self, - embedding: List[float], - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - filter: Optional[Dict[str, str]] = None, - **kwargs: Any, + self, + embedding: List[float], + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[Dict[str, str]] = None, + **kwargs: Any, ) -> List[Document]: """Return docs selected using the maximal marginal relevance to embedding vector. @@ -2056,7 +2050,6 @@ def max_marginal_relevance_search_by_vector( Returns: List[Document]: List of Documents selected by maximal marginal relevance. """ - assert not self._async_engine, "This method must be called without async_mode" docs_and_scores = self.max_marginal_relevance_search_with_score_by_vector( embedding, k=k, @@ -2069,13 +2062,13 @@ def max_marginal_relevance_search_by_vector( return _results_to_docs(docs_and_scores) async def amax_marginal_relevance_search_by_vector( - self, - embedding: List[float], - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - filter: Optional[Dict[str, str]] = None, - **kwargs: Any, + self, + embedding: List[float], + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[Dict[str, str]] = None, + **kwargs: Any, ) -> List[Document]: """Return docs selected using the maximal marginal relevance to embedding vector. @@ -2097,17 +2090,13 @@ async def amax_marginal_relevance_search_by_vector( Returns: List[Document]: List of Documents selected by maximal marginal relevance. """ - assert self._async_engine, "This method must be called with async_mode" - await self.__apost_init__() # Lazy async init - docs_and_scores = ( - await self.amax_marginal_relevance_search_with_score_by_vector( - embedding, - k=k, - fetch_k=fetch_k, - lambda_mult=lambda_mult, - filter=filter, - **kwargs, - ) + docs_and_scores = await self.amax_marginal_relevance_search_with_score_by_vector( + embedding, + k=k, + fetch_k=fetch_k, + lambda_mult=lambda_mult, + filter=filter, + **kwargs, ) return _results_to_docs(docs_and_scores) diff --git a/tests/unit_tests/test_vectorstore.py b/tests/unit_tests/test_vectorstore.py index 3b959710..42da5909 100644 --- a/tests/unit_tests/test_vectorstore.py +++ b/tests/unit_tests/test_vectorstore.py @@ -1,15 +1,16 @@ """Test PGVector functionality.""" import contextlib -from typing import Any, AsyncGenerator, Dict, Generator, List +from typing import Any, Dict, Generator, List import pytest from langchain_core.documents import Document -from sqlalchemy import select from langchain_postgres.vectorstores import ( SUPPORTED_OPERATORS, PGVector, ) +from sqlalchemy import select + from tests.unit_tests.fake_embeddings import FakeEmbeddings from tests.unit_tests.fixtures.filtering_test_cases import ( DOCUMENTS, @@ -52,7 +53,7 @@ def test_pgvector() -> None: output = docsearch.similarity_search("foo", k=1) assert output == [Document(page_content="foo")] - +# @pytest.mark.requires("xxx") # FIXME @pytest.mark.asyncio async def test_async_pgvector() -> None: """Test end to end construction and search.""" @@ -196,9 +197,7 @@ async def test_async_pgvector_with_filter_match() -> None: connection=CONNECTION_STRING, pre_delete_collection=True, ) - output = await docsearch.asimilarity_search_with_score( - "foo", k=1, filter={"page": "0"} - ) + output = await docsearch.asimilarity_search_with_score("foo", k=1, filter={"page": "0"}) assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)] @@ -219,7 +218,6 @@ def test_pgvector_with_filter_distant_match() -> None: (Document(page_content="baz", metadata={"page": "2"}), 0.0013003906671379406) ] - @pytest.mark.asyncio async def test_async_pgvector_with_filter_distant_match() -> None: """Test end to end construction and search.""" @@ -233,9 +231,7 @@ async def test_async_pgvector_with_filter_distant_match() -> None: connection=CONNECTION_STRING, pre_delete_collection=True, ) - output = await docsearch.asimilarity_search_with_score( - "foo", k=1, filter={"page": "2"} - ) + output = await docsearch.asimilarity_search_with_score("foo", k=1, filter={"page": "2"}) assert output == [ (Document(page_content="baz", metadata={"page": "2"}), 0.0013003906671379406) ] @@ -256,7 +252,6 @@ def test_pgvector_with_filter_no_match() -> None: output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "5"}) assert output == [] - @pytest.mark.asyncio async def test_async_pgvector_with_filter_no_match() -> None: """Test end to end construction and search.""" @@ -270,9 +265,7 @@ async def test_async_pgvector_with_filter_no_match() -> None: connection=CONNECTION_STRING, pre_delete_collection=True, ) - output = await docsearch.asimilarity_search_with_score( - "foo", k=1, filter={"page": "5"} - ) + output = await docsearch.asimilarity_search_with_score("foo", k=1, filter={"page": "5"}) assert output == [] @@ -285,7 +278,7 @@ def test_pgvector_collection_with_metadata() -> None: connection=CONNECTION_STRING, pre_delete_collection=True, ) - with pgvector.session_maker() as session: + with pgvector._session_maker() as session: collection = pgvector.get_collection(session) if collection is None: assert False, "Expected a CollectionStore object but received None" @@ -293,19 +286,17 @@ def test_pgvector_collection_with_metadata() -> None: assert collection.name == "test_collection" assert collection.cmetadata == {"foo": "bar"} - @pytest.mark.asyncio async def test_async_pgvector_collection_with_metadata() -> None: """Test end to end collection construction""" - pgvector = PGVector( + pgvector = await PGVector.create( collection_name="test_collection", collection_metadata={"foo": "bar"}, embeddings=FakeEmbeddingsWithAdaDimension(), connection=CONNECTION_STRING, pre_delete_collection=True, - async_mode=True, ) - async with pgvector.session_maker() as session: + async with pgvector._session_maker() as session: collection = await pgvector.aget_collection(session) if collection is None: assert False, "Expected a CollectionStore object but received None" @@ -314,6 +305,8 @@ async def test_async_pgvector_collection_with_metadata() -> None: assert collection.cmetadata == {"foo": "bar"} + + def test_pgvector_delete_docs() -> None: """Add and delete documents.""" texts = ["foo", "bar", "baz"] @@ -328,36 +321,20 @@ def test_pgvector_delete_docs() -> None: pre_delete_collection=True, ) vectorstore.delete(["1", "2"]) - with vectorstore.session_maker() as session: + with vectorstore._session_maker() as session: records = list(session.query(vectorstore.EmbeddingStore).all()) # ignoring type error since mypy cannot determine whether # the list is sortable assert sorted(record.id for record in records) == ["3"] # type: ignore vectorstore.delete(["2", "3"]) # Should not raise on missing ids - with vectorstore.session_maker() as session: + with vectorstore._session_maker() as session: records = list(session.query(vectorstore.EmbeddingStore).all()) # ignoring type error since mypy cannot determine whether # the list is sortable assert sorted(record.id for record in records) == [] # type: ignore -def test_pgvector_delete_collection() -> None: - """Add and delete documents.""" - texts = ["foo", "bar", "baz"] - metadatas = [{"page": str(i)} for i in range(len(texts))] - vectorstore = PGVector.from_texts( - texts=texts, - collection_name="test_collection_filter", - embedding=FakeEmbeddingsWithAdaDimension(), - metadatas=metadatas, - ids=["1", "2", "3"], - connection=CONNECTION_STRING, - pre_delete_collection=True, - ) - vectorstore.delete(collection_only=True) - - @pytest.mark.asyncio async def test_async_pgvector_delete_docs() -> None: """Add and delete documents.""" @@ -373,19 +350,15 @@ async def test_async_pgvector_delete_docs() -> None: pre_delete_collection=True, ) await vectorstore.adelete(["1", "2"]) - async with vectorstore.session_maker() as session: - records = ( - (await session.execute(select(vectorstore.EmbeddingStore))).scalars().all() - ) + async with vectorstore._session_maker() as session: + records=(await session.execute(select(vectorstore.EmbeddingStore))).scalars().all() # ignoring type error since mypy cannot determine whether # the list is sortable assert sorted(record.id for record in records) == ["3"] # type: ignore await vectorstore.adelete(["2", "3"]) # Should not raise on missing ids - async with vectorstore.session_maker() as session: - records = ( - (await session.execute(select(vectorstore.EmbeddingStore))).scalars().all() - ) + async with vectorstore._session_maker() as session: + records=(await session.execute(select(vectorstore.EmbeddingStore))).scalars().all() # ignoring type error since mypy cannot determine whether # the list is sortable assert sorted(record.id for record in records) == [] # type: ignore @@ -424,7 +397,7 @@ def test_pgvector_index_documents() -> None: connection=CONNECTION_STRING, pre_delete_collection=True, ) - with vectorstore.session_maker() as session: + with vectorstore._session_maker() as session: records = list(session.query(vectorstore.EmbeddingStore).all()) # ignoring type error since mypy cannot determine whether # the list is sortable @@ -446,7 +419,7 @@ def test_pgvector_index_documents() -> None: vectorstore.add_documents(documents, ids=[doc.metadata["id"] for doc in documents]) - with vectorstore.session_maker() as session: + with vectorstore._session_maker() as session: records = list(session.query(vectorstore.EmbeddingStore).all()) ordered_records = sorted(records, key=lambda x: x.id) # ignoring type error since mypy cannot determine whether @@ -465,7 +438,6 @@ def test_pgvector_index_documents() -> None: "topic": "zoo", } - @pytest.mark.asyncio async def test_async_pgvector_index_documents() -> None: """Test adding duplicate documents results in overwrites.""" @@ -500,10 +472,8 @@ async def test_async_pgvector_index_documents() -> None: connection=CONNECTION_STRING, pre_delete_collection=True, ) - async with vectorstore.session_maker() as session: - records = ( - (await session.execute(select(vectorstore.EmbeddingStore))).scalars().all() - ) + async with vectorstore._session_maker() as session: + records = (await session.execute(select(vectorstore.EmbeddingStore))).scalars().all() # ignoring type error since mypy cannot determine whether # the list is sortable assert sorted(record.id for record in records) == [ @@ -522,14 +492,10 @@ async def test_async_pgvector_index_documents() -> None: ), ] - await vectorstore.aadd_documents( - documents, ids=[doc.metadata["id"] for doc in documents] - ) + await vectorstore.aadd_documents(documents, ids=[doc.metadata["id"] for doc in documents]) - async with vectorstore.session_maker() as session: - records = ( - (await session.execute(select(vectorstore.EmbeddingStore))).scalars().all() - ) + async with vectorstore._session_maker() as session: + records = (await session.execute(select(vectorstore.EmbeddingStore))).scalars().all() ordered_records = sorted(records, key=lambda x: x.id) # ignoring type error since mypy cannot determine whether # the list is sortable @@ -663,9 +629,7 @@ def test_pgvector_retriever_search_threshold_custom_normalization_fn() -> None: @pytest.mark.asyncio -async def test_async_pgvector_retriever_search_threshold_custom_normalization_fn() -> ( - None -): +async def test_async_pgvector_retriever_search_threshold_custom_normalization_fn() -> None: """Test searching with threshold and custom normalization function""" texts = ["foo", "bar", "baz"] metadatas = [{"page": str(i)} for i in range(len(texts))] @@ -741,9 +705,7 @@ async def test_async_pgvector_max_marginal_relevance_search_with_score() -> None connection=CONNECTION_STRING, pre_delete_collection=True, ) - output = await docsearch.amax_marginal_relevance_search_with_score( - "foo", k=1, fetch_k=3 - ) + output = await docsearch.amax_marginal_relevance_search_with_score("foo", k=1, fetch_k=3) assert output == [(Document(page_content="foo"), 0.0)] @@ -807,10 +769,9 @@ def pgvector() -> Generator[PGVector, None, None]: with get_vectorstore() as vector_store: yield vector_store - @pytest.mark.asyncio @pytest.fixture -async def async_pgvector() -> AsyncGenerator[PGVector, None]: +async def async_pgvector() -> Generator[PGVector, None, None]: """Create an async PGVector instance.""" store = await PGVector.afrom_documents( documents=DOCUMENTS, @@ -847,7 +808,7 @@ def get_vectorstore() -> Generator[PGVector, None, None]: @contextlib.asynccontextmanager -async def aget_vectorstore() -> AsyncGenerator[PGVector, None]: +async def aget_vectorstore() -> Generator[PGVector, None, None]: """Get a pre-populated-vectorstore""" store = await PGVector.afrom_documents( documents=DOCUMENTS, @@ -874,7 +835,6 @@ def test_pgvector_with_with_metadata_filters_1( docs = pgvector.similarity_search("meow", k=5, filter=test_filter) assert [doc.metadata["id"] for doc in docs] == expected_ids, test_filter - @pytest.mark.asyncio @pytest.mark.parametrize("test_filter, expected_ids", TYPE_1_FILTERING_TEST_CASES) async def test_async_pgvector_with_with_metadata_filters_1( @@ -946,7 +906,7 @@ def test_pgvector_with_with_metadata_filters_4( @pytest.mark.asyncio @pytest.mark.parametrize("test_filter, expected_ids", TYPE_4_FILTERING_TEST_CASES) -async def test_async_pgvector_with_with_metadata_filters_4( +async def test_pgvector_with_with_metadata_filters_4( async_pgvector: PGVector, test_filter: Dict[str, Any], expected_ids: List[int], From 95eebae1bd903bc20b74846c034eb8ed00383af7 Mon Sep 17 00:00:00 2001 From: Philippe Prados Date: Mon, 15 Apr 2024 10:49:54 +0200 Subject: [PATCH 12/18] Fix lint --- langchain_postgres/vectorstores.py | 856 ++++++++++++++------------- tests/unit_tests/test_vectorstore.py | 79 ++- 2 files changed, 500 insertions(+), 435 deletions(-) diff --git a/langchain_postgres/vectorstores.py b/langchain_postgres/vectorstores.py index d5aa6dba..ffce9ff6 100644 --- a/langchain_postgres/vectorstores.py +++ b/langchain_postgres/vectorstores.py @@ -1,6 +1,5 @@ from __future__ import annotations -import asyncio import enum import logging import uuid @@ -11,27 +10,38 @@ Iterable, List, Optional, + Sequence, Tuple, - Type, Union, + Type, + Union, ) import numpy as np import sqlalchemy -from sqlalchemy import SQLColumnExpression, cast, delete, func, select, Engine -from sqlalchemy.dialects.postgresql import JSON, JSONB, JSONPATH, UUID, insert -# TODO: accepter l'absence de l'option async lors des imports -from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker -from sqlalchemy.orm import Session, relationship, sessionmaker - -try: - from sqlalchemy.orm import declarative_base -except ImportError: - from sqlalchemy.ext.declarative import declarative_base +# try: +# from sqlalchemy.orm import declarative_base +# except ImportError: +# from sqlalchemy.ext.declarative import declarative_base from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.utils import get_from_dict_or_env from langchain_core.vectorstores import VectorStore +from sqlalchemy import SQLColumnExpression, cast, create_engine, delete, func, select +from sqlalchemy.dialects.postgresql import JSON, JSONB, JSONPATH, UUID, insert +from sqlalchemy.engine import Engine +from sqlalchemy.ext.asyncio import ( + AsyncEngine, + AsyncSession, + async_sessionmaker, + create_async_engine, +) +from sqlalchemy.orm import ( + Session, + declarative_base, + relationship, + sessionmaker, +) from langchain_postgres._utils import maximal_marginal_relevance @@ -109,28 +119,26 @@ class CollectionStore(Base): @classmethod def get_by_name( - cls, session: Session, name: str + cls, session: Session, name: str ) -> Optional["CollectionStore"]: return session.query(cls).filter(cls.name == name).first() # type: ignore @classmethod async def aget_by_name( - cls, session: AsyncSession, name: str + cls, session: AsyncSession, name: str ) -> Optional["CollectionStore"]: - stmt = select(cls).filter(cls.name == name) - # return await session.execute(stmt) # FIXME - return (await session.execute(stmt)).scalars().first() # FIXME - # stmt = select(cls).filter(cls.name == name) - # result = await session.execute(stmt) - # x = result.scalars() - # return session.query(cls).filter(cls.name == name).first() + return ( + (await session.execute(select(CollectionStore).where(cls.name == name))) + .scalars() + .first() + ) @classmethod def get_or_create( - cls, - session: Session, - name: str, - cmetadata: Optional[dict] = None, + cls, + session: Session, + name: str, + cmetadata: Optional[dict] = None, ) -> Tuple["CollectionStore", bool]: """Get or create a collection. Returns: @@ -143,16 +151,16 @@ def get_or_create( collection = cls(name=name, cmetadata=cmetadata) session.add(collection) - session.commit() # FIXME PPR semble utile + session.commit() created = True return collection, created @classmethod async def aget_or_create( - cls, - session: AsyncSession, - name: str, - cmetadata: Optional[dict] = None, + cls, + session: AsyncSession, + name: str, + cmetadata: Optional[dict] = None, ) -> Tuple["CollectionStore", bool]: """ Get or create a collection. @@ -268,23 +276,24 @@ class PGVector(VectorStore): """ def __init__( - self, - embeddings: Embeddings, - *, - connection: Union[None, Connection, str] = None, - embedding_length: Optional[int] = None, - collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, - collection_metadata: Optional[dict] = None, - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - pre_delete_collection: bool = False, - logger: Optional[logging.Logger] = None, - relevance_score_fn: Optional[Callable[[float], float]] = None, - engine_args: Optional[dict[str, Any]] = None, - use_jsonb: bool = True, - create_extension: bool = True, - async_mode: bool = False, # FIXME: a virer. Gaff aux imports sans async + self, + embeddings: Embeddings, + *, + connection: Union[None, Connection, str] = None, + embedding_length: Optional[int] = None, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + collection_metadata: Optional[dict] = None, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + pre_delete_collection: bool = False, + logger: Optional[logging.Logger] = None, + relevance_score_fn: Optional[Callable[[float], float]] = None, + engine_args: Optional[dict[str, Any]] = None, + use_jsonb: bool = True, + create_extension: bool = True, + _async_mode: bool = False, # Tag to force the async mode ) -> None: """Initialize the PGVector store. + For an async version, use `PGVector.acreate()` instead. Args: connection: Postgres connection string. @@ -311,7 +320,7 @@ def __init__( doesn't exist. disabling creation is useful when using ReadOnly Databases. """ - self.async_mode = async_mode + self._async_mode = _async_mode self.embedding_function = embeddings self._embedding_length = embedding_length self.collection_name = collection_name @@ -320,22 +329,29 @@ def __init__( self.pre_delete_collection = pre_delete_collection self.logger = logger or logging.getLogger(__name__) self.override_relevance_score_fn = relevance_score_fn + self._engine: Optional[Engine] = None + self._async_engine: Optional[AsyncEngine] = None if isinstance(connection, str): - self._engine = self._create_engine( - connection, engine_args, async_mode) - elif isinstance(connection, sqlalchemy.engine.Engine): + if _async_mode: + self._async_engine = create_async_engine( + connection, **(engine_args or {}) + ) + else: + self._engine = create_engine(url=connection, **(engine_args or {})) + elif isinstance(connection, Engine): self._engine = connection + elif isinstance(connection, AsyncEngine): + self._async_mode = True + self._async_engine = connection else: raise ValueError( "connection should be a connection string or an instance of " - "sqlalchemy.engine.Engine" + "sqlalchemy.engine.Engine or sqlalchemy.ext.asyncio.engine.AsyncEngine" ) - # If the driver accept only the synchrone calls, update the async_mode - self.async_mode = not isinstance(self._engine, Engine) self._session_maker: Union[sessionmaker, async_sessionmaker] - if self.async_mode: - self._session_maker = async_sessionmaker(bind=self._engine) + if self._async_mode: + self._session_maker = async_sessionmaker(bind=self._async_engine) else: self._session_maker = sessionmaker(bind=self._engine) @@ -345,11 +361,18 @@ def __init__( if not use_jsonb: # Replace with a deprecation warning. raise NotImplementedError("use_jsonb=False is no longer supported.") - if not async_mode: + if not _async_mode: self.__post_init__() + else: + import inspect + + assert inspect.stack()[1].function in [ + "acreate", + "__afrom", + ], "Call await PGVector.acreate() instead of PGVector(...))" def __post_init__( - self, + self, ) -> None: """Initialize the store.""" if self.create_extension: @@ -364,10 +387,9 @@ def __post_init__( self.create_collection() async def __apost_init__( - self, + self, ) -> None: - - """Initialize the store.""" + """Async initialize the store.""" if self.create_extension: await self.acreate_vector_extension() @@ -380,22 +402,49 @@ async def __apost_init__( await self.acreate_collection() @classmethod - async def create(cls, - embeddings: Embeddings, - *, - connection: Optional[Connection] = None, - embedding_length: Optional[int] = None, - collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, - collection_metadata: Optional[dict] = None, - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - pre_delete_collection: bool = False, - logger: Optional[logging.Logger] = None, - relevance_score_fn: Optional[Callable[[float], float]] = None, - engine_args: Optional[dict[str, Any]] = None, - use_jsonb: bool = True, - create_extension: bool = True, - async_mode: bool = True, - ) -> PGVector: + async def acreate( + cls, + embeddings: Embeddings, + *, + connection: Optional[Connection] = None, + embedding_length: Optional[int] = None, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + collection_metadata: Optional[dict] = None, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + pre_delete_collection: bool = False, + logger: Optional[logging.Logger] = None, + relevance_score_fn: Optional[Callable[[float], float]] = None, + engine_args: Optional[dict[str, Any]] = None, + use_jsonb: bool = True, + create_extension: bool = True, + ) -> PGVector: + """Async create instance + + Args: + connection: Postgres connection string. + embeddings: Any embedding function implementing + `langchain.embeddings.base.Embeddings` interface. + embedding_length: The length of the embedding vector. (default: None) + NOTE: This is not mandatory. Defining it will prevent vectors of + any other size to be added to the embeddings table but, without it, + the embeddings can't be indexed. + collection_name: The name of the collection to use. (default: langchain) + NOTE: This is not the name of the table, but the name of the collection. + The tables will be created when initializing the store (if not exists) + So, make sure the user has the right permissions to create tables. + distance_strategy: The distance strategy to use. (default: COSINE) + pre_delete_collection: If True, will delete the collection if it exists. + (default: False). Useful for testing. + engine_args: SQLAlchemy's create engine arguments. + use_jsonb: Use JSONB instead of JSON for metadata. (default: True) + Strongly discouraged from using JSON as it's not as efficient + for querying. + It's provided here for backwards compatibility with older versions, + and will be removed in the future. + create_extension: If True, will create the vector extension if it + doesn't exist. disabling creation is useful when using ReadOnly + Databases. + """ self = cls( embeddings=embeddings, connection=connection, @@ -409,40 +458,11 @@ async def create(cls, engine_args=engine_args, use_jsonb=use_jsonb, create_extension=create_extension, - async_mode=async_mode, + _async_mode=True, ) - if async_mode: - await self.__apost_init__() + await self.__apost_init__() return self - def _create_engine(self, - connection: str, - engine_args: Optional[dict[str, Any]] = None, - async_mode: bool = False) -> sqlalchemy.engine.Engine | sqlalchemy.ext.asyncio.AsyncEngine: - if async_mode: - from sqlalchemy.ext.asyncio import create_async_engine - # FIXME: gérer appel async sur un sync - from sqlalchemy.exc import InvalidRequestError - try: - return create_async_engine( - url=connection, - isolation_level="REPEATABLE READ", # FIXME: merge avec la suite ? - echo=True, # FIXME: a virer - **(engine_args or {}) - ) - except InvalidRequestError: - pass # Ignore and return the synchrone version - logging.warning("Use a synchrone SQL engine !") - return sqlalchemy.create_engine(url=connection, - **(engine_args or {})) - - def __del__(self) -> None: - if isinstance(self._engine, sqlalchemy.engine.Connection): - if self.async_mode: - asyncio.run(self._engine.close()) - else: - self._engine.close() - @property def embeddings(self) -> Embeddings: return self.embedding_function @@ -475,10 +495,11 @@ async def acreate_vector_extension(self) -> None: # For more information see: # https://www.postgresql.org/docs/16/explicit-locking.html#ADVISORY-LOCKS await session.execute( - sqlalchemy.text( - "SELECT pg_advisory_xact_lock(1573678846307946496)")) + sqlalchemy.text("SELECT pg_advisory_xact_lock(1573678846307946496)") + ) await session.execute( - sqlalchemy.text("CREATE EXTENSION IF NOT EXISTS vector")) + sqlalchemy.text("CREATE EXTENSION IF NOT EXISTS vector") + ) except Exception as e: raise Exception(f"Failed to create vector extension: {e}") from e @@ -487,20 +508,17 @@ def create_tables_if_not_exists(self) -> None: Base.metadata.create_all(session.get_bind()) async def acreate_tables_if_not_exists(self) -> None: - if isinstance(self._engine, sqlalchemy.ext.asyncio.engine.AsyncConnection): - await self._engine.run_sync(Base.metadata.create_all) - else: - async with self._engine.begin() as conn: # FIXME: session.run_sync existe - await conn.run_sync(Base.metadata.create_all) - # async with self._amake_session() as session: - # await session.run_sync(Base.metadata.create_all) + assert self._async_engine, "Use with async mode" + async with self._async_engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) def drop_tables(self) -> None: with self._session_maker() as session: Base.metadata.drop_all(session.get_bind()) async def adrop_tables(self) -> None: - async with self._engine.begin() as conn: + assert self._async_engine, "Use with async mode" + async with self._async_engine.begin() as conn: await conn.run_sync(Base.metadata.drop_all) def create_collection(self) -> None: @@ -512,6 +530,7 @@ def create_collection(self) -> None: ) async def acreate_collection(self) -> None: + assert self._async_engine, "Use with async mode" async with self._session_maker() as session: if self.pre_delete_collection: await self._adelete_collection(session) @@ -527,7 +546,6 @@ def _delete_collection(self, session: Session) -> None: return session.delete(collection) - # FIXME: necessaire le _adelete ? async def _adelete_collection(self, session: AsyncSession) -> None: self.logger.debug("Trying to delete collection") collection = await self.aget_collection(session) @@ -536,9 +554,6 @@ async def _adelete_collection(self, session: AsyncSession) -> None: return await session.delete(collection) - # def delete_collection(self) -> None: - # with self._session_maker() as session: - # self._delete_collection(session) def delete_collection(self) -> None: self.logger.debug("Trying to delete collection") with self._session_maker() as session: # type: ignore[arg-type] @@ -551,7 +566,7 @@ def delete_collection(self) -> None: async def adelete_collection(self) -> None: self.logger.debug("Trying to delete collection") - with self._session_maker() as session: # type: ignore[arg-type] + async with self._session_maker() as session: # type: ignore[arg-type] collection = await self.aget_collection(session) if not collection: self.logger.warning("Collection not found") @@ -560,10 +575,10 @@ async def adelete_collection(self) -> None: await session.commit() def delete( - self, - ids: Optional[List[str]] = None, - collection_only: bool = False, - **kwargs: Any, + self, + ids: Optional[List[str]] = None, + collection_only: bool = False, + **kwargs: Any, ) -> None: """Delete vectors by ids or uuids. @@ -595,12 +610,12 @@ def delete( session.commit() async def adelete( - self, - ids: Optional[List[str]] = None, - collection_only: bool = False, - **kwargs: Any, + self, + ids: Optional[List[str]] = None, + collection_only: bool = False, + **kwargs: Any, ) -> None: - """Delete vectors by ids or uuids. + """Async delete vectors by ids or uuids. Args: ids: List of ids to delete. @@ -637,19 +652,19 @@ async def aget_collection(self, session: AsyncSession) -> Any: @classmethod def __from( - cls, - texts: List[str], - embeddings: List[List[float]], - embedding: Embeddings, - metadatas: Optional[List[dict]] = None, - ids: Optional[List[str]] = None, - collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - connection: Optional[str] = None, - pre_delete_collection: bool = False, - *, - use_jsonb: bool = True, - **kwargs: Any, + cls, + texts: List[str], + embeddings: List[List[float]], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + connection: Optional[str] = None, + pre_delete_collection: bool = False, + *, + use_jsonb: bool = True, + **kwargs: Any, ) -> PGVector: if ids is None: ids = [str(uuid.uuid4()) for _ in texts] @@ -675,19 +690,19 @@ def __from( @classmethod async def __afrom( - cls, - texts: List[str], - embeddings: List[List[float]], - embedding: Embeddings, - metadatas: Optional[List[dict]] = None, - ids: Optional[List[str]] = None, - collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - connection: Optional[str] = None, - pre_delete_collection: bool = False, - *, - use_jsonb: bool = True, - **kwargs: Any, + cls, + texts: List[str], + embeddings: List[List[float]], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + connection: Optional[str] = None, + pre_delete_collection: bool = False, + *, + use_jsonb: bool = True, + **kwargs: Any, ) -> PGVector: if ids is None: ids = [str(uuid.uuid1()) for _ in texts] @@ -702,7 +717,7 @@ async def __afrom( distance_strategy=distance_strategy, pre_delete_collection=pre_delete_collection, use_jsonb=use_jsonb, - async_mode=True, # FIXME + _async_mode=True, **kwargs, ) # Second phase to create @@ -715,12 +730,12 @@ async def __afrom( return store def add_embeddings( - self, - texts: Iterable[str], - embeddings: List[List[float]], - metadatas: Optional[List[dict]] = None, - ids: Optional[List[str]] = None, - **kwargs: Any, + self, + texts: Iterable[str], + embeddings: List[List[float]], + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + **kwargs: Any, ) -> List[str]: """Add embeddings to the vectorstore. @@ -728,6 +743,8 @@ def add_embeddings( texts: Iterable of strings to add to the vectorstore. embeddings: List of list of embedding vectors. metadatas: List of metadatas associated with the texts. + ids: Optional list of ids for the documents. + If not provided, will generate a new id for each document. kwargs: vectorstore specific parameters """ if ids is None: @@ -768,19 +785,21 @@ def add_embeddings( return ids async def aadd_embeddings( - self, - texts: Iterable[str], - embeddings: List[List[float]], - metadatas: Optional[List[dict]] = None, - ids: Optional[List[str]] = None, - **kwargs: Any, + self, + texts: Iterable[str], + embeddings: List[List[float]], + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + **kwargs: Any, ) -> List[str]: - """Add embeddings to the vectorstore. + """Async add embeddings to the vectorstore. Args: texts: Iterable of strings to add to the vectorstore. embeddings: List of list of embedding vectors. metadatas: List of metadatas associated with the texts. + ids: Optional list of ids for the texts. + If not provided, will generate a new id for each text. kwargs: vectorstore specific parameters """ if ids is None: @@ -821,17 +840,19 @@ async def aadd_embeddings( return ids def add_texts( - self, - texts: Iterable[str], - metadatas: Optional[List[dict]] = None, - ids: Optional[List[str]] = None, - **kwargs: Any, + self, + texts: Iterable[str], + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + **kwargs: Any, ) -> List[str]: """Run more texts through the embeddings and add to the vectorstore. Args: texts: Iterable of strings to add to the vectorstore. metadatas: Optional list of metadatas associated with the texts. + ids: Optional list of ids for the texts. + If not provided, will generate a new id for each text. kwargs: vectorstore specific parameters Returns: @@ -843,17 +864,19 @@ def add_texts( ) async def aadd_texts( - self, - texts: Iterable[str], - metadatas: Optional[List[dict]] = None, - ids: Optional[List[str]] = None, - **kwargs: Any, + self, + texts: Iterable[str], + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + **kwargs: Any, ) -> List[str]: """Run more texts through the embeddings and add to the vectorstore. Args: texts: Iterable of strings to add to the vectorstore. metadatas: Optional list of metadatas associated with the texts. + ids: Optional list of ids for the texts. + If not provided, will generate a new id for each text. kwargs: vectorstore specific parameters Returns: @@ -865,11 +888,11 @@ async def aadd_texts( ) def similarity_search( - self, - query: str, - k: int = 4, - filter: Optional[dict] = None, - **kwargs: Any, + self, + query: str, + k: int = 4, + filter: Optional[dict] = None, + **kwargs: Any, ) -> List[Document]: """Run similarity search with PGVector with distance. @@ -889,11 +912,11 @@ def similarity_search( ) async def asimilarity_search( - self, - query: str, - k: int = 4, - filter: Optional[dict] = None, - **kwargs: Any, + self, + query: str, + k: int = 4, + filter: Optional[dict] = None, + **kwargs: Any, ) -> List[Document]: """Run similarity search with PGVector with distance. @@ -913,10 +936,10 @@ async def asimilarity_search( ) def similarity_search_with_score( - self, - query: str, - k: int = 4, - filter: Optional[dict] = None, + self, + query: str, + k: int = 4, + filter: Optional[dict] = None, ) -> List[Tuple[Document, float]]: """Return docs most similar to query. @@ -935,10 +958,10 @@ def similarity_search_with_score( return docs async def asimilarity_search_with_score( - self, - query: str, - k: int = 4, - filter: Optional[dict] = None, + self, + query: str, + k: int = 4, + filter: Optional[dict] = None, ) -> List[Tuple[Document, float]]: """Return docs most similar to query. @@ -971,26 +994,25 @@ def distance_strategy(self) -> Any: ) def similarity_search_with_score_by_vector( - self, - embedding: List[float], - k: int = 4, - filter: Optional[dict] = None, + self, + embedding: List[float], + k: int = 4, + filter: Optional[dict] = None, ) -> List[Tuple[Document, float]]: results = self.__query_collection(embedding=embedding, k=k, filter=filter) return self._results_to_docs_and_scores(results) async def asimilarity_search_with_score_by_vector( - self, - embedding: List[float], - k: int = 4, - filter: Optional[dict] = None, + self, + embedding: List[float], + k: int = 4, + filter: Optional[dict] = None, ) -> List[Tuple[Document, float]]: async with self._session_maker() as session: # type: ignore[arg-type] results = await self.__aquery_collection( - session=session, - embedding=embedding, k=k, - filter=filter) + session=session, embedding=embedding, k=k, filter=filter + ) return self._results_to_docs_and_scores(results) @@ -1009,9 +1031,9 @@ def _results_to_docs_and_scores(self, results: Any) -> List[Tuple[Document, floa return docs def _handle_field_filter( - self, - field: str, - value: Any, + self, + field: str, + value: Any, ) -> SQLColumnExpression: """Create a filter for a specific field. @@ -1121,8 +1143,7 @@ def _handle_field_filter( else: raise NotImplementedError() - def _create_filter_clause_deprecated(self, key, - value): # type: ignore[no-untyped-def] + def _create_filter_clause_deprecated(self, key, value): # type: ignore[no-untyped-def] """Deprecated functionality. This is for backwards compatibility with the JSON based schema for metadata. @@ -1191,7 +1212,7 @@ def _create_filter_clause_deprecated(self, key, return filter_by_metadata def _create_filter_clause_json_deprecated( - self, filter: Any + self, filter: Any ) -> List[SQLColumnExpression]: """Convert filters from IR to SQL clauses. @@ -1302,11 +1323,11 @@ def _create_filter_clause(self, filters: Any) -> Any: ) def __query_collection( - self, - embedding: List[float], - k: int = 4, - filter: Optional[Dict[str, str]] = None, - ) -> List[Any]: + self, + embedding: List[float], + k: int = 4, + filter: Optional[Dict[str, str]] = None, + ) -> Sequence[Any]: """Query the collection.""" with self._session_maker() as session: # type: ignore[arg-type] collection = self.get_collection(session) @@ -1344,52 +1365,55 @@ def __query_collection( return results async def __aquery_collection( - self, - session: AsyncSession, - embedding: List[float], - k: int = 4, - filter: Optional[Dict[str, str]] = None, - ) -> List[Any]: + self, + session: AsyncSession, + embedding: List[float], + k: int = 4, + filter: Optional[Dict[str, str]] = None, + ) -> Sequence[Any]: """Query the collection.""" - collection = await self.aget_collection(session) - if not collection: - raise ValueError("Collection not found") - - filter_by = [self.EmbeddingStore.collection_id == collection.uuid] - if filter: - if self.use_jsonb: - filter_clauses = self._create_filter_clause(filter) - if filter_clauses is not None: - filter_by.append(filter_clauses) - else: - # Old way of doing things - filter_clauses = self._create_filter_clause_json_deprecated(filter) - filter_by.extend(filter_clauses) + async with self._session_maker() as session: # type: ignore[arg-type] + collection = await self.aget_collection(session) + if not collection: + raise ValueError("Collection not found") - _type = self.EmbeddingStore + filter_by = [self.EmbeddingStore.collection_id == collection.uuid] + if filter: + if self.use_jsonb: + filter_clauses = self._create_filter_clause(filter) + if filter_clauses is not None: + filter_by.append(filter_clauses) + else: + # Old way of doing things + filter_clauses = self._create_filter_clause_json_deprecated(filter) + filter_by.extend(filter_clauses) - stmt = (select( - self.EmbeddingStore, - self.distance_strategy(embedding).label("distance"), # type: ignore - ) + _type = self.EmbeddingStore + + stmt = ( + select( + self.EmbeddingStore, + self.distance_strategy(embedding).label("distance"), # type: ignore + ) .filter(*filter_by) .order_by(sqlalchemy.asc("distance")) .join( - self.CollectionStore, - self.EmbeddingStore.collection_id == self.CollectionStore.uuid, - ) - .limit(k)) + self.CollectionStore, + self.EmbeddingStore.collection_id == self.CollectionStore.uuid, + ) + .limit(k) + ) - results: List[Any] = (await session.execute(stmt)).all() + results: Sequence[Any] = (await session.execute(stmt)).all() - return results + return results def similarity_search_by_vector( - self, - embedding: List[float], - k: int = 4, - filter: Optional[dict] = None, - **kwargs: Any, + self, + embedding: List[float], + k: int = 4, + filter: Optional[dict] = None, + **kwargs: Any, ) -> List[Document]: """Return docs most similar to embedding vector. @@ -1407,11 +1431,11 @@ def similarity_search_by_vector( return _results_to_docs(docs_and_scores) async def asimilarity_search_by_vector( - self, - embedding: List[float], - k: int = 4, - filter: Optional[dict] = None, - **kwargs: Any, + self, + embedding: List[float], + k: int = 4, + filter: Optional[dict] = None, + **kwargs: Any, ) -> List[Document]: """Return docs most similar to embedding vector. @@ -1430,17 +1454,17 @@ async def asimilarity_search_by_vector( @classmethod def from_texts( - cls: Type[PGVector], - texts: List[str], - embedding: Embeddings, - metadatas: Optional[List[dict]] = None, - *, - collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - ids: Optional[List[str]] = None, - pre_delete_collection: bool = False, - use_jsonb: bool = True, - **kwargs: Any, + cls: Type[PGVector], + texts: List[str], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + *, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + ids: Optional[List[str]] = None, + pre_delete_collection: bool = False, + use_jsonb: bool = True, + **kwargs: Any, ) -> PGVector: """Return VectorStore initialized from documents and embeddings.""" embeddings = embedding.embed_documents(list(texts)) @@ -1460,17 +1484,17 @@ def from_texts( @classmethod async def afrom_texts( - cls: Type[PGVector], - texts: List[str], - embedding: Embeddings, - metadatas: Optional[List[dict]] = None, - collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - ids: Optional[List[str]] = None, - pre_delete_collection: bool = False, - *, - use_jsonb: bool = True, - **kwargs: Any, + cls: Type[PGVector], + texts: List[str], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + ids: Optional[List[str]] = None, + pre_delete_collection: bool = False, + *, + use_jsonb: bool = True, + **kwargs: Any, ) -> PGVector: """Return VectorStore initialized from documents and embeddings.""" embeddings = embedding.embed_documents(list(texts)) @@ -1489,16 +1513,16 @@ async def afrom_texts( @classmethod def from_embeddings( - cls, - text_embeddings: List[Tuple[str, List[float]]], - embedding: Embeddings, - *, - metadatas: Optional[List[dict]] = None, - collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - ids: Optional[List[str]] = None, - pre_delete_collection: bool = False, - **kwargs: Any, + cls, + text_embeddings: List[Tuple[str, List[float]]], + embedding: Embeddings, + *, + metadatas: Optional[List[dict]] = None, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + ids: Optional[List[str]] = None, + pre_delete_collection: bool = False, + **kwargs: Any, ) -> PGVector: """Construct PGVector wrapper from raw documents and embeddings. @@ -1509,6 +1533,7 @@ def from_embeddings( collection_name: Name of the collection. distance_strategy: Distance strategy to use. ids: Optional list of ids for the documents. + If not provided, will generate a new id for each document. pre_delete_collection: If True, will delete the collection if it exists. **Attention**: This will delete all the documents in the existing collection. @@ -1545,15 +1570,15 @@ def from_embeddings( @classmethod async def afrom_embeddings( - cls, - text_embeddings: List[Tuple[str, List[float]]], - embedding: Embeddings, - metadatas: Optional[List[dict]] = None, - collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - ids: Optional[List[str]] = None, - pre_delete_collection: bool = False, - **kwargs: Any, + cls, + text_embeddings: List[Tuple[str, List[float]]], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + ids: Optional[List[str]] = None, + pre_delete_collection: bool = False, + **kwargs: Any, ) -> PGVector: """Construct PGVector wrapper from raw documents and pre- generated embeddings. @@ -1590,14 +1615,14 @@ async def afrom_embeddings( @classmethod def from_existing_index( - cls: Type[PGVector], - embedding: Embeddings, - *, - collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - pre_delete_collection: bool = False, - connection: Optional[Connection] = None, - **kwargs: Any, + cls: Type[PGVector], + embedding: Embeddings, + *, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + pre_delete_collection: bool = False, + connection: Optional[Connection] = None, + **kwargs: Any, ) -> PGVector: """ Get instance of an existing PGVector store.This method will @@ -1617,27 +1642,26 @@ def from_existing_index( @classmethod async def afrom_existing_index( - cls: Type[PGVector], - embedding: Embeddings, - *, - collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - pre_delete_collection: bool = False, - connection: Optional[Connection] = None, - **kwargs: Any, + cls: Type[PGVector], + embedding: Embeddings, + *, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + pre_delete_collection: bool = False, + connection: Optional[Connection] = None, + **kwargs: Any, ) -> PGVector: """ Get instance of an existing PGVector store.This method will return the instance of the store without inserting any new embeddings """ - store = cls( # FIXME: créate + store = await PGVector.acreate( connection=connection, collection_name=collection_name, embeddings=embedding, distance_strategy=distance_strategy, pre_delete_collection=pre_delete_collection, - async_mode=True, **kwargs, ) @@ -1662,17 +1686,17 @@ def get_connection_string(cls, kwargs: Dict[str, Any]) -> str: @classmethod def from_documents( - cls: Type[PGVector], - documents: List[Document], - embedding: Embeddings, - *, - connection: Optional[Connection] = None, - collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - ids: Optional[List[str]] = None, - pre_delete_collection: bool = False, - use_jsonb: bool = True, - **kwargs: Any, + cls: Type[PGVector], + documents: List[Document], + embedding: Embeddings, + *, + connection: Optional[Connection] = None, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + ids: Optional[List[str]] = None, + pre_delete_collection: bool = False, + use_jsonb: bool = True, + **kwargs: Any, ) -> PGVector: """Return VectorStore initialized from documents and embeddings.""" @@ -1694,16 +1718,16 @@ def from_documents( @classmethod async def afrom_documents( - cls: Type[PGVector], - documents: List[Document], - embedding: Embeddings, - collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - ids: Optional[List[str]] = None, - pre_delete_collection: bool = False, - *, - use_jsonb: bool = True, - **kwargs: Any, + cls: Type[PGVector], + documents: List[Document], + embedding: Embeddings, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + ids: Optional[List[str]] = None, + pre_delete_collection: bool = False, + *, + use_jsonb: bool = True, + **kwargs: Any, ) -> PGVector: """ Return VectorStore initialized from documents and embeddings. @@ -1732,13 +1756,13 @@ async def afrom_documents( @classmethod def connection_string_from_db_params( - cls, - driver: str, - host: str, - port: int, - database: str, - user: str, - password: str, + cls, + driver: str, + host: str, + port: int, + database: str, + user: str, + password: str, ) -> str: """Return connection string from database parameters.""" if driver != "psycopg": @@ -1773,13 +1797,13 @@ def _select_relevance_score_fn(self) -> Callable[[float], float]: ) def max_marginal_relevance_search_with_score_by_vector( - self, - embedding: List[float], - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - filter: Optional[Dict[str, str]] = None, - **kwargs: Any, + self, + embedding: List[float], + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[Dict[str, str]] = None, + **kwargs: Any, ) -> List[Tuple[Document, float]]: """Return docs selected using the maximal marginal relevance with score to embedding vector. @@ -1818,13 +1842,13 @@ def max_marginal_relevance_search_with_score_by_vector( return [r for i, r in enumerate(candidates) if i in mmr_selected] async def amax_marginal_relevance_search_with_score_by_vector( - self, - embedding: List[float], - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - filter: Optional[Dict[str, str]] = None, - **kwargs: Any, + self, + embedding: List[float], + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[Dict[str, str]] = None, + **kwargs: Any, ) -> List[Tuple[Document, float]]: """Return docs selected using the maximal marginal relevance with score to embedding vector. @@ -1848,9 +1872,9 @@ async def amax_marginal_relevance_search_with_score_by_vector( relevance to the query and score for each. """ async with self._session_maker() as session: - results = await self.__aquery_collection(session=session, - embedding=embedding, k=fetch_k, - filter=filter) + results = await self.__aquery_collection( + session=session, embedding=embedding, k=fetch_k, filter=filter + ) embedding_list = [result.EmbeddingStore.embedding for result in results] @@ -1866,13 +1890,13 @@ async def amax_marginal_relevance_search_with_score_by_vector( return [r for i, r in enumerate(candidates) if i in mmr_selected] def max_marginal_relevance_search( - self, - query: str, - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - filter: Optional[Dict[str, str]] = None, - **kwargs: Any, + self, + query: str, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[Dict[str, str]] = None, + **kwargs: Any, ) -> List[Document]: """Return docs selected using the maximal marginal relevance. @@ -1904,13 +1928,13 @@ def max_marginal_relevance_search( ) async def amax_marginal_relevance_search( - self, - query: str, - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - filter: Optional[Dict[str, str]] = None, - **kwargs: Any, + self, + query: str, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[Dict[str, str]] = None, + **kwargs: Any, ) -> List[Document]: """Return docs selected using the maximal marginal relevance. @@ -1942,13 +1966,13 @@ async def amax_marginal_relevance_search( ) def max_marginal_relevance_search_with_score( - self, - query: str, - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - filter: Optional[dict] = None, - **kwargs: Any, + self, + query: str, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[dict] = None, + **kwargs: Any, ) -> List[Tuple[Document, float]]: """Return docs selected using the maximal marginal relevance with score. @@ -1982,13 +2006,13 @@ def max_marginal_relevance_search_with_score( return docs async def amax_marginal_relevance_search_with_score( - self, - query: str, - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - filter: Optional[dict] = None, - **kwargs: Any, + self, + query: str, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[dict] = None, + **kwargs: Any, ) -> List[Tuple[Document, float]]: """Return docs selected using the maximal marginal relevance with score. @@ -2022,13 +2046,13 @@ async def amax_marginal_relevance_search_with_score( return docs def max_marginal_relevance_search_by_vector( - self, - embedding: List[float], - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - filter: Optional[Dict[str, str]] = None, - **kwargs: Any, + self, + embedding: List[float], + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[Dict[str, str]] = None, + **kwargs: Any, ) -> List[Document]: """Return docs selected using the maximal marginal relevance to embedding vector. @@ -2062,13 +2086,13 @@ def max_marginal_relevance_search_by_vector( return _results_to_docs(docs_and_scores) async def amax_marginal_relevance_search_by_vector( - self, - embedding: List[float], - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - filter: Optional[Dict[str, str]] = None, - **kwargs: Any, + self, + embedding: List[float], + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[Dict[str, str]] = None, + **kwargs: Any, ) -> List[Document]: """Return docs selected using the maximal marginal relevance to embedding vector. @@ -2090,13 +2114,15 @@ async def amax_marginal_relevance_search_by_vector( Returns: List[Document]: List of Documents selected by maximal marginal relevance. """ - docs_and_scores = await self.amax_marginal_relevance_search_with_score_by_vector( - embedding, - k=k, - fetch_k=fetch_k, - lambda_mult=lambda_mult, - filter=filter, - **kwargs, + docs_and_scores = ( + await self.amax_marginal_relevance_search_with_score_by_vector( + embedding, + k=k, + fetch_k=fetch_k, + lambda_mult=lambda_mult, + filter=filter, + **kwargs, + ) ) return _results_to_docs(docs_and_scores) diff --git a/tests/unit_tests/test_vectorstore.py b/tests/unit_tests/test_vectorstore.py index 42da5909..e0f05e18 100644 --- a/tests/unit_tests/test_vectorstore.py +++ b/tests/unit_tests/test_vectorstore.py @@ -1,16 +1,15 @@ """Test PGVector functionality.""" import contextlib -from typing import Any, Dict, Generator, List +from typing import Any, AsyncGenerator, Dict, Generator, List import pytest from langchain_core.documents import Document +from sqlalchemy import select from langchain_postgres.vectorstores import ( SUPPORTED_OPERATORS, PGVector, ) -from sqlalchemy import select - from tests.unit_tests.fake_embeddings import FakeEmbeddings from tests.unit_tests.fixtures.filtering_test_cases import ( DOCUMENTS, @@ -53,7 +52,7 @@ def test_pgvector() -> None: output = docsearch.similarity_search("foo", k=1) assert output == [Document(page_content="foo")] -# @pytest.mark.requires("xxx") # FIXME + @pytest.mark.asyncio async def test_async_pgvector() -> None: """Test end to end construction and search.""" @@ -197,7 +196,9 @@ async def test_async_pgvector_with_filter_match() -> None: connection=CONNECTION_STRING, pre_delete_collection=True, ) - output = await docsearch.asimilarity_search_with_score("foo", k=1, filter={"page": "0"}) + output = await docsearch.asimilarity_search_with_score( + "foo", k=1, filter={"page": "0"} + ) assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)] @@ -218,6 +219,7 @@ def test_pgvector_with_filter_distant_match() -> None: (Document(page_content="baz", metadata={"page": "2"}), 0.0013003906671379406) ] + @pytest.mark.asyncio async def test_async_pgvector_with_filter_distant_match() -> None: """Test end to end construction and search.""" @@ -231,7 +233,9 @@ async def test_async_pgvector_with_filter_distant_match() -> None: connection=CONNECTION_STRING, pre_delete_collection=True, ) - output = await docsearch.asimilarity_search_with_score("foo", k=1, filter={"page": "2"}) + output = await docsearch.asimilarity_search_with_score( + "foo", k=1, filter={"page": "2"} + ) assert output == [ (Document(page_content="baz", metadata={"page": "2"}), 0.0013003906671379406) ] @@ -252,6 +256,7 @@ def test_pgvector_with_filter_no_match() -> None: output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "5"}) assert output == [] + @pytest.mark.asyncio async def test_async_pgvector_with_filter_no_match() -> None: """Test end to end construction and search.""" @@ -265,7 +270,9 @@ async def test_async_pgvector_with_filter_no_match() -> None: connection=CONNECTION_STRING, pre_delete_collection=True, ) - output = await docsearch.asimilarity_search_with_score("foo", k=1, filter={"page": "5"}) + output = await docsearch.asimilarity_search_with_score( + "foo", k=1, filter={"page": "5"} + ) assert output == [] @@ -286,10 +293,11 @@ def test_pgvector_collection_with_metadata() -> None: assert collection.name == "test_collection" assert collection.cmetadata == {"foo": "bar"} + @pytest.mark.asyncio async def test_async_pgvector_collection_with_metadata() -> None: """Test end to end collection construction""" - pgvector = await PGVector.create( + pgvector = await PGVector.acreate( collection_name="test_collection", collection_metadata={"foo": "bar"}, embeddings=FakeEmbeddingsWithAdaDimension(), @@ -305,8 +313,6 @@ async def test_async_pgvector_collection_with_metadata() -> None: assert collection.cmetadata == {"foo": "bar"} - - def test_pgvector_delete_docs() -> None: """Add and delete documents.""" texts = ["foo", "bar", "baz"] @@ -335,6 +341,22 @@ def test_pgvector_delete_docs() -> None: assert sorted(record.id for record in records) == [] # type: ignore +def test_pgvector_delete_collection() -> None: + """Add and delete documents.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + vectorstore = PGVector.from_texts( + texts=texts, + collection_name="test_collection_filter", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + ids=["1", "2", "3"], + connection=CONNECTION_STRING, + pre_delete_collection=True, + ) + vectorstore.delete(collection_only=True) + + @pytest.mark.asyncio async def test_async_pgvector_delete_docs() -> None: """Add and delete documents.""" @@ -351,14 +373,18 @@ async def test_async_pgvector_delete_docs() -> None: ) await vectorstore.adelete(["1", "2"]) async with vectorstore._session_maker() as session: - records=(await session.execute(select(vectorstore.EmbeddingStore))).scalars().all() + records = ( + (await session.execute(select(vectorstore.EmbeddingStore))).scalars().all() + ) # ignoring type error since mypy cannot determine whether # the list is sortable assert sorted(record.id for record in records) == ["3"] # type: ignore await vectorstore.adelete(["2", "3"]) # Should not raise on missing ids async with vectorstore._session_maker() as session: - records=(await session.execute(select(vectorstore.EmbeddingStore))).scalars().all() + records = ( + (await session.execute(select(vectorstore.EmbeddingStore))).scalars().all() + ) # ignoring type error since mypy cannot determine whether # the list is sortable assert sorted(record.id for record in records) == [] # type: ignore @@ -438,6 +464,7 @@ def test_pgvector_index_documents() -> None: "topic": "zoo", } + @pytest.mark.asyncio async def test_async_pgvector_index_documents() -> None: """Test adding duplicate documents results in overwrites.""" @@ -473,7 +500,9 @@ async def test_async_pgvector_index_documents() -> None: pre_delete_collection=True, ) async with vectorstore._session_maker() as session: - records = (await session.execute(select(vectorstore.EmbeddingStore))).scalars().all() + records = ( + (await session.execute(select(vectorstore.EmbeddingStore))).scalars().all() + ) # ignoring type error since mypy cannot determine whether # the list is sortable assert sorted(record.id for record in records) == [ @@ -492,10 +521,14 @@ async def test_async_pgvector_index_documents() -> None: ), ] - await vectorstore.aadd_documents(documents, ids=[doc.metadata["id"] for doc in documents]) + await vectorstore.aadd_documents( + documents, ids=[doc.metadata["id"] for doc in documents] + ) async with vectorstore._session_maker() as session: - records = (await session.execute(select(vectorstore.EmbeddingStore))).scalars().all() + records = ( + (await session.execute(select(vectorstore.EmbeddingStore))).scalars().all() + ) ordered_records = sorted(records, key=lambda x: x.id) # ignoring type error since mypy cannot determine whether # the list is sortable @@ -629,7 +662,9 @@ def test_pgvector_retriever_search_threshold_custom_normalization_fn() -> None: @pytest.mark.asyncio -async def test_async_pgvector_retriever_search_threshold_custom_normalization_fn() -> None: +async def test_async_pgvector_retriever_search_threshold_custom_normalization_fn() -> ( + None +): """Test searching with threshold and custom normalization function""" texts = ["foo", "bar", "baz"] metadatas = [{"page": str(i)} for i in range(len(texts))] @@ -705,7 +740,9 @@ async def test_async_pgvector_max_marginal_relevance_search_with_score() -> None connection=CONNECTION_STRING, pre_delete_collection=True, ) - output = await docsearch.amax_marginal_relevance_search_with_score("foo", k=1, fetch_k=3) + output = await docsearch.amax_marginal_relevance_search_with_score( + "foo", k=1, fetch_k=3 + ) assert output == [(Document(page_content="foo"), 0.0)] @@ -769,9 +806,10 @@ def pgvector() -> Generator[PGVector, None, None]: with get_vectorstore() as vector_store: yield vector_store + @pytest.mark.asyncio @pytest.fixture -async def async_pgvector() -> Generator[PGVector, None, None]: +async def async_pgvector() -> AsyncGenerator[PGVector, None]: """Create an async PGVector instance.""" store = await PGVector.afrom_documents( documents=DOCUMENTS, @@ -808,7 +846,7 @@ def get_vectorstore() -> Generator[PGVector, None, None]: @contextlib.asynccontextmanager -async def aget_vectorstore() -> Generator[PGVector, None, None]: +async def aget_vectorstore() -> AsyncGenerator[PGVector, None]: """Get a pre-populated-vectorstore""" store = await PGVector.afrom_documents( documents=DOCUMENTS, @@ -835,6 +873,7 @@ def test_pgvector_with_with_metadata_filters_1( docs = pgvector.similarity_search("meow", k=5, filter=test_filter) assert [doc.metadata["id"] for doc in docs] == expected_ids, test_filter + @pytest.mark.asyncio @pytest.mark.parametrize("test_filter, expected_ids", TYPE_1_FILTERING_TEST_CASES) async def test_async_pgvector_with_with_metadata_filters_1( @@ -906,7 +945,7 @@ def test_pgvector_with_with_metadata_filters_4( @pytest.mark.asyncio @pytest.mark.parametrize("test_filter, expected_ids", TYPE_4_FILTERING_TEST_CASES) -async def test_pgvector_with_with_metadata_filters_4( +async def test_async_pgvector_with_with_metadata_filters_4( async_pgvector: PGVector, test_filter: Dict[str, Any], expected_ids: List[int], From e46d5e276fd11705ccd8a5c085b687905a9d5345 Mon Sep 17 00:00:00 2001 From: Philippe Prados Date: Tue, 16 Apr 2024 17:19:36 +0200 Subject: [PATCH 13/18] It's possible to share the session_maker with sync mode. --- langchain_postgres/vectorstores.py | 187 ++++++++++++--------------- tests/unit_tests/test_vectorstore.py | 23 ++-- 2 files changed, 97 insertions(+), 113 deletions(-) diff --git a/langchain_postgres/vectorstores.py b/langchain_postgres/vectorstores.py index ffce9ff6..c32739b8 100644 --- a/langchain_postgres/vectorstores.py +++ b/langchain_postgres/vectorstores.py @@ -29,6 +29,7 @@ from langchain_core.vectorstores import VectorStore from sqlalchemy import SQLColumnExpression, cast, create_engine, delete, func, select from sqlalchemy.dialects.postgresql import JSON, JSONB, JSONPATH, UUID, insert +from sqlalchemy.orm import scoped_session from sqlalchemy.engine import Engine from sqlalchemy.ext.asyncio import ( AsyncEngine, @@ -279,7 +280,7 @@ def __init__( self, embeddings: Embeddings, *, - connection: Union[None, Connection, str] = None, + connection: Union[None, Connection, Engine, AsyncEngine, str] = None, embedding_length: Optional[int] = None, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, collection_metadata: Optional[dict] = None, @@ -290,7 +291,7 @@ def __init__( engine_args: Optional[dict[str, Any]] = None, use_jsonb: bool = True, create_extension: bool = True, - _async_mode: bool = False, # Tag to force the async mode + async_mode: bool = False, ) -> None: """Initialize the PGVector store. For an async version, use `PGVector.acreate()` instead. @@ -320,7 +321,7 @@ def __init__( doesn't exist. disabling creation is useful when using ReadOnly Databases. """ - self._async_mode = _async_mode + self.async_mode = async_mode self.embedding_function = embeddings self._embedding_length = embedding_length self.collection_name = collection_name @@ -331,9 +332,10 @@ def __init__( self.override_relevance_score_fn = relevance_score_fn self._engine: Optional[Engine] = None self._async_engine: Optional[AsyncEngine] = None + self._async_init=False if isinstance(connection, str): - if _async_mode: + if async_mode: self._async_engine = create_async_engine( connection, **(engine_args or {}) ) @@ -342,7 +344,7 @@ def __init__( elif isinstance(connection, Engine): self._engine = connection elif isinstance(connection, AsyncEngine): - self._async_mode = True + self.async_mode = True self._async_engine = connection else: raise ValueError( @@ -350,10 +352,10 @@ def __init__( "sqlalchemy.engine.Engine or sqlalchemy.ext.asyncio.engine.AsyncEngine" ) self._session_maker: Union[sessionmaker, async_sessionmaker] - if self._async_mode: - self._session_maker = async_sessionmaker(bind=self._async_engine) + if self.async_mode: + self.session_maker = async_sessionmaker(bind=self._async_engine) else: - self._session_maker = sessionmaker(bind=self._engine) + self.session_maker = scoped_session(sessionmaker(bind=self._engine)) self.use_jsonb = use_jsonb self.create_extension = create_extension @@ -361,15 +363,8 @@ def __init__( if not use_jsonb: # Replace with a deprecation warning. raise NotImplementedError("use_jsonb=False is no longer supported.") - if not _async_mode: + if not async_mode: self.__post_init__() - else: - import inspect - - assert inspect.stack()[1].function in [ - "acreate", - "__afrom", - ], "Call await PGVector.acreate() instead of PGVector(...))" def __post_init__( self, @@ -390,6 +385,10 @@ async def __apost_init__( self, ) -> None: """Async initialize the store.""" + if self._async_init: + return + self._async_init = True + if self.create_extension: await self.acreate_vector_extension() @@ -401,75 +400,14 @@ async def __apost_init__( await self.acreate_tables_if_not_exists() await self.acreate_collection() - @classmethod - async def acreate( - cls, - embeddings: Embeddings, - *, - connection: Optional[Connection] = None, - embedding_length: Optional[int] = None, - collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, - collection_metadata: Optional[dict] = None, - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - pre_delete_collection: bool = False, - logger: Optional[logging.Logger] = None, - relevance_score_fn: Optional[Callable[[float], float]] = None, - engine_args: Optional[dict[str, Any]] = None, - use_jsonb: bool = True, - create_extension: bool = True, - ) -> PGVector: - """Async create instance - - Args: - connection: Postgres connection string. - embeddings: Any embedding function implementing - `langchain.embeddings.base.Embeddings` interface. - embedding_length: The length of the embedding vector. (default: None) - NOTE: This is not mandatory. Defining it will prevent vectors of - any other size to be added to the embeddings table but, without it, - the embeddings can't be indexed. - collection_name: The name of the collection to use. (default: langchain) - NOTE: This is not the name of the table, but the name of the collection. - The tables will be created when initializing the store (if not exists) - So, make sure the user has the right permissions to create tables. - distance_strategy: The distance strategy to use. (default: COSINE) - pre_delete_collection: If True, will delete the collection if it exists. - (default: False). Useful for testing. - engine_args: SQLAlchemy's create engine arguments. - use_jsonb: Use JSONB instead of JSON for metadata. (default: True) - Strongly discouraged from using JSON as it's not as efficient - for querying. - It's provided here for backwards compatibility with older versions, - and will be removed in the future. - create_extension: If True, will create the vector extension if it - doesn't exist. disabling creation is useful when using ReadOnly - Databases. - """ - self = cls( - embeddings=embeddings, - connection=connection, - embedding_length=embedding_length, - collection_name=collection_name, - collection_metadata=collection_metadata, - distance_strategy=distance_strategy, - pre_delete_collection=pre_delete_collection, - logger=logger, - relevance_score_fn=relevance_score_fn, - engine_args=engine_args, - use_jsonb=use_jsonb, - create_extension=create_extension, - _async_mode=True, - ) - await self.__apost_init__() - return self - @property def embeddings(self) -> Embeddings: return self.embedding_function def create_vector_extension(self) -> None: + assert not self._async_engine, "This method must be called without async_mode" try: - with self._session_maker() as session: # type: ignore[arg-type] + with self.session_maker() as session: # type: ignore[arg-type] # The advisor lock fixes issue arising from concurrent # creation of the vector extension. # https://github.com/langchain-ai/langchain/issues/12933 @@ -487,8 +425,11 @@ def create_vector_extension(self) -> None: raise Exception(f"Failed to create vector extension: {e}") from e async def acreate_vector_extension(self) -> None: + assert self.async_mode,"This method must be called with async_mode" + await self.__apost_init__() # Lazy async init + try: - async with self._session_maker() as session: + async with self.session_maker() as session: # The advisor lock fixes issue arising from concurrent # creation of the vector extension. # https://github.com/langchain-ai/langchain/issues/12933 @@ -504,34 +445,40 @@ async def acreate_vector_extension(self) -> None: raise Exception(f"Failed to create vector extension: {e}") from e def create_tables_if_not_exists(self) -> None: - with self._session_maker() as session: + assert not self._async_engine, "This method must be called without async_mode" + with self.session_maker() as session: Base.metadata.create_all(session.get_bind()) async def acreate_tables_if_not_exists(self) -> None: - assert self._async_engine, "Use with async mode" + assert self._async_engine, "This method must be called with async_mode" + await self.__apost_init__() # Lazy async init async with self._async_engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) def drop_tables(self) -> None: - with self._session_maker() as session: + assert not self._async_engine, "This method must be called without async_mode" + with self.session_maker() as session: Base.metadata.drop_all(session.get_bind()) async def adrop_tables(self) -> None: - assert self._async_engine, "Use with async mode" + assert self._async_engine, "This method must be called with async_mode" + await self.__apost_init__() # Lazy async init async with self._async_engine.begin() as conn: await conn.run_sync(Base.metadata.drop_all) def create_collection(self) -> None: + assert not self._async_engine, "This method must be called without async_mode" if self.pre_delete_collection: self.delete_collection() - with self._session_maker() as session: + with self.session_maker() as session: self.CollectionStore.get_or_create( session, self.collection_name, cmetadata=self.collection_metadata ) async def acreate_collection(self) -> None: - assert self._async_engine, "Use with async mode" - async with self._session_maker() as session: + assert self._async_engine, "This method must be called with async_mode" + await self.__apost_init__() # Lazy async init + async with self.session_maker() as session: if self.pre_delete_collection: await self._adelete_collection(session) await self.CollectionStore.aget_or_create( @@ -555,8 +502,9 @@ async def _adelete_collection(self, session: AsyncSession) -> None: await session.delete(collection) def delete_collection(self) -> None: + assert not self._async_engine, "This method must be called without async_mode" self.logger.debug("Trying to delete collection") - with self._session_maker() as session: # type: ignore[arg-type] + with self.session_maker() as session: # type: ignore[arg-type] collection = self.get_collection(session) if not collection: self.logger.warning("Collection not found") @@ -565,8 +513,9 @@ def delete_collection(self) -> None: session.commit() async def adelete_collection(self) -> None: - self.logger.debug("Trying to delete collection") - async with self._session_maker() as session: # type: ignore[arg-type] + assert self._async_engine, "This method must be called with async_mode" + await self.__apost_init__() # Lazy async init + async with self.session_maker() as session: # type: ignore[arg-type] collection = await self.aget_collection(session) if not collection: self.logger.warning("Collection not found") @@ -586,7 +535,8 @@ def delete( ids: List of ids to delete. collection_only: Only delete ids in the collection. """ - with self._session_maker() as session: + assert not self._async_engine, "This method must be called without async_mode" + with self.session_maker() as session: if ids is not None: self.logger.debug( "Trying to delete vectors by ids (represented by the model " @@ -621,7 +571,9 @@ async def adelete( ids: List of ids to delete. collection_only: Only delete ids in the collection. """ - async with self._session_maker() as session: + assert self._async_engine, "This method must be called with async_mode" + await self.__apost_init__() # Lazy async init + async with self.session_maker() as session: if ids is not None: self.logger.debug( "Trying to delete vectors by ids (represented by the model " @@ -645,9 +597,12 @@ async def adelete( await session.commit() def get_collection(self, session: Session) -> Any: + assert not self._async_engine, "This method must be called without async_mode" return self.CollectionStore.get_by_name(session, self.collection_name) async def aget_collection(self, session: AsyncSession) -> Any: + assert self._async_engine, "This method must be called with async_mode" + await self.__apost_init__() # Lazy async init return await self.CollectionStore.aget_by_name(session, self.collection_name) @classmethod @@ -717,11 +672,9 @@ async def __afrom( distance_strategy=distance_strategy, pre_delete_collection=pre_delete_collection, use_jsonb=use_jsonb, - _async_mode=True, + async_mode=True, **kwargs, ) - # Second phase to create - await store.__apost_init__() await store.aadd_embeddings( texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs @@ -747,13 +700,14 @@ def add_embeddings( If not provided, will generate a new id for each document. kwargs: vectorstore specific parameters """ + assert not self._async_engine, "This method must be called without async_mode" if ids is None: ids = [str(uuid.uuid4()) for _ in texts] if not metadatas: metadatas = [{} for _ in texts] - with self._session_maker() as session: # type: ignore[arg-type] + with self.session_maker() as session: # type: ignore[arg-type] collection = self.get_collection(session) if not collection: raise ValueError("Collection not found") @@ -802,13 +756,15 @@ async def aadd_embeddings( If not provided, will generate a new id for each text. kwargs: vectorstore specific parameters """ + assert self._async_engine, "This method must be called with async_mode" + await self.__apost_init__() # Lazy async init if ids is None: ids = [str(uuid.uuid1()) for _ in texts] if not metadatas: metadatas = [{} for _ in texts] - async with self._session_maker() as session: # type: ignore[arg-type] + async with self.session_maker() as session: # type: ignore[arg-type] collection = await self.aget_collection(session) if not collection: raise ValueError("Collection not found") @@ -858,6 +814,7 @@ def add_texts( Returns: List of ids from adding the texts into the vectorstore. """ + assert not self._async_engine, "This method must be called without async_mode" embeddings = self.embedding_function.embed_documents(list(texts)) return self.add_embeddings( texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs @@ -882,6 +839,8 @@ async def aadd_texts( Returns: List of ids from adding the texts into the vectorstore. """ + assert self._async_engine, "This method must be called with async_mode" + await self.__apost_init__() # Lazy async init embeddings = await self.embedding_function.aembed_documents(list(texts)) return await self.aadd_embeddings( texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs @@ -904,6 +863,7 @@ def similarity_search( Returns: List of Documents most similar to the query. """ + assert not self._async_engine, "This method must be called without async_mode" embedding = self.embedding_function.embed_query(text=query) return self.similarity_search_by_vector( embedding=embedding, @@ -928,6 +888,8 @@ async def asimilarity_search( Returns: List of Documents most similar to the query. """ + assert self._async_engine, "This method must be called with async_mode" + await self.__apost_init__() # Lazy async init embedding = self.embedding_function.embed_query(text=query) return await self.asimilarity_search_by_vector( embedding=embedding, @@ -951,6 +913,7 @@ def similarity_search_with_score( Returns: List of Documents most similar to the query and score for each. """ + assert not self._async_engine, "This method must be called without async_mode" embedding = self.embedding_function.embed_query(query) docs = self.similarity_search_with_score_by_vector( embedding=embedding, k=k, filter=filter @@ -973,6 +936,8 @@ async def asimilarity_search_with_score( Returns: List of Documents most similar to the query and score for each. """ + assert self._async_engine, "This method must be called with async_mode" + await self.__apost_init__() # Lazy async init embedding = self.embedding_function.embed_query(query) docs = await self.asimilarity_search_with_score_by_vector( embedding=embedding, k=k, filter=filter @@ -999,6 +964,7 @@ def similarity_search_with_score_by_vector( k: int = 4, filter: Optional[dict] = None, ) -> List[Tuple[Document, float]]: + assert not self._async_engine, "This method must be called without async_mode" results = self.__query_collection(embedding=embedding, k=k, filter=filter) return self._results_to_docs_and_scores(results) @@ -1009,7 +975,9 @@ async def asimilarity_search_with_score_by_vector( k: int = 4, filter: Optional[dict] = None, ) -> List[Tuple[Document, float]]: - async with self._session_maker() as session: # type: ignore[arg-type] + assert self._async_engine, "This method must be called with async_mode" + await self.__apost_init__() # Lazy async init + async with self.session_maker() as session: # type: ignore[arg-type] results = await self.__aquery_collection( session=session, embedding=embedding, k=k, filter=filter ) @@ -1329,7 +1297,7 @@ def __query_collection( filter: Optional[Dict[str, str]] = None, ) -> Sequence[Any]: """Query the collection.""" - with self._session_maker() as session: # type: ignore[arg-type] + with self.session_maker() as session: # type: ignore[arg-type] collection = self.get_collection(session) if not collection: raise ValueError("Collection not found") @@ -1372,7 +1340,7 @@ async def __aquery_collection( filter: Optional[Dict[str, str]] = None, ) -> Sequence[Any]: """Query the collection.""" - async with self._session_maker() as session: # type: ignore[arg-type] + async with self.session_maker() as session: # type: ignore[arg-type] collection = await self.aget_collection(session) if not collection: raise ValueError("Collection not found") @@ -1425,6 +1393,7 @@ def similarity_search_by_vector( Returns: List of Documents most similar to the query vector. """ + assert not self._async_engine, "This method must be called without async_mode" docs_and_scores = self.similarity_search_with_score_by_vector( embedding=embedding, k=k, filter=filter ) @@ -1447,6 +1416,8 @@ async def asimilarity_search_by_vector( Returns: List of Documents most similar to the query vector. """ + assert self._async_engine, "This method must be called with async_mode" + await self.__apost_init__() # Lazy async init docs_and_scores = await self.asimilarity_search_with_score_by_vector( embedding=embedding, k=k, filter=filter ) @@ -1656,7 +1627,7 @@ async def afrom_existing_index( return the instance of the store without inserting any new embeddings """ - store = await PGVector.acreate( + store = await PGVector( connection=connection, collection_name=collection_name, embeddings=embedding, @@ -1826,6 +1797,7 @@ def max_marginal_relevance_search_with_score_by_vector( List[Tuple[Document, float]]: List of Documents selected by maximal marginal relevance to the query and score for each. """ + assert not self._async_engine, "This method must be called without async_mode" results = self.__query_collection(embedding=embedding, k=fetch_k, filter=filter) embedding_list = [result.EmbeddingStore.embedding for result in results] @@ -1871,7 +1843,9 @@ async def amax_marginal_relevance_search_with_score_by_vector( List[Tuple[Document, float]]: List of Documents selected by maximal marginal relevance to the query and score for each. """ - async with self._session_maker() as session: + assert self._async_engine, "This method must be called with async_mode" + await self.__apost_init__() # Lazy async init + async with self.session_maker() as session: results = await self.__aquery_collection( session=session, embedding=embedding, k=fetch_k, filter=filter ) @@ -1917,6 +1891,7 @@ def max_marginal_relevance_search( Returns: List[Document]: List of Documents selected by maximal marginal relevance. """ + assert not self._async_engine, "This method must be called without async_mode" embedding = self.embedding_function.embed_query(query) return self.max_marginal_relevance_search_by_vector( embedding, @@ -1955,6 +1930,8 @@ async def amax_marginal_relevance_search( Returns: List[Document]: List of Documents selected by maximal marginal relevance. """ + assert self._async_engine, "This method must be called with async_mode" + await self.__apost_init__() # Lazy async init embedding = self.embedding_function.embed_query(query) return await self.amax_marginal_relevance_search_by_vector( embedding, @@ -1994,6 +1971,7 @@ def max_marginal_relevance_search_with_score( List[Tuple[Document, float]]: List of Documents selected by maximal marginal relevance to the query and score for each. """ + assert not self._async_engine, "This method must be called without async_mode" embedding = self.embedding_function.embed_query(query) docs = self.max_marginal_relevance_search_with_score_by_vector( embedding=embedding, @@ -2034,6 +2012,8 @@ async def amax_marginal_relevance_search_with_score( List[Tuple[Document, float]]: List of Documents selected by maximal marginal relevance to the query and score for each. """ + assert self._async_engine, "This method must be called with async_mode" + await self.__apost_init__() # Lazy async init embedding = self.embedding_function.embed_query(query) docs = await self.amax_marginal_relevance_search_with_score_by_vector( embedding=embedding, @@ -2074,6 +2054,7 @@ def max_marginal_relevance_search_by_vector( Returns: List[Document]: List of Documents selected by maximal marginal relevance. """ + assert not self._async_engine, "This method must be called without async_mode" docs_and_scores = self.max_marginal_relevance_search_with_score_by_vector( embedding, k=k, @@ -2114,6 +2095,8 @@ async def amax_marginal_relevance_search_by_vector( Returns: List[Document]: List of Documents selected by maximal marginal relevance. """ + assert self._async_engine, "This method must be called with async_mode" + await self.__apost_init__() # Lazy async init docs_and_scores = ( await self.amax_marginal_relevance_search_with_score_by_vector( embedding, diff --git a/tests/unit_tests/test_vectorstore.py b/tests/unit_tests/test_vectorstore.py index e0f05e18..3b959710 100644 --- a/tests/unit_tests/test_vectorstore.py +++ b/tests/unit_tests/test_vectorstore.py @@ -285,7 +285,7 @@ def test_pgvector_collection_with_metadata() -> None: connection=CONNECTION_STRING, pre_delete_collection=True, ) - with pgvector._session_maker() as session: + with pgvector.session_maker() as session: collection = pgvector.get_collection(session) if collection is None: assert False, "Expected a CollectionStore object but received None" @@ -297,14 +297,15 @@ def test_pgvector_collection_with_metadata() -> None: @pytest.mark.asyncio async def test_async_pgvector_collection_with_metadata() -> None: """Test end to end collection construction""" - pgvector = await PGVector.acreate( + pgvector = PGVector( collection_name="test_collection", collection_metadata={"foo": "bar"}, embeddings=FakeEmbeddingsWithAdaDimension(), connection=CONNECTION_STRING, pre_delete_collection=True, + async_mode=True, ) - async with pgvector._session_maker() as session: + async with pgvector.session_maker() as session: collection = await pgvector.aget_collection(session) if collection is None: assert False, "Expected a CollectionStore object but received None" @@ -327,14 +328,14 @@ def test_pgvector_delete_docs() -> None: pre_delete_collection=True, ) vectorstore.delete(["1", "2"]) - with vectorstore._session_maker() as session: + with vectorstore.session_maker() as session: records = list(session.query(vectorstore.EmbeddingStore).all()) # ignoring type error since mypy cannot determine whether # the list is sortable assert sorted(record.id for record in records) == ["3"] # type: ignore vectorstore.delete(["2", "3"]) # Should not raise on missing ids - with vectorstore._session_maker() as session: + with vectorstore.session_maker() as session: records = list(session.query(vectorstore.EmbeddingStore).all()) # ignoring type error since mypy cannot determine whether # the list is sortable @@ -372,7 +373,7 @@ async def test_async_pgvector_delete_docs() -> None: pre_delete_collection=True, ) await vectorstore.adelete(["1", "2"]) - async with vectorstore._session_maker() as session: + async with vectorstore.session_maker() as session: records = ( (await session.execute(select(vectorstore.EmbeddingStore))).scalars().all() ) @@ -381,7 +382,7 @@ async def test_async_pgvector_delete_docs() -> None: assert sorted(record.id for record in records) == ["3"] # type: ignore await vectorstore.adelete(["2", "3"]) # Should not raise on missing ids - async with vectorstore._session_maker() as session: + async with vectorstore.session_maker() as session: records = ( (await session.execute(select(vectorstore.EmbeddingStore))).scalars().all() ) @@ -423,7 +424,7 @@ def test_pgvector_index_documents() -> None: connection=CONNECTION_STRING, pre_delete_collection=True, ) - with vectorstore._session_maker() as session: + with vectorstore.session_maker() as session: records = list(session.query(vectorstore.EmbeddingStore).all()) # ignoring type error since mypy cannot determine whether # the list is sortable @@ -445,7 +446,7 @@ def test_pgvector_index_documents() -> None: vectorstore.add_documents(documents, ids=[doc.metadata["id"] for doc in documents]) - with vectorstore._session_maker() as session: + with vectorstore.session_maker() as session: records = list(session.query(vectorstore.EmbeddingStore).all()) ordered_records = sorted(records, key=lambda x: x.id) # ignoring type error since mypy cannot determine whether @@ -499,7 +500,7 @@ async def test_async_pgvector_index_documents() -> None: connection=CONNECTION_STRING, pre_delete_collection=True, ) - async with vectorstore._session_maker() as session: + async with vectorstore.session_maker() as session: records = ( (await session.execute(select(vectorstore.EmbeddingStore))).scalars().all() ) @@ -525,7 +526,7 @@ async def test_async_pgvector_index_documents() -> None: documents, ids=[doc.metadata["id"] for doc in documents] ) - async with vectorstore._session_maker() as session: + async with vectorstore.session_maker() as session: records = ( (await session.execute(select(vectorstore.EmbeddingStore))).scalars().all() ) From dc5c637f4d40a53319c6a0adff2f95011077d15f Mon Sep 17 00:00:00 2001 From: Philippe Prados Date: Mon, 6 May 2024 15:54:40 +0200 Subject: [PATCH 14/18] reformat --- langchain_postgres/vectorstores.py | 145 +++++++++++++++-------------- 1 file changed, 74 insertions(+), 71 deletions(-) diff --git a/langchain_postgres/vectorstores.py b/langchain_postgres/vectorstores.py index c32739b8..8f8f7a35 100644 --- a/langchain_postgres/vectorstores.py +++ b/langchain_postgres/vectorstores.py @@ -1,3 +1,4 @@ +# pylint: disable=too-many-lines from __future__ import annotations import enum @@ -15,22 +16,19 @@ Type, Union, ) +from typing import ( + cast as typing_cast, +) import numpy as np import sqlalchemy - -# try: -# from sqlalchemy.orm import declarative_base -# except ImportError: -# from sqlalchemy.ext.declarative import declarative_base from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.utils import get_from_dict_or_env from langchain_core.vectorstores import VectorStore from sqlalchemy import SQLColumnExpression, cast, create_engine, delete, func, select from sqlalchemy.dialects.postgresql import JSON, JSONB, JSONPATH, UUID, insert -from sqlalchemy.orm import scoped_session -from sqlalchemy.engine import Engine +from sqlalchemy.engine import Connection, Engine from sqlalchemy.ext.asyncio import ( AsyncEngine, AsyncSession, @@ -41,6 +39,7 @@ Session, declarative_base, relationship, + scoped_session, sessionmaker, ) @@ -59,8 +58,10 @@ class DistanceStrategy(str, enum.Enum): Base = declarative_base() # type: Any + _LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain" + _classes: Any = None COMPARISONS_TO_NATIVE = { @@ -122,14 +123,24 @@ class CollectionStore(Base): def get_by_name( cls, session: Session, name: str ) -> Optional["CollectionStore"]: - return session.query(cls).filter(cls.name == name).first() # type: ignore + return ( + session.query(cls) + .filter(typing_cast(sqlalchemy.Column, cls.name) == name) + .first() + ) @classmethod async def aget_by_name( cls, session: AsyncSession, name: str ) -> Optional["CollectionStore"]: return ( - (await session.execute(select(CollectionStore).where(cls.name == name))) + ( + await session.execute( + select(CollectionStore).where( + typing_cast(sqlalchemy.Column, cls.name) == name + ) + ) + ) .scalars() .first() ) @@ -219,7 +230,16 @@ def _results_to_docs(docs_and_scores: Any) -> List[Document]: return [doc for doc, _ in docs_and_scores] -Connection = Union[sqlalchemy.engine.Engine, str] +def _create_vector_extension(conn: Connection) -> None: + statement = sqlalchemy.text( + "SELECT pg_advisory_xact_lock(1573678846307946496);" + "CREATE EXTENSION IF NOT EXISTS vector;" + ) + conn.execute(statement) + conn.commit() + + +DBConnection = Union[sqlalchemy.engine.Engine, str] class PGVector(VectorStore): @@ -257,6 +277,7 @@ class PGVector(VectorStore): connection=connection_string, collection_name=collection_name, use_jsonb=True, + async_mode=False, ) @@ -274,13 +295,16 @@ class PGVector(VectorStore): You will need to recreate the tables if you are using an existing database. * A Connection object has to be provided explicitly. Connections will not be picked up automatically based on env variables. + * langchain_postgres now accept async connections. If you want to use the async + version, you need to set `async_mode=True` when initializing the store or + use an async engine. """ def __init__( self, embeddings: Embeddings, *, - connection: Union[None, Connection, Engine, AsyncEngine, str] = None, + connection: Union[None, DBConnection, Engine, AsyncEngine, str] = None, embedding_length: Optional[int] = None, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, collection_metadata: Optional[dict] = None, @@ -297,7 +321,7 @@ def __init__( For an async version, use `PGVector.acreate()` instead. Args: - connection: Postgres connection string. + connection: Postgres connection string or (async)engine. embeddings: Any embedding function implementing `langchain.embeddings.base.Embeddings` interface. embedding_length: The length of the embedding vector. (default: None) @@ -332,7 +356,7 @@ def __init__( self.override_relevance_score_fn = relevance_score_fn self._engine: Optional[Engine] = None self._async_engine: Optional[AsyncEngine] = None - self._async_init=False + self._async_init = False if isinstance(connection, str): if async_mode: @@ -342,6 +366,7 @@ def __init__( else: self._engine = create_engine(url=connection, **(engine_args or {})) elif isinstance(connection, Engine): + self.async_mode = False self._engine = connection elif isinstance(connection, AsyncEngine): self.async_mode = True @@ -351,7 +376,7 @@ def __init__( "connection should be a connection string or an instance of " "sqlalchemy.engine.Engine or sqlalchemy.ext.asyncio.engine.AsyncEngine" ) - self._session_maker: Union[sessionmaker, async_sessionmaker] + self.session_maker: Union[scoped_session, async_sessionmaker] if self.async_mode: self.session_maker = async_sessionmaker(bind=self._async_engine) else: @@ -363,7 +388,7 @@ def __init__( if not use_jsonb: # Replace with a deprecation warning. raise NotImplementedError("use_jsonb=False is no longer supported.") - if not async_mode: + if not self.async_mode: self.__post_init__() def __post_init__( @@ -384,19 +409,19 @@ def __post_init__( async def __apost_init__( self, ) -> None: - """Async initialize the store.""" - if self._async_init: + """Async initialize the store (use lazy approach).""" + if self._async_init: # Warning: possible race condition return self._async_init = True - if self.create_extension: - await self.acreate_vector_extension() - EmbeddingStore, CollectionStore = _get_embedding_collection_store( self._embedding_length ) self.CollectionStore = CollectionStore self.EmbeddingStore = EmbeddingStore + if self.create_extension: + await self.acreate_vector_extension() + await self.acreate_tables_if_not_exists() await self.acreate_collection() @@ -407,51 +432,25 @@ def embeddings(self) -> Embeddings: def create_vector_extension(self) -> None: assert not self._async_engine, "This method must be called without async_mode" try: - with self.session_maker() as session: # type: ignore[arg-type] - # The advisor lock fixes issue arising from concurrent - # creation of the vector extension. - # https://github.com/langchain-ai/langchain/issues/12933 - # For more information see: - # https://www.postgresql.org/docs/16/explicit-locking.html#ADVISORY-LOCKS - statement = sqlalchemy.text( - "BEGIN;" - "SELECT pg_advisory_xact_lock(1573678846307946496);" - "CREATE EXTENSION IF NOT EXISTS vector;" - "COMMIT;" - ) - session.execute(statement) - session.commit() + with self._engine.connect() as conn: + _create_vector_extension(conn) except Exception as e: raise Exception(f"Failed to create vector extension: {e}") from e async def acreate_vector_extension(self) -> None: - assert self.async_mode,"This method must be called with async_mode" - await self.__apost_init__() # Lazy async init + assert self.async_mode, "This method must be called with async_mode" - try: - async with self.session_maker() as session: - # The advisor lock fixes issue arising from concurrent - # creation of the vector extension. - # https://github.com/langchain-ai/langchain/issues/12933 - # For more information see: - # https://www.postgresql.org/docs/16/explicit-locking.html#ADVISORY-LOCKS - await session.execute( - sqlalchemy.text("SELECT pg_advisory_xact_lock(1573678846307946496)") - ) - await session.execute( - sqlalchemy.text("CREATE EXTENSION IF NOT EXISTS vector") - ) - except Exception as e: - raise Exception(f"Failed to create vector extension: {e}") from e + async with self._async_engine.begin() as conn: + await conn.run_sync(_create_vector_extension) def create_tables_if_not_exists(self) -> None: assert not self._async_engine, "This method must be called without async_mode" with self.session_maker() as session: Base.metadata.create_all(session.get_bind()) + session.commit() async def acreate_tables_if_not_exists(self) -> None: assert self._async_engine, "This method must be called with async_mode" - await self.__apost_init__() # Lazy async init async with self._async_engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) @@ -459,10 +458,11 @@ def drop_tables(self) -> None: assert not self._async_engine, "This method must be called without async_mode" with self.session_maker() as session: Base.metadata.drop_all(session.get_bind()) + session.commit() async def adrop_tables(self) -> None: assert self._async_engine, "This method must be called with async_mode" - await self.__apost_init__() # Lazy async init + await self.__apost_init__() # Lazy async init async with self._async_engine.begin() as conn: await conn.run_sync(Base.metadata.drop_all) @@ -474,16 +474,18 @@ def create_collection(self) -> None: self.CollectionStore.get_or_create( session, self.collection_name, cmetadata=self.collection_metadata ) + session.commit() async def acreate_collection(self) -> None: assert self._async_engine, "This method must be called with async_mode" - await self.__apost_init__() # Lazy async init + await self.__apost_init__() # Lazy async init async with self.session_maker() as session: if self.pre_delete_collection: await self._adelete_collection(session) await self.CollectionStore.aget_or_create( session, self.collection_name, cmetadata=self.collection_metadata ) + await session.commit() def _delete_collection(self, session: Session) -> None: self.logger.debug("Trying to delete collection") @@ -514,7 +516,7 @@ def delete_collection(self) -> None: async def adelete_collection(self) -> None: assert self._async_engine, "This method must be called with async_mode" - await self.__apost_init__() # Lazy async init + await self.__apost_init__() # Lazy async init async with self.session_maker() as session: # type: ignore[arg-type] collection = await self.aget_collection(session) if not collection: @@ -572,7 +574,7 @@ async def adelete( collection_only: Only delete ids in the collection. """ assert self._async_engine, "This method must be called with async_mode" - await self.__apost_init__() # Lazy async init + await self.__apost_init__() # Lazy async init async with self.session_maker() as session: if ids is not None: self.logger.debug( @@ -602,7 +604,7 @@ def get_collection(self, session: Session) -> Any: async def aget_collection(self, session: AsyncSession) -> Any: assert self._async_engine, "This method must be called with async_mode" - await self.__apost_init__() # Lazy async init + await self.__apost_init__() # Lazy async init return await self.CollectionStore.aget_by_name(session, self.collection_name) @classmethod @@ -757,7 +759,7 @@ async def aadd_embeddings( kwargs: vectorstore specific parameters """ assert self._async_engine, "This method must be called with async_mode" - await self.__apost_init__() # Lazy async init + await self.__apost_init__() # Lazy async init if ids is None: ids = [str(uuid.uuid1()) for _ in texts] @@ -840,7 +842,7 @@ async def aadd_texts( List of ids from adding the texts into the vectorstore. """ assert self._async_engine, "This method must be called with async_mode" - await self.__apost_init__() # Lazy async init + await self.__apost_init__() # Lazy async init embeddings = await self.embedding_function.aembed_documents(list(texts)) return await self.aadd_embeddings( texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs @@ -889,7 +891,7 @@ async def asimilarity_search( List of Documents most similar to the query. """ assert self._async_engine, "This method must be called with async_mode" - await self.__apost_init__() # Lazy async init + await self.__apost_init__() # Lazy async init embedding = self.embedding_function.embed_query(text=query) return await self.asimilarity_search_by_vector( embedding=embedding, @@ -937,7 +939,7 @@ async def asimilarity_search_with_score( List of Documents most similar to the query and score for each. """ assert self._async_engine, "This method must be called with async_mode" - await self.__apost_init__() # Lazy async init + await self.__apost_init__() # Lazy async init embedding = self.embedding_function.embed_query(query) docs = await self.asimilarity_search_with_score_by_vector( embedding=embedding, k=k, filter=filter @@ -976,7 +978,7 @@ async def asimilarity_search_with_score_by_vector( filter: Optional[dict] = None, ) -> List[Tuple[Document, float]]: assert self._async_engine, "This method must be called with async_mode" - await self.__apost_init__() # Lazy async init + await self.__apost_init__() # Lazy async init async with self.session_maker() as session: # type: ignore[arg-type] results = await self.__aquery_collection( session=session, embedding=embedding, k=k, filter=filter @@ -1417,7 +1419,7 @@ async def asimilarity_search_by_vector( List of Documents most similar to the query vector. """ assert self._async_engine, "This method must be called with async_mode" - await self.__apost_init__() # Lazy async init + await self.__apost_init__() # Lazy async init docs_and_scores = await self.asimilarity_search_with_score_by_vector( embedding=embedding, k=k, filter=filter ) @@ -1592,7 +1594,7 @@ def from_existing_index( collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, pre_delete_collection: bool = False, - connection: Optional[Connection] = None, + connection: Optional[DBConnection] = None, **kwargs: Any, ) -> PGVector: """ @@ -1619,7 +1621,7 @@ async def afrom_existing_index( collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, pre_delete_collection: bool = False, - connection: Optional[Connection] = None, + connection: Optional[DBConnection] = None, **kwargs: Any, ) -> PGVector: """ @@ -1627,12 +1629,13 @@ async def afrom_existing_index( return the instance of the store without inserting any new embeddings """ - store = await PGVector( + store = PGVector( connection=connection, collection_name=collection_name, embeddings=embedding, distance_strategy=distance_strategy, pre_delete_collection=pre_delete_collection, + async_mode=True, **kwargs, ) @@ -1661,7 +1664,7 @@ def from_documents( documents: List[Document], embedding: Embeddings, *, - connection: Optional[Connection] = None, + connection: Optional[DBConnection] = None, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, ids: Optional[List[str]] = None, @@ -1844,7 +1847,7 @@ async def amax_marginal_relevance_search_with_score_by_vector( relevance to the query and score for each. """ assert self._async_engine, "This method must be called with async_mode" - await self.__apost_init__() # Lazy async init + await self.__apost_init__() # Lazy async init async with self.session_maker() as session: results = await self.__aquery_collection( session=session, embedding=embedding, k=fetch_k, filter=filter @@ -1931,7 +1934,7 @@ async def amax_marginal_relevance_search( List[Document]: List of Documents selected by maximal marginal relevance. """ assert self._async_engine, "This method must be called with async_mode" - await self.__apost_init__() # Lazy async init + await self.__apost_init__() # Lazy async init embedding = self.embedding_function.embed_query(query) return await self.amax_marginal_relevance_search_by_vector( embedding, @@ -2013,7 +2016,7 @@ async def amax_marginal_relevance_search_with_score( relevance to the query and score for each. """ assert self._async_engine, "This method must be called with async_mode" - await self.__apost_init__() # Lazy async init + await self.__apost_init__() # Lazy async init embedding = self.embedding_function.embed_query(query) docs = await self.amax_marginal_relevance_search_with_score_by_vector( embedding=embedding, @@ -2096,7 +2099,7 @@ async def amax_marginal_relevance_search_by_vector( List[Document]: List of Documents selected by maximal marginal relevance. """ assert self._async_engine, "This method must be called with async_mode" - await self.__apost_init__() # Lazy async init + await self.__apost_init__() # Lazy async init docs_and_scores = ( await self.amax_marginal_relevance_search_with_score_by_vector( embedding, From 045acfb0cb6ed3632ac606171077b1904740d11e Mon Sep 17 00:00:00 2001 From: Philippe Prados Date: Thu, 16 May 2024 12:07:27 +0200 Subject: [PATCH 15/18] Fix lint --- langchain_postgres/vectorstores.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/langchain_postgres/vectorstores.py b/langchain_postgres/vectorstores.py index 8f8f7a35..f00107c4 100644 --- a/langchain_postgres/vectorstores.py +++ b/langchain_postgres/vectorstores.py @@ -431,6 +431,7 @@ def embeddings(self) -> Embeddings: def create_vector_extension(self) -> None: assert not self._async_engine, "This method must be called without async_mode" + assert self._engine, "engine not found" try: with self._engine.connect() as conn: _create_vector_extension(conn) @@ -439,6 +440,7 @@ def create_vector_extension(self) -> None: async def acreate_vector_extension(self) -> None: assert self.async_mode, "This method must be called with async_mode" + assert self._async_engine, "_async_engine not found" async with self._async_engine.begin() as conn: await conn.run_sync(_create_vector_extension) From 3957a793613517a5404b012c44a8d66c03f74046 Mon Sep 17 00:00:00 2001 From: Philippe Prados Date: Tue, 21 May 2024 11:05:37 +0200 Subject: [PATCH 16/18] Rebase --- tests/unit_tests/test_vectorstore.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unit_tests/test_vectorstore.py b/tests/unit_tests/test_vectorstore.py index 3b959710..5a4f8772 100644 --- a/tests/unit_tests/test_vectorstore.py +++ b/tests/unit_tests/test_vectorstore.py @@ -1005,6 +1005,7 @@ def test_validate_operators() -> None: "$and", "$between", "$eq", + "$exists", "$gt", "$gte", "$ilike", From dd7d7cd200a355261101bcad105b7da6386952ab Mon Sep 17 00:00:00 2001 From: Philippe Prados Date: Mon, 27 May 2024 08:29:04 +0200 Subject: [PATCH 17/18] Align the code with SQLChatMessageHistory --- langchain_postgres/vectorstores.py | 86 ++++++++++++++-------------- tests/unit_tests/test_vectorstore.py | 1 - 2 files changed, 43 insertions(+), 44 deletions(-) diff --git a/langchain_postgres/vectorstores.py b/langchain_postgres/vectorstores.py index f00107c4..659b6f32 100644 --- a/langchain_postgres/vectorstores.py +++ b/langchain_postgres/vectorstores.py @@ -1,13 +1,16 @@ # pylint: disable=too-many-lines from __future__ import annotations +import contextlib import enum import logging import uuid from typing import ( Any, + AsyncGenerator, Callable, Dict, + Generator, Iterable, List, Optional, @@ -430,7 +433,6 @@ def embeddings(self) -> Embeddings: return self.embedding_function def create_vector_extension(self) -> None: - assert not self._async_engine, "This method must be called without async_mode" assert self._engine, "engine not found" try: with self._engine.connect() as conn: @@ -439,15 +441,13 @@ def create_vector_extension(self) -> None: raise Exception(f"Failed to create vector extension: {e}") from e async def acreate_vector_extension(self) -> None: - assert self.async_mode, "This method must be called with async_mode" assert self._async_engine, "_async_engine not found" async with self._async_engine.begin() as conn: await conn.run_sync(_create_vector_extension) def create_tables_if_not_exists(self) -> None: - assert not self._async_engine, "This method must be called without async_mode" - with self.session_maker() as session: + with self._make_sync_session() as session: Base.metadata.create_all(session.get_bind()) session.commit() @@ -457,8 +457,7 @@ async def acreate_tables_if_not_exists(self) -> None: await conn.run_sync(Base.metadata.create_all) def drop_tables(self) -> None: - assert not self._async_engine, "This method must be called without async_mode" - with self.session_maker() as session: + with self._make_sync_session() as session: Base.metadata.drop_all(session.get_bind()) session.commit() @@ -469,19 +468,17 @@ async def adrop_tables(self) -> None: await conn.run_sync(Base.metadata.drop_all) def create_collection(self) -> None: - assert not self._async_engine, "This method must be called without async_mode" if self.pre_delete_collection: self.delete_collection() - with self.session_maker() as session: + with self._make_sync_session() as session: self.CollectionStore.get_or_create( session, self.collection_name, cmetadata=self.collection_metadata ) session.commit() async def acreate_collection(self) -> None: - assert self._async_engine, "This method must be called with async_mode" await self.__apost_init__() # Lazy async init - async with self.session_maker() as session: + async with self._make_async_session() as session: if self.pre_delete_collection: await self._adelete_collection(session) await self.CollectionStore.aget_or_create( @@ -490,7 +487,6 @@ async def acreate_collection(self) -> None: await session.commit() def _delete_collection(self, session: Session) -> None: - self.logger.debug("Trying to delete collection") collection = self.get_collection(session) if not collection: self.logger.warning("Collection not found") @@ -498,7 +494,6 @@ def _delete_collection(self, session: Session) -> None: session.delete(collection) async def _adelete_collection(self, session: AsyncSession) -> None: - self.logger.debug("Trying to delete collection") collection = await self.aget_collection(session) if not collection: self.logger.warning("Collection not found") @@ -506,9 +501,7 @@ async def _adelete_collection(self, session: AsyncSession) -> None: await session.delete(collection) def delete_collection(self) -> None: - assert not self._async_engine, "This method must be called without async_mode" - self.logger.debug("Trying to delete collection") - with self.session_maker() as session: # type: ignore[arg-type] + with self._make_sync_session() as session: collection = self.get_collection(session) if not collection: self.logger.warning("Collection not found") @@ -517,14 +510,13 @@ def delete_collection(self) -> None: session.commit() async def adelete_collection(self) -> None: - assert self._async_engine, "This method must be called with async_mode" await self.__apost_init__() # Lazy async init - async with self.session_maker() as session: # type: ignore[arg-type] + async with self._make_async_session() as session: collection = await self.aget_collection(session) if not collection: self.logger.warning("Collection not found") return - await session.adelete(collection) + await session.delete(collection) await session.commit() def delete( @@ -539,8 +531,7 @@ def delete( ids: List of ids to delete. collection_only: Only delete ids in the collection. """ - assert not self._async_engine, "This method must be called without async_mode" - with self.session_maker() as session: + with self._make_sync_session() as session: if ids is not None: self.logger.debug( "Trying to delete vectors by ids (represented by the model " @@ -575,9 +566,8 @@ async def adelete( ids: List of ids to delete. collection_only: Only delete ids in the collection. """ - assert self._async_engine, "This method must be called with async_mode" await self.__apost_init__() # Lazy async init - async with self.session_maker() as session: + async with self._make_async_session() as session: if ids is not None: self.logger.debug( "Trying to delete vectors by ids (represented by the model " @@ -704,14 +694,14 @@ def add_embeddings( If not provided, will generate a new id for each document. kwargs: vectorstore specific parameters """ - assert not self._async_engine, "This method must be called without async_mode" + assert not self._async_engine, "This method must be called with sync_mode" if ids is None: ids = [str(uuid.uuid4()) for _ in texts] if not metadatas: metadatas = [{} for _ in texts] - with self.session_maker() as session: # type: ignore[arg-type] + with self._make_sync_session() as session: # type: ignore[arg-type] collection = self.get_collection(session) if not collection: raise ValueError("Collection not found") @@ -760,7 +750,6 @@ async def aadd_embeddings( If not provided, will generate a new id for each text. kwargs: vectorstore specific parameters """ - assert self._async_engine, "This method must be called with async_mode" await self.__apost_init__() # Lazy async init if ids is None: ids = [str(uuid.uuid1()) for _ in texts] @@ -768,7 +757,7 @@ async def aadd_embeddings( if not metadatas: metadatas = [{} for _ in texts] - async with self.session_maker() as session: # type: ignore[arg-type] + async with self._make_async_session() as session: # type: ignore[arg-type] collection = await self.aget_collection(session) if not collection: raise ValueError("Collection not found") @@ -843,7 +832,6 @@ async def aadd_texts( Returns: List of ids from adding the texts into the vectorstore. """ - assert self._async_engine, "This method must be called with async_mode" await self.__apost_init__() # Lazy async init embeddings = await self.embedding_function.aembed_documents(list(texts)) return await self.aadd_embeddings( @@ -892,7 +880,6 @@ async def asimilarity_search( Returns: List of Documents most similar to the query. """ - assert self._async_engine, "This method must be called with async_mode" await self.__apost_init__() # Lazy async init embedding = self.embedding_function.embed_query(text=query) return await self.asimilarity_search_by_vector( @@ -940,7 +927,6 @@ async def asimilarity_search_with_score( Returns: List of Documents most similar to the query and score for each. """ - assert self._async_engine, "This method must be called with async_mode" await self.__apost_init__() # Lazy async init embedding = self.embedding_function.embed_query(query) docs = await self.asimilarity_search_with_score_by_vector( @@ -979,9 +965,8 @@ async def asimilarity_search_with_score_by_vector( k: int = 4, filter: Optional[dict] = None, ) -> List[Tuple[Document, float]]: - assert self._async_engine, "This method must be called with async_mode" await self.__apost_init__() # Lazy async init - async with self.session_maker() as session: # type: ignore[arg-type] + async with self._make_async_session() as session: # type: ignore[arg-type] results = await self.__aquery_collection( session=session, embedding=embedding, k=k, filter=filter ) @@ -1301,7 +1286,7 @@ def __query_collection( filter: Optional[Dict[str, str]] = None, ) -> Sequence[Any]: """Query the collection.""" - with self.session_maker() as session: # type: ignore[arg-type] + with self._make_sync_session() as session: # type: ignore[arg-type] collection = self.get_collection(session) if not collection: raise ValueError("Collection not found") @@ -1322,7 +1307,7 @@ def __query_collection( results: List[Any] = ( session.query( self.EmbeddingStore, - self.distance_strategy(embedding).label("distance"), # type: ignore + self.distance_strategy(embedding).label("distance"), ) .filter(*filter_by) .order_by(sqlalchemy.asc("distance")) @@ -1344,7 +1329,7 @@ async def __aquery_collection( filter: Optional[Dict[str, str]] = None, ) -> Sequence[Any]: """Query the collection.""" - async with self.session_maker() as session: # type: ignore[arg-type] + async with self._make_async_session() as session: # type: ignore[arg-type] collection = await self.aget_collection(session) if not collection: raise ValueError("Collection not found") @@ -1365,7 +1350,7 @@ async def __aquery_collection( stmt = ( select( self.EmbeddingStore, - self.distance_strategy(embedding).label("distance"), # type: ignore + self.distance_strategy(embedding).label("distance"), ) .filter(*filter_by) .order_by(sqlalchemy.asc("distance")) @@ -1848,9 +1833,8 @@ async def amax_marginal_relevance_search_with_score_by_vector( List[Tuple[Document, float]]: List of Documents selected by maximal marginal relevance to the query and score for each. """ - assert self._async_engine, "This method must be called with async_mode" await self.__apost_init__() # Lazy async init - async with self.session_maker() as session: + async with self._make_async_session() as session: results = await self.__aquery_collection( session=session, embedding=embedding, k=fetch_k, filter=filter ) @@ -1896,7 +1880,6 @@ def max_marginal_relevance_search( Returns: List[Document]: List of Documents selected by maximal marginal relevance. """ - assert not self._async_engine, "This method must be called without async_mode" embedding = self.embedding_function.embed_query(query) return self.max_marginal_relevance_search_by_vector( embedding, @@ -1935,7 +1918,6 @@ async def amax_marginal_relevance_search( Returns: List[Document]: List of Documents selected by maximal marginal relevance. """ - assert self._async_engine, "This method must be called with async_mode" await self.__apost_init__() # Lazy async init embedding = self.embedding_function.embed_query(query) return await self.amax_marginal_relevance_search_by_vector( @@ -1976,7 +1958,6 @@ def max_marginal_relevance_search_with_score( List[Tuple[Document, float]]: List of Documents selected by maximal marginal relevance to the query and score for each. """ - assert not self._async_engine, "This method must be called without async_mode" embedding = self.embedding_function.embed_query(query) docs = self.max_marginal_relevance_search_with_score_by_vector( embedding=embedding, @@ -2017,7 +1998,6 @@ async def amax_marginal_relevance_search_with_score( List[Tuple[Document, float]]: List of Documents selected by maximal marginal relevance to the query and score for each. """ - assert self._async_engine, "This method must be called with async_mode" await self.__apost_init__() # Lazy async init embedding = self.embedding_function.embed_query(query) docs = await self.amax_marginal_relevance_search_with_score_by_vector( @@ -2059,7 +2039,6 @@ def max_marginal_relevance_search_by_vector( Returns: List[Document]: List of Documents selected by maximal marginal relevance. """ - assert not self._async_engine, "This method must be called without async_mode" docs_and_scores = self.max_marginal_relevance_search_with_score_by_vector( embedding, k=k, @@ -2100,7 +2079,6 @@ async def amax_marginal_relevance_search_by_vector( Returns: List[Document]: List of Documents selected by maximal marginal relevance. """ - assert self._async_engine, "This method must be called with async_mode" await self.__apost_init__() # Lazy async init docs_and_scores = ( await self.amax_marginal_relevance_search_with_score_by_vector( @@ -2114,3 +2092,25 @@ async def amax_marginal_relevance_search_by_vector( ) return _results_to_docs(docs_and_scores) + + @contextlib.contextmanager + def _make_sync_session(self) -> Generator[Session, None, None]: + """Make an async session.""" + if self.async_mode: + raise ValueError( + "Attempting to use a sync method in when async mode is turned on. " + "Please use the corresponding async method instead." + ) + with self.session_maker() as session: + yield typing_cast(Session, session) + + @contextlib.asynccontextmanager + async def _make_async_session(self) -> AsyncGenerator[AsyncSession, None]: + """Make an async session.""" + if not self.async_mode: + raise ValueError( + "Attempting to use an async method in when sync mode is turned on. " + "Please use the corresponding async method instead." + ) + async with self.session_maker() as session: + yield typing_cast(AsyncSession, session) diff --git a/tests/unit_tests/test_vectorstore.py b/tests/unit_tests/test_vectorstore.py index 5a4f8772..5ad81169 100644 --- a/tests/unit_tests/test_vectorstore.py +++ b/tests/unit_tests/test_vectorstore.py @@ -18,7 +18,6 @@ TYPE_3_FILTERING_TEST_CASES, TYPE_4_FILTERING_TEST_CASES, TYPE_5_FILTERING_TEST_CASES, - TYPE_6_FILTERING_TEST_CASES, ) from tests.utils import VECTORSTORE_CONNECTION_STRING as CONNECTION_STRING From e8bf4d177c0aad437895de2674eb6139664785d9 Mon Sep 17 00:00:00 2001 From: Philippe Prados Date: Mon, 10 Jun 2024 16:47:52 +0200 Subject: [PATCH 18/18] Fix test --- tests/unit_tests/test_vectorstore.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/unit_tests/test_vectorstore.py b/tests/unit_tests/test_vectorstore.py index 5ad81169..fcba8efc 100644 --- a/tests/unit_tests/test_vectorstore.py +++ b/tests/unit_tests/test_vectorstore.py @@ -990,6 +990,8 @@ async def test_async_pgvector_with_with_metadata_filters_5( {"$and": {}}, {"$between": {}}, {"$eq": {}}, + {"$exists": {}}, + {"$exists": 1}, ], ) def test_invalid_filters(pgvector: PGVector, invalid_filter: Any) -> None: