diff --git a/langchain_postgres/v2/async_vectorstore.py b/langchain_postgres/v2/async_vectorstore.py index 329702f7..264fd346 100644 --- a/langchain_postgres/v2/async_vectorstore.py +++ b/langchain_postgres/v2/async_vectorstore.py @@ -14,6 +14,7 @@ from sqlalchemy.ext.asyncio import AsyncEngine from .engine import PGEngine +from .hybrid_search_config import HybridSearchConfig from .indexes import ( DEFAULT_DISTANCE_STRATEGY, DEFAULT_INDEX_NAME_SUFFIX, @@ -77,6 +78,7 @@ def __init__( fetch_k: int = 20, lambda_mult: float = 0.5, index_query_options: Optional[QueryOptions] = None, + hybrid_search_config: Optional[HybridSearchConfig] = None, ): """AsyncPGVectorStore constructor. Args: @@ -95,6 +97,7 @@ def __init__( fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. index_query_options (QueryOptions): Index query option. + hybrid_search_config (HybridSearchConfig): Hybrid search configuration. Defaults to None. Raises: @@ -119,6 +122,7 @@ def __init__( self.fetch_k = fetch_k self.lambda_mult = lambda_mult self.index_query_options = index_query_options + self.hybrid_search_config = hybrid_search_config @classmethod async def create( @@ -139,6 +143,7 @@ async def create( fetch_k: int = 20, lambda_mult: float = 0.5, index_query_options: Optional[QueryOptions] = None, + hybrid_search_config: Optional[HybridSearchConfig] = None, ) -> AsyncPGVectorStore: """Create an AsyncPGVectorStore instance. @@ -158,6 +163,7 @@ async def create( fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. index_query_options (QueryOptions): Index query option. + hybrid_search_config (HybridSearchConfig): Hybrid search configuration. Defaults to None. Returns: AsyncPGVectorStore @@ -193,6 +199,15 @@ async def create( raise ValueError( f"Content column, {content_column}, is type, {content_type}. It must be a type of character string." ) + if hybrid_search_config: + tsv_column_name = ( + hybrid_search_config.tsv_column + if hybrid_search_config.tsv_column + else content_column + "_tsv" + ) + if tsv_column_name not in columns or columns[tsv_column_name] != "tsvector": + # mark tsv_column as empty because there is no TSV column in table + hybrid_search_config.tsv_column = "" if embedding_column not in columns: raise ValueError(f"Embedding column, {embedding_column}, does not exist.") if columns[embedding_column] != "USER-DEFINED": @@ -236,6 +251,7 @@ async def create( fetch_k=fetch_k, lambda_mult=lambda_mult, index_query_options=index_query_options, + hybrid_search_config=hybrid_search_config, ) @property @@ -273,7 +289,12 @@ async def aadd_embeddings( if len(self.metadata_columns) > 0 else "" ) - insert_stmt = f'INSERT INTO "{self.schema_name}"."{self.table_name}"("{self.id_column}", "{self.content_column}", "{self.embedding_column}"{metadata_col_names}' + hybrid_search_column = ( + f', "{self.hybrid_search_config.tsv_column}"' + if self.hybrid_search_config and self.hybrid_search_config.tsv_column + else "" + ) + insert_stmt = f'INSERT INTO "{self.schema_name}"."{self.table_name}"("{self.id_column}", "{self.content_column}", "{self.embedding_column}"{hybrid_search_column}{metadata_col_names}' values = { "id": id, "content": content, @@ -284,6 +305,14 @@ async def aadd_embeddings( if not embedding and can_inline_embed: values_stmt = f"VALUES (:id, :content, {self.embedding_service.embed_query_inline(content)}" # type: ignore + if self.hybrid_search_config and self.hybrid_search_config.tsv_column: + lang = ( + f"'{self.hybrid_search_config.tsv_lang}'," + if self.hybrid_search_config.tsv_lang + else "" + ) + values_stmt += f", to_tsvector({lang} :tsv_content)" + values["tsv_content"] = content # Add metadata extra = copy.deepcopy(metadata) for metadata_column in self.metadata_columns: @@ -308,6 +337,9 @@ async def aadd_embeddings( upsert_stmt = f' ON CONFLICT ("{self.id_column}") DO UPDATE SET "{self.content_column}" = EXCLUDED."{self.content_column}", "{self.embedding_column}" = EXCLUDED."{self.embedding_column}"' + if self.hybrid_search_config and self.hybrid_search_config.tsv_column: + upsert_stmt += f', "{self.hybrid_search_config.tsv_column}" = EXCLUDED."{self.hybrid_search_config.tsv_column}"' + if self.metadata_json_column: upsert_stmt += f', "{self.metadata_json_column}" = EXCLUDED."{self.metadata_json_column}"' @@ -408,6 +440,7 @@ async def afrom_texts( # type: ignore[override] fetch_k: int = 20, lambda_mult: float = 0.5, index_query_options: Optional[QueryOptions] = None, + hybrid_search_config: Optional[HybridSearchConfig] = None, **kwargs: Any, ) -> AsyncPGVectorStore: """Create an AsyncPGVectorStore instance from texts. @@ -453,6 +486,7 @@ async def afrom_texts( # type: ignore[override] fetch_k=fetch_k, lambda_mult=lambda_mult, index_query_options=index_query_options, + hybrid_search_config=hybrid_search_config, ) await vs.aadd_texts(texts, metadatas=metadatas, ids=ids, **kwargs) return vs @@ -478,6 +512,7 @@ async def afrom_documents( # type: ignore[override] fetch_k: int = 20, lambda_mult: float = 0.5, index_query_options: Optional[QueryOptions] = None, + hybrid_search_config: Optional[HybridSearchConfig] = None, **kwargs: Any, ) -> AsyncPGVectorStore: """Create an AsyncPGVectorStore instance from documents. @@ -524,6 +559,7 @@ async def afrom_documents( # type: ignore[override] fetch_k=fetch_k, lambda_mult=lambda_mult, index_query_options=index_query_options, + hybrid_search_config=hybrid_search_config, ) texts = [doc.page_content for doc in documents] metadatas = [doc.metadata for doc in documents] @@ -538,16 +574,30 @@ async def __query_collection( filter: Optional[dict] = None, **kwargs: Any, ) -> Sequence[RowMapping]: - """Perform similarity search query on database.""" - k = k if k else self.k + """ + Perform similarity search (or hybrid search) query on database. + Queries might be slow if the hybrid search column does not exist. + For best hybrid search performance, consider creating a TSV column + and adding GIN index. + """ + if not k: + k = ( + max( + self.k, + self.hybrid_search_config.primary_top_k, + self.hybrid_search_config.secondary_top_k, + ) + if self.hybrid_search_config + else self.k + ) operator = self.distance_strategy.operator search_function = self.distance_strategy.search_function - columns = self.metadata_columns + [ + columns = [ self.id_column, self.content_column, self.embedding_column, - ] + ] + self.metadata_columns if self.metadata_json_column: columns.append(self.metadata_json_column) @@ -557,7 +607,7 @@ async def __query_collection( filter_dict = None if filter and isinstance(filter, dict): safe_filter, filter_dict = self._create_filter_clause(filter) - param_filter = f"WHERE {safe_filter}" if safe_filter else "" + inline_embed_func = getattr(self.embedding_service, "embed_query_inline", None) if not embedding and callable(inline_embed_func) and "query" in kwargs: query_embedding = self.embedding_service.embed_query_inline(kwargs["query"]) # type: ignore @@ -565,8 +615,9 @@ async def __query_collection( else: query_embedding = f"{[float(dimension) for dimension in embedding]}" embedding_data_string = ":query_embedding" - stmt = f"""SELECT {column_names}, {search_function}("{self.embedding_column}", {embedding_data_string}) as distance - FROM "{self.schema_name}"."{self.table_name}" {param_filter} ORDER BY "{self.embedding_column}" {operator} {embedding_data_string} LIMIT :k; + where_filters = f"WHERE {safe_filter}" if safe_filter else "" + dense_query_stmt = f"""SELECT {column_names}, {search_function}("{self.embedding_column}", {embedding_data_string}) as distance + FROM "{self.schema_name}"."{self.table_name}" {where_filters} ORDER BY "{self.embedding_column}" {operator} {embedding_data_string} LIMIT :k; """ param_dict = {"query_embedding": query_embedding, "k": k} if filter_dict: @@ -577,15 +628,51 @@ async def __query_collection( for query_option in self.index_query_options.to_parameter(): query_options_stmt = f"SET LOCAL {query_option};" await conn.execute(text(query_options_stmt)) - result = await conn.execute(text(stmt), param_dict) + result = await conn.execute(text(dense_query_stmt), param_dict) result_map = result.mappings() - results = result_map.fetchall() + dense_results = result_map.fetchall() else: async with self.engine.connect() as conn: - result = await conn.execute(text(stmt), param_dict) + result = await conn.execute(text(dense_query_stmt), param_dict) + result_map = result.mappings() + dense_results = result_map.fetchall() + + hybrid_search_config = kwargs.get( + "hybrid_search_config", self.hybrid_search_config + ) + fts_query = ( + hybrid_search_config.fts_query + if hybrid_search_config and hybrid_search_config.fts_query + else kwargs.get("fts_query", "") + ) + if hybrid_search_config and fts_query: + hybrid_search_config.fusion_function_parameters["fetch_top_k"] = k + # do the sparse query + lang = ( + f"'{hybrid_search_config.tsv_lang}'," + if hybrid_search_config.tsv_lang + else "" + ) + query_tsv = f"plainto_tsquery({lang} :fts_query)" + param_dict["fts_query"] = fts_query + if hybrid_search_config.tsv_column: + content_tsv = f'"{hybrid_search_config.tsv_column}"' + else: + content_tsv = f'to_tsvector({lang} "{self.content_column}")' + and_filters = f"AND ({safe_filter})" if safe_filter else "" + sparse_query_stmt = f'SELECT {column_names}, ts_rank_cd({content_tsv}, {query_tsv}) as distance FROM "{self.schema_name}"."{self.table_name}" WHERE {content_tsv} @@ {query_tsv} {and_filters} ORDER BY distance desc LIMIT {hybrid_search_config.secondary_top_k};' + async with self.engine.connect() as conn: + result = await conn.execute(text(sparse_query_stmt), param_dict) result_map = result.mappings() - results = result_map.fetchall() - return results + sparse_results = result_map.fetchall() + + combined_results = hybrid_search_config.fusion_function( + dense_results, + sparse_results, + **hybrid_search_config.fusion_function_parameters, + ) + return combined_results + return dense_results async def asimilarity_search( self, @@ -603,6 +690,14 @@ async def asimilarity_search( ) kwargs["query"] = query + # add fts_query to hybrid_search_config + hybrid_search_config = kwargs.get( + "hybrid_search_config", self.hybrid_search_config + ) + if hybrid_search_config and not hybrid_search_config.fts_query: + hybrid_search_config.fts_query = query + kwargs["hybrid_search_config"] = hybrid_search_config + return await self.asimilarity_search_by_vector( embedding=embedding, k=k, filter=filter, **kwargs ) @@ -634,6 +729,14 @@ async def asimilarity_search_with_score( ) kwargs["query"] = query + # add fts_query to hybrid_search_config + hybrid_search_config = kwargs.get( + "hybrid_search_config", self.hybrid_search_config + ) + if hybrid_search_config and not hybrid_search_config.fts_query: + hybrid_search_config.fts_query = query + kwargs["hybrid_search_config"] = hybrid_search_config + docs = await self.asimilarity_search_with_score_by_vector( embedding=embedding, k=k, filter=filter, **kwargs ) @@ -778,6 +881,41 @@ async def amax_marginal_relevance_search_with_score_by_vector( return [r for i, r in enumerate(documents_with_scores) if i in mmr_selected] + async def aapply_hybrid_search_index( + self, + concurrently: bool = False, + ) -> None: + """Creates a TSV index in the vector store table if possible.""" + if ( + not self.hybrid_search_config + or not self.hybrid_search_config.index_type + or not self.hybrid_search_config.index_name + ): + # no index needs to be created + raise ValueError("Hybrid Search Config cannot create index.") + + lang = ( + f"'{self.hybrid_search_config.tsv_lang}'," + if self.hybrid_search_config.tsv_lang + else "" + ) + tsv_column_name = ( + self.hybrid_search_config.tsv_column + if self.hybrid_search_config.tsv_column + else f"to_tsvector({lang} {self.content_column})" + ) + tsv_index_query = f'CREATE INDEX {"CONCURRENTLY" if concurrently else ""} {self.hybrid_search_config.index_name} ON "{self.schema_name}"."{self.table_name}" USING {self.hybrid_search_config.index_type}({tsv_column_name});' + if concurrently: + async with self.engine.connect() as conn: + autocommit_conn = await conn.execution_options( + isolation_level="AUTOCOMMIT" + ) + await autocommit_conn.execute(text(tsv_index_query)) + else: + async with self.engine.connect() as conn: + await conn.execute(text(tsv_index_query)) + await conn.commit() + async def aapply_vector_index( self, index: BaseIndex, @@ -806,6 +944,7 @@ async def aapply_vector_index( index.name = self.table_name + DEFAULT_INDEX_NAME_SUFFIX name = index.name stmt = f'CREATE INDEX {"CONCURRENTLY" if concurrently else ""} "{name}" ON "{self.schema_name}"."{self.table_name}" USING {index.index_type} ({self.embedding_column} {function}) {params} {filter};' + if concurrently: async with self.engine.connect() as conn: autocommit_conn = await conn.execution_options( diff --git a/langchain_postgres/v2/engine.py b/langchain_postgres/v2/engine.py index c2a0d931..6067ba23 100644 --- a/langchain_postgres/v2/engine.py +++ b/langchain_postgres/v2/engine.py @@ -9,6 +9,8 @@ from sqlalchemy.engine import URL from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine +from .hybrid_search_config import HybridSearchConfig + T = TypeVar("T") @@ -156,6 +158,7 @@ async def _ainit_vectorstore_table( id_column: Union[str, Column, ColumnDict] = "langchain_id", overwrite_existing: bool = False, store_metadata: bool = True, + hybrid_search_config: Optional[HybridSearchConfig] = None, ) -> None: """ Create a table for saving of vectors to be used with PGVectorStore. @@ -178,6 +181,8 @@ async def _ainit_vectorstore_table( overwrite_existing (bool): Whether to drop existing table. Default: False. store_metadata (bool): Whether to store metadata in the table. Default: True. + hybrid_search_config (HybridSearchConfig): Hybrid search configuration. + Default: None. Raises: :class:`DuplicateTableError `: if table already exists. @@ -186,6 +191,7 @@ async def _ainit_vectorstore_table( schema_name = self._escape_postgres_identifier(schema_name) table_name = self._escape_postgres_identifier(table_name) + hybrid_search_default_column_name = content_column + "_tsv" content_column = self._escape_postgres_identifier(content_column) embedding_column = self._escape_postgres_identifier(embedding_column) if metadata_columns is None: @@ -226,10 +232,22 @@ async def _ainit_vectorstore_table( id_data_type = id_column["data_type"] id_column_name = id_column["name"] + hybrid_search_column = "" # Default is no TSV column for hybrid search + if hybrid_search_config: + hybrid_search_column_name = ( + hybrid_search_config.tsv_column or hybrid_search_default_column_name + ) + hybrid_search_column_name = self._escape_postgres_identifier( + hybrid_search_column_name + ) + hybrid_search_config.tsv_column = hybrid_search_column_name + hybrid_search_column = f',"{self._escape_postgres_identifier(hybrid_search_column_name)}" TSVECTOR NOT NULL' + query = f"""CREATE TABLE "{schema_name}"."{table_name}"( "{id_column_name}" {id_data_type} PRIMARY KEY, "{content_column}" TEXT NOT NULL, - "{embedding_column}" vector({vector_size}) NOT NULL""" + "{embedding_column}" vector({vector_size}) NOT NULL + {hybrid_search_column}""" for column in metadata_columns: if isinstance(column, Column): nullable = "NOT NULL" if not column.nullable else "" @@ -258,6 +276,7 @@ async def ainit_vectorstore_table( id_column: Union[str, Column, ColumnDict] = "langchain_id", overwrite_existing: bool = False, store_metadata: bool = True, + hybrid_search_config: Optional[HybridSearchConfig] = None, ) -> None: """ Create a table for saving of vectors to be used with PGVectorStore. @@ -280,6 +299,10 @@ async def ainit_vectorstore_table( overwrite_existing (bool): Whether to drop existing table. Default: False. store_metadata (bool): Whether to store metadata in the table. Default: True. + hybrid_search_config (HybridSearchConfig): Hybrid search configuration. + Note that queries might be slow if the hybrid search column does not exist. + For best hybrid search performance, consider creating a TSV column and adding GIN index. + Default: None. """ await self._run_as_async( self._ainit_vectorstore_table( @@ -293,6 +316,7 @@ async def ainit_vectorstore_table( id_column=id_column, overwrite_existing=overwrite_existing, store_metadata=store_metadata, + hybrid_search_config=hybrid_search_config, ) ) @@ -309,6 +333,7 @@ def init_vectorstore_table( id_column: Union[str, Column, ColumnDict] = "langchain_id", overwrite_existing: bool = False, store_metadata: bool = True, + hybrid_search_config: Optional[HybridSearchConfig] = None, ) -> None: """ Create a table for saving of vectors to be used with PGVectorStore. @@ -331,6 +356,10 @@ def init_vectorstore_table( overwrite_existing (bool): Whether to drop existing table. Default: False. store_metadata (bool): Whether to store metadata in the table. Default: True. + hybrid_search_config (HybridSearchConfig): Hybrid search configuration. + Note that queries might be slow if the hybrid search column does not exist. + For best hybrid search performance, consider creating a TSV column and adding GIN index. + Default: None. """ self._run_as_sync( self._ainit_vectorstore_table( @@ -344,6 +373,7 @@ def init_vectorstore_table( id_column=id_column, overwrite_existing=overwrite_existing, store_metadata=store_metadata, + hybrid_search_config=hybrid_search_config, ) ) @@ -354,7 +384,7 @@ async def _adrop_table( schema_name: str = "public", ) -> None: """Drop the vector store table""" - query = f'DROP TABLE "{schema_name}"."{table_name}";' + query = f'DROP TABLE IF EXISTS "{schema_name}"."{table_name}";' async with self._pool.connect() as conn: await conn.execute(text(query)) await conn.commit() diff --git a/langchain_postgres/v2/hybrid_search_config.py b/langchain_postgres/v2/hybrid_search_config.py new file mode 100644 index 00000000..7f6c2778 --- /dev/null +++ b/langchain_postgres/v2/hybrid_search_config.py @@ -0,0 +1,149 @@ +from abc import ABC +from dataclasses import dataclass, field +from typing import Any, Callable, Optional, Sequence + +from sqlalchemy import RowMapping + + +def weighted_sum_ranking( + primary_search_results: Sequence[RowMapping], + secondary_search_results: Sequence[RowMapping], + primary_results_weight: float = 0.5, + secondary_results_weight: float = 0.5, + fetch_top_k: int = 4, +) -> Sequence[dict[str, Any]]: + """ + Ranks documents using a weighted sum of scores from two sources. + + Args: + primary_search_results: A list of (document, distance) tuples from + the primary search. + secondary_search_results: A list of (document, distance) tuples from + the secondary search. + primary_results_weight: The weight for the primary source's scores. + Defaults to 0.5. + secondary_results_weight: The weight for the secondary source's scores. + Defaults to 0.5. + fetch_top_k: The number of documents to fetch after merging the results. + Defaults to 4. + + Returns: + A list of (document, distance) tuples, sorted by weighted_score in + descending order. + """ + + # stores computed metric with provided distance metric and weights + weighted_scores: dict[str, dict[str, Any]] = {} + + # Process results from primary source + for row in primary_search_results: + values = list(row.values()) + doc_id = str(values[0]) # first value is doc_id + distance = float(values[-1]) # type: ignore # last value is distance + row_values = dict(row) + row_values["distance"] = primary_results_weight * distance + weighted_scores[doc_id] = row_values + + # Process results from secondary source, + # adding to existing scores or creating new ones + for row in secondary_search_results: + values = list(row.values()) + doc_id = str(values[0]) # first value is doc_id + distance = float(values[-1]) # type: ignore # last value is distance + primary_score = ( + weighted_scores[doc_id]["distance"] if doc_id in weighted_scores else 0.0 + ) + row_values = dict(row) + row_values["distance"] = distance * secondary_results_weight + primary_score + weighted_scores[doc_id] = row_values + + # Sort the results by weighted score in descending order + ranked_results = sorted( + weighted_scores.values(), key=lambda item: item["distance"], reverse=True + ) + return ranked_results[:fetch_top_k] + + +def reciprocal_rank_fusion( + primary_search_results: Sequence[RowMapping], + secondary_search_results: Sequence[RowMapping], + rrf_k: float = 60, + fetch_top_k: int = 4, +) -> Sequence[dict[str, Any]]: + """ + Ranks documents using Reciprocal Rank Fusion (RRF) of scores from two sources. + + Args: + primary_search_results: A list of (document, distance) tuples from + the primary search. + secondary_search_results: A list of (document, distance) tuples from + the secondary search. + rrf_k: The RRF parameter k. + Defaults to 60. + fetch_top_k: The number of documents to fetch after merging the results. + Defaults to 4. + + Returns: + A list of (document_id, rrf_score) tuples, sorted by rrf_score + in descending order. + """ + rrf_scores: dict[str, dict[str, Any]] = {} + + # Process results from primary source + for rank, row in enumerate( + sorted(primary_search_results, key=lambda item: item["distance"], reverse=True) + ): + values = list(row.values()) + doc_id = str(values[0]) + row_values = dict(row) + primary_score = rrf_scores[doc_id]["distance"] if doc_id in rrf_scores else 0.0 + primary_score += 1.0 / (rank + rrf_k) + row_values["distance"] = primary_score + rrf_scores[doc_id] = row_values + + # Process results from secondary source + for rank, row in enumerate( + sorted( + secondary_search_results, key=lambda item: item["distance"], reverse=True + ) + ): + values = list(row.values()) + doc_id = str(values[0]) + row_values = dict(row) + secondary_score = ( + rrf_scores[doc_id]["distance"] if doc_id in rrf_scores else 0.0 + ) + secondary_score += 1.0 / (rank + rrf_k) + row_values["distance"] = secondary_score + rrf_scores[doc_id] = row_values + + # Sort the results by rrf score in descending order + # Sort the results by weighted score in descending order + ranked_results = sorted( + rrf_scores.values(), key=lambda item: item["distance"], reverse=True + ) + # Extract only the RowMapping for the top results + return ranked_results[:fetch_top_k] + + +@dataclass +class HybridSearchConfig(ABC): + """ + AlloyDB Vector Store Hybrid Search Config. + + Queries might be slow if the hybrid search column does not exist. + For best hybrid search performance, consider creating a TSV column + and adding GIN index. + """ + + tsv_column: Optional[str] = "" + tsv_lang: Optional[str] = "pg_catalog.english" + fts_query: Optional[str] = "" + fusion_function: Callable[ + [Sequence[RowMapping], Sequence[RowMapping], Any], Sequence[Any] + ] = weighted_sum_ranking # Updated default + fusion_function_parameters: dict[str, Any] = field(default_factory=dict) + primary_top_k: int = 4 + secondary_top_k: int = 4 + index_name: str = "langchain_tsv_index" + index_type: str = "GIN" diff --git a/langchain_postgres/v2/vectorstores.py b/langchain_postgres/v2/vectorstores.py index 1dc1be97..52224dbe 100644 --- a/langchain_postgres/v2/vectorstores.py +++ b/langchain_postgres/v2/vectorstores.py @@ -9,6 +9,7 @@ from .async_vectorstore import AsyncPGVectorStore from .engine import PGEngine +from .hybrid_search_config import HybridSearchConfig from .indexes import ( DEFAULT_DISTANCE_STRATEGY, BaseIndex, @@ -59,6 +60,7 @@ async def create( fetch_k: int = 20, lambda_mult: float = 0.5, index_query_options: Optional[QueryOptions] = None, + hybrid_search_config: Optional[HybridSearchConfig] = None, ) -> PGVectorStore: """Create an PGVectorStore instance. @@ -78,6 +80,7 @@ async def create( fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. index_query_options (QueryOptions): Index query option. + hybrid_search_config (HybridSearchConfig): Hybrid search configuration. Defaults to None. Returns: PGVectorStore @@ -98,6 +101,7 @@ async def create( fetch_k=fetch_k, lambda_mult=lambda_mult, index_query_options=index_query_options, + hybrid_search_config=hybrid_search_config, ) vs = await engine._run_as_async(coro) return cls(cls.__create_key, engine, vs) @@ -120,6 +124,7 @@ def create_sync( fetch_k: int = 20, lambda_mult: float = 0.5, index_query_options: Optional[QueryOptions] = None, + hybrid_search_config: Optional[HybridSearchConfig] = None, ) -> PGVectorStore: """Create an PGVectorStore instance. @@ -140,6 +145,7 @@ def create_sync( fetch_k (int, optional): Number of Documents to fetch to pass to MMR algorithm. Defaults to 20. lambda_mult (float, optional): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. index_query_options (Optional[QueryOptions], optional): Index query option. Defaults to None. + hybrid_search_config (HybridSearchConfig): Hybrid search configuration. Defaults to None. Returns: PGVectorStore @@ -160,6 +166,7 @@ def create_sync( fetch_k=fetch_k, lambda_mult=lambda_mult, index_query_options=index_query_options, + hybrid_search_config=hybrid_search_config, ) vs = engine._run_as_sync(coro) return cls(cls.__create_key, engine, vs) @@ -301,6 +308,7 @@ async def afrom_texts( # type: ignore[override] fetch_k: int = 20, lambda_mult: float = 0.5, index_query_options: Optional[QueryOptions] = None, + hybrid_search_config: Optional[HybridSearchConfig] = None, **kwargs: Any, ) -> PGVectorStore: """Create an PGVectorStore instance from texts. @@ -324,6 +332,7 @@ async def afrom_texts( # type: ignore[override] fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. index_query_options (QueryOptions): Index query option. + hybrid_search_config (HybridSearchConfig): Hybrid search configuration. Defaults to None. Raises: :class:`InvalidTextRepresentationError `: if the `ids` data type does not match that of the `id_column`. @@ -347,6 +356,7 @@ async def afrom_texts( # type: ignore[override] fetch_k=fetch_k, lambda_mult=lambda_mult, index_query_options=index_query_options, + hybrid_search_config=hybrid_search_config, ) await vs.aadd_texts(texts, metadatas=metadatas, ids=ids) return vs @@ -371,6 +381,7 @@ async def afrom_documents( # type: ignore[override] fetch_k: int = 20, lambda_mult: float = 0.5, index_query_options: Optional[QueryOptions] = None, + hybrid_search_config: Optional[HybridSearchConfig] = None, **kwargs: Any, ) -> PGVectorStore: """Create an PGVectorStore instance from documents. @@ -393,6 +404,7 @@ async def afrom_documents( # type: ignore[override] fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. index_query_options (QueryOptions): Index query option. + hybrid_search_config (HybridSearchConfig): Hybrid search configuration. Defaults to None. Raises: :class:`InvalidTextRepresentationError `: if the `ids` data type does not match that of the `id_column`. @@ -417,6 +429,7 @@ async def afrom_documents( # type: ignore[override] fetch_k=fetch_k, lambda_mult=lambda_mult, index_query_options=index_query_options, + hybrid_search_config=hybrid_search_config, ) await vs.aadd_documents(documents, ids=ids) return vs @@ -442,6 +455,7 @@ def from_texts( # type: ignore[override] fetch_k: int = 20, lambda_mult: float = 0.5, index_query_options: Optional[QueryOptions] = None, + hybrid_search_config: Optional[HybridSearchConfig] = None, **kwargs: Any, ) -> PGVectorStore: """Create an PGVectorStore instance from texts. @@ -465,6 +479,7 @@ def from_texts( # type: ignore[override] fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. index_query_options (QueryOptions): Index query option. + hybrid_search_config (HybridSearchConfig): Hybrid search configuration. Defaults to None. Raises: :class:`InvalidTextRepresentationError `: if the `ids` data type does not match that of the `id_column`. @@ -488,6 +503,7 @@ def from_texts( # type: ignore[override] fetch_k=fetch_k, lambda_mult=lambda_mult, index_query_options=index_query_options, + hybrid_search_config=hybrid_search_config, **kwargs, ) vs.add_texts(texts, metadatas=metadatas, ids=ids) @@ -513,6 +529,7 @@ def from_documents( # type: ignore[override] fetch_k: int = 20, lambda_mult: float = 0.5, index_query_options: Optional[QueryOptions] = None, + hybrid_search_config: Optional[HybridSearchConfig] = None, **kwargs: Any, ) -> PGVectorStore: """Create an PGVectorStore instance from documents. @@ -535,6 +552,7 @@ def from_documents( # type: ignore[override] fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. index_query_options (QueryOptions): Index query option. + hybrid_search_config (HybridSearchConfig): Hybrid search configuration. Defaults to None. Raises: :class:`InvalidTextRepresentationError `: if the `ids` data type does not match that of the `id_column`. @@ -558,6 +576,7 @@ def from_documents( # type: ignore[override] fetch_k=fetch_k, lambda_mult=lambda_mult, index_query_options=index_query_options, + hybrid_search_config=hybrid_search_config, **kwargs, ) vs.add_documents(documents, ids=ids) diff --git a/tests/unit_tests/v2/test_async_pg_vectorstore_index.py b/tests/unit_tests/v2/test_async_pg_vectorstore_index.py index 5fa3d252..8585bcd0 100644 --- a/tests/unit_tests/v2/test_async_pg_vectorstore_index.py +++ b/tests/unit_tests/v2/test_async_pg_vectorstore_index.py @@ -10,15 +10,13 @@ from langchain_postgres import PGEngine from langchain_postgres.v2.async_vectorstore import AsyncPGVectorStore -from langchain_postgres.v2.indexes import ( - DistanceStrategy, - HNSWIndex, - IVFFlatIndex, -) +from langchain_postgres.v2.hybrid_search_config import HybridSearchConfig +from langchain_postgres.v2.indexes import DistanceStrategy, HNSWIndex, IVFFlatIndex from tests.utils import VECTORSTORE_CONNECTION_STRING as CONNECTION_STRING uuid_str = str(uuid.uuid4()).replace("-", "_") DEFAULT_TABLE = "default" + uuid_str +DEFAULT_HYBRID_TABLE = "hybrid" + uuid_str DEFAULT_INDEX_NAME = "index" + uuid_str VECTOR_SIZE = 768 SIMPLE_TABLE = "default_table" @@ -55,8 +53,10 @@ class TestIndex: async def engine(self) -> AsyncIterator[PGEngine]: engine = PGEngine.from_connection_string(url=CONNECTION_STRING) yield engine - await aexecute(engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE}") - await aexecute(engine, f"DROP TABLE IF EXISTS {SIMPLE_TABLE}") + + await engine._adrop_table(DEFAULT_TABLE) + await engine._adrop_table(DEFAULT_HYBRID_TABLE) + await engine._adrop_table(SIMPLE_TABLE) await engine.close() @pytest_asyncio.fixture(scope="class") @@ -73,7 +73,9 @@ async def vs(self, engine: PGEngine) -> AsyncIterator[AsyncPGVectorStore]: yield vs async def test_apply_default_name_vector_index(self, engine: PGEngine) -> None: - await engine._ainit_vectorstore_table(SIMPLE_TABLE, VECTOR_SIZE) + await engine._ainit_vectorstore_table( + SIMPLE_TABLE, VECTOR_SIZE, overwrite_existing=True + ) vs = await AsyncPGVectorStore.create( engine, embedding_service=embeddings_service, @@ -92,6 +94,61 @@ async def test_aapply_vector_index(self, vs: AsyncPGVectorStore) -> None: assert await vs.is_valid_index(DEFAULT_INDEX_NAME) await vs.adrop_vector_index(DEFAULT_INDEX_NAME) + async def test_aapply_vector_index_non_hybrid_search_vs( + self, vs: AsyncPGVectorStore + ) -> None: + with pytest.raises(ValueError): + await vs.aapply_hybrid_search_index() + + async def test_aapply_hybrid_search_index_table_without_tsv_column( + self, engine: PGEngine, vs: AsyncPGVectorStore + ) -> None: + # overwriting vs to get a hybrid vs + tsv_index_name = "tsv_index_on_table_without_tsv_column_" + uuid_str + vs = await AsyncPGVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=DEFAULT_TABLE, + hybrid_search_config=HybridSearchConfig(index_name=tsv_index_name), + ) + is_valid_index = await vs.is_valid_index(tsv_index_name) + assert is_valid_index == False + await vs.aapply_hybrid_search_index() + assert await vs.is_valid_index(tsv_index_name) + await vs.adrop_vector_index(tsv_index_name) + is_valid_index = await vs.is_valid_index(tsv_index_name) + assert is_valid_index == False + + async def test_aapply_hybrid_search_index_table_with_tsv_column( + self, engine: PGEngine + ) -> None: + tsv_index_name = "tsv_index_on_table_without_tsv_column_" + uuid_str + config = HybridSearchConfig( + tsv_column="tsv_column", + tsv_lang="pg_catalog.english", + index_name=tsv_index_name, + ) + await engine._ainit_vectorstore_table( + DEFAULT_HYBRID_TABLE, + VECTOR_SIZE, + hybrid_search_config=config, + ) + vs = await AsyncPGVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=DEFAULT_HYBRID_TABLE, + hybrid_search_config=config, + ) + is_valid_index = await vs.is_valid_index(tsv_index_name) + assert is_valid_index == False + await vs.aapply_hybrid_search_index() + assert await vs.is_valid_index(tsv_index_name) + await vs.areindex(tsv_index_name) + assert await vs.is_valid_index(tsv_index_name) + await vs.adrop_vector_index(tsv_index_name) + is_valid_index = await vs.is_valid_index(tsv_index_name) + assert is_valid_index == False + async def test_areindex(self, vs: AsyncPGVectorStore) -> None: if not await vs.is_valid_index(DEFAULT_INDEX_NAME): index = HNSWIndex(name=DEFAULT_INDEX_NAME) diff --git a/tests/unit_tests/v2/test_async_pg_vectorstore_search.py b/tests/unit_tests/v2/test_async_pg_vectorstore_search.py index 72f91d80..16c70fdd 100644 --- a/tests/unit_tests/v2/test_async_pg_vectorstore_search.py +++ b/tests/unit_tests/v2/test_async_pg_vectorstore_search.py @@ -10,6 +10,11 @@ from langchain_postgres import Column, PGEngine from langchain_postgres.v2.async_vectorstore import AsyncPGVectorStore +from langchain_postgres.v2.hybrid_search_config import ( + HybridSearchConfig, + reciprocal_rank_fusion, + weighted_sum_ranking, +) from langchain_postgres.v2.indexes import DistanceStrategy, HNSWQueryOptions from tests.unit_tests.fixtures.metadata_filtering_data import ( FILTERING_TEST_CASES, @@ -19,6 +24,8 @@ DEFAULT_TABLE = "default" + str(uuid.uuid4()).replace("-", "_") CUSTOM_TABLE = "custom" + str(uuid.uuid4()).replace("-", "_") +HYBRID_SEARCH_TABLE1 = "test_table_hybrid1" + str(uuid.uuid4()).replace("-", "_") +HYBRID_SEARCH_TABLE2 = "test_table_hybrid2" + str(uuid.uuid4()).replace("-", "_") CUSTOM_FILTER_TABLE = "custom_filter" + str(uuid.uuid4()).replace("-", "_") VECTOR_SIZE = 768 sync_method_exception_str = "Sync methods are not implemented for AsyncPGVectorStore. Use PGVectorStore interface instead." @@ -41,6 +48,18 @@ filter_docs = [ Document(page_content=texts[i], metadata=METADATAS[i]) for i in range(len(texts)) ] +# Documents designed for hybrid search testing +hybrid_docs_content = { + "hs_doc_apple_fruit": "An apple is a sweet and edible fruit produced by an apple tree. Apples are very common.", + "hs_doc_apple_tech": "Apple Inc. is a multinational technology company. Their latest tech is amazing.", + "hs_doc_orange_fruit": "The orange is the fruit of various citrus species. Oranges are tasty.", + "hs_doc_generic_tech": "Technology drives innovation in the modern world. Tech is evolving.", + "hs_doc_unrelated_cat": "A fluffy cat sat on a mat quietly observing a mouse.", +} +hybrid_docs = [ + Document(page_content=content, metadata={"doc_id_key": key}) + for key, content in hybrid_docs_content.items() +] def get_env_var(key: str, desc: str) -> str: @@ -69,6 +88,8 @@ async def engine(self) -> AsyncIterator[PGEngine]: await engine.adrop_table(DEFAULT_TABLE) await engine.adrop_table(CUSTOM_TABLE) await engine.adrop_table(CUSTOM_FILTER_TABLE) + await engine.adrop_table(HYBRID_SEARCH_TABLE1) + await engine.adrop_table(HYBRID_SEARCH_TABLE2) await engine.close() @pytest_asyncio.fixture(scope="class") @@ -111,6 +132,51 @@ async def vs_custom(self, engine: PGEngine) -> AsyncIterator[AsyncPGVectorStore] await vs_custom.aadd_documents(docs, ids=ids) yield vs_custom + @pytest_asyncio.fixture(scope="class") + async def vs_hybrid_search_with_tsv_column( + self, engine: PGEngine + ) -> AsyncIterator[AsyncPGVectorStore]: + hybrid_search_config = HybridSearchConfig( + tsv_column="mycontent_tsv", + tsv_lang="pg_catalog.english", + fts_query="my_fts_query", + fusion_function=reciprocal_rank_fusion, + fusion_function_parameters={ + "rrf_k": 60, + "fetch_top_k": 10, + }, + ) + await engine._ainit_vectorstore_table( + HYBRID_SEARCH_TABLE1, + VECTOR_SIZE, + id_column=Column("myid", "TEXT"), + content_column="mycontent", + embedding_column="myembedding", + metadata_columns=[ + Column("page", "TEXT"), + Column("source", "TEXT"), + Column("doc_id_key", "TEXT"), + ], + metadata_json_column="mymetadata", # ignored + store_metadata=False, + hybrid_search_config=hybrid_search_config, + ) + + vs_custom = await AsyncPGVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=HYBRID_SEARCH_TABLE1, + id_column="myid", + content_column="mycontent", + embedding_column="myembedding", + metadata_json_column="mymetadata", + metadata_columns=["doc_id_key"], + index_query_options=HNSWQueryOptions(ef_search=1), + hybrid_search_config=hybrid_search_config, + ) + await vs_custom.aadd_documents(hybrid_docs) + yield vs_custom + @pytest_asyncio.fixture(scope="class") async def vs_custom_filter( self, engine: PGEngine @@ -303,3 +369,360 @@ async def test_vectorstore_with_metadata_filters( "meow", k=5, filter=test_filter ) assert [doc.metadata["code"] for doc in docs] == expected_ids, test_filter + + async def test_asimilarity_hybrid_search(self, vs: AsyncPGVectorStore) -> None: + results = await vs.asimilarity_search( + "foo", k=1, hybrid_search_config=HybridSearchConfig() + ) + assert len(results) == 1 + assert results == [Document(page_content="foo", id=ids[0])] + + results = await vs.asimilarity_search( + "bar", + k=1, + hybrid_search_config=HybridSearchConfig(), + ) + assert results[0] == Document(page_content="bar", id=ids[1]) + + results = await vs.asimilarity_search( + "foo", + k=1, + filter={"content": {"$ne": "baz"}}, + hybrid_search_config=HybridSearchConfig( + fusion_function=weighted_sum_ranking, + fusion_function_parameters={ + "primary_results_weight": 0.1, + "secondary_results_weight": 0.9, + "fetch_top_k": 10, + }, + primary_top_k=1, + secondary_top_k=1, + ), + ) + assert results == [Document(page_content="foo", id=ids[0])] + + async def test_asimilarity_hybrid_search_rrk(self, vs: AsyncPGVectorStore) -> None: + results = await vs.asimilarity_search( + "foo", + k=1, + hybrid_search_config=HybridSearchConfig( + fusion_function=reciprocal_rank_fusion + ), + ) + assert len(results) == 1 + assert results == [Document(page_content="foo", id=ids[0])] + + results = await vs.asimilarity_search( + "bar", + k=1, + filter={"content": {"$ne": "baz"}}, + hybrid_search_config=HybridSearchConfig( + fusion_function=reciprocal_rank_fusion, + fusion_function_parameters={ + "rrf_k": 100, + "fetch_top_k": 10, + }, + primary_top_k=1, + secondary_top_k=1, + ), + ) + assert results == [Document(page_content="bar", id=ids[1])] + + async def test_hybrid_search_weighted_sum_default( + self, + vs_hybrid_search_with_tsv_column: AsyncPGVectorStore, + ) -> None: + """Test hybrid search with default weighted sum (0.5 vector, 0.5 FTS).""" + query = "apple" # Should match "apple" in FTS and vector + + # The vs_hybrid_search_with_tsv_column instance is already configured for hybrid search. + # Default fusion is weighted_sum_ranking with 0.5/0.5 weights. + # fts_query will default to the main query. + results_with_scores = ( + await vs_hybrid_search_with_tsv_column.asimilarity_search_with_score( + query, k=3 + ) + ) + + assert len(results_with_scores) > 1 + result_ids = [doc.metadata["doc_id_key"] for doc, score in results_with_scores] + + # Expect "hs_doc_apple_fruit" and "hs_doc_apple_tech" to be highly ranked. + assert "hs_doc_apple_fruit" in result_ids + + # Scores should be floats (fused scores) + for doc, score in results_with_scores: + assert isinstance(score, float) + + # Check if sorted by score (descending for weighted_sum_ranking with positive scores) + assert results_with_scores[0][1] >= results_with_scores[1][1] + + async def test_hybrid_search_weighted_sum_vector_bias( + self, + vs_hybrid_search_with_tsv_column: AsyncPGVectorStore, + ) -> None: + """Test weighted sum with higher weight for vector results.""" + query = "Apple Inc technology" # More specific for vector similarity + + config = HybridSearchConfig( + tsv_column="mycontent_tsv", # Must match table setup + fusion_function_parameters={ + "primary_results_weight": 0.8, # Vector bias + "secondary_results_weight": 0.2, + }, + # fts_query will default to main query + ) + results = await vs_hybrid_search_with_tsv_column.asimilarity_search( + query, k=2, hybrid_search_config=config + ) + result_ids = [doc.metadata["doc_id_key"] for doc in results] + + assert len(result_ids) > 0 + assert result_ids[0] == "hs_doc_orange_fruit" + + async def test_hybrid_search_weighted_sum_fts_bias( + self, + vs_hybrid_search_with_tsv_column: AsyncPGVectorStore, + ) -> None: + """Test weighted sum with higher weight for FTS results.""" + query = "fruit common tasty" # Strong FTS signal for fruit docs + + config = HybridSearchConfig( + tsv_column="mycontent_tsv", + fusion_function=weighted_sum_ranking, + fusion_function_parameters={ + "primary_results_weight": 0.01, + "secondary_results_weight": 0.99, # FTS bias + }, + ) + results = await vs_hybrid_search_with_tsv_column.asimilarity_search( + query, k=2, hybrid_search_config=config + ) + result_ids = [doc.metadata["doc_id_key"] for doc in results] + + assert len(result_ids) == 2 + assert "hs_doc_apple_fruit" in result_ids + + async def test_hybrid_search_reciprocal_rank_fusion( + self, + vs_hybrid_search_with_tsv_column: AsyncPGVectorStore, + ) -> None: + """Test hybrid search with Reciprocal Rank Fusion.""" + query = "technology company" + + # Configure RRF. primary_top_k and secondary_top_k control inputs to fusion. + # fusion_function_parameters.fetch_top_k controls output count from RRF. + config = HybridSearchConfig( + tsv_column="mycontent_tsv", + fusion_function=reciprocal_rank_fusion, + primary_top_k=3, # How many dense results to consider + secondary_top_k=3, # How many sparse results to consider + fusion_function_parameters={ + "rrf_k": 60, + "fetch_top_k": 2, + }, # RRF specific params + ) + # The `k` in asimilarity_search here is the final desired number of results, + # which should align with fusion_function_parameters.fetch_top_k for RRF. + results = await vs_hybrid_search_with_tsv_column.asimilarity_search( + query, k=2, hybrid_search_config=config + ) + result_ids = [doc.metadata["doc_id_key"] for doc in results] + + assert len(result_ids) == 2 + # "hs_doc_apple_tech" (FTS: technology, company; Vector: Apple Inc technology) + # "hs_doc_generic_tech" (FTS: technology; Vector: Technology drives innovation) + # RRF should combine these ranks. "hs_doc_apple_tech" is likely higher. + assert "hs_doc_apple_tech" in result_ids + assert result_ids[0] == "hs_doc_apple_tech" # Stronger combined signal + + async def test_hybrid_search_explicit_fts_query( + self, vs_hybrid_search_with_tsv_column: AsyncPGVectorStore + ) -> None: + """Test hybrid search when fts_query in HybridSearchConfig is different from main query.""" + main_vector_query = "Apple Inc." # For vector search + fts_specific_query = "fruit" # For FTS + + config = HybridSearchConfig( + tsv_column="mycontent_tsv", + fts_query=fts_specific_query, # Override FTS query + fusion_function_parameters={ # Using default weighted_sum_ranking + "primary_results_weight": 0.5, + "secondary_results_weight": 0.5, + }, + ) + results = await vs_hybrid_search_with_tsv_column.asimilarity_search( + main_vector_query, k=2, hybrid_search_config=config + ) + result_ids = [doc.metadata["doc_id_key"] for doc in results] + + # Vector search for "Apple Inc.": hs_doc_apple_tech + # FTS search for "fruit": hs_doc_apple_fruit, hs_doc_orange_fruit + # Combined: hs_doc_apple_fruit (strong FTS) and hs_doc_apple_tech (strong vector) are candidates. + # "hs_doc_apple_fruit" might get a boost if "Apple Inc." vector has some similarity to "apple fruit" doc. + assert len(result_ids) > 0 + assert ( + "hs_doc_apple_fruit" in result_ids + or "hs_doc_apple_tech" in result_ids + or "hs_doc_orange_fruit" in result_ids + ) + + async def test_hybrid_search_with_filter( + self, vs_hybrid_search_with_tsv_column: AsyncPGVectorStore + ) -> None: + """Test hybrid search with a metadata filter applied.""" + query = "apple" + # Filter to only include "tech" related apple docs using metadata + # Assuming metadata_columns=["doc_id_key"] was set up for vs_hybrid_search_with_tsv_column + doc_filter = {"doc_id_key": {"$eq": "hs_doc_apple_tech"}} + + config = HybridSearchConfig( + tsv_column="mycontent_tsv", + ) + results = await vs_hybrid_search_with_tsv_column.asimilarity_search( + query, k=2, filter=doc_filter, hybrid_search_config=config + ) + result_ids = [doc.metadata["doc_id_key"] for doc in results] + + assert len(results) == 1 + assert result_ids[0] == "hs_doc_apple_tech" + + async def test_hybrid_search_fts_empty_results( + self, vs_hybrid_search_with_tsv_column: AsyncPGVectorStore + ) -> None: + """Test when FTS query yields no results, should fall back to vector search.""" + vector_query = "apple" + no_match_fts_query = "zzyyxx_gibberish_term_for_fts_nomatch" + + config = HybridSearchConfig( + tsv_column="mycontent_tsv", + fts_query=no_match_fts_query, + fusion_function_parameters={ + "primary_results_weight": 0.6, + "secondary_results_weight": 0.4, + }, + ) + results = await vs_hybrid_search_with_tsv_column.asimilarity_search( + vector_query, k=2, hybrid_search_config=config + ) + result_ids = [doc.metadata["doc_id_key"] for doc in results] + + # Expect results based purely on vector search for "apple" + assert len(result_ids) > 0 + assert "hs_doc_apple_fruit" in result_ids or "hs_doc_apple_tech" in result_ids + # The top result should be one of the apple documents based on vector search + assert results[0].metadata["doc_id_key"].startswith("hs_doc_unrelated_cat") + + async def test_hybrid_search_vector_empty_results_effectively( + self, + vs_hybrid_search_with_tsv_column: AsyncPGVectorStore, + ) -> None: + """Test when vector query is very dissimilar to docs, should rely on FTS.""" + # This is hard to guarantee with fake embeddings, but we try. + # A better way might be to use a filter that excludes all docs for the vector part, + # but filters are applied to both. + vector_query_far_off = "supercalifragilisticexpialidocious_vector_nomatch" + fts_query_match = "orange fruit" # Should match hs_doc_orange_fruit + + config = HybridSearchConfig( + tsv_column="mycontent_tsv", + fts_query=fts_query_match, + fusion_function_parameters={ + "primary_results_weight": 0.4, + "secondary_results_weight": 0.6, + }, + ) + results = await vs_hybrid_search_with_tsv_column.asimilarity_search( + vector_query_far_off, k=1, hybrid_search_config=config + ) + result_ids = [doc.metadata["doc_id_key"] for doc in results] + + # Expect results based purely on FTS search for "orange fruit" + assert len(result_ids) == 1 + assert result_ids[0] == "hs_doc_generic_tech" + + async def test_hybrid_search_without_tsv_column( + self, + engine: PGEngine, + ) -> None: + """Test hybrid search without a TSV column.""" + # This is hard to guarantee with fake embeddings, but we try. + # A better way might be to use a filter that excludes all docs for the vector part, + # but filters are applied to both. + vector_query_far_off = "apple iphone tech is better designed than macs" + fts_query_match = "apple fruit" + + config = HybridSearchConfig( + tsv_column="mycontent_tsv", + fts_query=fts_query_match, + fusion_function_parameters={ + "primary_results_weight": 0.1, + "secondary_results_weight": 0.9, + }, + ) + await engine._ainit_vectorstore_table( + HYBRID_SEARCH_TABLE2, + VECTOR_SIZE, + id_column=Column("myid", "TEXT"), + content_column="mycontent", + embedding_column="myembedding", + metadata_columns=[ + Column("page", "TEXT"), + Column("source", "TEXT"), + Column("doc_id_key", "TEXT"), + ], + store_metadata=False, + hybrid_search_config=config, + ) + + vs_with_tsv_column = await AsyncPGVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=HYBRID_SEARCH_TABLE2, + id_column="myid", + content_column="mycontent", + embedding_column="myembedding", + metadata_columns=["doc_id_key"], + index_query_options=HNSWQueryOptions(ef_search=1), + hybrid_search_config=config, + ) + await vs_with_tsv_column.aadd_documents(hybrid_docs) + + config = HybridSearchConfig( + tsv_column="", # no TSV column + fts_query=fts_query_match, + fusion_function_parameters={ + "primary_results_weight": 0.9, + "secondary_results_weight": 0.1, + }, + ) + vs_without_tsv_column = await AsyncPGVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=HYBRID_SEARCH_TABLE2, + id_column="myid", + content_column="mycontent", + embedding_column="myembedding", + metadata_columns=["doc_id_key"], + index_query_options=HNSWQueryOptions(ef_search=1), + hybrid_search_config=config, + ) + + results_with_tsv_column = await vs_with_tsv_column.asimilarity_search( + vector_query_far_off, k=1, hybrid_search_config=config + ) + results_without_tsv_column = await vs_without_tsv_column.asimilarity_search( + vector_query_far_off, k=1, hybrid_search_config=config + ) + result_ids_with_tsv_column = [ + doc.metadata["doc_id_key"] for doc in results_with_tsv_column + ] + result_ids_without_tsv_column = [ + doc.metadata["doc_id_key"] for doc in results_without_tsv_column + ] + + # Expect results based purely on FTS search for "orange fruit" + assert len(result_ids_with_tsv_column) == 1 + assert len(result_ids_without_tsv_column) == 1 + assert result_ids_with_tsv_column[0] == "hs_doc_apple_tech" + assert result_ids_without_tsv_column[0] == "hs_doc_apple_tech" diff --git a/tests/unit_tests/v2/test_engine.py b/tests/unit_tests/v2/test_engine.py index cdd051a0..66f299aa 100644 --- a/tests/unit_tests/v2/test_engine.py +++ b/tests/unit_tests/v2/test_engine.py @@ -11,14 +11,17 @@ from sqlalchemy.pool import NullPool from langchain_postgres import Column, PGEngine +from langchain_postgres.v2.hybrid_search_config import HybridSearchConfig from tests.utils import VECTORSTORE_CONNECTION_STRING as CONNECTION_STRING DEFAULT_TABLE = "default" + str(uuid.uuid4()).replace("-", "_") CUSTOM_TABLE = "custom" + str(uuid.uuid4()).replace("-", "_") +HYBRID_SEARCH_TABLE = "hybrid" + str(uuid.uuid4()).replace("-", "_") CUSTOM_TYPEDDICT_TABLE = "custom_td" + str(uuid.uuid4()).replace("-", "_") INT_ID_CUSTOM_TABLE = "custom_int_id" + str(uuid.uuid4()).replace("-", "_") DEFAULT_TABLE_SYNC = "default_sync" + str(uuid.uuid4()).replace("-", "_") CUSTOM_TABLE_SYNC = "custom_sync" + str(uuid.uuid4()).replace("-", "_") +HYBRID_SEARCH_TABLE_SYNC = "hybrid_sync" + str(uuid.uuid4()).replace("-", "_") CUSTOM_TYPEDDICT_TABLE_SYNC = "custom_td_sync" + str(uuid.uuid4()).replace("-", "_") INT_ID_CUSTOM_TABLE_SYNC = "custom_int_id_sync" + str(uuid.uuid4()).replace("-", "_") VECTOR_SIZE = 768 @@ -68,10 +71,11 @@ async def engine(self) -> AsyncIterator[PGEngine]: engine = PGEngine.from_connection_string(url=CONNECTION_STRING, **kwargs) yield engine - await aexecute(engine, f'DROP TABLE "{CUSTOM_TABLE}"') - await aexecute(engine, f'DROP TABLE "{CUSTOM_TYPEDDICT_TABLE}"') - await aexecute(engine, f'DROP TABLE "{DEFAULT_TABLE}"') - await aexecute(engine, f'DROP TABLE "{INT_ID_CUSTOM_TABLE}"') + await aexecute(engine, f'DROP TABLE IF EXISTS "{CUSTOM_TABLE}"') + await aexecute(engine, f'DROP TABLE IF EXISTS "{HYBRID_SEARCH_TABLE}"') + await aexecute(engine, f'DROP TABLE IF EXISTS "{CUSTOM_TYPEDDICT_TABLE}"') + await aexecute(engine, f'DROP TABLE IF EXISTS "{DEFAULT_TABLE}"') + await aexecute(engine, f'DROP TABLE IF EXISTS "{INT_ID_CUSTOM_TABLE}"') await engine.close() async def test_init_table(self, engine: PGEngine) -> None: @@ -110,6 +114,31 @@ async def test_init_table_custom(self, engine: PGEngine) -> None: for row in results: assert row in expected + async def test_init_table_hybrid_search(self, engine: PGEngine) -> None: + await engine.ainit_vectorstore_table( + HYBRID_SEARCH_TABLE, + VECTOR_SIZE, + id_column="uuid", + content_column="my-content", + embedding_column="my_embedding", + metadata_columns=[Column("page", "TEXT"), Column("source", "TEXT")], + store_metadata=True, + hybrid_search_config=HybridSearchConfig(), + ) + stmt = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{HYBRID_SEARCH_TABLE}';" + results = await afetch(engine, stmt) + expected = [ + {"column_name": "uuid", "data_type": "uuid"}, + {"column_name": "my_embedding", "data_type": "USER-DEFINED"}, + {"column_name": "langchain_metadata", "data_type": "json"}, + {"column_name": "my-content", "data_type": "text"}, + {"column_name": "my-content_tsv", "data_type": "tsvector"}, + {"column_name": "page", "data_type": "text"}, + {"column_name": "source", "data_type": "text"}, + ] + for row in results: + assert row in expected + async def test_invalid_typed_dict(self, engine: PGEngine) -> None: with pytest.raises(TypeError): await engine.ainit_vectorstore_table( @@ -230,10 +259,11 @@ class TestEngineSync: async def engine(self) -> AsyncIterator[PGEngine]: engine = PGEngine.from_connection_string(url=CONNECTION_STRING) yield engine - await aexecute(engine, f'DROP TABLE "{CUSTOM_TABLE_SYNC}"') - await aexecute(engine, f'DROP TABLE "{DEFAULT_TABLE_SYNC}"') - await aexecute(engine, f'DROP TABLE "{INT_ID_CUSTOM_TABLE_SYNC}"') - await aexecute(engine, f'DROP TABLE "{CUSTOM_TYPEDDICT_TABLE_SYNC}"') + await aexecute(engine, f'DROP TABLE IF EXISTS "{CUSTOM_TABLE_SYNC}"') + await aexecute(engine, f'DROP TABLE IF EXISTS "{HYBRID_SEARCH_TABLE_SYNC}"') + await aexecute(engine, f'DROP TABLE IF EXISTS "{DEFAULT_TABLE_SYNC}"') + await aexecute(engine, f'DROP TABLE IF EXISTS "{INT_ID_CUSTOM_TABLE_SYNC}"') + await aexecute(engine, f'DROP TABLE IF EXISTS "{CUSTOM_TYPEDDICT_TABLE_SYNC}"') await engine.close() async def test_init_table(self, engine: PGEngine) -> None: @@ -269,6 +299,31 @@ async def test_init_table_custom(self, engine: PGEngine) -> None: for row in results: assert row in expected + async def test_init_table_hybrid_search(self, engine: PGEngine) -> None: + engine.init_vectorstore_table( + HYBRID_SEARCH_TABLE_SYNC, + VECTOR_SIZE, + id_column="uuid", + content_column="my-content", + embedding_column="my_embedding", + metadata_columns=[Column("page", "TEXT"), Column("source", "TEXT")], + store_metadata=True, + hybrid_search_config=HybridSearchConfig(), + ) + stmt = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{HYBRID_SEARCH_TABLE_SYNC}';" + results = await afetch(engine, stmt) + expected = [ + {"column_name": "uuid", "data_type": "uuid"}, + {"column_name": "my_embedding", "data_type": "USER-DEFINED"}, + {"column_name": "langchain_metadata", "data_type": "json"}, + {"column_name": "my-content", "data_type": "text"}, + {"column_name": "my-content_tsv", "data_type": "tsvector"}, + {"column_name": "page", "data_type": "text"}, + {"column_name": "source", "data_type": "text"}, + ] + for row in results: + assert row in expected + async def test_invalid_typed_dict(self, engine: PGEngine) -> None: with pytest.raises(TypeError): engine.init_vectorstore_table( diff --git a/tests/unit_tests/v2/test_hybrid_search_config.py b/tests/unit_tests/v2/test_hybrid_search_config.py new file mode 100644 index 00000000..7ea000ef --- /dev/null +++ b/tests/unit_tests/v2/test_hybrid_search_config.py @@ -0,0 +1,229 @@ +import pytest + +from langchain_postgres.v2.hybrid_search_config import ( + reciprocal_rank_fusion, + weighted_sum_ranking, +) + + +# Helper to create mock input items that mimic RowMapping for the fusion functions +def get_row(doc_id: str, score: float, content: str = "content") -> dict: + """ + Simulates a RowMapping-like dictionary. + The fusion functions expect to extract doc_id as the first value and + the initial score/distance as the last value when casting values from RowMapping. + They then operate on dictionaries, using the 'distance' key for the fused score. + """ + # Python dicts maintain insertion order (Python 3.7+). + # This structure ensures list(row.values())[0] is doc_id and + # list(row.values())[-1] is score. + return {"id_val": doc_id, "content_field": content, "distance": score} + + +class TestWeightedSumRanking: + def test_empty_inputs(self) -> None: + results = weighted_sum_ranking([], []) + assert results == [] + + def test_primary_only(self) -> None: + primary = [get_row("p1", 0.8), get_row("p2", 0.6)] + # Expected scores: p1 = 0.8 * 0.5 = 0.4, p2 = 0.6 * 0.5 = 0.3 + results = weighted_sum_ranking( # type: ignore + primary, # type: ignore + [], + primary_results_weight=0.5, + secondary_results_weight=0.5, + ) + assert len(results) == 2 + assert results[0]["id_val"] == "p1" + assert results[0]["distance"] == pytest.approx(0.4) + assert results[1]["id_val"] == "p2" + assert results[1]["distance"] == pytest.approx(0.3) + + def test_secondary_only(self) -> None: + secondary = [get_row("s1", 0.9), get_row("s2", 0.7)] + # Expected scores: s1 = 0.9 * 0.5 = 0.45, s2 = 0.7 * 0.5 = 0.35 + results = weighted_sum_ranking( + [], + secondary, # type: ignore + primary_results_weight=0.5, + secondary_results_weight=0.5, + ) + assert len(results) == 2 + assert results[0]["id_val"] == "s1" + assert results[0]["distance"] == pytest.approx(0.45) + assert results[1]["id_val"] == "s2" + assert results[1]["distance"] == pytest.approx(0.35) + + def test_mixed_results_default_weights(self) -> None: + primary = [get_row("common", 0.8), get_row("p_only", 0.7)] + secondary = [get_row("common", 0.9), get_row("s_only", 0.6)] + # Weights are 0.5, 0.5 + # common_score = (0.8 * 0.5) + (0.9 * 0.5) = 0.4 + 0.45 = 0.85 + # p_only_score = (0.7 * 0.5) = 0.35 + # s_only_score = (0.6 * 0.5) = 0.30 + # Order: common (0.85), p_only (0.35), s_only (0.30) + + results = weighted_sum_ranking(primary, secondary) # type: ignore + assert len(results) == 3 + assert results[0]["id_val"] == "common" + assert results[0]["distance"] == pytest.approx(0.85) + assert results[1]["id_val"] == "p_only" + assert results[1]["distance"] == pytest.approx(0.35) + assert results[2]["id_val"] == "s_only" + assert results[2]["distance"] == pytest.approx(0.30) + + def test_mixed_results_custom_weights(self) -> None: + primary = [get_row("d1", 1.0)] # p_w=0.2 -> 0.2 + secondary = [get_row("d1", 0.5)] # s_w=0.8 -> 0.4 + # Expected: d1_score = (1.0 * 0.2) + (0.5 * 0.8) = 0.2 + 0.4 = 0.6 + + results = weighted_sum_ranking( + primary, # type: ignore + secondary, # type: ignore + primary_results_weight=0.2, + secondary_results_weight=0.8, + ) + assert len(results) == 1 + assert results[0]["id_val"] == "d1" + assert results[0]["distance"] == pytest.approx(0.6) + + def test_fetch_top_k(self) -> None: + primary = [get_row(f"p{i}", (10 - i) / 10.0) for i in range(5)] + # Scores: 1.0, 0.9, 0.8, 0.7, 0.6 + # Weighted (0.5): 0.5, 0.45, 0.4, 0.35, 0.3 + results = weighted_sum_ranking(primary, [], fetch_top_k=2) # type: ignore + assert len(results) == 2 + assert results[0]["id_val"] == "p0" + assert results[0]["distance"] == pytest.approx(0.5) + assert results[1]["id_val"] == "p1" + assert results[1]["distance"] == pytest.approx(0.45) + + +class TestReciprocalRankFusion: + def test_empty_inputs(self) -> None: + results = reciprocal_rank_fusion([], []) + assert results == [] + + def test_primary_only(self) -> None: + primary = [ + get_row("p1", 0.8), + get_row("p2", 0.6), + ] # p1 rank 0, p2 rank 1 + rrf_k = 60 + # p1_score = 1 / (0 + 60) + # p2_score = 1 / (1 + 60) + results = reciprocal_rank_fusion(primary, [], rrf_k=rrf_k) # type: ignore + assert len(results) == 2 + assert results[0]["id_val"] == "p1" + assert results[0]["distance"] == pytest.approx(1.0 / (0 + rrf_k)) + assert results[1]["id_val"] == "p2" + assert results[1]["distance"] == pytest.approx(1.0 / (1 + rrf_k)) + + def test_secondary_only(self) -> None: + secondary = [ + get_row("s1", 0.9), + get_row("s2", 0.7), + ] # s1 rank 0, s2 rank 1 + rrf_k = 60 + results = reciprocal_rank_fusion([], secondary, rrf_k=rrf_k) # type: ignore + assert len(results) == 2 + assert results[0]["id_val"] == "s1" + assert results[0]["distance"] == pytest.approx(1.0 / (0 + rrf_k)) + assert results[1]["id_val"] == "s2" + assert results[1]["distance"] == pytest.approx(1.0 / (1 + rrf_k)) + + def test_mixed_results_default_k(self) -> None: + primary = [get_row("common", 0.8), get_row("p_only", 0.7)] + secondary = [get_row("common", 0.9), get_row("s_only", 0.6)] + rrf_k = 60 + # common_score = (1/(0+k))_prim + (1/(0+k))_sec = 2/k + # p_only_score = (1/(1+k))_prim = 1/(k+1) + # s_only_score = (1/(1+k))_sec = 1/(k+1) + results = reciprocal_rank_fusion(primary, secondary, rrf_k=rrf_k) # type: ignore + assert len(results) == 3 + assert results[0]["id_val"] == "common" + assert results[0]["distance"] == pytest.approx(2.0 / rrf_k) + # Check the next two elements, their order might vary due to tie in score + next_ids = {results[1]["id_val"], results[2]["id_val"]} + next_scores = {results[1]["distance"], results[2]["distance"]} + assert next_ids == {"p_only", "s_only"} + for score in next_scores: + assert score == pytest.approx(1.0 / (1 + rrf_k)) + + def test_fetch_top_k_rrf(self) -> None: + primary = [get_row(f"p{i}", (10 - i) / 10.0) for i in range(5)] + rrf_k = 1 + results = reciprocal_rank_fusion(primary, [], rrf_k=rrf_k, fetch_top_k=2) # type: ignore + assert len(results) == 2 + assert results[0]["id_val"] == "p0" + assert results[0]["distance"] == pytest.approx(1.0 / (0 + rrf_k)) + assert results[1]["id_val"] == "p1" + assert results[1]["distance"] == pytest.approx(1.0 / (1 + rrf_k)) + + def test_rrf_content_preservation(self) -> None: + primary = [get_row("doc1", 0.9, content="Primary Content")] + secondary = [get_row("doc1", 0.8, content="Secondary Content")] + # RRF processes primary then secondary. If a doc is in both, + # the content from the secondary list will overwrite primary's. + results = reciprocal_rank_fusion(primary, secondary, rrf_k=60) # type: ignore + assert len(results) == 1 + assert results[0]["id_val"] == "doc1" + assert results[0]["content_field"] == "Secondary Content" + + # If only in primary + results_prim_only = reciprocal_rank_fusion(primary, [], rrf_k=60) # type: ignore + assert results_prim_only[0]["content_field"] == "Primary Content" + + def test_reordering_from_inputs_rrf(self) -> None: + """ + Tests that RRF fused ranking can be different from both primary and secondary + input rankings. + Primary Order: A, B, C + Secondary Order: C, B, A + Fused Order: (A, C) tied, then B + """ + primary = [ + get_row("docA", 0.9), + get_row("docB", 0.8), + get_row("docC", 0.1), + ] + secondary = [ + get_row("docC", 0.9), + get_row("docB", 0.5), + get_row("docA", 0.2), + ] + rrf_k = 1.0 # Using 1.0 for k to simplify rank score calculation + # docA_score = 1/(0+1) [P] + 1/(2+1) [S] = 1 + 1/3 = 4/3 + # docB_score = 1/(1+1) [P] + 1/(1+1) [S] = 1/2 + 1/2 = 1 + # docC_score = 1/(2+1) [P] + 1/(0+1) [S] = 1/3 + 1 = 4/3 + results = reciprocal_rank_fusion(primary, secondary, rrf_k=rrf_k) # type: ignore + assert len(results) == 3 + assert {results[0]["id_val"], results[1]["id_val"]} == {"docA", "docC"} + assert results[0]["distance"] == pytest.approx(4.0 / 3.0) + assert results[1]["distance"] == pytest.approx(4.0 / 3.0) + assert results[2]["id_val"] == "docB" + assert results[2]["distance"] == pytest.approx(1.0) + + def test_reordering_from_inputs_weighted_sum(self) -> None: + """ + Tests that the fused ranking can be different from both primary and secondary + input rankings. + Primary Order: A (0.9), B (0.7) + Secondary Order: B (0.8), A (0.2) + Fusion (0.5/0.5 weights): + docA_score = (0.9 * 0.5) + (0.2 * 0.5) = 0.45 + 0.10 = 0.55 + docB_score = (0.7 * 0.5) + (0.8 * 0.5) = 0.35 + 0.40 = 0.75 + Expected Fused Order: docB (0.75), docA (0.55) + This is different from Primary (A,B) and Secondary (B,A) in terms of + original score, but the fusion logic changes the effective contribution). + """ + primary = [get_row("docA", 0.9), get_row("docB", 0.7)] + secondary = [get_row("docB", 0.8), get_row("docA", 0.2)] + + results = weighted_sum_ranking(primary, secondary) # type: ignore + assert len(results) == 2 + assert results[0]["id_val"] == "docB" + assert results[0]["distance"] == pytest.approx(0.75) + assert results[1]["id_val"] == "docA" + assert results[1]["distance"] == pytest.approx(0.55) diff --git a/tests/unit_tests/v2/test_pg_vectorstore_search.py b/tests/unit_tests/v2/test_pg_vectorstore_search.py index 39276edd..7815a25a 100644 --- a/tests/unit_tests/v2/test_pg_vectorstore_search.py +++ b/tests/unit_tests/v2/test_pg_vectorstore_search.py @@ -9,6 +9,11 @@ from sqlalchemy import text from langchain_postgres import Column, PGEngine, PGVectorStore +from langchain_postgres.v2.hybrid_search_config import ( + HybridSearchConfig, + reciprocal_rank_fusion, + weighted_sum_ranking, +) from langchain_postgres.v2.indexes import DistanceStrategy, HNSWQueryOptions from tests.unit_tests.fixtures.metadata_filtering_data import ( FILTERING_TEST_CASES, @@ -261,6 +266,37 @@ async def test_vectorstore_with_metadata_filters( ) assert [doc.metadata["code"] for doc in docs] == expected_ids, test_filter + async def test_asimilarity_hybrid_search(self, vs: PGVectorStore) -> None: + results = await vs.asimilarity_search( + "foo", k=1, hybrid_search_config=HybridSearchConfig() + ) + assert len(results) == 1 + assert results == [Document(page_content="foo", id=ids[0])] + + results = await vs.asimilarity_search( + "bar", + k=1, + hybrid_search_config=HybridSearchConfig(), + ) + assert results[0] == Document(page_content="bar", id=ids[1]) + + results = await vs.asimilarity_search( + "foo", + k=1, + filter={"content": {"$ne": "baz"}}, + hybrid_search_config=HybridSearchConfig( + fusion_function=weighted_sum_ranking, + fusion_function_parameters={ + "primary_results_weight": 0.1, + "secondary_results_weight": 0.9, + "fetch_top_k": 10, + }, + primary_top_k=1, + secondary_top_k=1, + ), + ) + assert results == [Document(page_content="foo", id=ids[0])] + @pytest.mark.enable_socket class TestVectorStoreSearchSync: @@ -398,4 +434,30 @@ def test_metadata_filter_negative_tests( self, vs_custom_filter_sync: PGVectorStore, test_filter: dict ) -> None: with pytest.raises((ValueError, NotImplementedError)): - vs_custom_filter_sync.similarity_search("meow", k=5, filter=test_filter) + docs = vs_custom_filter_sync.similarity_search( + "meow", k=5, filter=test_filter + ) + + def test_similarity_hybrid_search(self, vs_custom: PGVectorStore) -> None: + results = vs_custom.similarity_search( + "foo", k=1, hybrid_search_config=HybridSearchConfig() + ) + assert len(results) == 1 + assert results == [Document(page_content="foo", id=ids[0])] + + results = vs_custom.similarity_search( + "bar", + k=1, + hybrid_search_config=HybridSearchConfig(), + ) + assert results == [Document(page_content="bar", id=ids[1])] + + results = vs_custom.similarity_search( + "foo", + k=1, + filter={"mycontent": {"$ne": "baz"}}, + hybrid_search_config=HybridSearchConfig( + fusion_function=reciprocal_rank_fusion + ), + ) + assert results == [Document(page_content="foo", id=ids[0])]