From 9be659b69a918a3856ab1e9fdf051566ee90c1c3 Mon Sep 17 00:00:00 2001 From: Averi Kitsch Date: Mon, 7 Apr 2025 18:58:34 -0700 Subject: [PATCH 1/5] fix: remove string filters and parameterize filters --- langchain_postgres/v2/async_vectorstore.py | 183 ++++++++++++------ langchain_postgres/v2/vectorstores.py | 28 +-- .../v2/test_async_pg_vectorstore_search.py | 86 ++++---- .../v2/test_pg_vectorstore_search.py | 88 ++++----- 4 files changed, 218 insertions(+), 167 deletions(-) diff --git a/langchain_postgres/v2/async_vectorstore.py b/langchain_postgres/v2/async_vectorstore.py index bae31ba2..e70e06a1 100644 --- a/langchain_postgres/v2/async_vectorstore.py +++ b/langchain_postgres/v2/async_vectorstore.py @@ -3,7 +3,6 @@ import copy import json -import re import uuid from typing import Any, Callable, Iterable, Optional, Sequence @@ -112,7 +111,9 @@ def __init__( self.schema_name = schema_name self.content_column = content_column self.embedding_column = embedding_column - self.metadata_columns = metadata_columns if metadata_columns is not None else [] + self.metadata_columns = ( + metadata_columns if metadata_columns is not None else [] + ) self.id_column = id_column self.metadata_json_column = metadata_json_column self.distance_strategy = distance_strategy @@ -175,7 +176,8 @@ async def create( stmt = "SELECT column_name, data_type FROM information_schema.columns WHERE table_name = :table_name AND table_schema = :schema_name" async with engine._pool.connect() as conn: result = await conn.execute( - text(stmt), {"table_name": table_name, "schema_name": schema_name} + text(stmt), + {"table_name": table_name, "schema_name": schema_name}, ) result_map = result.mappings() results = result_map.fetchall() @@ -187,21 +189,27 @@ async def create( if id_column not in columns: raise ValueError(f"Id column, {id_column}, does not exist.") if content_column not in columns: - raise ValueError(f"Content column, {content_column}, does not exist.") + raise ValueError( + f"Content column, {content_column}, does not exist." + ) content_type = columns[content_column] if content_type != "text" and "char" not in content_type: raise ValueError( f"Content column, {content_column}, is type, {content_type}. It must be a type of character string." ) if embedding_column not in columns: - raise ValueError(f"Embedding column, {embedding_column}, does not exist.") + raise ValueError( + f"Embedding column, {embedding_column}, does not exist." + ) if columns[embedding_column] != "USER-DEFINED": raise ValueError( f"Embedding column, {embedding_column}, is not type Vector." ) metadata_json_column = ( - None if metadata_json_column not in columns else metadata_json_column + None + if metadata_json_column not in columns + else metadata_json_column ) # If using metadata_columns check to make sure column exists @@ -264,10 +272,14 @@ async def aadd_embeddings( metadatas = [{} for _ in texts] # Check for inline embedding capability - inline_embed_func = getattr(self.embedding_service, "embed_query_inline", None) + inline_embed_func = getattr( + self.embedding_service, "embed_query_inline", None + ) can_inline_embed = callable(inline_embed_func) # Insert embeddings - for id, content, embedding, metadata in zip(ids, texts, embeddings, metadatas): + for id, content, embedding, metadata in zip( + ids, texts, embeddings, metadatas + ): metadata_col_names = ( ", " + ", ".join(f'"{col}"' for col in self.metadata_columns) if len(self.metadata_columns) > 0 @@ -336,11 +348,15 @@ async def aadd_texts( :class:`InvalidTextRepresentationError `: if the `ids` data type does not match that of the `id_column`. """ # Check for inline embedding query - inline_embed_func = getattr(self.embedding_service, "embed_query_inline", None) + inline_embed_func = getattr( + self.embedding_service, "embed_query_inline", None + ) if callable(inline_embed_func): embeddings: list[list[float]] = [[] for _ in list(texts)] else: - embeddings = await self.embedding_service.aembed_documents(list(texts)) + embeddings = await self.embedding_service.aembed_documents( + list(texts) + ) ids = await self.aadd_embeddings( texts, embeddings, metadatas=metadatas, ids=ids, **kwargs @@ -362,7 +378,9 @@ async def aadd_documents( metadatas = [doc.metadata for doc in documents] if not ids: ids = [doc.id for doc in documents] - ids = await self.aadd_texts(texts, metadatas=metadatas, ids=ids, **kwargs) + ids = await self.aadd_texts( + texts, metadatas=metadatas, ids=ids, **kwargs + ) return ids async def adelete( @@ -535,7 +553,7 @@ async def __query_collection( embedding: list[float], *, k: Optional[int] = None, - filter: Optional[dict] | Optional[str] = None, + filter: Optional[dict] = None, **kwargs: Any, ) -> Sequence[RowMapping]: """Perform similarity search query on database.""" @@ -553,16 +571,24 @@ async def __query_collection( column_names = ", ".join(f'"{col}"' for col in columns) + safe_filter = None + filter_dict = None if filter and isinstance(filter, dict): - filter = self._create_filter_clause(filter) - filter = f"WHERE {filter}" if filter else "" - inline_embed_func = getattr(self.embedding_service, "embed_query_inline", None) + safe_filter, filter_dict = self._create_filter_clause(filter) + filter = f"WHERE {safe_filter}" if safe_filter else "" + inline_embed_func = getattr( + self.embedding_service, "embed_query_inline", None + ) if not embedding and callable(inline_embed_func) and "query" in kwargs: query_embedding = self.embedding_service.embed_query_inline(kwargs["query"]) # type: ignore else: query_embedding = f"{[float(dimension) for dimension in embedding]}" - stmt = f'SELECT {column_names}, {search_function}("{self.embedding_column}", :query_embedding) as distance FROM "{self.schema_name}"."{self.table_name}" {filter} ORDER BY "{self.embedding_column}" {operator} :query_embedding LIMIT :k;' + stmt = f"""SELECT {column_names}, {search_function}("{self.embedding_column}", :query_embedding) as distance + FROM "{self.schema_name}"."{self.table_name}" {filter} ORDER BY "{self.embedding_column}" {operator} :query_embedding LIMIT :k; + """ param_dict = {"query_embedding": query_embedding, "k": k} + if filter_dict: + param_dict.update(filter_dict) if self.index_query_options: async with self.engine.connect() as conn: # Set each query option individually @@ -583,11 +609,13 @@ async def asimilarity_search( self, query: str, k: Optional[int] = None, - filter: Optional[dict] | Optional[str] = None, + filter: Optional[dict] = None, **kwargs: Any, ) -> list[Document]: """Return docs selected by similarity search on query.""" - inline_embed_func = getattr(self.embedding_service, "embed_query_inline", None) + inline_embed_func = getattr( + self.embedding_service, "embed_query_inline", None + ) embedding = ( [] if callable(inline_embed_func) @@ -614,11 +642,13 @@ async def asimilarity_search_with_score( self, query: str, k: Optional[int] = None, - filter: Optional[dict] | Optional[str] = None, + filter: Optional[dict] = None, **kwargs: Any, ) -> list[tuple[Document, float]]: """Return docs and distance scores selected by similarity search on query.""" - inline_embed_func = getattr(self.embedding_service, "embed_query_inline", None) + inline_embed_func = getattr( + self.embedding_service, "embed_query_inline", None + ) embedding = ( [] if callable(inline_embed_func) @@ -635,7 +665,7 @@ async def asimilarity_search_by_vector( self, embedding: list[float], k: Optional[int] = None, - filter: Optional[dict] | Optional[str] = None, + filter: Optional[dict] = None, **kwargs: Any, ) -> list[Document]: """Return docs selected by vector similarity search.""" @@ -649,7 +679,7 @@ async def asimilarity_search_with_score_by_vector( self, embedding: list[float], k: Optional[int] = None, - filter: Optional[dict] | Optional[str] = None, + filter: Optional[dict] = None, **kwargs: Any, ) -> list[tuple[Document, float]]: """Return docs and distance scores selected by vector similarity search.""" @@ -685,7 +715,7 @@ async def amax_marginal_relevance_search( k: Optional[int] = None, fetch_k: Optional[int] = None, lambda_mult: Optional[float] = None, - filter: Optional[dict] | Optional[str] = None, + filter: Optional[dict] = None, **kwargs: Any, ) -> list[Document]: """Return docs selected using the maximal marginal relevance.""" @@ -706,7 +736,7 @@ async def amax_marginal_relevance_search_by_vector( k: Optional[int] = None, fetch_k: Optional[int] = None, lambda_mult: Optional[float] = None, - filter: Optional[dict] | Optional[str] = None, + filter: Optional[dict] = None, **kwargs: Any, ) -> list[Document]: """Return docs selected using the maximal marginal relevance.""" @@ -729,7 +759,7 @@ async def amax_marginal_relevance_search_with_score_by_vector( k: Optional[int] = None, fetch_k: Optional[int] = None, lambda_mult: Optional[float] = None, - filter: Optional[dict] | Optional[str] = None, + filter: Optional[dict] = None, **kwargs: Any, ) -> list[tuple[Document, float]]: """Return docs and distance scores selected using the maximal marginal relevance.""" @@ -740,7 +770,9 @@ async def amax_marginal_relevance_search_with_score_by_vector( k = k if k else self.k fetch_k = fetch_k if fetch_k else self.fetch_k lambda_mult = lambda_mult if lambda_mult else self.lambda_mult - embedding_list = [json.loads(row[self.embedding_column]) for row in results] + embedding_list = [ + json.loads(row[self.embedding_column]) for row in results + ] mmr_selected = utils.maximal_marginal_relevance( np.array(embedding, dtype=np.float32), embedding_list, @@ -768,7 +800,9 @@ async def amax_marginal_relevance_search_with_score_by_vector( ) ) - return [r for i, r in enumerate(documents_with_scores) if i in mmr_selected] + return [ + r for i, r in enumerate(documents_with_scores) if i in mmr_selected + ] async def aapply_vector_index( self, @@ -786,12 +820,16 @@ async def aapply_vector_index( if index.extension_name: async with self.engine.connect() as conn: await conn.execute( - text(f"CREATE EXTENSION IF NOT EXISTS {index.extension_name}") + text( + f"CREATE EXTENSION IF NOT EXISTS {index.extension_name}" + ) ) await conn.commit() function = index.get_index_function() - filter = f"WHERE ({index.partial_indexes})" if index.partial_indexes else "" + filter = ( + f"WHERE ({index.partial_indexes})" if index.partial_indexes else "" + ) params = "WITH " + index.index_options() if name is None: if index.name == None: @@ -834,7 +872,7 @@ async def is_valid_index( ) -> bool: """Check if index exists in the table.""" index_name = index_name or self.table_name + DEFAULT_INDEX_NAME_SUFFIX - query = f""" + query = """ SELECT tablename, indexname FROM pg_indexes WHERE tablename = :table_name AND schemaname = :schema_name AND indexname = :index_name; @@ -898,7 +936,7 @@ def _handle_field_filter( *, field: str, value: Any, - ) -> str: + ) -> tuple[str, dict]: """Create a filter for a specific field. Args: @@ -951,15 +989,21 @@ def _handle_field_filter( if operator in COMPARISONS_TO_NATIVE: # Then we implement an equality filter # native is trusted input - if isinstance(filter_value, str): - filter_value = f"'{filter_value}'" + # if isinstance(filter_value, str): + # filter_value = f"'{filter_value}'" native = COMPARISONS_TO_NATIVE[operator] - return f"({field} {native} {filter_value})" + id = str(uuid.uuid4()).split("-")[0] + return f"{field} {native} :{field}_{id}", { + f"{field}_{id}": filter_value + } elif operator == "$between": # Use AND with two comparisons low, high = filter_value - return f"({field} BETWEEN {low} AND {high})" + return f"({field} BETWEEN :{field}_low AND :{field}_high)", { + f"{field}_low": low, + f"{field}_high": high, + } elif operator in {"$in", "$nin", "$like", "$ilike"}: # We'll do force coercion to text if operator in {"$in", "$nin"}: @@ -975,15 +1019,21 @@ def _handle_field_filter( ) if operator in {"$in"}: - values = str(tuple(val for val in filter_value)) - return f"({field} IN {values})" + return f"{field} = ANY(:{field}_in)", { + f"{field}_in": filter_value + } elif operator in {"$nin"}: - values = str(tuple(val for val in filter_value)) - return f"({field} NOT IN {values})" + return f"{field} <> ALL (:{field}_nin)", { + f"{field}_nin": filter_value + } elif operator in {"$like"}: - return f"({field} LIKE '{filter_value}')" + return f"({field} LIKE :{field}_like)", { + f"{field}_like": filter_value + } elif operator in {"$ilike"}: - return f"({field} ILIKE '{filter_value}')" + return f"({field} ILIKE :{field}_ilike)", { + f"{field}_ilike": filter_value + } else: raise NotImplementedError() elif operator == "$exists": @@ -994,13 +1044,13 @@ def _handle_field_filter( ) else: if filter_value: - return f"({field} IS NOT NULL)" + return f"({field} IS NOT NULL)", {} else: - return f"({field} IS NULL)" + return f"({field} IS NULL)", {} else: raise NotImplementedError() - def _create_filter_clause(self, filters: Any) -> str: + def _create_filter_clause(self, filters: Any) -> tuple[str, dict]: """Create LangChain filter representation to matching SQL where clauses Args: @@ -1037,7 +1087,11 @@ def _create_filter_clause(self, filters: Any) -> str: op = key[1:].upper() # Extract the operator filter_clause = [self._create_filter_clause(el) for el in value] if len(filter_clause) > 1: - return f"({f' {op} '.join(filter_clause)})" + all_clauses = [clause[0] for clause in filter_clause] + params = {} + for clause in filter_clause: + params.update(clause[1]) + return f"({f' {op} '.join(all_clauses)})", params elif len(filter_clause) == 1: return filter_clause[0] else: @@ -1050,11 +1104,17 @@ def _create_filter_clause(self, filters: Any) -> str: not_conditions = [ self._create_filter_clause(item) for item in value ] - not_stmts = [f"NOT {condition}" for condition in not_conditions] - return f"({' AND '.join(not_stmts)})" + all_clauses = [clause[0] for clause in not_conditions] + params = {} + for clause in not_conditions: + params.update(clause[1]) + not_stmts = [ + f"NOT {condition}" for condition in all_clauses + ] + return f"({' AND '.join(not_stmts)})", params elif isinstance(value, dict): - not_ = self._create_filter_clause(value) - return f"(NOT {not_})" + not_, params = self._create_filter_clause(value) + return f"(NOT {not_})", params else: raise ValueError( f"Invalid filter condition. Expected a dictionary " @@ -1074,10 +1134,15 @@ def _create_filter_clause(self, filters: Any) -> str: ) # These should all be fields and combined using an $and operator and_ = [ - self._handle_field_filter(field=k, value=v) for k, v in filters.items() + self._handle_field_filter(field=k, value=v) + for k, v in filters.items() ] if len(and_) > 1: - return f"({' AND '.join(and_)})" + all_clauses = [clause[0] for clause in and_] + params = {} + for clause in and_: + params.update(clause[1]) + return f"({' AND '.join(all_clauses)})", params elif len(and_) == 1: return and_[0] else: @@ -1086,7 +1151,7 @@ def _create_filter_clause(self, filters: Any) -> str: "but got an empty dictionary" ) else: - return "" + return "", {} def get_by_ids(self, ids: Sequence[str]) -> list[Document]: raise NotImplementedError( @@ -1168,7 +1233,7 @@ def similarity_search( self, query: str, k: Optional[int] = None, - filter: Optional[dict] | Optional[str] = None, + filter: Optional[dict] = None, **kwargs: Any, ) -> list[Document]: raise NotImplementedError( @@ -1179,7 +1244,7 @@ def similarity_search_with_score( self, query: str, k: Optional[int] = None, - filter: Optional[dict] | Optional[str] = None, + filter: Optional[dict] = None, **kwargs: Any, ) -> list[tuple[Document, float]]: raise NotImplementedError( @@ -1190,7 +1255,7 @@ def similarity_search_by_vector( self, embedding: list[float], k: Optional[int] = None, - filter: Optional[dict] | Optional[str] = None, + filter: Optional[dict] = None, **kwargs: Any, ) -> list[Document]: raise NotImplementedError( @@ -1201,7 +1266,7 @@ def similarity_search_with_score_by_vector( self, embedding: list[float], k: Optional[int] = None, - filter: Optional[dict] | Optional[str] = None, + filter: Optional[dict] = None, **kwargs: Any, ) -> list[tuple[Document, float]]: raise NotImplementedError( @@ -1214,7 +1279,7 @@ def max_marginal_relevance_search( k: Optional[int] = None, fetch_k: Optional[int] = None, lambda_mult: Optional[float] = None, - filter: Optional[dict] | Optional[str] = None, + filter: Optional[dict] = None, **kwargs: Any, ) -> list[Document]: raise NotImplementedError( @@ -1227,7 +1292,7 @@ def max_marginal_relevance_search_by_vector( k: Optional[int] = None, fetch_k: Optional[int] = None, lambda_mult: Optional[float] = None, - filter: Optional[dict] | Optional[str] = None, + filter: Optional[dict] = None, **kwargs: Any, ) -> list[Document]: raise NotImplementedError( @@ -1240,7 +1305,7 @@ def max_marginal_relevance_search_with_score_by_vector( k: Optional[int] = None, fetch_k: Optional[int] = None, lambda_mult: Optional[float] = None, - filter: Optional[dict] | Optional[str] = None, + filter: Optional[dict] = None, **kwargs: Any, ) -> list[tuple[Document, float]]: raise NotImplementedError( diff --git a/langchain_postgres/v2/vectorstores.py b/langchain_postgres/v2/vectorstores.py index 7f71d108..1dc1be97 100644 --- a/langchain_postgres/v2/vectorstores.py +++ b/langchain_postgres/v2/vectorstores.py @@ -567,7 +567,7 @@ def similarity_search( self, query: str, k: Optional[int] = None, - filter: Optional[dict] | Optional[str] = None, + filter: Optional[dict] = None, **kwargs: Any, ) -> list[Document]: """Return docs selected by similarity search on query.""" @@ -579,7 +579,7 @@ async def asimilarity_search( self, query: str, k: Optional[int] = None, - filter: Optional[dict] | Optional[str] = None, + filter: Optional[dict] = None, **kwargs: Any, ) -> list[Document]: """Return docs selected by similarity search on query.""" @@ -602,7 +602,7 @@ async def asimilarity_search_with_score( self, query: str, k: Optional[int] = None, - filter: Optional[dict] | Optional[str] = None, + filter: Optional[dict] = None, **kwargs: Any, ) -> list[tuple[Document, float]]: """Return docs and distance scores selected by similarity search on query.""" @@ -614,7 +614,7 @@ async def asimilarity_search_by_vector( self, embedding: list[float], k: Optional[int] = None, - filter: Optional[dict] | Optional[str] = None, + filter: Optional[dict] = None, **kwargs: Any, ) -> list[Document]: """Return docs selected by vector similarity search.""" @@ -626,7 +626,7 @@ async def asimilarity_search_with_score_by_vector( self, embedding: list[float], k: Optional[int] = None, - filter: Optional[dict] | Optional[str] = None, + filter: Optional[dict] = None, **kwargs: Any, ) -> list[tuple[Document, float]]: """Return docs and distance scores selected by vector similarity search.""" @@ -642,7 +642,7 @@ async def amax_marginal_relevance_search( k: Optional[int] = None, fetch_k: Optional[int] = None, lambda_mult: Optional[float] = None, - filter: Optional[dict] | Optional[str] = None, + filter: Optional[dict] = None, **kwargs: Any, ) -> list[Document]: """Return docs selected using the maximal marginal relevance.""" @@ -658,7 +658,7 @@ async def amax_marginal_relevance_search_by_vector( k: Optional[int] = None, fetch_k: Optional[int] = None, lambda_mult: Optional[float] = None, - filter: Optional[dict] | Optional[str] = None, + filter: Optional[dict] = None, **kwargs: Any, ) -> list[Document]: """Return docs selected using the maximal marginal relevance.""" @@ -674,7 +674,7 @@ async def amax_marginal_relevance_search_with_score_by_vector( k: Optional[int] = None, fetch_k: Optional[int] = None, lambda_mult: Optional[float] = None, - filter: Optional[dict] | Optional[str] = None, + filter: Optional[dict] = None, **kwargs: Any, ) -> list[tuple[Document, float]]: """Return docs and distance scores selected using the maximal marginal relevance.""" @@ -688,7 +688,7 @@ def similarity_search_with_score( self, query: str, k: Optional[int] = None, - filter: Optional[dict] | Optional[str] = None, + filter: Optional[dict] = None, **kwargs: Any, ) -> list[tuple[Document, float]]: """Return docs and distance scores selected by similarity search on query.""" @@ -700,7 +700,7 @@ def similarity_search_by_vector( self, embedding: list[float], k: Optional[int] = None, - filter: Optional[dict] | Optional[str] = None, + filter: Optional[dict] = None, **kwargs: Any, ) -> list[Document]: """Return docs selected by vector similarity search.""" @@ -712,7 +712,7 @@ def similarity_search_with_score_by_vector( self, embedding: list[float], k: Optional[int] = None, - filter: Optional[dict] | Optional[str] = None, + filter: Optional[dict] = None, **kwargs: Any, ) -> list[tuple[Document, float]]: """Return docs and distance scores selected by similarity search on vector.""" @@ -728,7 +728,7 @@ def max_marginal_relevance_search( k: Optional[int] = None, fetch_k: Optional[int] = None, lambda_mult: Optional[float] = None, - filter: Optional[dict] | Optional[str] = None, + filter: Optional[dict] = None, **kwargs: Any, ) -> list[Document]: """Return docs selected using the maximal marginal relevance.""" @@ -744,7 +744,7 @@ def max_marginal_relevance_search_by_vector( k: Optional[int] = None, fetch_k: Optional[int] = None, lambda_mult: Optional[float] = None, - filter: Optional[dict] | Optional[str] = None, + filter: Optional[dict] = None, **kwargs: Any, ) -> list[Document]: """Return docs selected using the maximal marginal relevance.""" @@ -760,7 +760,7 @@ def max_marginal_relevance_search_with_score_by_vector( k: Optional[int] = None, fetch_k: Optional[int] = None, lambda_mult: Optional[float] = None, - filter: Optional[dict] | Optional[str] = None, + filter: Optional[dict] = None, **kwargs: Any, ) -> list[tuple[Document, float]]: """Return docs and distance scores selected using the maximal marginal relevance.""" diff --git a/tests/unit_tests/v2/test_async_pg_vectorstore_search.py b/tests/unit_tests/v2/test_async_pg_vectorstore_search.py index dc3a771b..0f629e9b 100644 --- a/tests/unit_tests/v2/test_async_pg_vectorstore_search.py +++ b/tests/unit_tests/v2/test_async_pg_vectorstore_search.py @@ -33,13 +33,15 @@ ids = [str(uuid.uuid4()) for i in range(len(texts))] metadatas = [{"page": str(i), "source": "postgres"} for i in range(len(texts))] docs = [ - Document(page_content=texts[i], metadata=metadatas[i]) for i in range(len(texts)) + Document(page_content=texts[i], metadata=metadatas[i]) + for i in range(len(texts)) ] embeddings = [embeddings_service.embed_query("foo") for i in range(len(texts))] filter_docs = [ - Document(page_content=texts[i], metadata=METADATAS[i]) for i in range(len(texts)) + Document(page_content=texts[i], metadata=METADATAS[i]) + for i in range(len(texts)) ] @@ -68,7 +70,7 @@ async def engine(self) -> AsyncIterator[PGEngine]: yield engine await aexecute(engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE}") await aexecute(engine, f"DROP TABLE IF EXISTS {CUSTOM_TABLE}") - await aexecute(engine, f"DROP TABLE IF EXISTS {CUSTOM_FILTER_TABLE}") + # await aexecute(engine, f"DROP TABLE IF EXISTS {CUSTOM_FILTER_TABLE}") await engine.close() @pytest_asyncio.fixture(scope="class") @@ -85,7 +87,9 @@ async def vs(self, engine: PGEngine) -> AsyncIterator[AsyncPGVectorStore]: yield vs @pytest_asyncio.fixture(scope="class") - async def vs_custom(self, engine: PGEngine) -> AsyncIterator[AsyncPGVectorStore]: + async def vs_custom( + self, engine: PGEngine + ) -> AsyncIterator[AsyncPGVectorStore]: await engine._ainit_vectorstore_table( CUSTOM_TABLE, VECTOR_SIZE, @@ -149,25 +153,24 @@ async def vs_custom_filter( await vs_custom_filter.aadd_documents(filter_docs, ids=ids) yield vs_custom_filter - async def test_asimilarity_search(self, vs: AsyncPGVectorStore) -> None: - results = await vs.asimilarity_search("foo", k=1) - assert len(results) == 1 - assert results == [Document(page_content="foo", id=ids[0])] - results = await vs.asimilarity_search("foo", k=1, filter="content = 'bar'") - assert results == [Document(page_content="bar", id=ids[1])] - - async def test_asimilarity_search_score(self, vs: AsyncPGVectorStore) -> None: + async def test_asimilarity_search_score( + self, vs: AsyncPGVectorStore + ) -> None: results = await vs.asimilarity_search_with_score("foo") assert len(results) == 4 assert results[0][0] == Document(page_content="foo", id=ids[0]) assert results[0][1] == 0 - async def test_asimilarity_search_by_vector(self, vs: AsyncPGVectorStore) -> None: + async def test_asimilarity_search_by_vector( + self, vs: AsyncPGVectorStore + ) -> None: embedding = embeddings_service.embed_query("foo") results = await vs.asimilarity_search_by_vector(embedding) assert len(results) == 4 assert results[0] == Document(page_content="foo", id=ids[0]) - result = await vs.asimilarity_search_with_score_by_vector(embedding=embedding) + result = await vs.asimilarity_search_with_score_by_vector( + embedding=embedding + ) assert result[0][0] == Document(page_content="foo", id=ids[0]) assert result[0][1] == 0 @@ -220,14 +223,6 @@ async def test_similarity_search_with_relevance_scores_threshold_euclidean( assert len(results) == 1 assert results[0][0] == Document(page_content="foo", id=ids[0]) - async def test_amax_marginal_relevance_search(self, vs: AsyncPGVectorStore) -> None: - results = await vs.amax_marginal_relevance_search("bar") - assert results[0] == Document(page_content="bar", id=ids[1]) - results = await vs.amax_marginal_relevance_search( - "bar", filter="content = 'boo'" - ) - assert results[0] == Document(page_content="boo", id=ids[3]) - async def test_amax_marginal_relevance_search_vector( self, vs: AsyncPGVectorStore ) -> None: @@ -249,16 +244,9 @@ async def test_amax_marginal_relevance_search_vector_score( ) assert results[0][0] == Document(page_content="bar", id=ids[1]) - async def test_similarity_search(self, vs_custom: AsyncPGVectorStore) -> None: - results = await vs_custom.asimilarity_search("foo", k=1) - assert len(results) == 1 - assert results == [Document(page_content="foo", id=ids[0])] - results = await vs_custom.asimilarity_search( - "foo", k=1, filter="mycontent = 'bar'" - ) - assert results == [Document(page_content="bar", id=ids[1])] - - async def test_similarity_search_score(self, vs_custom: AsyncPGVectorStore) -> None: + async def test_similarity_search_score( + self, vs_custom: AsyncPGVectorStore + ) -> None: results = await vs_custom.asimilarity_search_with_score("foo") assert len(results) == 4 assert results[0][0] == Document(page_content="foo", id=ids[0]) @@ -277,34 +265,30 @@ async def test_similarity_search_by_vector( assert result[0][0] == Document(page_content="foo", id=ids[0]) assert result[0][1] == 0 - async def test_max_marginal_relevance_search( - self, vs_custom: AsyncPGVectorStore - ) -> None: - results = await vs_custom.amax_marginal_relevance_search("bar") - assert results[0] == Document(page_content="bar", id=ids[1]) - results = await vs_custom.amax_marginal_relevance_search( - "bar", filter="mycontent = 'boo'" - ) - assert results[0] == Document(page_content="boo", id=ids[3]) - async def test_max_marginal_relevance_search_vector( self, vs_custom: AsyncPGVectorStore ) -> None: embedding = embeddings_service.embed_query("bar") - results = await vs_custom.amax_marginal_relevance_search_by_vector(embedding) + results = await vs_custom.amax_marginal_relevance_search_by_vector( + embedding + ) assert results[0] == Document(page_content="bar", id=ids[1]) async def test_max_marginal_relevance_search_vector_score( self, vs_custom: AsyncPGVectorStore ) -> None: embedding = embeddings_service.embed_query("bar") - results = await vs_custom.amax_marginal_relevance_search_with_score_by_vector( - embedding + results = ( + await vs_custom.amax_marginal_relevance_search_with_score_by_vector( + embedding + ) ) assert results[0][0] == Document(page_content="bar", id=ids[1]) - results = await vs_custom.amax_marginal_relevance_search_with_score_by_vector( - embedding, lambda_mult=0.75, fetch_k=10 + results = ( + await vs_custom.amax_marginal_relevance_search_with_score_by_vector( + embedding, lambda_mult=0.75, fetch_k=10 + ) ) assert results[0][0] == Document(page_content="bar", id=ids[1]) @@ -314,7 +298,9 @@ async def test_aget_by_ids(self, vs: AsyncPGVectorStore) -> None: assert results[0] == Document(page_content="foo", id=ids[0]) - async def test_aget_by_ids_custom_vs(self, vs_custom: AsyncPGVectorStore) -> None: + async def test_aget_by_ids_custom_vs( + self, vs_custom: AsyncPGVectorStore + ) -> None: test_ids = [ids[0]] results = await vs_custom.aget_by_ids(ids=test_ids) @@ -336,4 +322,6 @@ async def test_vectorstore_with_metadata_filters( docs = await vs_custom_filter.asimilarity_search( "meow", k=5, filter=test_filter ) - assert [doc.metadata["code"] for doc in docs] == expected_ids, test_filter + assert [ + doc.metadata["code"] for doc in docs + ] == expected_ids, test_filter diff --git a/tests/unit_tests/v2/test_pg_vectorstore_search.py b/tests/unit_tests/v2/test_pg_vectorstore_search.py index 28d0f07f..d114885e 100644 --- a/tests/unit_tests/v2/test_pg_vectorstore_search.py +++ b/tests/unit_tests/v2/test_pg_vectorstore_search.py @@ -21,7 +21,9 @@ DEFAULT_TABLE_SYNC = "default_sync" + str(uuid.uuid4()).replace("-", "_") CUSTOM_TABLE = "custom" + str(uuid.uuid4()).replace("-", "_") CUSTOM_FILTER_TABLE = "custom_filter" + str(uuid.uuid4()).replace("-", "_") -CUSTOM_FILTER_TABLE_SYNC = "custom_filter_sync" + str(uuid.uuid4()).replace("-", "_") +CUSTOM_FILTER_TABLE_SYNC = "custom_filter_sync" + str(uuid.uuid4()).replace( + "-", "_" +) VECTOR_SIZE = 768 embeddings_service = DeterministicFakeEmbedding(size=VECTOR_SIZE) @@ -34,10 +36,12 @@ ids = [str(uuid.uuid4()) for i in range(len(texts))] metadatas = [{"page": str(i), "source": "postgres"} for i in range(len(texts))] docs = [ - Document(page_content=texts[i], metadata=metadatas[i]) for i in range(len(texts)) + Document(page_content=texts[i], metadata=metadatas[i]) + for i in range(len(texts)) ] filter_docs = [ - Document(page_content=texts[i], metadata=METADATAS[i]) for i in range(len(texts)) + Document(page_content=texts[i], metadata=METADATAS[i]) + for i in range(len(texts)) ] embeddings = [embeddings_service.embed_query("foo") for i in range(len(texts))] @@ -95,7 +99,9 @@ async def engine_sync(self) -> AsyncIterator[PGEngine]: await engine.close() @pytest_asyncio.fixture(scope="class") - async def vs_custom(self, engine_sync: PGEngine) -> AsyncIterator[PGVectorStore]: + async def vs_custom( + self, engine_sync: PGEngine + ) -> AsyncIterator[PGVectorStore]: engine_sync.init_vectorstore_table( CUSTOM_TABLE, VECTOR_SIZE, @@ -122,7 +128,9 @@ async def vs_custom(self, engine_sync: PGEngine) -> AsyncIterator[PGVectorStore] yield vs_custom @pytest_asyncio.fixture(scope="class") - async def vs_custom_filter(self, engine: PGEngine) -> AsyncIterator[PGVectorStore]: + async def vs_custom_filter( + self, engine: PGEngine + ) -> AsyncIterator[PGVectorStore]: await engine.ainit_vectorstore_table( CUSTOM_FILTER_TABLE, VECTOR_SIZE, @@ -158,25 +166,22 @@ async def vs_custom_filter(self, engine: PGEngine) -> AsyncIterator[PGVectorStor await vs_custom_filter.aadd_documents(filter_docs, ids=ids) yield vs_custom_filter - async def test_asimilarity_search(self, vs: PGVectorStore) -> None: - results = await vs.asimilarity_search("foo", k=1) - assert len(results) == 1 - assert results == [Document(page_content="foo", id=ids[0])] - results = await vs.asimilarity_search("foo", k=1, filter="content = 'bar'") - assert results == [Document(page_content="bar", id=ids[1])] - async def test_asimilarity_search_score(self, vs: PGVectorStore) -> None: results = await vs.asimilarity_search_with_score("foo") assert len(results) == 4 assert results[0][0] == Document(page_content="foo", id=ids[0]) assert results[0][1] == 0 - async def test_asimilarity_search_by_vector(self, vs: PGVectorStore) -> None: + async def test_asimilarity_search_by_vector( + self, vs: PGVectorStore + ) -> None: embedding = embeddings_service.embed_query("foo") results = await vs.asimilarity_search_by_vector(embedding) assert len(results) == 4 assert results[0] == Document(page_content="foo", id=ids[0]) - result = await vs.asimilarity_search_with_score_by_vector(embedding=embedding) + result = await vs.asimilarity_search_with_score_by_vector( + embedding=embedding + ) assert result[0][0] == Document(page_content="foo", id=ids[0]) assert result[0][1] == 0 @@ -222,14 +227,6 @@ async def test_similarity_search_with_relevance_scores_threshold_euclidean( assert len(results) == 1 assert results[0][0] == Document(page_content="foo", id=ids[0]) - async def test_amax_marginal_relevance_search(self, vs: PGVectorStore) -> None: - results = await vs.amax_marginal_relevance_search("bar") - assert results[0] == Document(page_content="bar", id=ids[1]) - results = await vs.amax_marginal_relevance_search( - "bar", filter="content = 'boo'" - ) - assert results[0] == Document(page_content="boo", id=ids[3]) - async def test_amax_marginal_relevance_search_vector( self, vs: PGVectorStore ) -> None: @@ -257,7 +254,9 @@ async def test_aget_by_ids(self, vs: PGVectorStore) -> None: assert results[0] == Document(page_content="foo", id=ids[0]) - async def test_aget_by_ids_custom_vs(self, vs_custom: PGVectorStore) -> None: + async def test_aget_by_ids_custom_vs( + self, vs_custom: PGVectorStore + ) -> None: test_ids = [ids[0]] results = await vs_custom.aget_by_ids(ids=test_ids) @@ -274,7 +273,9 @@ async def test_vectorstore_with_metadata_filters( docs = await vs_custom_filter.asimilarity_search( "meow", k=5, filter=test_filter ) - assert [doc.metadata["code"] for doc in docs] == expected_ids, test_filter + assert [ + doc.metadata["code"] for doc in docs + ] == expected_ids, test_filter @pytest.mark.enable_socket @@ -284,11 +285,15 @@ async def engine_sync(self) -> AsyncIterator[PGEngine]: engine = PGEngine.from_connection_string(url=CONNECTION_STRING) yield engine await aexecute(engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE_SYNC}") - await aexecute(engine, f"DROP TABLE IF EXISTS {CUSTOM_FILTER_TABLE_SYNC}") + await aexecute( + engine, f"DROP TABLE IF EXISTS {CUSTOM_FILTER_TABLE_SYNC}" + ) await engine.close() @pytest_asyncio.fixture(scope="class") - async def vs_custom(self, engine_sync: PGEngine) -> AsyncIterator[PGVectorStore]: + async def vs_custom( + self, engine_sync: PGEngine + ) -> AsyncIterator[PGVectorStore]: engine_sync.init_vectorstore_table( DEFAULT_TABLE_SYNC, VECTOR_SIZE, @@ -354,36 +359,25 @@ async def vs_custom_filter_sync( vs_custom_filter_sync.add_documents(filter_docs, ids=ids) yield vs_custom_filter_sync - def test_similarity_search(self, vs_custom: PGVectorStore) -> None: - results = vs_custom.similarity_search("foo", k=1) - assert len(results) == 1 - assert results == [Document(page_content="foo", id=ids[0])] - results = vs_custom.similarity_search("foo", k=1, filter="mycontent = 'bar'") - assert results == [Document(page_content="bar", id=ids[1])] - def test_similarity_search_score(self, vs_custom: PGVectorStore) -> None: results = vs_custom.similarity_search_with_score("foo") assert len(results) == 4 assert results[0][0] == Document(page_content="foo", id=ids[0]) assert results[0][1] == 0 - def test_similarity_search_by_vector(self, vs_custom: PGVectorStore) -> None: + def test_similarity_search_by_vector( + self, vs_custom: PGVectorStore + ) -> None: embedding = embeddings_service.embed_query("foo") results = vs_custom.similarity_search_by_vector(embedding) assert len(results) == 4 assert results[0] == Document(page_content="foo", id=ids[0]) - result = vs_custom.similarity_search_with_score_by_vector(embedding=embedding) + result = vs_custom.similarity_search_with_score_by_vector( + embedding=embedding + ) assert result[0][0] == Document(page_content="foo", id=ids[0]) assert result[0][1] == 0 - def test_max_marginal_relevance_search(self, vs_custom: PGVectorStore) -> None: - results = vs_custom.max_marginal_relevance_search("bar") - assert results[0] == Document(page_content="bar", id=ids[1]) - results = vs_custom.max_marginal_relevance_search( - "bar", filter="mycontent = 'boo'" - ) - assert results[0] == Document(page_content="boo", id=ids[3]) - def test_max_marginal_relevance_search_vector( self, vs_custom: PGVectorStore ) -> None: @@ -420,8 +414,12 @@ def test_sync_vectorstore_with_metadata_filters( ) -> None: """Test end to end construction and search.""" - docs = vs_custom_filter_sync.similarity_search("meow", k=5, filter=test_filter) - assert [doc.metadata["code"] for doc in docs] == expected_ids, test_filter + docs = vs_custom_filter_sync.similarity_search( + "meow", k=5, filter=test_filter + ) + assert [ + doc.metadata["code"] for doc in docs + ] == expected_ids, test_filter @pytest.mark.parametrize("test_filter", NEGATIVE_TEST_CASES) def test_metadata_filter_negative_tests( From 637d80edc0efe3569800de40a16a8c4c9fd939a8 Mon Sep 17 00:00:00 2001 From: Averi Kitsch Date: Mon, 7 Apr 2025 19:03:47 -0700 Subject: [PATCH 2/5] lint --- langchain_postgres/v2/async_vectorstore.py | 87 +++++-------------- .../v2/test_async_pg_vectorstore_search.py | 50 +++-------- .../v2/test_pg_vectorstore_search.py | 58 ++++--------- 3 files changed, 51 insertions(+), 144 deletions(-) diff --git a/langchain_postgres/v2/async_vectorstore.py b/langchain_postgres/v2/async_vectorstore.py index e70e06a1..1fdd9ea8 100644 --- a/langchain_postgres/v2/async_vectorstore.py +++ b/langchain_postgres/v2/async_vectorstore.py @@ -111,9 +111,7 @@ def __init__( self.schema_name = schema_name self.content_column = content_column self.embedding_column = embedding_column - self.metadata_columns = ( - metadata_columns if metadata_columns is not None else [] - ) + self.metadata_columns = metadata_columns if metadata_columns is not None else [] self.id_column = id_column self.metadata_json_column = metadata_json_column self.distance_strategy = distance_strategy @@ -189,27 +187,21 @@ async def create( if id_column not in columns: raise ValueError(f"Id column, {id_column}, does not exist.") if content_column not in columns: - raise ValueError( - f"Content column, {content_column}, does not exist." - ) + raise ValueError(f"Content column, {content_column}, does not exist.") content_type = columns[content_column] if content_type != "text" and "char" not in content_type: raise ValueError( f"Content column, {content_column}, is type, {content_type}. It must be a type of character string." ) if embedding_column not in columns: - raise ValueError( - f"Embedding column, {embedding_column}, does not exist." - ) + raise ValueError(f"Embedding column, {embedding_column}, does not exist.") if columns[embedding_column] != "USER-DEFINED": raise ValueError( f"Embedding column, {embedding_column}, is not type Vector." ) metadata_json_column = ( - None - if metadata_json_column not in columns - else metadata_json_column + None if metadata_json_column not in columns else metadata_json_column ) # If using metadata_columns check to make sure column exists @@ -272,14 +264,10 @@ async def aadd_embeddings( metadatas = [{} for _ in texts] # Check for inline embedding capability - inline_embed_func = getattr( - self.embedding_service, "embed_query_inline", None - ) + inline_embed_func = getattr(self.embedding_service, "embed_query_inline", None) can_inline_embed = callable(inline_embed_func) # Insert embeddings - for id, content, embedding, metadata in zip( - ids, texts, embeddings, metadatas - ): + for id, content, embedding, metadata in zip(ids, texts, embeddings, metadatas): metadata_col_names = ( ", " + ", ".join(f'"{col}"' for col in self.metadata_columns) if len(self.metadata_columns) > 0 @@ -348,15 +336,11 @@ async def aadd_texts( :class:`InvalidTextRepresentationError `: if the `ids` data type does not match that of the `id_column`. """ # Check for inline embedding query - inline_embed_func = getattr( - self.embedding_service, "embed_query_inline", None - ) + inline_embed_func = getattr(self.embedding_service, "embed_query_inline", None) if callable(inline_embed_func): embeddings: list[list[float]] = [[] for _ in list(texts)] else: - embeddings = await self.embedding_service.aembed_documents( - list(texts) - ) + embeddings = await self.embedding_service.aembed_documents(list(texts)) ids = await self.aadd_embeddings( texts, embeddings, metadatas=metadatas, ids=ids, **kwargs @@ -378,9 +362,7 @@ async def aadd_documents( metadatas = [doc.metadata for doc in documents] if not ids: ids = [doc.id for doc in documents] - ids = await self.aadd_texts( - texts, metadatas=metadatas, ids=ids, **kwargs - ) + ids = await self.aadd_texts(texts, metadatas=metadatas, ids=ids, **kwargs) return ids async def adelete( @@ -576,9 +558,7 @@ async def __query_collection( if filter and isinstance(filter, dict): safe_filter, filter_dict = self._create_filter_clause(filter) filter = f"WHERE {safe_filter}" if safe_filter else "" - inline_embed_func = getattr( - self.embedding_service, "embed_query_inline", None - ) + inline_embed_func = getattr(self.embedding_service, "embed_query_inline", None) if not embedding and callable(inline_embed_func) and "query" in kwargs: query_embedding = self.embedding_service.embed_query_inline(kwargs["query"]) # type: ignore else: @@ -613,9 +593,7 @@ async def asimilarity_search( **kwargs: Any, ) -> list[Document]: """Return docs selected by similarity search on query.""" - inline_embed_func = getattr( - self.embedding_service, "embed_query_inline", None - ) + inline_embed_func = getattr(self.embedding_service, "embed_query_inline", None) embedding = ( [] if callable(inline_embed_func) @@ -646,9 +624,7 @@ async def asimilarity_search_with_score( **kwargs: Any, ) -> list[tuple[Document, float]]: """Return docs and distance scores selected by similarity search on query.""" - inline_embed_func = getattr( - self.embedding_service, "embed_query_inline", None - ) + inline_embed_func = getattr(self.embedding_service, "embed_query_inline", None) embedding = ( [] if callable(inline_embed_func) @@ -770,9 +746,7 @@ async def amax_marginal_relevance_search_with_score_by_vector( k = k if k else self.k fetch_k = fetch_k if fetch_k else self.fetch_k lambda_mult = lambda_mult if lambda_mult else self.lambda_mult - embedding_list = [ - json.loads(row[self.embedding_column]) for row in results - ] + embedding_list = [json.loads(row[self.embedding_column]) for row in results] mmr_selected = utils.maximal_marginal_relevance( np.array(embedding, dtype=np.float32), embedding_list, @@ -800,9 +774,7 @@ async def amax_marginal_relevance_search_with_score_by_vector( ) ) - return [ - r for i, r in enumerate(documents_with_scores) if i in mmr_selected - ] + return [r for i, r in enumerate(documents_with_scores) if i in mmr_selected] async def aapply_vector_index( self, @@ -820,16 +792,12 @@ async def aapply_vector_index( if index.extension_name: async with self.engine.connect() as conn: await conn.execute( - text( - f"CREATE EXTENSION IF NOT EXISTS {index.extension_name}" - ) + text(f"CREATE EXTENSION IF NOT EXISTS {index.extension_name}") ) await conn.commit() function = index.get_index_function() - filter = ( - f"WHERE ({index.partial_indexes})" if index.partial_indexes else "" - ) + filter = f"WHERE ({index.partial_indexes})" if index.partial_indexes else "" params = "WITH " + index.index_options() if name is None: if index.name == None: @@ -993,9 +961,7 @@ def _handle_field_filter( # filter_value = f"'{filter_value}'" native = COMPARISONS_TO_NATIVE[operator] id = str(uuid.uuid4()).split("-")[0] - return f"{field} {native} :{field}_{id}", { - f"{field}_{id}": filter_value - } + return f"{field} {native} :{field}_{id}", {f"{field}_{id}": filter_value} elif operator == "$between": # Use AND with two comparisons low, high = filter_value @@ -1019,17 +985,11 @@ def _handle_field_filter( ) if operator in {"$in"}: - return f"{field} = ANY(:{field}_in)", { - f"{field}_in": filter_value - } + return f"{field} = ANY(:{field}_in)", {f"{field}_in": filter_value} elif operator in {"$nin"}: - return f"{field} <> ALL (:{field}_nin)", { - f"{field}_nin": filter_value - } + return f"{field} <> ALL (:{field}_nin)", {f"{field}_nin": filter_value} elif operator in {"$like"}: - return f"({field} LIKE :{field}_like)", { - f"{field}_like": filter_value - } + return f"({field} LIKE :{field}_like)", {f"{field}_like": filter_value} elif operator in {"$ilike"}: return f"({field} ILIKE :{field}_ilike)", { f"{field}_ilike": filter_value @@ -1108,9 +1068,7 @@ def _create_filter_clause(self, filters: Any) -> tuple[str, dict]: params = {} for clause in not_conditions: params.update(clause[1]) - not_stmts = [ - f"NOT {condition}" for condition in all_clauses - ] + not_stmts = [f"NOT {condition}" for condition in all_clauses] return f"({' AND '.join(not_stmts)})", params elif isinstance(value, dict): not_, params = self._create_filter_clause(value) @@ -1134,8 +1092,7 @@ def _create_filter_clause(self, filters: Any) -> tuple[str, dict]: ) # These should all be fields and combined using an $and operator and_ = [ - self._handle_field_filter(field=k, value=v) - for k, v in filters.items() + self._handle_field_filter(field=k, value=v) for k, v in filters.items() ] if len(and_) > 1: all_clauses = [clause[0] for clause in and_] diff --git a/tests/unit_tests/v2/test_async_pg_vectorstore_search.py b/tests/unit_tests/v2/test_async_pg_vectorstore_search.py index 0f629e9b..4d705ff3 100644 --- a/tests/unit_tests/v2/test_async_pg_vectorstore_search.py +++ b/tests/unit_tests/v2/test_async_pg_vectorstore_search.py @@ -33,15 +33,13 @@ ids = [str(uuid.uuid4()) for i in range(len(texts))] metadatas = [{"page": str(i), "source": "postgres"} for i in range(len(texts))] docs = [ - Document(page_content=texts[i], metadata=metadatas[i]) - for i in range(len(texts)) + Document(page_content=texts[i], metadata=metadatas[i]) for i in range(len(texts)) ] embeddings = [embeddings_service.embed_query("foo") for i in range(len(texts))] filter_docs = [ - Document(page_content=texts[i], metadata=METADATAS[i]) - for i in range(len(texts)) + Document(page_content=texts[i], metadata=METADATAS[i]) for i in range(len(texts)) ] @@ -87,9 +85,7 @@ async def vs(self, engine: PGEngine) -> AsyncIterator[AsyncPGVectorStore]: yield vs @pytest_asyncio.fixture(scope="class") - async def vs_custom( - self, engine: PGEngine - ) -> AsyncIterator[AsyncPGVectorStore]: + async def vs_custom(self, engine: PGEngine) -> AsyncIterator[AsyncPGVectorStore]: await engine._ainit_vectorstore_table( CUSTOM_TABLE, VECTOR_SIZE, @@ -153,24 +149,18 @@ async def vs_custom_filter( await vs_custom_filter.aadd_documents(filter_docs, ids=ids) yield vs_custom_filter - async def test_asimilarity_search_score( - self, vs: AsyncPGVectorStore - ) -> None: + async def test_asimilarity_search_score(self, vs: AsyncPGVectorStore) -> None: results = await vs.asimilarity_search_with_score("foo") assert len(results) == 4 assert results[0][0] == Document(page_content="foo", id=ids[0]) assert results[0][1] == 0 - async def test_asimilarity_search_by_vector( - self, vs: AsyncPGVectorStore - ) -> None: + async def test_asimilarity_search_by_vector(self, vs: AsyncPGVectorStore) -> None: embedding = embeddings_service.embed_query("foo") results = await vs.asimilarity_search_by_vector(embedding) assert len(results) == 4 assert results[0] == Document(page_content="foo", id=ids[0]) - result = await vs.asimilarity_search_with_score_by_vector( - embedding=embedding - ) + result = await vs.asimilarity_search_with_score_by_vector(embedding=embedding) assert result[0][0] == Document(page_content="foo", id=ids[0]) assert result[0][1] == 0 @@ -244,9 +234,7 @@ async def test_amax_marginal_relevance_search_vector_score( ) assert results[0][0] == Document(page_content="bar", id=ids[1]) - async def test_similarity_search_score( - self, vs_custom: AsyncPGVectorStore - ) -> None: + async def test_similarity_search_score(self, vs_custom: AsyncPGVectorStore) -> None: results = await vs_custom.asimilarity_search_with_score("foo") assert len(results) == 4 assert results[0][0] == Document(page_content="foo", id=ids[0]) @@ -269,26 +257,20 @@ async def test_max_marginal_relevance_search_vector( self, vs_custom: AsyncPGVectorStore ) -> None: embedding = embeddings_service.embed_query("bar") - results = await vs_custom.amax_marginal_relevance_search_by_vector( - embedding - ) + results = await vs_custom.amax_marginal_relevance_search_by_vector(embedding) assert results[0] == Document(page_content="bar", id=ids[1]) async def test_max_marginal_relevance_search_vector_score( self, vs_custom: AsyncPGVectorStore ) -> None: embedding = embeddings_service.embed_query("bar") - results = ( - await vs_custom.amax_marginal_relevance_search_with_score_by_vector( - embedding - ) + results = await vs_custom.amax_marginal_relevance_search_with_score_by_vector( + embedding ) assert results[0][0] == Document(page_content="bar", id=ids[1]) - results = ( - await vs_custom.amax_marginal_relevance_search_with_score_by_vector( - embedding, lambda_mult=0.75, fetch_k=10 - ) + results = await vs_custom.amax_marginal_relevance_search_with_score_by_vector( + embedding, lambda_mult=0.75, fetch_k=10 ) assert results[0][0] == Document(page_content="bar", id=ids[1]) @@ -298,9 +280,7 @@ async def test_aget_by_ids(self, vs: AsyncPGVectorStore) -> None: assert results[0] == Document(page_content="foo", id=ids[0]) - async def test_aget_by_ids_custom_vs( - self, vs_custom: AsyncPGVectorStore - ) -> None: + async def test_aget_by_ids_custom_vs(self, vs_custom: AsyncPGVectorStore) -> None: test_ids = [ids[0]] results = await vs_custom.aget_by_ids(ids=test_ids) @@ -322,6 +302,4 @@ async def test_vectorstore_with_metadata_filters( docs = await vs_custom_filter.asimilarity_search( "meow", k=5, filter=test_filter ) - assert [ - doc.metadata["code"] for doc in docs - ] == expected_ids, test_filter + assert [doc.metadata["code"] for doc in docs] == expected_ids, test_filter diff --git a/tests/unit_tests/v2/test_pg_vectorstore_search.py b/tests/unit_tests/v2/test_pg_vectorstore_search.py index d114885e..379f5295 100644 --- a/tests/unit_tests/v2/test_pg_vectorstore_search.py +++ b/tests/unit_tests/v2/test_pg_vectorstore_search.py @@ -21,9 +21,7 @@ DEFAULT_TABLE_SYNC = "default_sync" + str(uuid.uuid4()).replace("-", "_") CUSTOM_TABLE = "custom" + str(uuid.uuid4()).replace("-", "_") CUSTOM_FILTER_TABLE = "custom_filter" + str(uuid.uuid4()).replace("-", "_") -CUSTOM_FILTER_TABLE_SYNC = "custom_filter_sync" + str(uuid.uuid4()).replace( - "-", "_" -) +CUSTOM_FILTER_TABLE_SYNC = "custom_filter_sync" + str(uuid.uuid4()).replace("-", "_") VECTOR_SIZE = 768 embeddings_service = DeterministicFakeEmbedding(size=VECTOR_SIZE) @@ -36,12 +34,10 @@ ids = [str(uuid.uuid4()) for i in range(len(texts))] metadatas = [{"page": str(i), "source": "postgres"} for i in range(len(texts))] docs = [ - Document(page_content=texts[i], metadata=metadatas[i]) - for i in range(len(texts)) + Document(page_content=texts[i], metadata=metadatas[i]) for i in range(len(texts)) ] filter_docs = [ - Document(page_content=texts[i], metadata=METADATAS[i]) - for i in range(len(texts)) + Document(page_content=texts[i], metadata=METADATAS[i]) for i in range(len(texts)) ] embeddings = [embeddings_service.embed_query("foo") for i in range(len(texts))] @@ -99,9 +95,7 @@ async def engine_sync(self) -> AsyncIterator[PGEngine]: await engine.close() @pytest_asyncio.fixture(scope="class") - async def vs_custom( - self, engine_sync: PGEngine - ) -> AsyncIterator[PGVectorStore]: + async def vs_custom(self, engine_sync: PGEngine) -> AsyncIterator[PGVectorStore]: engine_sync.init_vectorstore_table( CUSTOM_TABLE, VECTOR_SIZE, @@ -128,9 +122,7 @@ async def vs_custom( yield vs_custom @pytest_asyncio.fixture(scope="class") - async def vs_custom_filter( - self, engine: PGEngine - ) -> AsyncIterator[PGVectorStore]: + async def vs_custom_filter(self, engine: PGEngine) -> AsyncIterator[PGVectorStore]: await engine.ainit_vectorstore_table( CUSTOM_FILTER_TABLE, VECTOR_SIZE, @@ -172,16 +164,12 @@ async def test_asimilarity_search_score(self, vs: PGVectorStore) -> None: assert results[0][0] == Document(page_content="foo", id=ids[0]) assert results[0][1] == 0 - async def test_asimilarity_search_by_vector( - self, vs: PGVectorStore - ) -> None: + async def test_asimilarity_search_by_vector(self, vs: PGVectorStore) -> None: embedding = embeddings_service.embed_query("foo") results = await vs.asimilarity_search_by_vector(embedding) assert len(results) == 4 assert results[0] == Document(page_content="foo", id=ids[0]) - result = await vs.asimilarity_search_with_score_by_vector( - embedding=embedding - ) + result = await vs.asimilarity_search_with_score_by_vector(embedding=embedding) assert result[0][0] == Document(page_content="foo", id=ids[0]) assert result[0][1] == 0 @@ -254,9 +242,7 @@ async def test_aget_by_ids(self, vs: PGVectorStore) -> None: assert results[0] == Document(page_content="foo", id=ids[0]) - async def test_aget_by_ids_custom_vs( - self, vs_custom: PGVectorStore - ) -> None: + async def test_aget_by_ids_custom_vs(self, vs_custom: PGVectorStore) -> None: test_ids = [ids[0]] results = await vs_custom.aget_by_ids(ids=test_ids) @@ -273,9 +259,7 @@ async def test_vectorstore_with_metadata_filters( docs = await vs_custom_filter.asimilarity_search( "meow", k=5, filter=test_filter ) - assert [ - doc.metadata["code"] for doc in docs - ] == expected_ids, test_filter + assert [doc.metadata["code"] for doc in docs] == expected_ids, test_filter @pytest.mark.enable_socket @@ -285,15 +269,11 @@ async def engine_sync(self) -> AsyncIterator[PGEngine]: engine = PGEngine.from_connection_string(url=CONNECTION_STRING) yield engine await aexecute(engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE_SYNC}") - await aexecute( - engine, f"DROP TABLE IF EXISTS {CUSTOM_FILTER_TABLE_SYNC}" - ) + await aexecute(engine, f"DROP TABLE IF EXISTS {CUSTOM_FILTER_TABLE_SYNC}") await engine.close() @pytest_asyncio.fixture(scope="class") - async def vs_custom( - self, engine_sync: PGEngine - ) -> AsyncIterator[PGVectorStore]: + async def vs_custom(self, engine_sync: PGEngine) -> AsyncIterator[PGVectorStore]: engine_sync.init_vectorstore_table( DEFAULT_TABLE_SYNC, VECTOR_SIZE, @@ -365,16 +345,12 @@ def test_similarity_search_score(self, vs_custom: PGVectorStore) -> None: assert results[0][0] == Document(page_content="foo", id=ids[0]) assert results[0][1] == 0 - def test_similarity_search_by_vector( - self, vs_custom: PGVectorStore - ) -> None: + def test_similarity_search_by_vector(self, vs_custom: PGVectorStore) -> None: embedding = embeddings_service.embed_query("foo") results = vs_custom.similarity_search_by_vector(embedding) assert len(results) == 4 assert results[0] == Document(page_content="foo", id=ids[0]) - result = vs_custom.similarity_search_with_score_by_vector( - embedding=embedding - ) + result = vs_custom.similarity_search_with_score_by_vector(embedding=embedding) assert result[0][0] == Document(page_content="foo", id=ids[0]) assert result[0][1] == 0 @@ -414,12 +390,8 @@ def test_sync_vectorstore_with_metadata_filters( ) -> None: """Test end to end construction and search.""" - docs = vs_custom_filter_sync.similarity_search( - "meow", k=5, filter=test_filter - ) - assert [ - doc.metadata["code"] for doc in docs - ] == expected_ids, test_filter + docs = vs_custom_filter_sync.similarity_search("meow", k=5, filter=test_filter) + assert [doc.metadata["code"] for doc in docs] == expected_ids, test_filter @pytest.mark.parametrize("test_filter", NEGATIVE_TEST_CASES) def test_metadata_filter_negative_tests( From 2d6565b19a12e836348b776e2f8c5bc655f08bca Mon Sep 17 00:00:00 2001 From: Averi Kitsch Date: Mon, 7 Apr 2025 19:07:48 -0700 Subject: [PATCH 3/5] lint --- langchain_postgres/v2/async_vectorstore.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/langchain_postgres/v2/async_vectorstore.py b/langchain_postgres/v2/async_vectorstore.py index 1fdd9ea8..e99e2455 100644 --- a/langchain_postgres/v2/async_vectorstore.py +++ b/langchain_postgres/v2/async_vectorstore.py @@ -557,14 +557,14 @@ async def __query_collection( filter_dict = None if filter and isinstance(filter, dict): safe_filter, filter_dict = self._create_filter_clause(filter) - filter = f"WHERE {safe_filter}" if safe_filter else "" + param_filter = f"WHERE {safe_filter}" if safe_filter else "" inline_embed_func = getattr(self.embedding_service, "embed_query_inline", None) if not embedding and callable(inline_embed_func) and "query" in kwargs: query_embedding = self.embedding_service.embed_query_inline(kwargs["query"]) # type: ignore else: query_embedding = f"{[float(dimension) for dimension in embedding]}" stmt = f"""SELECT {column_names}, {search_function}("{self.embedding_column}", :query_embedding) as distance - FROM "{self.schema_name}"."{self.table_name}" {filter} ORDER BY "{self.embedding_column}" {operator} :query_embedding LIMIT :k; + FROM "{self.schema_name}"."{self.table_name}" {param_filter} ORDER BY "{self.embedding_column}" {operator} :query_embedding LIMIT :k; """ param_dict = {"query_embedding": query_embedding, "k": k} if filter_dict: From bbbe10eb64e4603b4b22c68647e6821f08a64aec Mon Sep 17 00:00:00 2001 From: Averi Kitsch Date: Mon, 7 Apr 2025 19:20:24 -0700 Subject: [PATCH 4/5] Update test_async_pg_vectorstore_search.py --- tests/unit_tests/v2/test_async_pg_vectorstore_search.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit_tests/v2/test_async_pg_vectorstore_search.py b/tests/unit_tests/v2/test_async_pg_vectorstore_search.py index 4d705ff3..8e5e371c 100644 --- a/tests/unit_tests/v2/test_async_pg_vectorstore_search.py +++ b/tests/unit_tests/v2/test_async_pg_vectorstore_search.py @@ -68,7 +68,7 @@ async def engine(self) -> AsyncIterator[PGEngine]: yield engine await aexecute(engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE}") await aexecute(engine, f"DROP TABLE IF EXISTS {CUSTOM_TABLE}") - # await aexecute(engine, f"DROP TABLE IF EXISTS {CUSTOM_FILTER_TABLE}") + await aexecute(engine, f"DROP TABLE IF EXISTS {CUSTOM_FILTER_TABLE}") await engine.close() @pytest_asyncio.fixture(scope="class") From 4a6a48bf96836fed7a87eea4630aa079733f7f6e Mon Sep 17 00:00:00 2001 From: Averi Kitsch Date: Mon, 7 Apr 2025 19:25:30 -0700 Subject: [PATCH 5/5] Update async_vectorstore.py --- langchain_postgres/v2/async_vectorstore.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/langchain_postgres/v2/async_vectorstore.py b/langchain_postgres/v2/async_vectorstore.py index e99e2455..dde47f40 100644 --- a/langchain_postgres/v2/async_vectorstore.py +++ b/langchain_postgres/v2/async_vectorstore.py @@ -957,8 +957,6 @@ def _handle_field_filter( if operator in COMPARISONS_TO_NATIVE: # Then we implement an equality filter # native is trusted input - # if isinstance(filter_value, str): - # filter_value = f"'{filter_value}'" native = COMPARISONS_TO_NATIVE[operator] id = str(uuid.uuid4()).split("-")[0] return f"{field} {native} :{field}_{id}", {f"{field}_{id}": filter_value}