From 22088a1e090226f82abec461d0e2372ed74ce45c Mon Sep 17 00:00:00 2001 From: Vishwaraj Anand Date: Mon, 19 May 2025 08:10:41 +0000 Subject: [PATCH 01/15] feat: Added Hybrid Search Config and Tests [1/N] --- langchain_postgres/v2/hybrid_search_config.py | 143 ++++++++++++ .../v2/test_hybrid_search_config.py | 220 ++++++++++++++++++ 2 files changed, 363 insertions(+) create mode 100644 langchain_postgres/v2/hybrid_search_config.py create mode 100644 tests/unit_tests/v2/test_hybrid_search_config.py diff --git a/langchain_postgres/v2/hybrid_search_config.py b/langchain_postgres/v2/hybrid_search_config.py new file mode 100644 index 00000000..9e7286e0 --- /dev/null +++ b/langchain_postgres/v2/hybrid_search_config.py @@ -0,0 +1,143 @@ +from abc import ABC +from dataclasses import dataclass, field +from typing import Any, Callable, Optional, Sequence + +from sqlalchemy import RowMapping + + +def weighted_sum_ranking( + primary_search_results: Sequence[RowMapping], + secondary_search_results: Sequence[RowMapping], + primary_results_weight: float = 0.5, + secondary_results_weight: float = 0.5, + fetch_top_k: int = 4, +) -> Sequence[dict[str, Any]]: + """ + Ranks documents using a weighted sum of scores from two sources. + + Args: + primary_search_results: A list of (document, distance) tuples from + the primary search. + secondary_search_results: A list of (document, distance) tuples from + the secondary search. + primary_results_weight: The weight for the primary source's scores. + Defaults to 0.5. + secondary_results_weight: The weight for the secondary source's scores. + Defaults to 0.5. + fetch_top_k: The number of documents to fetch after merging the results. + Defaults to 4. + + Returns: + A list of (document, distance) tuples, sorted by weighted_score in + descending order. + """ + + # stores computed metric with provided distance metric and weights + weighted_scores: dict[str, dict[str, Any]] = {} + + # Process results from primary source + for row in primary_search_results: + values = list(row.values()) + doc_id = str(values[0]) # first value is doc_id + distance = float(values[-1]) # type: ignore # last value is distance + row_values = dict(row) + row_values["distance"] = primary_results_weight * distance + weighted_scores[doc_id] = row_values + + # Process results from secondary source, + # adding to existing scores or creating new ones + for row in secondary_search_results: + values = list(row.values()) + doc_id = str(values[0]) # first value is doc_id + distance = float(values[-1]) # type: ignore # last value is distance + primary_score = ( + weighted_scores[doc_id]["distance"] if doc_id in weighted_scores else 0.0 + ) + row_values = dict(row) + row_values["distance"] = distance * secondary_results_weight + primary_score + weighted_scores[doc_id] = row_values + + # Sort the results by weighted score in descending order + ranked_results = sorted( + weighted_scores.values(), key=lambda item: item["distance"], reverse=True + ) + return ranked_results[:fetch_top_k] + + +def reciprocal_rank_fusion( + primary_search_results: Sequence[RowMapping], + secondary_search_results: Sequence[RowMapping], + rrf_k: float = 60, + fetch_top_k: int = 4, +) -> Sequence[dict[str, Any]]: + """ + Ranks documents using Reciprocal Rank Fusion (RRF) of scores from two sources. + + Args: + primary_search_results: A list of (document, distance) tuples from + the primary search. + secondary_search_results: A list of (document, distance) tuples from + the secondary search. + rrf_k: The RRF parameter k. + Defaults to 60. + fetch_top_k: The number of documents to fetch after merging the results. + Defaults to 4. + + Returns: + A list of (document_id, rrf_score) tuples, sorted by rrf_score + in descending order. + """ + rrf_scores: dict[str, dict[str, Any]] = {} + + # Process results from primary source + for rank, row in enumerate( + sorted(primary_search_results, key=lambda item: item["distance"], reverse=True) + ): + values = list(row.values()) + doc_id = str(values[0]) + row_values = dict(row) + primary_score = rrf_scores[doc_id]["distance"] if doc_id in rrf_scores else 0.0 + primary_score += 1.0 / (rank + rrf_k) + row_values["distance"] = primary_score + rrf_scores[doc_id] = row_values + + # Process results from secondary source + for rank, row in enumerate( + sorted( + secondary_search_results, key=lambda item: item["distance"], reverse=True + ) + ): + values = list(row.values()) + doc_id = str(values[0]) + row_values = dict(row) + secondary_score = ( + rrf_scores[doc_id]["distance"] if doc_id in rrf_scores else 0.0 + ) + secondary_score += 1.0 / (rank + rrf_k) + row_values["distance"] = secondary_score + rrf_scores[doc_id] = row_values + + # Sort the results by rrf score in descending order + # Sort the results by weighted score in descending order + ranked_results = sorted( + rrf_scores.values(), key=lambda item: item["distance"], reverse=True + ) + # Extract only the RowMapping for the top results + return ranked_results[:fetch_top_k] + + +@dataclass +class HybridSearchConfig(ABC): + """Google AlloyDB Vector Store Hybrid Search Config.""" + + tsv_column: Optional[str] = "" + tsv_lang: Optional[str] = "pg_catalog.english" + fts_query: Optional[str] = "" + fusion_function: Callable[ + [Sequence[RowMapping], Sequence[RowMapping], Any], Sequence[Any] + ] = weighted_sum_ranking # Updated default + fusion_function_parameters: dict[str, Any] = field(default_factory=dict) + primary_top_k: int = 4 + secondary_top_k: int = 4 + index_name: str = "langchain_tsv_index" + index_type: str = "GIN" diff --git a/tests/unit_tests/v2/test_hybrid_search_config.py b/tests/unit_tests/v2/test_hybrid_search_config.py new file mode 100644 index 00000000..366f81d4 --- /dev/null +++ b/tests/unit_tests/v2/test_hybrid_search_config.py @@ -0,0 +1,220 @@ +import pytest + +from langchain_postgres.v2.hybrid_search_config import (reciprocal_rank_fusion, + weighted_sum_ranking) + + +# Helper to create mock input items that mimic RowMapping for the fusion functions +def get_row(doc_id: str, score: float, content: str = "content") -> dict: + """ + Simulates a RowMapping-like dictionary. + The fusion functions expect to extract doc_id as the first value and + the initial score/distance as the last value when casting values from RowMapping. + They then operate on dictionaries, using the 'distance' key for the fused score. + """ + # Python dicts maintain insertion order (Python 3.7+). + # This structure ensures list(row.values())[0] is doc_id and + # list(row.values())[-1] is score. + return {"id_val": doc_id, "content_field": content, "distance": score} + + +class TestWeightedSumRanking: + def test_empty_inputs(self): + results = weighted_sum_ranking([], []) + assert results == [] + + def test_primary_only(self): + primary = [get_row("p1", 0.8), get_row("p2", 0.6)] + # Expected scores: p1 = 0.8 * 0.5 = 0.4, p2 = 0.6 * 0.5 = 0.3 + results = weighted_sum_ranking( + primary, [], primary_results_weight=0.5, secondary_results_weight=0.5 + ) + assert len(results) == 2 + assert results[0]["id_val"] == "p1" + assert results[0]["distance"] == pytest.approx(0.4) + assert results[1]["id_val"] == "p2" + assert results[1]["distance"] == pytest.approx(0.3) + + def test_secondary_only(self): + secondary = [get_row("s1", 0.9), get_row("s2", 0.7)] + # Expected scores: s1 = 0.9 * 0.5 = 0.45, s2 = 0.7 * 0.5 = 0.35 + results = weighted_sum_ranking( + [], secondary, primary_results_weight=0.5, secondary_results_weight=0.5 + ) + assert len(results) == 2 + assert results[0]["id_val"] == "s1" + assert results[0]["distance"] == pytest.approx(0.45) + assert results[1]["id_val"] == "s2" + assert results[1]["distance"] == pytest.approx(0.35) + + def test_mixed_results_default_weights(self): + primary = [get_row("common", 0.8), get_row("p_only", 0.7)] + secondary = [get_row("common", 0.9), get_row("s_only", 0.6)] + # Weights are 0.5, 0.5 + # common_score = (0.8 * 0.5) + (0.9 * 0.5) = 0.4 + 0.45 = 0.85 + # p_only_score = (0.7 * 0.5) = 0.35 + # s_only_score = (0.6 * 0.5) = 0.30 + # Order: common (0.85), p_only (0.35), s_only (0.30) + + results = weighted_sum_ranking(primary, secondary) + assert len(results) == 3 + assert results[0]["id_val"] == "common" + assert results[0]["distance"] == pytest.approx(0.85) + assert results[1]["id_val"] == "p_only" + assert results[1]["distance"] == pytest.approx(0.35) + assert results[2]["id_val"] == "s_only" + assert results[2]["distance"] == pytest.approx(0.30) + + def test_mixed_results_custom_weights(self): + primary = [get_row("d1", 1.0)] # p_w=0.2 -> 0.2 + secondary = [get_row("d1", 0.5)] # s_w=0.8 -> 0.4 + # Expected: d1_score = (1.0 * 0.2) + (0.5 * 0.8) = 0.2 + 0.4 = 0.6 + + results = weighted_sum_ranking( + primary, secondary, primary_results_weight=0.2, secondary_results_weight=0.8 + ) + assert len(results) == 1 + assert results[0]["id_val"] == "d1" + assert results[0]["distance"] == pytest.approx(0.6) + + def test_fetch_top_k(self): + primary = [get_row(f"p{i}", (10 - i) / 10.0) for i in range(5)] + # Scores: 1.0, 0.9, 0.8, 0.7, 0.6 + # Weighted (0.5): 0.5, 0.45, 0.4, 0.35, 0.3 + secondary = [] + results = weighted_sum_ranking(primary, secondary, fetch_top_k=2) + assert len(results) == 2 + assert results[0]["id_val"] == "p0" + assert results[0]["distance"] == pytest.approx(0.5) + assert results[1]["id_val"] == "p1" + assert results[1]["distance"] == pytest.approx(0.45) + + +class TestReciprocalRankFusion: + def test_empty_inputs(self): + results = reciprocal_rank_fusion([], []) + assert results == [] + + def test_primary_only(self): + primary = [ + get_row("p1", 0.8), + get_row("p2", 0.6), + ] # p1 rank 0, p2 rank 1 + rrf_k = 60 + # p1_score = 1 / (0 + 60) + # p2_score = 1 / (1 + 60) + results = reciprocal_rank_fusion(primary, [], rrf_k=rrf_k) + assert len(results) == 2 + assert results[0]["id_val"] == "p1" + assert results[0]["distance"] == pytest.approx(1.0 / (0 + rrf_k)) + assert results[1]["id_val"] == "p2" + assert results[1]["distance"] == pytest.approx(1.0 / (1 + rrf_k)) + + def test_secondary_only(self): + secondary = [ + get_row("s1", 0.9), + get_row("s2", 0.7), + ] # s1 rank 0, s2 rank 1 + rrf_k = 60 + results = reciprocal_rank_fusion([], secondary, rrf_k=rrf_k) + assert len(results) == 2 + assert results[0]["id_val"] == "s1" + assert results[0]["distance"] == pytest.approx(1.0 / (0 + rrf_k)) + assert results[1]["id_val"] == "s2" + assert results[1]["distance"] == pytest.approx(1.0 / (1 + rrf_k)) + + def test_mixed_results_default_k(self): + primary = [get_row("common", 0.8), get_row("p_only", 0.7)] + secondary = [get_row("common", 0.9), get_row("s_only", 0.6)] + rrf_k = 60 + # common_score = (1/(0+k))_prim + (1/(0+k))_sec = 2/k + # p_only_score = (1/(1+k))_prim = 1/(k+1) + # s_only_score = (1/(1+k))_sec = 1/(k+1) + results = reciprocal_rank_fusion(primary, secondary, rrf_k=rrf_k) + assert len(results) == 3 + assert results[0]["id_val"] == "common" + assert results[0]["distance"] == pytest.approx(2.0 / rrf_k) + # Check the next two elements, their order might vary due to tie in score + next_ids = {results[1]["id_val"], results[2]["id_val"]} + next_scores = {results[1]["distance"], results[2]["distance"]} + assert next_ids == {"p_only", "s_only"} + for score in next_scores: + assert score == pytest.approx(1.0 / (1 + rrf_k)) + + def test_fetch_top_k_rrf(self): + primary = [get_row(f"p{i}", (10 - i) / 10.0) for i in range(5)] + secondary = [] + rrf_k = 1 + results = reciprocal_rank_fusion(primary, secondary, rrf_k=rrf_k, fetch_top_k=2) + assert len(results) == 2 + assert results[0]["id_val"] == "p0" + assert results[0]["distance"] == pytest.approx(1.0 / (0 + rrf_k)) + assert results[1]["id_val"] == "p1" + assert results[1]["distance"] == pytest.approx(1.0 / (1 + rrf_k)) + + def test_rrf_content_preservation(self): + primary = [get_row("doc1", 0.9, content="Primary Content")] + secondary = [get_row("doc1", 0.8, content="Secondary Content")] + # RRF processes primary then secondary. If a doc is in both, + # the content from the secondary list will overwrite primary's. + results = reciprocal_rank_fusion(primary, secondary, rrf_k=60) + assert len(results) == 1 + assert results[0]["id_val"] == "doc1" + assert results[0]["content_field"] == "Secondary Content" + + # If only in primary + results_prim_only = reciprocal_rank_fusion(primary, [], rrf_k=60) + assert results_prim_only[0]["content_field"] == "Primary Content" + + def test_reordering_from_inputs_rrf(self): + """ + Tests that RRF fused ranking can be different from both primary and secondary + input rankings. + Primary Order: A, B, C + Secondary Order: C, B, A + Fused Order: (A, C) tied, then B + """ + primary = [ + get_row("docA", 0.9), + get_row("docB", 0.8), + get_row("docC", 0.1), + ] + secondary = [ + get_row("docC", 0.9), + get_row("docB", 0.5), + get_row("docA", 0.2), + ] + rrf_k = 1.0 # Using 1.0 for k to simplify rank score calculation + # docA_score = 1/(0+1) [P] + 1/(2+1) [S] = 1 + 1/3 = 4/3 + # docB_score = 1/(1+1) [P] + 1/(1+1) [S] = 1/2 + 1/2 = 1 + # docC_score = 1/(2+1) [P] + 1/(0+1) [S] = 1/3 + 1 = 4/3 + results = reciprocal_rank_fusion(primary, secondary, rrf_k=rrf_k) + assert len(results) == 3 + assert {results[0]["id_val"], results[1]["id_val"]} == {"docA", "docC"} + assert results[0]["distance"] == pytest.approx(4.0 / 3.0) + assert results[1]["distance"] == pytest.approx(4.0 / 3.0) + assert results[2]["id_val"] == "docB" + assert results[2]["distance"] == pytest.approx(1.0) + + def test_reordering_from_inputs_weighted_sum(self): + """ + Tests that the fused ranking can be different from both primary and secondary + input rankings. + Primary Order: A (0.9), B (0.7) + Secondary Order: B (0.8), A (0.2) + Fusion (0.5/0.5 weights): + docA_score = (0.9 * 0.5) + (0.2 * 0.5) = 0.45 + 0.10 = 0.55 + docB_score = (0.7 * 0.5) + (0.8 * 0.5) = 0.35 + 0.40 = 0.75 + Expected Fused Order: docB (0.75), docA (0.55) + This is different from Primary (A,B) and Secondary (B,A) in terms of + original score, but the fusion logic changes the effective contribution). + """ + primary = [get_row("docA", 0.9), get_row("docB", 0.7)] + secondary = [get_row("docB", 0.8), get_row("docA", 0.2)] + + results = weighted_sum_ranking(primary, secondary) + assert len(results) == 2 + assert results[0]["id_val"] == "docB" + assert results[0]["distance"] == pytest.approx(0.75) + assert results[1]["id_val"] == "docA" + assert results[1]["distance"] == pytest.approx(0.55) From 30942ff95476c5e2d0c6a93a3ea6c6f2ab0899f2 Mon Sep 17 00:00:00 2001 From: Vishwaraj Anand Date: Mon, 19 May 2025 21:52:43 +0000 Subject: [PATCH 02/15] feat: create hybrid search capable vector store table [2/N] --- langchain_postgres/v2/engine.py | 28 +++++++++++- tests/unit_tests/v2/test_engine.py | 71 ++++++++++++++++++++++++++---- 2 files changed, 90 insertions(+), 9 deletions(-) diff --git a/langchain_postgres/v2/engine.py b/langchain_postgres/v2/engine.py index c2a0d931..755b5ff0 100644 --- a/langchain_postgres/v2/engine.py +++ b/langchain_postgres/v2/engine.py @@ -9,6 +9,8 @@ from sqlalchemy.engine import URL from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine +from .hybrid_search_config import HybridSearchConfig + T = TypeVar("T") @@ -156,6 +158,7 @@ async def _ainit_vectorstore_table( id_column: Union[str, Column, ColumnDict] = "langchain_id", overwrite_existing: bool = False, store_metadata: bool = True, + hybrid_search_config: Optional[HybridSearchConfig] = None, ) -> None: """ Create a table for saving of vectors to be used with PGVectorStore. @@ -178,6 +181,8 @@ async def _ainit_vectorstore_table( overwrite_existing (bool): Whether to drop existing table. Default: False. store_metadata (bool): Whether to store metadata in the table. Default: True. + hybrid_search_config (HybridSearchConfig): Hybrid search configuration. + Default: None. Raises: :class:`DuplicateTableError `: if table already exists. @@ -186,6 +191,7 @@ async def _ainit_vectorstore_table( schema_name = self._escape_postgres_identifier(schema_name) table_name = self._escape_postgres_identifier(table_name) + hybrid_search_default_column_name = content_column + "_tsv" content_column = self._escape_postgres_identifier(content_column) embedding_column = self._escape_postgres_identifier(embedding_column) if metadata_columns is None: @@ -226,10 +232,22 @@ async def _ainit_vectorstore_table( id_data_type = id_column["data_type"] id_column_name = id_column["name"] + hybrid_search_column = "" # Default is no TSV column for hybrid search + if hybrid_search_config: + hybrid_search_column_name = ( + hybrid_search_config.tsv_column or hybrid_search_default_column_name + ) + hybrid_search_column_name = self._escape_postgres_identifier( + hybrid_search_column_name + ) + hybrid_search_config.tsv_column = hybrid_search_column_name + hybrid_search_column = f',"{self._escape_postgres_identifier(hybrid_search_column_name)}" TSVECTOR NOT NULL' + query = f"""CREATE TABLE "{schema_name}"."{table_name}"( "{id_column_name}" {id_data_type} PRIMARY KEY, "{content_column}" TEXT NOT NULL, - "{embedding_column}" vector({vector_size}) NOT NULL""" + "{embedding_column}" vector({vector_size}) NOT NULL + {hybrid_search_column}""" for column in metadata_columns: if isinstance(column, Column): nullable = "NOT NULL" if not column.nullable else "" @@ -258,6 +276,7 @@ async def ainit_vectorstore_table( id_column: Union[str, Column, ColumnDict] = "langchain_id", overwrite_existing: bool = False, store_metadata: bool = True, + hybrid_search_config: Optional[HybridSearchConfig] = None, ) -> None: """ Create a table for saving of vectors to be used with PGVectorStore. @@ -280,6 +299,8 @@ async def ainit_vectorstore_table( overwrite_existing (bool): Whether to drop existing table. Default: False. store_metadata (bool): Whether to store metadata in the table. Default: True. + hybrid_search_config (HybridSearchConfig): Hybrid search configuration. + Default: None. """ await self._run_as_async( self._ainit_vectorstore_table( @@ -293,6 +314,7 @@ async def ainit_vectorstore_table( id_column=id_column, overwrite_existing=overwrite_existing, store_metadata=store_metadata, + hybrid_search_config=hybrid_search_config, ) ) @@ -309,6 +331,7 @@ def init_vectorstore_table( id_column: Union[str, Column, ColumnDict] = "langchain_id", overwrite_existing: bool = False, store_metadata: bool = True, + hybrid_search_config: Optional[HybridSearchConfig] = None, ) -> None: """ Create a table for saving of vectors to be used with PGVectorStore. @@ -331,6 +354,8 @@ def init_vectorstore_table( overwrite_existing (bool): Whether to drop existing table. Default: False. store_metadata (bool): Whether to store metadata in the table. Default: True. + hybrid_search_config (HybridSearchConfig): Hybrid search configuration. + Default: None. """ self._run_as_sync( self._ainit_vectorstore_table( @@ -344,6 +369,7 @@ def init_vectorstore_table( id_column=id_column, overwrite_existing=overwrite_existing, store_metadata=store_metadata, + hybrid_search_config=hybrid_search_config, ) ) diff --git a/tests/unit_tests/v2/test_engine.py b/tests/unit_tests/v2/test_engine.py index cdd051a0..66f299aa 100644 --- a/tests/unit_tests/v2/test_engine.py +++ b/tests/unit_tests/v2/test_engine.py @@ -11,14 +11,17 @@ from sqlalchemy.pool import NullPool from langchain_postgres import Column, PGEngine +from langchain_postgres.v2.hybrid_search_config import HybridSearchConfig from tests.utils import VECTORSTORE_CONNECTION_STRING as CONNECTION_STRING DEFAULT_TABLE = "default" + str(uuid.uuid4()).replace("-", "_") CUSTOM_TABLE = "custom" + str(uuid.uuid4()).replace("-", "_") +HYBRID_SEARCH_TABLE = "hybrid" + str(uuid.uuid4()).replace("-", "_") CUSTOM_TYPEDDICT_TABLE = "custom_td" + str(uuid.uuid4()).replace("-", "_") INT_ID_CUSTOM_TABLE = "custom_int_id" + str(uuid.uuid4()).replace("-", "_") DEFAULT_TABLE_SYNC = "default_sync" + str(uuid.uuid4()).replace("-", "_") CUSTOM_TABLE_SYNC = "custom_sync" + str(uuid.uuid4()).replace("-", "_") +HYBRID_SEARCH_TABLE_SYNC = "hybrid_sync" + str(uuid.uuid4()).replace("-", "_") CUSTOM_TYPEDDICT_TABLE_SYNC = "custom_td_sync" + str(uuid.uuid4()).replace("-", "_") INT_ID_CUSTOM_TABLE_SYNC = "custom_int_id_sync" + str(uuid.uuid4()).replace("-", "_") VECTOR_SIZE = 768 @@ -68,10 +71,11 @@ async def engine(self) -> AsyncIterator[PGEngine]: engine = PGEngine.from_connection_string(url=CONNECTION_STRING, **kwargs) yield engine - await aexecute(engine, f'DROP TABLE "{CUSTOM_TABLE}"') - await aexecute(engine, f'DROP TABLE "{CUSTOM_TYPEDDICT_TABLE}"') - await aexecute(engine, f'DROP TABLE "{DEFAULT_TABLE}"') - await aexecute(engine, f'DROP TABLE "{INT_ID_CUSTOM_TABLE}"') + await aexecute(engine, f'DROP TABLE IF EXISTS "{CUSTOM_TABLE}"') + await aexecute(engine, f'DROP TABLE IF EXISTS "{HYBRID_SEARCH_TABLE}"') + await aexecute(engine, f'DROP TABLE IF EXISTS "{CUSTOM_TYPEDDICT_TABLE}"') + await aexecute(engine, f'DROP TABLE IF EXISTS "{DEFAULT_TABLE}"') + await aexecute(engine, f'DROP TABLE IF EXISTS "{INT_ID_CUSTOM_TABLE}"') await engine.close() async def test_init_table(self, engine: PGEngine) -> None: @@ -110,6 +114,31 @@ async def test_init_table_custom(self, engine: PGEngine) -> None: for row in results: assert row in expected + async def test_init_table_hybrid_search(self, engine: PGEngine) -> None: + await engine.ainit_vectorstore_table( + HYBRID_SEARCH_TABLE, + VECTOR_SIZE, + id_column="uuid", + content_column="my-content", + embedding_column="my_embedding", + metadata_columns=[Column("page", "TEXT"), Column("source", "TEXT")], + store_metadata=True, + hybrid_search_config=HybridSearchConfig(), + ) + stmt = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{HYBRID_SEARCH_TABLE}';" + results = await afetch(engine, stmt) + expected = [ + {"column_name": "uuid", "data_type": "uuid"}, + {"column_name": "my_embedding", "data_type": "USER-DEFINED"}, + {"column_name": "langchain_metadata", "data_type": "json"}, + {"column_name": "my-content", "data_type": "text"}, + {"column_name": "my-content_tsv", "data_type": "tsvector"}, + {"column_name": "page", "data_type": "text"}, + {"column_name": "source", "data_type": "text"}, + ] + for row in results: + assert row in expected + async def test_invalid_typed_dict(self, engine: PGEngine) -> None: with pytest.raises(TypeError): await engine.ainit_vectorstore_table( @@ -230,10 +259,11 @@ class TestEngineSync: async def engine(self) -> AsyncIterator[PGEngine]: engine = PGEngine.from_connection_string(url=CONNECTION_STRING) yield engine - await aexecute(engine, f'DROP TABLE "{CUSTOM_TABLE_SYNC}"') - await aexecute(engine, f'DROP TABLE "{DEFAULT_TABLE_SYNC}"') - await aexecute(engine, f'DROP TABLE "{INT_ID_CUSTOM_TABLE_SYNC}"') - await aexecute(engine, f'DROP TABLE "{CUSTOM_TYPEDDICT_TABLE_SYNC}"') + await aexecute(engine, f'DROP TABLE IF EXISTS "{CUSTOM_TABLE_SYNC}"') + await aexecute(engine, f'DROP TABLE IF EXISTS "{HYBRID_SEARCH_TABLE_SYNC}"') + await aexecute(engine, f'DROP TABLE IF EXISTS "{DEFAULT_TABLE_SYNC}"') + await aexecute(engine, f'DROP TABLE IF EXISTS "{INT_ID_CUSTOM_TABLE_SYNC}"') + await aexecute(engine, f'DROP TABLE IF EXISTS "{CUSTOM_TYPEDDICT_TABLE_SYNC}"') await engine.close() async def test_init_table(self, engine: PGEngine) -> None: @@ -269,6 +299,31 @@ async def test_init_table_custom(self, engine: PGEngine) -> None: for row in results: assert row in expected + async def test_init_table_hybrid_search(self, engine: PGEngine) -> None: + engine.init_vectorstore_table( + HYBRID_SEARCH_TABLE_SYNC, + VECTOR_SIZE, + id_column="uuid", + content_column="my-content", + embedding_column="my_embedding", + metadata_columns=[Column("page", "TEXT"), Column("source", "TEXT")], + store_metadata=True, + hybrid_search_config=HybridSearchConfig(), + ) + stmt = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{HYBRID_SEARCH_TABLE_SYNC}';" + results = await afetch(engine, stmt) + expected = [ + {"column_name": "uuid", "data_type": "uuid"}, + {"column_name": "my_embedding", "data_type": "USER-DEFINED"}, + {"column_name": "langchain_metadata", "data_type": "json"}, + {"column_name": "my-content", "data_type": "text"}, + {"column_name": "my-content_tsv", "data_type": "tsvector"}, + {"column_name": "page", "data_type": "text"}, + {"column_name": "source", "data_type": "text"}, + ] + for row in results: + assert row in expected + async def test_invalid_typed_dict(self, engine: PGEngine) -> None: with pytest.raises(TypeError): engine.init_vectorstore_table( From e641575193134162dca4dac245865a20b91f7e94 Mon Sep 17 00:00:00 2001 From: Vishwaraj Anand Date: Mon, 19 May 2025 21:57:50 +0000 Subject: [PATCH 03/15] feat: adds hybrid search for async VS interface [3/N] --- langchain_postgres/v2/async_vectorstore.py | 171 +++++++- .../v2/test_async_pg_vectorstore_index.py | 77 +++- .../v2/test_async_pg_vectorstore_search.py | 364 +++++++++++++++++- 3 files changed, 581 insertions(+), 31 deletions(-) diff --git a/langchain_postgres/v2/async_vectorstore.py b/langchain_postgres/v2/async_vectorstore.py index 11e5ff99..9045da42 100644 --- a/langchain_postgres/v2/async_vectorstore.py +++ b/langchain_postgres/v2/async_vectorstore.py @@ -14,14 +14,10 @@ from sqlalchemy.ext.asyncio import AsyncEngine from .engine import PGEngine -from .indexes import ( - DEFAULT_DISTANCE_STRATEGY, - DEFAULT_INDEX_NAME_SUFFIX, - BaseIndex, - DistanceStrategy, - ExactNearestNeighbor, - QueryOptions, -) +from .hybrid_search_config import HybridSearchConfig +from .indexes import (DEFAULT_DISTANCE_STRATEGY, DEFAULT_INDEX_NAME_SUFFIX, + BaseIndex, DistanceStrategy, ExactNearestNeighbor, + QueryOptions) COMPARISONS_TO_NATIVE = { "$eq": "=", @@ -77,6 +73,8 @@ def __init__( fetch_k: int = 20, lambda_mult: float = 0.5, index_query_options: Optional[QueryOptions] = None, + hybrid_search_config: Optional[HybridSearchConfig] = None, + hybrid_search_column_exists: bool = False, ): """AsyncPGVectorStore constructor. Args: @@ -95,6 +93,8 @@ def __init__( fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. index_query_options (QueryOptions): Index query option. + hybrid_search_config (HybridSearchConfig): Hybrid search configuration. Defaults to None. + hybrid_search_column_exists (bool): Defines whether the existing table has the hybrid search column. Raises: @@ -119,6 +119,8 @@ def __init__( self.fetch_k = fetch_k self.lambda_mult = lambda_mult self.index_query_options = index_query_options + self.hybrid_search_config = hybrid_search_config + self.hybrid_search_column_exists = hybrid_search_column_exists @classmethod async def create( @@ -139,6 +141,7 @@ async def create( fetch_k: int = 20, lambda_mult: float = 0.5, index_query_options: Optional[QueryOptions] = None, + hybrid_search_config: Optional[HybridSearchConfig] = None, ) -> AsyncPGVectorStore: """Create an AsyncPGVectorStore instance. @@ -158,6 +161,7 @@ async def create( fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. index_query_options (QueryOptions): Index query option. + hybrid_search_config (HybridSearchConfig): Hybrid search configuration. Defaults to None. Returns: AsyncPGVectorStore @@ -193,6 +197,17 @@ async def create( raise ValueError( f"Content column, {content_column}, is type, {content_type}. It must be a type of character string." ) + hybrid_search_column_exists = False + if hybrid_search_config: + tsv_column_name = ( + hybrid_search_config.tsv_column + if hybrid_search_config.tsv_column + else content_column + "_tsv" + ) + hybrid_search_config.tsv_column = tsv_column_name + hybrid_search_column_exists = ( + tsv_column_name in columns and columns[tsv_column_name] == "tsvector" + ) if embedding_column not in columns: raise ValueError(f"Embedding column, {embedding_column}, does not exist.") if columns[embedding_column] != "USER-DEFINED": @@ -236,6 +251,8 @@ async def create( fetch_k=fetch_k, lambda_mult=lambda_mult, index_query_options=index_query_options, + hybrid_search_config=hybrid_search_config, + hybrid_search_column_exists=hybrid_search_column_exists, ) @property @@ -273,7 +290,12 @@ async def aadd_embeddings( if len(self.metadata_columns) > 0 else "" ) - insert_stmt = f'INSERT INTO "{self.schema_name}"."{self.table_name}"("{self.id_column}", "{self.content_column}", "{self.embedding_column}"{metadata_col_names}' + hybrid_search_column = ( + f', "{self.hybrid_search_config.tsv_column}"' + if self.hybrid_search_config and self.hybrid_search_column_exists + else "" + ) + insert_stmt = f'INSERT INTO "{self.schema_name}"."{self.table_name}"("{self.id_column}", "{self.content_column}", "{self.embedding_column}"{hybrid_search_column}{metadata_col_names}' values = { "id": id, "content": content, @@ -284,6 +306,14 @@ async def aadd_embeddings( if not embedding and can_inline_embed: values_stmt = f"VALUES (:id, :content, {self.embedding_service.embed_query_inline(content)}" # type: ignore + if self.hybrid_search_config and self.hybrid_search_column_exists: + lang = ( + f"'{self.hybrid_search_config.tsv_lang}'," + if self.hybrid_search_config.tsv_lang + else "" + ) + values_stmt += f", to_tsvector({lang} :tsv_content)" + values["tsv_content"] = content # Add metadata extra = copy.deepcopy(metadata) for metadata_column in self.metadata_columns: @@ -308,6 +338,9 @@ async def aadd_embeddings( upsert_stmt = f' ON CONFLICT ("{self.id_column}") DO UPDATE SET "{self.content_column}" = EXCLUDED."{self.content_column}", "{self.embedding_column}" = EXCLUDED."{self.embedding_column}"' + if self.hybrid_search_config and self.hybrid_search_column_exists: + upsert_stmt += f', "{self.hybrid_search_config.tsv_column}" = EXCLUDED."{self.hybrid_search_config.tsv_column}"' + if self.metadata_json_column: upsert_stmt += f', "{self.metadata_json_column}" = EXCLUDED."{self.metadata_json_column}"' @@ -408,6 +441,7 @@ async def afrom_texts( # type: ignore[override] fetch_k: int = 20, lambda_mult: float = 0.5, index_query_options: Optional[QueryOptions] = None, + hybrid_search_config: Optional[HybridSearchConfig] = None, **kwargs: Any, ) -> AsyncPGVectorStore: """Create an AsyncPGVectorStore instance from texts. @@ -430,6 +464,7 @@ async def afrom_texts( # type: ignore[override] fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. index_query_options (QueryOptions): Index query option. + hybrid_search_column_exists (bool): Defines whether the existing table has the hybrid search column. Raises: :class:`InvalidTextRepresentationError `: if the `ids` data type does not match that of the `id_column`. @@ -453,6 +488,7 @@ async def afrom_texts( # type: ignore[override] fetch_k=fetch_k, lambda_mult=lambda_mult, index_query_options=index_query_options, + hybrid_search_config=hybrid_search_config, ) await vs.aadd_texts(texts, metadatas=metadatas, ids=ids, **kwargs) return vs @@ -478,6 +514,7 @@ async def afrom_documents( # type: ignore[override] fetch_k: int = 20, lambda_mult: float = 0.5, index_query_options: Optional[QueryOptions] = None, + hybrid_search_config: Optional[HybridSearchConfig] = None, **kwargs: Any, ) -> AsyncPGVectorStore: """Create an AsyncPGVectorStore instance from documents. @@ -500,6 +537,7 @@ async def afrom_documents( # type: ignore[override] fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. index_query_options (QueryOptions): Index query option. + hybrid_search_column_exists (bool): Defines whether the existing table has the hybrid search column. Raises: :class:`InvalidTextRepresentationError `: if the `ids` data type does not match that of the `id_column`. @@ -524,6 +562,7 @@ async def afrom_documents( # type: ignore[override] fetch_k=fetch_k, lambda_mult=lambda_mult, index_query_options=index_query_options, + hybrid_search_config=hybrid_search_config, ) texts = [doc.page_content for doc in documents] metadatas = [doc.metadata for doc in documents] @@ -538,16 +577,30 @@ async def __query_collection( filter: Optional[dict] = None, **kwargs: Any, ) -> Sequence[RowMapping]: - """Perform similarity search query on database.""" - k = k if k else self.k + """ + Perform similarity search (or hybrid search) query on database. + Queries might be slow if the hybrid search column does not exist. + For best hybrid search performance, consider creating a TSV column + and adding GIN index. + """ + if not k: + k = ( + max( + self.k, + self.hybrid_search_config.primary_top_k, + self.hybrid_search_config.secondary_top_k, + ) + if self.hybrid_search_config + else self.k + ) operator = self.distance_strategy.operator search_function = self.distance_strategy.search_function - columns = self.metadata_columns + [ + columns = [ self.id_column, self.content_column, self.embedding_column, - ] + ] + self.metadata_columns if self.metadata_json_column: columns.append(self.metadata_json_column) @@ -557,7 +610,9 @@ async def __query_collection( filter_dict = None if filter and isinstance(filter, dict): safe_filter, filter_dict = self._create_filter_clause(filter) - param_filter = f"WHERE {safe_filter}" if safe_filter else "" + where_filters = f"WHERE {safe_filter}" if safe_filter else "" + and_filters = f"AND ({safe_filter})" if safe_filter else "" + inline_embed_func = getattr(self.embedding_service, "embed_query_inline", None) if not embedding and callable(inline_embed_func) and "query" in kwargs: query_embedding = self.embedding_service.embed_query_inline(kwargs["query"]) # type: ignore @@ -565,8 +620,8 @@ async def __query_collection( else: query_embedding = f"{[float(dimension) for dimension in embedding]}" embedding_data_string = ":query_embedding" - stmt = f"""SELECT {column_names}, {search_function}("{self.embedding_column}", {embedding_data_string}) as distance - FROM "{self.schema_name}"."{self.table_name}" {param_filter} ORDER BY "{self.embedding_column}" {operator} {embedding_data_string} LIMIT :k; + dense_query_stmt = f"""SELECT {column_names}, {search_function}("{self.embedding_column}", {embedding_data_string}) as distance + FROM "{self.schema_name}"."{self.table_name}" {where_filters} ORDER BY "{self.embedding_column}" {operator} {embedding_data_string} LIMIT :k; """ param_dict = {"query_embedding": query_embedding, "k": k} if filter_dict: @@ -577,15 +632,50 @@ async def __query_collection( for query_option in self.index_query_options.to_parameter(): query_options_stmt = f"SET LOCAL {query_option};" await conn.execute(text(query_options_stmt)) - result = await conn.execute(text(stmt), param_dict) + result = await conn.execute(text(dense_query_stmt), param_dict) result_map = result.mappings() - results = result_map.fetchall() + dense_results = result_map.fetchall() else: async with self.engine.connect() as conn: - result = await conn.execute(text(stmt), param_dict) + result = await conn.execute(text(dense_query_stmt), param_dict) result_map = result.mappings() - results = result_map.fetchall() - return results + dense_results = result_map.fetchall() + + hybrid_search_config = kwargs.get( + "hybrid_search_config", self.hybrid_search_config + ) + fts_query = ( + hybrid_search_config.fts_query + if hybrid_search_config and hybrid_search_config.fts_query + else kwargs.get("fts_query", "") + ) + if hybrid_search_config and fts_query: + hybrid_search_config.fusion_function_parameters["fetch_top_k"] = k + # do the sparse query + lang = ( + f"'{hybrid_search_config.tsv_lang}'," + if hybrid_search_config.tsv_lang + else "" + ) + query_tsv = f"plainto_tsquery({lang} :fts_query)" + param_dict["fts_query"] = fts_query + if self.hybrid_search_column_exists: + content_tsv = f'"{hybrid_search_config.tsv_column}"' + else: + content_tsv = f'to_tsvector({lang} "{self.content_column}")' + sparse_query_stmt = f'SELECT {column_names}, ts_rank_cd({content_tsv}, {query_tsv}) as distance FROM "{self.schema_name}"."{self.table_name}" WHERE {content_tsv} @@ {query_tsv} {and_filters} ORDER BY distance desc LIMIT {hybrid_search_config.secondary_top_k};' + async with self.engine.connect() as conn: + result = await conn.execute(text(sparse_query_stmt), param_dict) + result_map = result.mappings() + sparse_results = result_map.fetchall() + + combined_results = hybrid_search_config.fusion_function( + dense_results, + sparse_results, + **hybrid_search_config.fusion_function_parameters, + ) + return combined_results + return dense_results async def asimilarity_search( self, @@ -603,6 +693,14 @@ async def asimilarity_search( ) kwargs["query"] = query + # add fts_query to hybrid_search_config + hybrid_search_config = kwargs.get( + "hybrid_search_config", self.hybrid_search_config + ) + if hybrid_search_config and not hybrid_search_config.fts_query: + hybrid_search_config.fts_query = query + kwargs["hybrid_search_config"] = hybrid_search_config + return await self.asimilarity_search_by_vector( embedding=embedding, k=k, filter=filter, **kwargs ) @@ -634,6 +732,14 @@ async def asimilarity_search_with_score( ) kwargs["query"] = query + # add fts_query to hybrid_search_config + hybrid_search_config = kwargs.get( + "hybrid_search_config", self.hybrid_search_config + ) + if hybrid_search_config and not hybrid_search_config.fts_query: + hybrid_search_config.fts_query = query + kwargs["hybrid_search_config"] = hybrid_search_config + docs = await self.asimilarity_search_with_score_by_vector( embedding=embedding, k=k, filter=filter, **kwargs ) @@ -806,15 +912,38 @@ async def aapply_vector_index( index.name = self.table_name + DEFAULT_INDEX_NAME_SUFFIX name = index.name stmt = f'CREATE INDEX {"CONCURRENTLY" if concurrently else ""} "{name}" ON "{self.schema_name}"."{self.table_name}" USING {index.index_type} ({self.embedding_column} {function}) {params} {filter};' + + if self.hybrid_search_config: + if self.hybrid_search_column_exists: + tsv_column_name = ( + self.hybrid_search_config.tsv_column + if self.hybrid_search_config.tsv_column + else self.content_column + "_tsv" + ) + tsv_column_name = f'"{tsv_column_name}"' + else: + lang = ( + f"'{self.hybrid_search_config.tsv_lang}'," + if self.hybrid_search_config.tsv_lang + else "" + ) + tsv_column_name = f"to_tsvector({lang} {self.content_column})" + tsv_index_name = self.table_name + self.hybrid_search_config.index_name + tsv_index_query = f'CREATE INDEX {"CONCURRENTLY" if concurrently else ""} {tsv_index_name} ON "{self.schema_name}"."{self.table_name}" USING {self.hybrid_search_config.index_type}({tsv_column_name});' + else: + tsv_index_query = "" + if concurrently: async with self.engine.connect() as conn: autocommit_conn = await conn.execution_options( isolation_level="AUTOCOMMIT" ) await autocommit_conn.execute(text(stmt)) + await conn.execute(text(tsv_index_query)) else: async with self.engine.connect() as conn: await conn.execute(text(stmt)) + await conn.execute(text(tsv_index_query)) await conn.commit() async def areindex(self, index_name: Optional[str] = None) -> None: diff --git a/tests/unit_tests/v2/test_async_pg_vectorstore_index.py b/tests/unit_tests/v2/test_async_pg_vectorstore_index.py index 3796ef5a..dc8768d1 100644 --- a/tests/unit_tests/v2/test_async_pg_vectorstore_index.py +++ b/tests/unit_tests/v2/test_async_pg_vectorstore_index.py @@ -10,15 +10,14 @@ from langchain_postgres import PGEngine from langchain_postgres.v2.async_vectorstore import AsyncPGVectorStore -from langchain_postgres.v2.indexes import ( - DistanceStrategy, - HNSWIndex, - IVFFlatIndex, -) +from langchain_postgres.v2.hybrid_search_config import HybridSearchConfig +from langchain_postgres.v2.indexes import (DistanceStrategy, HNSWIndex, + IVFFlatIndex) from tests.utils import VECTORSTORE_CONNECTION_STRING as CONNECTION_STRING uuid_str = str(uuid.uuid4()).replace("-", "_") DEFAULT_TABLE = "default" + uuid_str +DEFAULT_HYBRID_TABLE = "hybrid" + uuid_str DEFAULT_INDEX_NAME = "index" + uuid_str VECTOR_SIZE = 768 SIMPLE_TABLE = "default_table" @@ -55,8 +54,10 @@ class TestIndex: async def engine(self) -> AsyncIterator[PGEngine]: engine = PGEngine.from_connection_string(url=CONNECTION_STRING) yield engine - await aexecute(engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE}") - await aexecute(engine, f"DROP TABLE IF EXISTS {SIMPLE_TABLE}") + + engine._adrop_table(DEFAULT_TABLE) + engine._adrop_table(DEFAULT_HYBRID_TABLE) + engine._adrop_table(SIMPLE_TABLE) await engine.close() @pytest_asyncio.fixture(scope="class") @@ -92,6 +93,68 @@ async def test_aapply_vector_index(self, vs: AsyncPGVectorStore) -> None: assert await vs.is_valid_index(DEFAULT_INDEX_NAME) await vs.adrop_vector_index(DEFAULT_INDEX_NAME) + async def test_aapply_vector_index_non_hybrid_search_vs_without_tsv_column( + self, vs + ) -> None: + index = HNSWIndex(name="test_index_hybrid" + uuid_str) + + tsv_index_name = DEFAULT_TABLE + "langchain_tsv_index" + is_valid_index = await vs.is_valid_index(tsv_index_name) + assert is_valid_index == False + await vs.aapply_vector_index(index) + is_valid_index = await vs.is_valid_index(tsv_index_name) + assert is_valid_index == False + await vs.adrop_vector_index(tsv_index_name) + is_valid_index = await vs.is_valid_index(tsv_index_name) + assert is_valid_index == False + + async def test_aapply_vector_index_hybrid_search_vs_without_tsv_column( + self, engine, vs + ) -> None: + # overwriting vs to get a hybrid vs + vs = await AsyncPGVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=DEFAULT_TABLE, + hybrid_search_config=HybridSearchConfig(), + ) + index = HNSWIndex(name="test_index_hybrid" + uuid_str) + + tsv_index_name = DEFAULT_TABLE + "langchain_tsv_index" + is_valid_index = await vs.is_valid_index(tsv_index_name) + assert is_valid_index == False + await vs.adrop_vector_index(index.name) + await vs.aapply_vector_index(index) + assert await vs.is_valid_index(tsv_index_name) + await vs.areindex(tsv_index_name) + assert await vs.is_valid_index(tsv_index_name) + await vs.adrop_vector_index(tsv_index_name) + is_valid_index = await vs.is_valid_index(tsv_index_name) + assert is_valid_index == False + await vs.adrop_vector_index(index.name) + + async def test_aapply_vector_index_hybrid_search_with_tsv_column( + self, engine + ) -> None: + await engine._ainit_vectorstore_table( + DEFAULT_HYBRID_TABLE, VECTOR_SIZE, hybrid_search_config=HybridSearchConfig() + ) + vs = await AsyncPGVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=DEFAULT_HYBRID_TABLE, + hybrid_search_config=HybridSearchConfig(), + ) + tsv_index_name = DEFAULT_HYBRID_TABLE + "langchain_tsv_index" + is_valid_index = await vs.is_valid_index(tsv_index_name) + assert is_valid_index == False + index = HNSWIndex(name=DEFAULT_INDEX_NAME) + await vs.aapply_vector_index(index) + await vs.adrop_vector_index(tsv_index_name) + await vs.adrop_vector_index(index.name) + is_valid_index = await vs.is_valid_index(tsv_index_name) + assert is_valid_index == False + async def test_areindex(self, vs: AsyncPGVectorStore) -> None: if not await vs.is_valid_index(DEFAULT_INDEX_NAME): index = HNSWIndex(name=DEFAULT_INDEX_NAME) diff --git a/tests/unit_tests/v2/test_async_pg_vectorstore_search.py b/tests/unit_tests/v2/test_async_pg_vectorstore_search.py index 72f91d80..eb65da24 100644 --- a/tests/unit_tests/v2/test_async_pg_vectorstore_search.py +++ b/tests/unit_tests/v2/test_async_pg_vectorstore_search.py @@ -10,15 +10,18 @@ from langchain_postgres import Column, PGEngine from langchain_postgres.v2.async_vectorstore import AsyncPGVectorStore +from langchain_postgres.v2.hybrid_search_config import (HybridSearchConfig, + reciprocal_rank_fusion, + weighted_sum_ranking) from langchain_postgres.v2.indexes import DistanceStrategy, HNSWQueryOptions from tests.unit_tests.fixtures.metadata_filtering_data import ( - FILTERING_TEST_CASES, - METADATAS, -) + FILTERING_TEST_CASES, METADATAS) from tests.utils import VECTORSTORE_CONNECTION_STRING as CONNECTION_STRING DEFAULT_TABLE = "default" + str(uuid.uuid4()).replace("-", "_") CUSTOM_TABLE = "custom" + str(uuid.uuid4()).replace("-", "_") +HYBRID_SEARCH_TABLE1 = "test_table_hybrid1" + str(uuid.uuid4()).replace("-", "_") +HYBRID_SEARCH_TABLE2 = "test_table_hybrid2" + str(uuid.uuid4()).replace("-", "_") CUSTOM_FILTER_TABLE = "custom_filter" + str(uuid.uuid4()).replace("-", "_") VECTOR_SIZE = 768 sync_method_exception_str = "Sync methods are not implemented for AsyncPGVectorStore. Use PGVectorStore interface instead." @@ -41,6 +44,18 @@ filter_docs = [ Document(page_content=texts[i], metadata=METADATAS[i]) for i in range(len(texts)) ] +# Documents designed for hybrid search testing +hybrid_docs_content = { + "hs_doc_apple_fruit": "An apple is a sweet and edible fruit produced by an apple tree. Apples are very common.", + "hs_doc_apple_tech": "Apple Inc. is a multinational technology company. Their latest tech is amazing.", + "hs_doc_orange_fruit": "The orange is the fruit of various citrus species. Oranges are tasty.", + "hs_doc_generic_tech": "Technology drives innovation in the modern world. Tech is evolving.", + "hs_doc_unrelated_cat": "A fluffy cat sat on a mat quietly observing a mouse.", +} +hybrid_docs = [ + Document(page_content=content, metadata={"doc_id_key": key}) + for key, content in hybrid_docs_content.items() +] def get_env_var(key: str, desc: str) -> str: @@ -69,6 +84,8 @@ async def engine(self) -> AsyncIterator[PGEngine]: await engine.adrop_table(DEFAULT_TABLE) await engine.adrop_table(CUSTOM_TABLE) await engine.adrop_table(CUSTOM_FILTER_TABLE) + await engine.adrop_table(HYBRID_SEARCH_TABLE1) + await engine.adrop_table(HYBRID_SEARCH_TABLE2) await engine.close() @pytest_asyncio.fixture(scope="class") @@ -111,6 +128,79 @@ async def vs_custom(self, engine: PGEngine) -> AsyncIterator[AsyncPGVectorStore] await vs_custom.aadd_documents(docs, ids=ids) yield vs_custom + @pytest_asyncio.fixture(scope="class") + async def vs_hybrid_search_without_tsv_column(self, engine): + hybrid_search_config = HybridSearchConfig( + tsv_column="my_tsv_col", + tsv_lang="pg_catalog.english", + fts_query="my_fts_query", + fusion_function=reciprocal_rank_fusion, + fusion_function_parameters={ + "rrf_k": 60, + "fetch_top_k": 10, + }, + ) + await engine._ainit_vectorstore_table( + HYBRID_SEARCH_TABLE1, + VECTOR_SIZE, + id_column=Column("myid", "TEXT"), + content_column="mycontent", + embedding_column="myembedding", + metadata_columns=[ + Column("page", "TEXT"), + Column("source", "TEXT"), + Column("doc_id_key", "TEXT"), + ], + metadata_json_column="mymetadata", # ignored + store_metadata=False, + ) + + vs_custom = await AsyncPGVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=HYBRID_SEARCH_TABLE1, + id_column="myid", + content_column="mycontent", + embedding_column="myembedding", + metadata_json_column="mymetadata", + metadata_columns=["doc_id_key"], + index_query_options=HNSWQueryOptions(ef_search=1), + hybrid_search_config=hybrid_search_config, + ) + await vs_custom.aadd_documents(hybrid_docs) + yield vs_custom + + @pytest_asyncio.fixture(scope="class") + async def vs_hybrid_search_with_tsv_column(self, engine): + await engine._ainit_vectorstore_table( + HYBRID_SEARCH_TABLE2, + VECTOR_SIZE, + id_column=Column("myid", "TEXT"), + content_column="mycontent", + embedding_column="myembedding", + metadata_columns=[ + Column("page", "TEXT"), + Column("source", "TEXT"), + Column("doc_id_key", "TEXT"), + ], + store_metadata=False, + hybrid_search_config=HybridSearchConfig(), + ) + + vs_custom = await AsyncPGVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=HYBRID_SEARCH_TABLE2, + id_column="myid", + content_column="mycontent", + embedding_column="myembedding", + metadata_columns=["doc_id_key"], + index_query_options=HNSWQueryOptions(ef_search=1), + hybrid_search_config=HybridSearchConfig(), + ) + await vs_custom.aadd_documents(hybrid_docs) + yield vs_custom + @pytest_asyncio.fixture(scope="class") async def vs_custom_filter( self, engine: PGEngine @@ -303,3 +393,271 @@ async def test_vectorstore_with_metadata_filters( "meow", k=5, filter=test_filter ) assert [doc.metadata["code"] for doc in docs] == expected_ids, test_filter + + async def test_asimilarity_hybrid_search(self, vs): + results = await vs.asimilarity_search( + "foo", k=1, hybrid_search_config=HybridSearchConfig() + ) + assert len(results) == 1 + assert results == [Document(page_content="foo", id=ids[0])] + + results = await vs.asimilarity_search( + "bar", + k=1, + hybrid_search_config=HybridSearchConfig(), + ) + assert results[0] == Document(page_content="bar", id=ids[1]) + + results = await vs.asimilarity_search( + "foo", + k=1, + filter={"content": {"$ne": "baz"}}, + hybrid_search_config=HybridSearchConfig( + fusion_function=weighted_sum_ranking, + fusion_function_parameters={ + "primary_results_weight": 0.1, + "secondary_results_weight": 0.9, + "fetch_top_k": 10, + }, + primary_top_k=1, + secondary_top_k=1, + ), + ) + assert results == [Document(page_content="foo", id=ids[0])] + + async def test_asimilarity_hybrid_search_rrk(self, vs): + results = await vs.asimilarity_search( + "foo", + k=1, + hybrid_search_config=HybridSearchConfig( + fusion_function=reciprocal_rank_fusion + ), + ) + assert len(results) == 1 + assert results == [Document(page_content="foo", id=ids[0])] + + results = await vs.asimilarity_search( + "bar", + k=1, + filter={"content": {"$ne": "baz"}}, + hybrid_search_config=HybridSearchConfig( + fusion_function=reciprocal_rank_fusion, + fusion_function_parameters={ + "rrf_k": 100, + "fetch_top_k": 10, + }, + primary_top_k=1, + secondary_top_k=1, + ), + ) + assert results == [Document(page_content="bar", id=ids[1])] + + async def test_hybrid_search_weighted_sum_default( + self, vs_hybrid_search_with_tsv_column + ): + """Test hybrid search with default weighted sum (0.5 vector, 0.5 FTS).""" + query = "apple" # Should match "apple" in FTS and vector + + # The vs_hybrid_search_with_tsv_column instance is already configured for hybrid search. + # Default fusion is weighted_sum_ranking with 0.5/0.5 weights. + # fts_query will default to the main query. + results_with_scores = ( + await vs_hybrid_search_with_tsv_column.asimilarity_search_with_score( + query, k=3 + ) + ) + + assert len(results_with_scores) > 1 + result_ids = [doc.metadata["doc_id_key"] for doc, score in results_with_scores] + + # Expect "hs_doc_apple_fruit" and "hs_doc_apple_tech" to be highly ranked. + assert "hs_doc_apple_fruit" in result_ids + + # Scores should be floats (fused scores) + for doc, score in results_with_scores: + assert isinstance(score, float) + + # Check if sorted by score (descending for weighted_sum_ranking with positive scores) + assert results_with_scores[0][1] >= results_with_scores[1][1] + + async def test_hybrid_search_weighted_sum_vector_bias( + self, + vs_hybrid_search_without_tsv_column, + ): + """Test weighted sum with higher weight for vector results.""" + query = "Apple Inc technology" # More specific for vector similarity + + config = HybridSearchConfig( + tsv_column="mycontent_tsv", # Must match table setup + fusion_function_parameters={ + "primary_results_weight": 0.8, # Vector bias + "secondary_results_weight": 0.2, + }, + # fts_query will default to main query + ) + results = await vs_hybrid_search_without_tsv_column.asimilarity_search( + query, k=2, hybrid_search_config=config + ) + result_ids = [doc.metadata["doc_id_key"] for doc in results] + + assert len(result_ids) > 0 + assert result_ids[0] == "hs_doc_orange_fruit" + + async def test_hybrid_search_weighted_sum_fts_bias( + self, + vs_hybrid_search_with_tsv_column, + ): + """Test weighted sum with higher weight for FTS results.""" + query = "fruit common tasty" # Strong FTS signal for fruit docs + + config = HybridSearchConfig( + tsv_column="mycontent_tsv", + fusion_function=weighted_sum_ranking, + fusion_function_parameters={ + "primary_results_weight": 0.01, + "secondary_results_weight": 0.99, # FTS bias + }, + ) + results = await vs_hybrid_search_with_tsv_column.asimilarity_search( + query, k=2, hybrid_search_config=config + ) + result_ids = [doc.metadata["doc_id_key"] for doc in results] + + assert len(result_ids) == 2 + assert "hs_doc_apple_fruit" in result_ids + + async def test_hybrid_search_reciprocal_rank_fusion( + self, + vs_hybrid_search_with_tsv_column, + ): + """Test hybrid search with Reciprocal Rank Fusion.""" + query = "technology company" + + # Configure RRF. primary_top_k and secondary_top_k control inputs to fusion. + # fusion_function_parameters.fetch_top_k controls output count from RRF. + config = HybridSearchConfig( + tsv_column="mycontent_tsv", + fusion_function=reciprocal_rank_fusion, + primary_top_k=3, # How many dense results to consider + secondary_top_k=3, # How many sparse results to consider + fusion_function_parameters={ + "rrf_k": 60, + "fetch_top_k": 2, + }, # RRF specific params + ) + # The `k` in asimilarity_search here is the final desired number of results, + # which should align with fusion_function_parameters.fetch_top_k for RRF. + results = await vs_hybrid_search_with_tsv_column.asimilarity_search( + query, k=2, hybrid_search_config=config + ) + result_ids = [doc.metadata["doc_id_key"] for doc in results] + + assert len(result_ids) == 2 + # "hs_doc_apple_tech" (FTS: technology, company; Vector: Apple Inc technology) + # "hs_doc_generic_tech" (FTS: technology; Vector: Technology drives innovation) + # RRF should combine these ranks. "hs_doc_apple_tech" is likely higher. + assert "hs_doc_apple_tech" in result_ids + assert result_ids[0] == "hs_doc_apple_tech" # Stronger combined signal + + async def test_hybrid_search_explicit_fts_query( + self, vs_hybrid_search_with_tsv_column + ): + """Test hybrid search when fts_query in HybridSearchConfig is different from main query.""" + main_vector_query = "Apple Inc." # For vector search + fts_specific_query = "fruit" # For FTS + + config = HybridSearchConfig( + tsv_column="mycontent_tsv", + fts_query=fts_specific_query, # Override FTS query + fusion_function_parameters={ # Using default weighted_sum_ranking + "primary_results_weight": 0.5, + "secondary_results_weight": 0.5, + }, + ) + results = await vs_hybrid_search_with_tsv_column.asimilarity_search( + main_vector_query, k=2, hybrid_search_config=config + ) + result_ids = [doc.metadata["doc_id_key"] for doc in results] + + # Vector search for "Apple Inc.": hs_doc_apple_tech + # FTS search for "fruit": hs_doc_apple_fruit, hs_doc_orange_fruit + # Combined: hs_doc_apple_fruit (strong FTS) and hs_doc_apple_tech (strong vector) are candidates. + # "hs_doc_apple_fruit" might get a boost if "Apple Inc." vector has some similarity to "apple fruit" doc. + assert len(result_ids) > 0 + assert ( + "hs_doc_apple_fruit" in result_ids + or "hs_doc_apple_tech" in result_ids + or "hs_doc_orange_fruit" in result_ids + ) + + async def test_hybrid_search_with_filter(self, vs_hybrid_search_with_tsv_column): + """Test hybrid search with a metadata filter applied.""" + query = "apple" + # Filter to only include "tech" related apple docs using metadata + # Assuming metadata_columns=["doc_id_key"] was set up for vs_hybrid_search_with_tsv_column + doc_filter = {"doc_id_key": {"$eq": "hs_doc_apple_tech"}} + + config = HybridSearchConfig( + tsv_column="mycontent_tsv", + ) + results = await vs_hybrid_search_with_tsv_column.asimilarity_search( + query, k=2, filter=doc_filter, hybrid_search_config=config + ) + result_ids = [doc.metadata["doc_id_key"] for doc in results] + + assert len(results) == 1 + assert result_ids[0] == "hs_doc_apple_tech" + + async def test_hybrid_search_fts_empty_results( + self, vs_hybrid_search_with_tsv_column + ): + """Test when FTS query yields no results, should fall back to vector search.""" + vector_query = "apple" + no_match_fts_query = "zzyyxx_gibberish_term_for_fts_nomatch" + + config = HybridSearchConfig( + tsv_column="mycontent_tsv", + fts_query=no_match_fts_query, + fusion_function_parameters={ + "primary_results_weight": 0.6, + "secondary_results_weight": 0.4, + }, + ) + results = await vs_hybrid_search_with_tsv_column.asimilarity_search( + vector_query, k=2, hybrid_search_config=config + ) + result_ids = [doc.metadata["doc_id_key"] for doc in results] + + # Expect results based purely on vector search for "apple" + assert len(result_ids) > 0 + assert "hs_doc_apple_fruit" in result_ids or "hs_doc_apple_tech" in result_ids + # The top result should be one of the apple documents based on vector search + assert results[0].metadata["doc_id_key"].startswith("hs_doc_unrelated_cat") + + async def test_hybrid_search_vector_empty_results_effectively( + self, + vs_hybrid_search_with_tsv_column, + ): + """Test when vector query is very dissimilar to docs, should rely on FTS.""" + # This is hard to guarantee with fake embeddings, but we try. + # A better way might be to use a filter that excludes all docs for the vector part, + # but filters are applied to both. + vector_query_far_off = "supercalifragilisticexpialidocious_vector_nomatch" + fts_query_match = "orange fruit" # Should match hs_doc_orange_fruit + + config = HybridSearchConfig( + tsv_column="mycontent_tsv", + fts_query=fts_query_match, + fusion_function_parameters={ + "primary_results_weight": 0.4, + "secondary_results_weight": 0.6, + }, + ) + results = await vs_hybrid_search_with_tsv_column.asimilarity_search( + vector_query_far_off, k=1, hybrid_search_config=config + ) + result_ids = [doc.metadata["doc_id_key"] for doc in results] + + # Expect results based purely on FTS search for "orange fruit" + assert len(result_ids) == 1 + assert result_ids[0] == "hs_doc_generic_tech" From 2a0bf0dd00c4fc81ea14e09b08c53ca55c83f911 Mon Sep 17 00:00:00 2001 From: Vishwaraj Anand Date: Mon, 19 May 2025 22:02:37 +0000 Subject: [PATCH 04/15] feat: adds hybrid search for sync VS interface [4/N] --- langchain_postgres/v2/vectorstores.py | 19 ++++++ .../v2/test_pg_vectorstore_search.py | 60 +++++++++++++++++++ 2 files changed, 79 insertions(+) diff --git a/langchain_postgres/v2/vectorstores.py b/langchain_postgres/v2/vectorstores.py index 1dc1be97..52224dbe 100644 --- a/langchain_postgres/v2/vectorstores.py +++ b/langchain_postgres/v2/vectorstores.py @@ -9,6 +9,7 @@ from .async_vectorstore import AsyncPGVectorStore from .engine import PGEngine +from .hybrid_search_config import HybridSearchConfig from .indexes import ( DEFAULT_DISTANCE_STRATEGY, BaseIndex, @@ -59,6 +60,7 @@ async def create( fetch_k: int = 20, lambda_mult: float = 0.5, index_query_options: Optional[QueryOptions] = None, + hybrid_search_config: Optional[HybridSearchConfig] = None, ) -> PGVectorStore: """Create an PGVectorStore instance. @@ -78,6 +80,7 @@ async def create( fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. index_query_options (QueryOptions): Index query option. + hybrid_search_config (HybridSearchConfig): Hybrid search configuration. Defaults to None. Returns: PGVectorStore @@ -98,6 +101,7 @@ async def create( fetch_k=fetch_k, lambda_mult=lambda_mult, index_query_options=index_query_options, + hybrid_search_config=hybrid_search_config, ) vs = await engine._run_as_async(coro) return cls(cls.__create_key, engine, vs) @@ -120,6 +124,7 @@ def create_sync( fetch_k: int = 20, lambda_mult: float = 0.5, index_query_options: Optional[QueryOptions] = None, + hybrid_search_config: Optional[HybridSearchConfig] = None, ) -> PGVectorStore: """Create an PGVectorStore instance. @@ -140,6 +145,7 @@ def create_sync( fetch_k (int, optional): Number of Documents to fetch to pass to MMR algorithm. Defaults to 20. lambda_mult (float, optional): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. index_query_options (Optional[QueryOptions], optional): Index query option. Defaults to None. + hybrid_search_config (HybridSearchConfig): Hybrid search configuration. Defaults to None. Returns: PGVectorStore @@ -160,6 +166,7 @@ def create_sync( fetch_k=fetch_k, lambda_mult=lambda_mult, index_query_options=index_query_options, + hybrid_search_config=hybrid_search_config, ) vs = engine._run_as_sync(coro) return cls(cls.__create_key, engine, vs) @@ -301,6 +308,7 @@ async def afrom_texts( # type: ignore[override] fetch_k: int = 20, lambda_mult: float = 0.5, index_query_options: Optional[QueryOptions] = None, + hybrid_search_config: Optional[HybridSearchConfig] = None, **kwargs: Any, ) -> PGVectorStore: """Create an PGVectorStore instance from texts. @@ -324,6 +332,7 @@ async def afrom_texts( # type: ignore[override] fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. index_query_options (QueryOptions): Index query option. + hybrid_search_config (HybridSearchConfig): Hybrid search configuration. Defaults to None. Raises: :class:`InvalidTextRepresentationError `: if the `ids` data type does not match that of the `id_column`. @@ -347,6 +356,7 @@ async def afrom_texts( # type: ignore[override] fetch_k=fetch_k, lambda_mult=lambda_mult, index_query_options=index_query_options, + hybrid_search_config=hybrid_search_config, ) await vs.aadd_texts(texts, metadatas=metadatas, ids=ids) return vs @@ -371,6 +381,7 @@ async def afrom_documents( # type: ignore[override] fetch_k: int = 20, lambda_mult: float = 0.5, index_query_options: Optional[QueryOptions] = None, + hybrid_search_config: Optional[HybridSearchConfig] = None, **kwargs: Any, ) -> PGVectorStore: """Create an PGVectorStore instance from documents. @@ -393,6 +404,7 @@ async def afrom_documents( # type: ignore[override] fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. index_query_options (QueryOptions): Index query option. + hybrid_search_config (HybridSearchConfig): Hybrid search configuration. Defaults to None. Raises: :class:`InvalidTextRepresentationError `: if the `ids` data type does not match that of the `id_column`. @@ -417,6 +429,7 @@ async def afrom_documents( # type: ignore[override] fetch_k=fetch_k, lambda_mult=lambda_mult, index_query_options=index_query_options, + hybrid_search_config=hybrid_search_config, ) await vs.aadd_documents(documents, ids=ids) return vs @@ -442,6 +455,7 @@ def from_texts( # type: ignore[override] fetch_k: int = 20, lambda_mult: float = 0.5, index_query_options: Optional[QueryOptions] = None, + hybrid_search_config: Optional[HybridSearchConfig] = None, **kwargs: Any, ) -> PGVectorStore: """Create an PGVectorStore instance from texts. @@ -465,6 +479,7 @@ def from_texts( # type: ignore[override] fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. index_query_options (QueryOptions): Index query option. + hybrid_search_config (HybridSearchConfig): Hybrid search configuration. Defaults to None. Raises: :class:`InvalidTextRepresentationError `: if the `ids` data type does not match that of the `id_column`. @@ -488,6 +503,7 @@ def from_texts( # type: ignore[override] fetch_k=fetch_k, lambda_mult=lambda_mult, index_query_options=index_query_options, + hybrid_search_config=hybrid_search_config, **kwargs, ) vs.add_texts(texts, metadatas=metadatas, ids=ids) @@ -513,6 +529,7 @@ def from_documents( # type: ignore[override] fetch_k: int = 20, lambda_mult: float = 0.5, index_query_options: Optional[QueryOptions] = None, + hybrid_search_config: Optional[HybridSearchConfig] = None, **kwargs: Any, ) -> PGVectorStore: """Create an PGVectorStore instance from documents. @@ -535,6 +552,7 @@ def from_documents( # type: ignore[override] fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. index_query_options (QueryOptions): Index query option. + hybrid_search_config (HybridSearchConfig): Hybrid search configuration. Defaults to None. Raises: :class:`InvalidTextRepresentationError `: if the `ids` data type does not match that of the `id_column`. @@ -558,6 +576,7 @@ def from_documents( # type: ignore[override] fetch_k=fetch_k, lambda_mult=lambda_mult, index_query_options=index_query_options, + hybrid_search_config=hybrid_search_config, **kwargs, ) vs.add_documents(documents, ids=ids) diff --git a/tests/unit_tests/v2/test_pg_vectorstore_search.py b/tests/unit_tests/v2/test_pg_vectorstore_search.py index 379f5295..d783a856 100644 --- a/tests/unit_tests/v2/test_pg_vectorstore_search.py +++ b/tests/unit_tests/v2/test_pg_vectorstore_search.py @@ -9,6 +9,11 @@ from sqlalchemy import text from langchain_postgres import Column, PGEngine, PGVectorStore +from langchain_postgres.v2.hybrid_search_config import ( + HybridSearchConfig, + reciprocal_rank_fusion, + weighted_sum_ranking, +) from langchain_postgres.v2.indexes import DistanceStrategy, HNSWQueryOptions from tests.unit_tests.fixtures.metadata_filtering_data import ( FILTERING_TEST_CASES, @@ -261,6 +266,37 @@ async def test_vectorstore_with_metadata_filters( ) assert [doc.metadata["code"] for doc in docs] == expected_ids, test_filter + async def test_asimilarity_hybrid_search(self, vs: PGVectorStore): + results = await vs.asimilarity_search( + "foo", k=1, hybrid_search_config=HybridSearchConfig() + ) + assert len(results) == 1 + assert results == [Document(page_content="foo", id=ids[0])] + + results = await vs.asimilarity_search( + "bar", + k=1, + hybrid_search_config=HybridSearchConfig(), + ) + assert results[0] == Document(page_content="bar", id=ids[1]) + + results = await vs.asimilarity_search( + "foo", + k=1, + filter={"content": {"$ne": "baz"}}, + hybrid_search_config=HybridSearchConfig( + fusion_function=weighted_sum_ranking, + fusion_function_parameters={ + "primary_results_weight": 0.1, + "secondary_results_weight": 0.9, + "fetch_top_k": 10, + }, + primary_top_k=1, + secondary_top_k=1, + ), + ) + assert results == [Document(page_content="foo", id=ids[0])] + @pytest.mark.enable_socket class TestVectorStoreSearchSync: @@ -401,3 +437,27 @@ def test_metadata_filter_negative_tests( docs = vs_custom_filter_sync.similarity_search( "meow", k=5, filter=test_filter ) + + def test_similarity_hybrid_search(self, vs_custom): + results = vs_custom.similarity_search( + "foo", k=1, hybrid_search_config=HybridSearchConfig() + ) + assert len(results) == 1 + assert results == [Document(page_content="foo", id=ids[0])] + + results = vs_custom.similarity_search( + "bar", + k=1, + hybrid_search_config=HybridSearchConfig(), + ) + assert results == [Document(page_content="bar", id=ids[1])] + + results = vs_custom.similarity_search( + "foo", + k=1, + filter={"mycontent": {"$ne": "baz"}}, + hybrid_search_config=HybridSearchConfig( + fusion_function=reciprocal_rank_fusion + ), + ) + assert results == [Document(page_content="baz", id=ids[2])] From 70ee3001cbdaeb30285b571678771d6c313105f3 Mon Sep 17 00:00:00 2001 From: Vishwaraj Anand Date: Fri, 30 May 2025 21:27:53 +0000 Subject: [PATCH 05/15] fix: tests --- tests/unit_tests/v2/test_hybrid_search_config.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/unit_tests/v2/test_hybrid_search_config.py b/tests/unit_tests/v2/test_hybrid_search_config.py index 366f81d4..780b0068 100644 --- a/tests/unit_tests/v2/test_hybrid_search_config.py +++ b/tests/unit_tests/v2/test_hybrid_search_config.py @@ -1,8 +1,9 @@ import pytest -from langchain_postgres.v2.hybrid_search_config import (reciprocal_rank_fusion, - weighted_sum_ranking) - +from langchain_postgres.v2.hybrid_search_config import ( + reciprocal_rank_fusion, + weighted_sum_ranking, +) # Helper to create mock input items that mimic RowMapping for the fusion functions def get_row(doc_id: str, score: float, content: str = "content") -> dict: From 523464836833351562150e3a5e50f31161489066 Mon Sep 17 00:00:00 2001 From: Vishwaraj Anand Date: Fri, 30 May 2025 21:28:35 +0000 Subject: [PATCH 06/15] fix: pr comments --- langchain_postgres/v2/hybrid_search_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/langchain_postgres/v2/hybrid_search_config.py b/langchain_postgres/v2/hybrid_search_config.py index 9e7286e0..3e1a61b5 100644 --- a/langchain_postgres/v2/hybrid_search_config.py +++ b/langchain_postgres/v2/hybrid_search_config.py @@ -128,7 +128,7 @@ def reciprocal_rank_fusion( @dataclass class HybridSearchConfig(ABC): - """Google AlloyDB Vector Store Hybrid Search Config.""" + """AlloyDB Vector Store Hybrid Search Config.""" tsv_column: Optional[str] = "" tsv_lang: Optional[str] = "pg_catalog.english" From 73d4400dd2ab3723fe3422af0839d9517f8159df Mon Sep 17 00:00:00 2001 From: Vishwaraj Anand Date: Fri, 30 May 2025 21:30:00 +0000 Subject: [PATCH 07/15] fix: lint --- tests/unit_tests/v2/test_hybrid_search_config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unit_tests/v2/test_hybrid_search_config.py b/tests/unit_tests/v2/test_hybrid_search_config.py index 780b0068..9c53662e 100644 --- a/tests/unit_tests/v2/test_hybrid_search_config.py +++ b/tests/unit_tests/v2/test_hybrid_search_config.py @@ -5,6 +5,7 @@ weighted_sum_ranking, ) + # Helper to create mock input items that mimic RowMapping for the fusion functions def get_row(doc_id: str, score: float, content: str = "content") -> dict: """ From 57ceb2c26d333a7dece8eb0efdc22829035d38a0 Mon Sep 17 00:00:00 2001 From: Vishwaraj Anand Date: Fri, 30 May 2025 21:46:12 +0000 Subject: [PATCH 08/15] fix: lint --- .../v2/test_hybrid_search_config.py | 67 ++++++++++--------- 1 file changed, 37 insertions(+), 30 deletions(-) diff --git a/tests/unit_tests/v2/test_hybrid_search_config.py b/tests/unit_tests/v2/test_hybrid_search_config.py index 9c53662e..7ea000ef 100644 --- a/tests/unit_tests/v2/test_hybrid_search_config.py +++ b/tests/unit_tests/v2/test_hybrid_search_config.py @@ -21,15 +21,18 @@ def get_row(doc_id: str, score: float, content: str = "content") -> dict: class TestWeightedSumRanking: - def test_empty_inputs(self): + def test_empty_inputs(self) -> None: results = weighted_sum_ranking([], []) assert results == [] - def test_primary_only(self): + def test_primary_only(self) -> None: primary = [get_row("p1", 0.8), get_row("p2", 0.6)] # Expected scores: p1 = 0.8 * 0.5 = 0.4, p2 = 0.6 * 0.5 = 0.3 - results = weighted_sum_ranking( - primary, [], primary_results_weight=0.5, secondary_results_weight=0.5 + results = weighted_sum_ranking( # type: ignore + primary, # type: ignore + [], + primary_results_weight=0.5, + secondary_results_weight=0.5, ) assert len(results) == 2 assert results[0]["id_val"] == "p1" @@ -37,11 +40,14 @@ def test_primary_only(self): assert results[1]["id_val"] == "p2" assert results[1]["distance"] == pytest.approx(0.3) - def test_secondary_only(self): + def test_secondary_only(self) -> None: secondary = [get_row("s1", 0.9), get_row("s2", 0.7)] # Expected scores: s1 = 0.9 * 0.5 = 0.45, s2 = 0.7 * 0.5 = 0.35 results = weighted_sum_ranking( - [], secondary, primary_results_weight=0.5, secondary_results_weight=0.5 + [], + secondary, # type: ignore + primary_results_weight=0.5, + secondary_results_weight=0.5, ) assert len(results) == 2 assert results[0]["id_val"] == "s1" @@ -49,7 +55,7 @@ def test_secondary_only(self): assert results[1]["id_val"] == "s2" assert results[1]["distance"] == pytest.approx(0.35) - def test_mixed_results_default_weights(self): + def test_mixed_results_default_weights(self) -> None: primary = [get_row("common", 0.8), get_row("p_only", 0.7)] secondary = [get_row("common", 0.9), get_row("s_only", 0.6)] # Weights are 0.5, 0.5 @@ -58,7 +64,7 @@ def test_mixed_results_default_weights(self): # s_only_score = (0.6 * 0.5) = 0.30 # Order: common (0.85), p_only (0.35), s_only (0.30) - results = weighted_sum_ranking(primary, secondary) + results = weighted_sum_ranking(primary, secondary) # type: ignore assert len(results) == 3 assert results[0]["id_val"] == "common" assert results[0]["distance"] == pytest.approx(0.85) @@ -67,24 +73,26 @@ def test_mixed_results_default_weights(self): assert results[2]["id_val"] == "s_only" assert results[2]["distance"] == pytest.approx(0.30) - def test_mixed_results_custom_weights(self): + def test_mixed_results_custom_weights(self) -> None: primary = [get_row("d1", 1.0)] # p_w=0.2 -> 0.2 secondary = [get_row("d1", 0.5)] # s_w=0.8 -> 0.4 # Expected: d1_score = (1.0 * 0.2) + (0.5 * 0.8) = 0.2 + 0.4 = 0.6 results = weighted_sum_ranking( - primary, secondary, primary_results_weight=0.2, secondary_results_weight=0.8 + primary, # type: ignore + secondary, # type: ignore + primary_results_weight=0.2, + secondary_results_weight=0.8, ) assert len(results) == 1 assert results[0]["id_val"] == "d1" assert results[0]["distance"] == pytest.approx(0.6) - def test_fetch_top_k(self): + def test_fetch_top_k(self) -> None: primary = [get_row(f"p{i}", (10 - i) / 10.0) for i in range(5)] # Scores: 1.0, 0.9, 0.8, 0.7, 0.6 # Weighted (0.5): 0.5, 0.45, 0.4, 0.35, 0.3 - secondary = [] - results = weighted_sum_ranking(primary, secondary, fetch_top_k=2) + results = weighted_sum_ranking(primary, [], fetch_top_k=2) # type: ignore assert len(results) == 2 assert results[0]["id_val"] == "p0" assert results[0]["distance"] == pytest.approx(0.5) @@ -93,11 +101,11 @@ def test_fetch_top_k(self): class TestReciprocalRankFusion: - def test_empty_inputs(self): + def test_empty_inputs(self) -> None: results = reciprocal_rank_fusion([], []) assert results == [] - def test_primary_only(self): + def test_primary_only(self) -> None: primary = [ get_row("p1", 0.8), get_row("p2", 0.6), @@ -105,34 +113,34 @@ def test_primary_only(self): rrf_k = 60 # p1_score = 1 / (0 + 60) # p2_score = 1 / (1 + 60) - results = reciprocal_rank_fusion(primary, [], rrf_k=rrf_k) + results = reciprocal_rank_fusion(primary, [], rrf_k=rrf_k) # type: ignore assert len(results) == 2 assert results[0]["id_val"] == "p1" assert results[0]["distance"] == pytest.approx(1.0 / (0 + rrf_k)) assert results[1]["id_val"] == "p2" assert results[1]["distance"] == pytest.approx(1.0 / (1 + rrf_k)) - def test_secondary_only(self): + def test_secondary_only(self) -> None: secondary = [ get_row("s1", 0.9), get_row("s2", 0.7), ] # s1 rank 0, s2 rank 1 rrf_k = 60 - results = reciprocal_rank_fusion([], secondary, rrf_k=rrf_k) + results = reciprocal_rank_fusion([], secondary, rrf_k=rrf_k) # type: ignore assert len(results) == 2 assert results[0]["id_val"] == "s1" assert results[0]["distance"] == pytest.approx(1.0 / (0 + rrf_k)) assert results[1]["id_val"] == "s2" assert results[1]["distance"] == pytest.approx(1.0 / (1 + rrf_k)) - def test_mixed_results_default_k(self): + def test_mixed_results_default_k(self) -> None: primary = [get_row("common", 0.8), get_row("p_only", 0.7)] secondary = [get_row("common", 0.9), get_row("s_only", 0.6)] rrf_k = 60 # common_score = (1/(0+k))_prim + (1/(0+k))_sec = 2/k # p_only_score = (1/(1+k))_prim = 1/(k+1) # s_only_score = (1/(1+k))_sec = 1/(k+1) - results = reciprocal_rank_fusion(primary, secondary, rrf_k=rrf_k) + results = reciprocal_rank_fusion(primary, secondary, rrf_k=rrf_k) # type: ignore assert len(results) == 3 assert results[0]["id_val"] == "common" assert results[0]["distance"] == pytest.approx(2.0 / rrf_k) @@ -143,32 +151,31 @@ def test_mixed_results_default_k(self): for score in next_scores: assert score == pytest.approx(1.0 / (1 + rrf_k)) - def test_fetch_top_k_rrf(self): + def test_fetch_top_k_rrf(self) -> None: primary = [get_row(f"p{i}", (10 - i) / 10.0) for i in range(5)] - secondary = [] rrf_k = 1 - results = reciprocal_rank_fusion(primary, secondary, rrf_k=rrf_k, fetch_top_k=2) + results = reciprocal_rank_fusion(primary, [], rrf_k=rrf_k, fetch_top_k=2) # type: ignore assert len(results) == 2 assert results[0]["id_val"] == "p0" assert results[0]["distance"] == pytest.approx(1.0 / (0 + rrf_k)) assert results[1]["id_val"] == "p1" assert results[1]["distance"] == pytest.approx(1.0 / (1 + rrf_k)) - def test_rrf_content_preservation(self): + def test_rrf_content_preservation(self) -> None: primary = [get_row("doc1", 0.9, content="Primary Content")] secondary = [get_row("doc1", 0.8, content="Secondary Content")] # RRF processes primary then secondary. If a doc is in both, # the content from the secondary list will overwrite primary's. - results = reciprocal_rank_fusion(primary, secondary, rrf_k=60) + results = reciprocal_rank_fusion(primary, secondary, rrf_k=60) # type: ignore assert len(results) == 1 assert results[0]["id_val"] == "doc1" assert results[0]["content_field"] == "Secondary Content" # If only in primary - results_prim_only = reciprocal_rank_fusion(primary, [], rrf_k=60) + results_prim_only = reciprocal_rank_fusion(primary, [], rrf_k=60) # type: ignore assert results_prim_only[0]["content_field"] == "Primary Content" - def test_reordering_from_inputs_rrf(self): + def test_reordering_from_inputs_rrf(self) -> None: """ Tests that RRF fused ranking can be different from both primary and secondary input rankings. @@ -190,7 +197,7 @@ def test_reordering_from_inputs_rrf(self): # docA_score = 1/(0+1) [P] + 1/(2+1) [S] = 1 + 1/3 = 4/3 # docB_score = 1/(1+1) [P] + 1/(1+1) [S] = 1/2 + 1/2 = 1 # docC_score = 1/(2+1) [P] + 1/(0+1) [S] = 1/3 + 1 = 4/3 - results = reciprocal_rank_fusion(primary, secondary, rrf_k=rrf_k) + results = reciprocal_rank_fusion(primary, secondary, rrf_k=rrf_k) # type: ignore assert len(results) == 3 assert {results[0]["id_val"], results[1]["id_val"]} == {"docA", "docC"} assert results[0]["distance"] == pytest.approx(4.0 / 3.0) @@ -198,7 +205,7 @@ def test_reordering_from_inputs_rrf(self): assert results[2]["id_val"] == "docB" assert results[2]["distance"] == pytest.approx(1.0) - def test_reordering_from_inputs_weighted_sum(self): + def test_reordering_from_inputs_weighted_sum(self) -> None: """ Tests that the fused ranking can be different from both primary and secondary input rankings. @@ -214,7 +221,7 @@ def test_reordering_from_inputs_weighted_sum(self): primary = [get_row("docA", 0.9), get_row("docB", 0.7)] secondary = [get_row("docB", 0.8), get_row("docA", 0.2)] - results = weighted_sum_ranking(primary, secondary) + results = weighted_sum_ranking(primary, secondary) # type: ignore assert len(results) == 2 assert results[0]["id_val"] == "docB" assert results[0]["distance"] == pytest.approx(0.75) From ef349a3f0c1281ffeba48332fe1c28a5caea2565 Mon Sep 17 00:00:00 2001 From: Vishwaraj Anand Date: Fri, 30 May 2025 21:57:45 +0000 Subject: [PATCH 09/15] pr comment: add disclaimer on slow query on config docstring --- langchain_postgres/v2/hybrid_search_config.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/langchain_postgres/v2/hybrid_search_config.py b/langchain_postgres/v2/hybrid_search_config.py index 3e1a61b5..7f6c2778 100644 --- a/langchain_postgres/v2/hybrid_search_config.py +++ b/langchain_postgres/v2/hybrid_search_config.py @@ -128,7 +128,13 @@ def reciprocal_rank_fusion( @dataclass class HybridSearchConfig(ABC): - """AlloyDB Vector Store Hybrid Search Config.""" + """ + AlloyDB Vector Store Hybrid Search Config. + + Queries might be slow if the hybrid search column does not exist. + For best hybrid search performance, consider creating a TSV column + and adding GIN index. + """ tsv_column: Optional[str] = "" tsv_lang: Optional[str] = "pg_catalog.english" From ceabf10d05ac91a4e08d08cff54522f9fa8e1d7f Mon Sep 17 00:00:00 2001 From: Vishwaraj Anand Date: Fri, 30 May 2025 22:00:53 +0000 Subject: [PATCH 10/15] pr comment: add disclaimer in engine table create --- langchain_postgres/v2/engine.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/langchain_postgres/v2/engine.py b/langchain_postgres/v2/engine.py index 755b5ff0..09e5076f 100644 --- a/langchain_postgres/v2/engine.py +++ b/langchain_postgres/v2/engine.py @@ -300,6 +300,8 @@ async def ainit_vectorstore_table( store_metadata (bool): Whether to store metadata in the table. Default: True. hybrid_search_config (HybridSearchConfig): Hybrid search configuration. + Note that queries might be slow if the hybrid search column does not exist. + For best hybrid search performance, consider creating a TSV column and adding GIN index. Default: None. """ await self._run_as_async( @@ -355,6 +357,8 @@ def init_vectorstore_table( store_metadata (bool): Whether to store metadata in the table. Default: True. hybrid_search_config (HybridSearchConfig): Hybrid search configuration. + Note that queries might be slow if the hybrid search column does not exist. + For best hybrid search performance, consider creating a TSV column and adding GIN index. Default: None. """ self._run_as_sync( From 8a39e61ab50f77196777f21c30ec83063401b199 Mon Sep 17 00:00:00 2001 From: Vishwaraj Anand Date: Mon, 2 Jun 2025 17:14:31 +0000 Subject: [PATCH 11/15] feat: address pr comments --- langchain_postgres/v2/async_vectorstore.py | 94 +++++----- .../v2/test_async_pg_vectorstore_index.py | 70 ++++--- .../v2/test_async_pg_vectorstore_search.py | 174 ++++++++++++------ 3 files changed, 203 insertions(+), 135 deletions(-) diff --git a/langchain_postgres/v2/async_vectorstore.py b/langchain_postgres/v2/async_vectorstore.py index d1451dc1..264fd346 100644 --- a/langchain_postgres/v2/async_vectorstore.py +++ b/langchain_postgres/v2/async_vectorstore.py @@ -15,9 +15,14 @@ from .engine import PGEngine from .hybrid_search_config import HybridSearchConfig -from .indexes import (DEFAULT_DISTANCE_STRATEGY, DEFAULT_INDEX_NAME_SUFFIX, - BaseIndex, DistanceStrategy, ExactNearestNeighbor, - QueryOptions) +from .indexes import ( + DEFAULT_DISTANCE_STRATEGY, + DEFAULT_INDEX_NAME_SUFFIX, + BaseIndex, + DistanceStrategy, + ExactNearestNeighbor, + QueryOptions, +) COMPARISONS_TO_NATIVE = { "$eq": "=", @@ -74,7 +79,6 @@ def __init__( lambda_mult: float = 0.5, index_query_options: Optional[QueryOptions] = None, hybrid_search_config: Optional[HybridSearchConfig] = None, - hybrid_search_column_exists: bool = False, ): """AsyncPGVectorStore constructor. Args: @@ -94,7 +98,6 @@ def __init__( lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. index_query_options (QueryOptions): Index query option. hybrid_search_config (HybridSearchConfig): Hybrid search configuration. Defaults to None. - hybrid_search_column_exists (bool): Defines whether the existing table has the hybrid search column. Raises: @@ -120,7 +123,6 @@ def __init__( self.lambda_mult = lambda_mult self.index_query_options = index_query_options self.hybrid_search_config = hybrid_search_config - self.hybrid_search_column_exists = hybrid_search_column_exists @classmethod async def create( @@ -197,17 +199,15 @@ async def create( raise ValueError( f"Content column, {content_column}, is type, {content_type}. It must be a type of character string." ) - hybrid_search_column_exists = False if hybrid_search_config: tsv_column_name = ( hybrid_search_config.tsv_column if hybrid_search_config.tsv_column else content_column + "_tsv" ) - hybrid_search_config.tsv_column = tsv_column_name - hybrid_search_column_exists = ( - tsv_column_name in columns and columns[tsv_column_name] == "tsvector" - ) + if tsv_column_name not in columns or columns[tsv_column_name] != "tsvector": + # mark tsv_column as empty because there is no TSV column in table + hybrid_search_config.tsv_column = "" if embedding_column not in columns: raise ValueError(f"Embedding column, {embedding_column}, does not exist.") if columns[embedding_column] != "USER-DEFINED": @@ -252,7 +252,6 @@ async def create( lambda_mult=lambda_mult, index_query_options=index_query_options, hybrid_search_config=hybrid_search_config, - hybrid_search_column_exists=hybrid_search_column_exists, ) @property @@ -292,7 +291,7 @@ async def aadd_embeddings( ) hybrid_search_column = ( f', "{self.hybrid_search_config.tsv_column}"' - if self.hybrid_search_config and self.hybrid_search_column_exists + if self.hybrid_search_config and self.hybrid_search_config.tsv_column else "" ) insert_stmt = f'INSERT INTO "{self.schema_name}"."{self.table_name}"("{self.id_column}", "{self.content_column}", "{self.embedding_column}"{hybrid_search_column}{metadata_col_names}' @@ -306,7 +305,7 @@ async def aadd_embeddings( if not embedding and can_inline_embed: values_stmt = f"VALUES (:id, :content, {self.embedding_service.embed_query_inline(content)}" # type: ignore - if self.hybrid_search_config and self.hybrid_search_column_exists: + if self.hybrid_search_config and self.hybrid_search_config.tsv_column: lang = ( f"'{self.hybrid_search_config.tsv_lang}'," if self.hybrid_search_config.tsv_lang @@ -338,7 +337,7 @@ async def aadd_embeddings( upsert_stmt = f' ON CONFLICT ("{self.id_column}") DO UPDATE SET "{self.content_column}" = EXCLUDED."{self.content_column}", "{self.embedding_column}" = EXCLUDED."{self.embedding_column}"' - if self.hybrid_search_config and self.hybrid_search_column_exists: + if self.hybrid_search_config and self.hybrid_search_config.tsv_column: upsert_stmt += f', "{self.hybrid_search_config.tsv_column}" = EXCLUDED."{self.hybrid_search_config.tsv_column}"' if self.metadata_json_column: @@ -464,7 +463,6 @@ async def afrom_texts( # type: ignore[override] fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. index_query_options (QueryOptions): Index query option. - hybrid_search_column_exists (bool): Defines whether the existing table has the hybrid search column. Raises: :class:`InvalidTextRepresentationError `: if the `ids` data type does not match that of the `id_column`. @@ -537,7 +535,6 @@ async def afrom_documents( # type: ignore[override] fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. index_query_options (QueryOptions): Index query option. - hybrid_search_column_exists (bool): Defines whether the existing table has the hybrid search column. Raises: :class:`InvalidTextRepresentationError `: if the `ids` data type does not match that of the `id_column`. @@ -610,8 +607,6 @@ async def __query_collection( filter_dict = None if filter and isinstance(filter, dict): safe_filter, filter_dict = self._create_filter_clause(filter) - where_filters = f"WHERE {safe_filter}" if safe_filter else "" - and_filters = f"AND ({safe_filter})" if safe_filter else "" inline_embed_func = getattr(self.embedding_service, "embed_query_inline", None) if not embedding and callable(inline_embed_func) and "query" in kwargs: @@ -620,6 +615,7 @@ async def __query_collection( else: query_embedding = f"{[float(dimension) for dimension in embedding]}" embedding_data_string = ":query_embedding" + where_filters = f"WHERE {safe_filter}" if safe_filter else "" dense_query_stmt = f"""SELECT {column_names}, {search_function}("{self.embedding_column}", {embedding_data_string}) as distance FROM "{self.schema_name}"."{self.table_name}" {where_filters} ORDER BY "{self.embedding_column}" {operator} {embedding_data_string} LIMIT :k; """ @@ -659,10 +655,11 @@ async def __query_collection( ) query_tsv = f"plainto_tsquery({lang} :fts_query)" param_dict["fts_query"] = fts_query - if self.hybrid_search_column_exists: + if hybrid_search_config.tsv_column: content_tsv = f'"{hybrid_search_config.tsv_column}"' else: content_tsv = f'to_tsvector({lang} "{self.content_column}")' + and_filters = f"AND ({safe_filter})" if safe_filter else "" sparse_query_stmt = f'SELECT {column_names}, ts_rank_cd({content_tsv}, {query_tsv}) as distance FROM "{self.schema_name}"."{self.table_name}" WHERE {content_tsv} @@ {query_tsv} {and_filters} ORDER BY distance desc LIMIT {hybrid_search_config.secondary_top_k};' async with self.engine.connect() as conn: result = await conn.execute(text(sparse_query_stmt), param_dict) @@ -884,6 +881,41 @@ async def amax_marginal_relevance_search_with_score_by_vector( return [r for i, r in enumerate(documents_with_scores) if i in mmr_selected] + async def aapply_hybrid_search_index( + self, + concurrently: bool = False, + ) -> None: + """Creates a TSV index in the vector store table if possible.""" + if ( + not self.hybrid_search_config + or not self.hybrid_search_config.index_type + or not self.hybrid_search_config.index_name + ): + # no index needs to be created + raise ValueError("Hybrid Search Config cannot create index.") + + lang = ( + f"'{self.hybrid_search_config.tsv_lang}'," + if self.hybrid_search_config.tsv_lang + else "" + ) + tsv_column_name = ( + self.hybrid_search_config.tsv_column + if self.hybrid_search_config.tsv_column + else f"to_tsvector({lang} {self.content_column})" + ) + tsv_index_query = f'CREATE INDEX {"CONCURRENTLY" if concurrently else ""} {self.hybrid_search_config.index_name} ON "{self.schema_name}"."{self.table_name}" USING {self.hybrid_search_config.index_type}({tsv_column_name});' + if concurrently: + async with self.engine.connect() as conn: + autocommit_conn = await conn.execution_options( + isolation_level="AUTOCOMMIT" + ) + await autocommit_conn.execute(text(tsv_index_query)) + else: + async with self.engine.connect() as conn: + await conn.execute(text(tsv_index_query)) + await conn.commit() + async def aapply_vector_index( self, index: BaseIndex, @@ -913,37 +945,15 @@ async def aapply_vector_index( name = index.name stmt = f'CREATE INDEX {"CONCURRENTLY" if concurrently else ""} "{name}" ON "{self.schema_name}"."{self.table_name}" USING {index.index_type} ({self.embedding_column} {function}) {params} {filter};' - if self.hybrid_search_config: - if self.hybrid_search_column_exists: - tsv_column_name = ( - self.hybrid_search_config.tsv_column - if self.hybrid_search_config.tsv_column - else self.content_column + "_tsv" - ) - tsv_column_name = f'"{tsv_column_name}"' - else: - lang = ( - f"'{self.hybrid_search_config.tsv_lang}'," - if self.hybrid_search_config.tsv_lang - else "" - ) - tsv_column_name = f"to_tsvector({lang} {self.content_column})" - tsv_index_name = self.table_name + self.hybrid_search_config.index_name - tsv_index_query = f'CREATE INDEX {"CONCURRENTLY" if concurrently else ""} {tsv_index_name} ON "{self.schema_name}"."{self.table_name}" USING {self.hybrid_search_config.index_type}({tsv_column_name});' - else: - tsv_index_query = "" - if concurrently: async with self.engine.connect() as conn: autocommit_conn = await conn.execution_options( isolation_level="AUTOCOMMIT" ) await autocommit_conn.execute(text(stmt)) - await conn.execute(text(tsv_index_query)) else: async with self.engine.connect() as conn: await conn.execute(text(stmt)) - await conn.execute(text(tsv_index_query)) await conn.commit() async def areindex(self, index_name: Optional[str] = None) -> None: diff --git a/tests/unit_tests/v2/test_async_pg_vectorstore_index.py b/tests/unit_tests/v2/test_async_pg_vectorstore_index.py index 052aebe1..8585bcd0 100644 --- a/tests/unit_tests/v2/test_async_pg_vectorstore_index.py +++ b/tests/unit_tests/v2/test_async_pg_vectorstore_index.py @@ -11,8 +11,7 @@ from langchain_postgres import PGEngine from langchain_postgres.v2.async_vectorstore import AsyncPGVectorStore from langchain_postgres.v2.hybrid_search_config import HybridSearchConfig -from langchain_postgres.v2.indexes import (DistanceStrategy, HNSWIndex, - IVFFlatIndex) +from langchain_postgres.v2.indexes import DistanceStrategy, HNSWIndex, IVFFlatIndex from tests.utils import VECTORSTORE_CONNECTION_STRING as CONNECTION_STRING uuid_str = str(uuid.uuid4()).replace("-", "_") @@ -55,9 +54,9 @@ async def engine(self) -> AsyncIterator[PGEngine]: engine = PGEngine.from_connection_string(url=CONNECTION_STRING) yield engine - engine._adrop_table(DEFAULT_TABLE) - engine._adrop_table(DEFAULT_HYBRID_TABLE) - engine._adrop_table(SIMPLE_TABLE) + await engine._adrop_table(DEFAULT_TABLE) + await engine._adrop_table(DEFAULT_HYBRID_TABLE) + await engine._adrop_table(SIMPLE_TABLE) await engine.close() @pytest_asyncio.fixture(scope="class") @@ -74,7 +73,9 @@ async def vs(self, engine: PGEngine) -> AsyncIterator[AsyncPGVectorStore]: yield vs async def test_apply_default_name_vector_index(self, engine: PGEngine) -> None: - await engine._ainit_vectorstore_table(SIMPLE_TABLE, VECTOR_SIZE) + await engine._ainit_vectorstore_table( + SIMPLE_TABLE, VECTOR_SIZE, overwrite_existing=True + ) vs = await AsyncPGVectorStore.create( engine, embedding_service=embeddings_service, @@ -93,65 +94,58 @@ async def test_aapply_vector_index(self, vs: AsyncPGVectorStore) -> None: assert await vs.is_valid_index(DEFAULT_INDEX_NAME) await vs.adrop_vector_index(DEFAULT_INDEX_NAME) - async def test_aapply_vector_index_non_hybrid_search_vs_without_tsv_column( - self, vs + async def test_aapply_vector_index_non_hybrid_search_vs( + self, vs: AsyncPGVectorStore ) -> None: - index = HNSWIndex(name="test_index_hybrid" + uuid_str) - - tsv_index_name = DEFAULT_TABLE + "langchain_tsv_index" - is_valid_index = await vs.is_valid_index(tsv_index_name) - assert is_valid_index == False - await vs.aapply_vector_index(index) - is_valid_index = await vs.is_valid_index(tsv_index_name) - assert is_valid_index == False - await vs.adrop_vector_index(tsv_index_name) - is_valid_index = await vs.is_valid_index(tsv_index_name) - assert is_valid_index == False + with pytest.raises(ValueError): + await vs.aapply_hybrid_search_index() - async def test_aapply_vector_index_hybrid_search_vs_without_tsv_column( - self, engine, vs + async def test_aapply_hybrid_search_index_table_without_tsv_column( + self, engine: PGEngine, vs: AsyncPGVectorStore ) -> None: # overwriting vs to get a hybrid vs + tsv_index_name = "tsv_index_on_table_without_tsv_column_" + uuid_str vs = await AsyncPGVectorStore.create( engine, embedding_service=embeddings_service, table_name=DEFAULT_TABLE, - hybrid_search_config=HybridSearchConfig(), + hybrid_search_config=HybridSearchConfig(index_name=tsv_index_name), ) - index = HNSWIndex(name="test_index_hybrid" + uuid_str) - - tsv_index_name = DEFAULT_TABLE + "langchain_tsv_index" is_valid_index = await vs.is_valid_index(tsv_index_name) assert is_valid_index == False - await vs.adrop_vector_index(index.name) - await vs.aapply_vector_index(index) - assert await vs.is_valid_index(tsv_index_name) - await vs.areindex(tsv_index_name) + await vs.aapply_hybrid_search_index() assert await vs.is_valid_index(tsv_index_name) await vs.adrop_vector_index(tsv_index_name) is_valid_index = await vs.is_valid_index(tsv_index_name) assert is_valid_index == False - await vs.adrop_vector_index(index.name) - async def test_aapply_vector_index_hybrid_search_with_tsv_column( - self, engine + async def test_aapply_hybrid_search_index_table_with_tsv_column( + self, engine: PGEngine ) -> None: + tsv_index_name = "tsv_index_on_table_without_tsv_column_" + uuid_str + config = HybridSearchConfig( + tsv_column="tsv_column", + tsv_lang="pg_catalog.english", + index_name=tsv_index_name, + ) await engine._ainit_vectorstore_table( - DEFAULT_HYBRID_TABLE, VECTOR_SIZE, hybrid_search_config=HybridSearchConfig() + DEFAULT_HYBRID_TABLE, + VECTOR_SIZE, + hybrid_search_config=config, ) vs = await AsyncPGVectorStore.create( engine, embedding_service=embeddings_service, table_name=DEFAULT_HYBRID_TABLE, - hybrid_search_config=HybridSearchConfig(), + hybrid_search_config=config, ) - tsv_index_name = DEFAULT_HYBRID_TABLE + "langchain_tsv_index" is_valid_index = await vs.is_valid_index(tsv_index_name) assert is_valid_index == False - index = HNSWIndex(name=DEFAULT_INDEX_NAME) - await vs.aapply_vector_index(index) + await vs.aapply_hybrid_search_index() + assert await vs.is_valid_index(tsv_index_name) + await vs.areindex(tsv_index_name) + assert await vs.is_valid_index(tsv_index_name) await vs.adrop_vector_index(tsv_index_name) - await vs.adrop_vector_index(index.name) is_valid_index = await vs.is_valid_index(tsv_index_name) assert is_valid_index == False diff --git a/tests/unit_tests/v2/test_async_pg_vectorstore_search.py b/tests/unit_tests/v2/test_async_pg_vectorstore_search.py index eb65da24..c37be1a0 100644 --- a/tests/unit_tests/v2/test_async_pg_vectorstore_search.py +++ b/tests/unit_tests/v2/test_async_pg_vectorstore_search.py @@ -10,12 +10,16 @@ from langchain_postgres import Column, PGEngine from langchain_postgres.v2.async_vectorstore import AsyncPGVectorStore -from langchain_postgres.v2.hybrid_search_config import (HybridSearchConfig, - reciprocal_rank_fusion, - weighted_sum_ranking) +from langchain_postgres.v2.hybrid_search_config import ( + HybridSearchConfig, + reciprocal_rank_fusion, + weighted_sum_ranking, +) from langchain_postgres.v2.indexes import DistanceStrategy, HNSWQueryOptions from tests.unit_tests.fixtures.metadata_filtering_data import ( - FILTERING_TEST_CASES, METADATAS) + FILTERING_TEST_CASES, + METADATAS, +) from tests.utils import VECTORSTORE_CONNECTION_STRING as CONNECTION_STRING DEFAULT_TABLE = "default" + str(uuid.uuid4()).replace("-", "_") @@ -129,9 +133,11 @@ async def vs_custom(self, engine: PGEngine) -> AsyncIterator[AsyncPGVectorStore] yield vs_custom @pytest_asyncio.fixture(scope="class") - async def vs_hybrid_search_without_tsv_column(self, engine): + async def vs_hybrid_search_without_tsv_column( + self, engine: PGEngine + ) -> AsyncIterator[AsyncPGVectorStore]: hybrid_search_config = HybridSearchConfig( - tsv_column="my_tsv_col", + tsv_column="", tsv_lang="pg_catalog.english", fts_query="my_fts_query", fusion_function=reciprocal_rank_fusion, @@ -170,37 +176,6 @@ async def vs_hybrid_search_without_tsv_column(self, engine): await vs_custom.aadd_documents(hybrid_docs) yield vs_custom - @pytest_asyncio.fixture(scope="class") - async def vs_hybrid_search_with_tsv_column(self, engine): - await engine._ainit_vectorstore_table( - HYBRID_SEARCH_TABLE2, - VECTOR_SIZE, - id_column=Column("myid", "TEXT"), - content_column="mycontent", - embedding_column="myembedding", - metadata_columns=[ - Column("page", "TEXT"), - Column("source", "TEXT"), - Column("doc_id_key", "TEXT"), - ], - store_metadata=False, - hybrid_search_config=HybridSearchConfig(), - ) - - vs_custom = await AsyncPGVectorStore.create( - engine, - embedding_service=embeddings_service, - table_name=HYBRID_SEARCH_TABLE2, - id_column="myid", - content_column="mycontent", - embedding_column="myembedding", - metadata_columns=["doc_id_key"], - index_query_options=HNSWQueryOptions(ef_search=1), - hybrid_search_config=HybridSearchConfig(), - ) - await vs_custom.aadd_documents(hybrid_docs) - yield vs_custom - @pytest_asyncio.fixture(scope="class") async def vs_custom_filter( self, engine: PGEngine @@ -394,7 +369,7 @@ async def test_vectorstore_with_metadata_filters( ) assert [doc.metadata["code"] for doc in docs] == expected_ids, test_filter - async def test_asimilarity_hybrid_search(self, vs): + async def test_asimilarity_hybrid_search(self, vs: AsyncPGVectorStore) -> None: results = await vs.asimilarity_search( "foo", k=1, hybrid_search_config=HybridSearchConfig() ) @@ -425,7 +400,7 @@ async def test_asimilarity_hybrid_search(self, vs): ) assert results == [Document(page_content="foo", id=ids[0])] - async def test_asimilarity_hybrid_search_rrk(self, vs): + async def test_asimilarity_hybrid_search_rrk(self, vs: AsyncPGVectorStore) -> None: results = await vs.asimilarity_search( "foo", k=1, @@ -453,8 +428,9 @@ async def test_asimilarity_hybrid_search_rrk(self, vs): assert results == [Document(page_content="bar", id=ids[1])] async def test_hybrid_search_weighted_sum_default( - self, vs_hybrid_search_with_tsv_column - ): + self, + vs_hybrid_search_with_tsv_column: AsyncPGVectorStore, + ) -> None: """Test hybrid search with default weighted sum (0.5 vector, 0.5 FTS).""" query = "apple" # Should match "apple" in FTS and vector @@ -482,8 +458,8 @@ async def test_hybrid_search_weighted_sum_default( async def test_hybrid_search_weighted_sum_vector_bias( self, - vs_hybrid_search_without_tsv_column, - ): + vs_hybrid_search_with_tsv_column: AsyncPGVectorStore, + ) -> None: """Test weighted sum with higher weight for vector results.""" query = "Apple Inc technology" # More specific for vector similarity @@ -495,7 +471,7 @@ async def test_hybrid_search_weighted_sum_vector_bias( }, # fts_query will default to main query ) - results = await vs_hybrid_search_without_tsv_column.asimilarity_search( + results = await vs_hybrid_search_with_tsv_column.asimilarity_search( query, k=2, hybrid_search_config=config ) result_ids = [doc.metadata["doc_id_key"] for doc in results] @@ -505,8 +481,8 @@ async def test_hybrid_search_weighted_sum_vector_bias( async def test_hybrid_search_weighted_sum_fts_bias( self, - vs_hybrid_search_with_tsv_column, - ): + vs_hybrid_search_with_tsv_column: AsyncPGVectorStore, + ) -> None: """Test weighted sum with higher weight for FTS results.""" query = "fruit common tasty" # Strong FTS signal for fruit docs @@ -528,8 +504,8 @@ async def test_hybrid_search_weighted_sum_fts_bias( async def test_hybrid_search_reciprocal_rank_fusion( self, - vs_hybrid_search_with_tsv_column, - ): + vs_hybrid_search_with_tsv_column: AsyncPGVectorStore, + ) -> None: """Test hybrid search with Reciprocal Rank Fusion.""" query = "technology company" @@ -560,8 +536,8 @@ async def test_hybrid_search_reciprocal_rank_fusion( assert result_ids[0] == "hs_doc_apple_tech" # Stronger combined signal async def test_hybrid_search_explicit_fts_query( - self, vs_hybrid_search_with_tsv_column - ): + self, vs_hybrid_search_with_tsv_column: AsyncPGVectorStore + ) -> None: """Test hybrid search when fts_query in HybridSearchConfig is different from main query.""" main_vector_query = "Apple Inc." # For vector search fts_specific_query = "fruit" # For FTS @@ -590,7 +566,9 @@ async def test_hybrid_search_explicit_fts_query( or "hs_doc_orange_fruit" in result_ids ) - async def test_hybrid_search_with_filter(self, vs_hybrid_search_with_tsv_column): + async def test_hybrid_search_with_filter( + self, vs_hybrid_search_with_tsv_column: AsyncPGVectorStore + ) -> None: """Test hybrid search with a metadata filter applied.""" query = "apple" # Filter to only include "tech" related apple docs using metadata @@ -609,8 +587,8 @@ async def test_hybrid_search_with_filter(self, vs_hybrid_search_with_tsv_column) assert result_ids[0] == "hs_doc_apple_tech" async def test_hybrid_search_fts_empty_results( - self, vs_hybrid_search_with_tsv_column - ): + self, vs_hybrid_search_with_tsv_column: AsyncPGVectorStore + ) -> None: """Test when FTS query yields no results, should fall back to vector search.""" vector_query = "apple" no_match_fts_query = "zzyyxx_gibberish_term_for_fts_nomatch" @@ -636,8 +614,8 @@ async def test_hybrid_search_fts_empty_results( async def test_hybrid_search_vector_empty_results_effectively( self, - vs_hybrid_search_with_tsv_column, - ): + vs_hybrid_search_with_tsv_column: AsyncPGVectorStore, + ) -> None: """Test when vector query is very dissimilar to docs, should rely on FTS.""" # This is hard to guarantee with fake embeddings, but we try. # A better way might be to use a filter that excludes all docs for the vector part, @@ -661,3 +639,89 @@ async def test_hybrid_search_vector_empty_results_effectively( # Expect results based purely on FTS search for "orange fruit" assert len(result_ids) == 1 assert result_ids[0] == "hs_doc_generic_tech" + + async def test_hybrid_search_without_tsv_column( + self, + engine: PGEngine, + ) -> None: + """Test hybrid search without a TSV column.""" + # This is hard to guarantee with fake embeddings, but we try. + # A better way might be to use a filter that excludes all docs for the vector part, + # but filters are applied to both. + vector_query_far_off = "apple iphone tech is better designed than macs" + fts_query_match = "apple fruit" + + config = HybridSearchConfig( + tsv_column="mycontent_tsv", + fts_query=fts_query_match, + fusion_function_parameters={ + "primary_results_weight": 0.1, + "secondary_results_weight": 0.9, + }, + ) + await engine._ainit_vectorstore_table( + HYBRID_SEARCH_TABLE2, + VECTOR_SIZE, + id_column=Column("myid", "TEXT"), + content_column="mycontent", + embedding_column="myembedding", + metadata_columns=[ + Column("page", "TEXT"), + Column("source", "TEXT"), + Column("doc_id_key", "TEXT"), + ], + store_metadata=False, + hybrid_search_config=config, + ) + + vs_with_tsv_column = await AsyncPGVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=HYBRID_SEARCH_TABLE2, + id_column="myid", + content_column="mycontent", + embedding_column="myembedding", + metadata_columns=["doc_id_key"], + index_query_options=HNSWQueryOptions(ef_search=1), + hybrid_search_config=config, + ) + await vs_with_tsv_column.aadd_documents(hybrid_docs) + + config = HybridSearchConfig( + tsv_column="", # no TSV column + fts_query=fts_query_match, + fusion_function_parameters={ + "primary_results_weight": 0.9, + "secondary_results_weight": 0.1, + }, + ) + vs_without_tsv_column = await AsyncPGVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=HYBRID_SEARCH_TABLE2, + id_column="myid", + content_column="mycontent", + embedding_column="myembedding", + metadata_columns=["doc_id_key"], + index_query_options=HNSWQueryOptions(ef_search=1), + hybrid_search_config=config, + ) + + results_with_tsv_column = await vs_with_tsv_column.asimilarity_search( + vector_query_far_off, k=1, hybrid_search_config=config + ) + results_without_tsv_column = await vs_without_tsv_column.asimilarity_search( + vector_query_far_off, k=1, hybrid_search_config=config + ) + result_ids_with_tsv_column = [ + doc.metadata["doc_id_key"] for doc in results_with_tsv_column + ] + result_ids_without_tsv_column = [ + doc.metadata["doc_id_key"] for doc in results_without_tsv_column + ] + + # Expect results based purely on FTS search for "orange fruit" + assert len(result_ids_with_tsv_column) == 1 + assert len(result_ids_without_tsv_column) == 1 + assert result_ids_with_tsv_column[0] == "hs_doc_apple_tech" + assert result_ids_without_tsv_column[0] == "hs_doc_apple_tech" From 6854ee0fe08ffee827e593051750402c5db676b4 Mon Sep 17 00:00:00 2001 From: Vishwaraj Anand Date: Mon, 2 Jun 2025 17:24:24 +0000 Subject: [PATCH 12/15] fix: tsv column name in tests --- tests/unit_tests/v2/test_async_pg_vectorstore_search.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/unit_tests/v2/test_async_pg_vectorstore_search.py b/tests/unit_tests/v2/test_async_pg_vectorstore_search.py index c37be1a0..16c70fdd 100644 --- a/tests/unit_tests/v2/test_async_pg_vectorstore_search.py +++ b/tests/unit_tests/v2/test_async_pg_vectorstore_search.py @@ -133,11 +133,11 @@ async def vs_custom(self, engine: PGEngine) -> AsyncIterator[AsyncPGVectorStore] yield vs_custom @pytest_asyncio.fixture(scope="class") - async def vs_hybrid_search_without_tsv_column( + async def vs_hybrid_search_with_tsv_column( self, engine: PGEngine ) -> AsyncIterator[AsyncPGVectorStore]: hybrid_search_config = HybridSearchConfig( - tsv_column="", + tsv_column="mycontent_tsv", tsv_lang="pg_catalog.english", fts_query="my_fts_query", fusion_function=reciprocal_rank_fusion, @@ -159,6 +159,7 @@ async def vs_hybrid_search_without_tsv_column( ], metadata_json_column="mymetadata", # ignored store_metadata=False, + hybrid_search_config=hybrid_search_config, ) vs_custom = await AsyncPGVectorStore.create( From 5bf1a4bf9da0d7c025d08889f1e8a096f7f1695a Mon Sep 17 00:00:00 2001 From: Vishwaraj Anand Date: Mon, 2 Jun 2025 17:24:40 +0000 Subject: [PATCH 13/15] fix: add if exists in drop to avoid failures --- langchain_postgres/v2/engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/langchain_postgres/v2/engine.py b/langchain_postgres/v2/engine.py index 09e5076f..6067ba23 100644 --- a/langchain_postgres/v2/engine.py +++ b/langchain_postgres/v2/engine.py @@ -384,7 +384,7 @@ async def _adrop_table( schema_name: str = "public", ) -> None: """Drop the vector store table""" - query = f'DROP TABLE "{schema_name}"."{table_name}";' + query = f'DROP TABLE IF EXISTS "{schema_name}"."{table_name}";' async with self._pool.connect() as conn: await conn.execute(text(query)) await conn.commit() From e092c826133c1cdde1f9d02598796a8777952d91 Mon Sep 17 00:00:00 2001 From: Vishwaraj Anand Date: Mon, 2 Jun 2025 17:30:26 +0000 Subject: [PATCH 14/15] fix: tests --- tests/unit_tests/v2/test_pg_vectorstore_search.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit_tests/v2/test_pg_vectorstore_search.py b/tests/unit_tests/v2/test_pg_vectorstore_search.py index d783a856..53382074 100644 --- a/tests/unit_tests/v2/test_pg_vectorstore_search.py +++ b/tests/unit_tests/v2/test_pg_vectorstore_search.py @@ -460,4 +460,4 @@ def test_similarity_hybrid_search(self, vs_custom): fusion_function=reciprocal_rank_fusion ), ) - assert results == [Document(page_content="baz", id=ids[2])] + assert results == [Document(page_content="foo", id=ids[0])] From 0d223fd2858d986670cbb48a5fa0c588ab11bff2 Mon Sep 17 00:00:00 2001 From: Vishwaraj Anand Date: Tue, 3 Jun 2025 06:47:35 +0000 Subject: [PATCH 15/15] chore: fix lint --- tests/unit_tests/v2/test_pg_vectorstore_search.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit_tests/v2/test_pg_vectorstore_search.py b/tests/unit_tests/v2/test_pg_vectorstore_search.py index 53382074..7815a25a 100644 --- a/tests/unit_tests/v2/test_pg_vectorstore_search.py +++ b/tests/unit_tests/v2/test_pg_vectorstore_search.py @@ -266,7 +266,7 @@ async def test_vectorstore_with_metadata_filters( ) assert [doc.metadata["code"] for doc in docs] == expected_ids, test_filter - async def test_asimilarity_hybrid_search(self, vs: PGVectorStore): + async def test_asimilarity_hybrid_search(self, vs: PGVectorStore) -> None: results = await vs.asimilarity_search( "foo", k=1, hybrid_search_config=HybridSearchConfig() ) @@ -438,7 +438,7 @@ def test_metadata_filter_negative_tests( "meow", k=5, filter=test_filter ) - def test_similarity_hybrid_search(self, vs_custom): + def test_similarity_hybrid_search(self, vs_custom: PGVectorStore) -> None: results = vs_custom.similarity_search( "foo", k=1, hybrid_search_config=HybridSearchConfig() )