Skip to content

feat: adds hybrid search for sync VS interface [4/N] #417

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/langchain_google_alloydb_pg/vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from .async_vectorstore import AsyncAlloyDBVectorStore
from .engine import AlloyDBEngine
from .hybrid_search_config import HybridSearchConfig
from .indexes import (
DEFAULT_DISTANCE_STRATEGY,
BaseIndex,
Expand Down Expand Up @@ -73,6 +74,7 @@ async def create(
fetch_k: int = 20,
lambda_mult: float = 0.5,
index_query_options: Optional[QueryOptions] = None,
hybrid_search_config: Optional[HybridSearchConfig] = None,
) -> AlloyDBVectorStore:
"""Create an AlloyDBVectorStore instance.

Expand All @@ -92,6 +94,7 @@ async def create(
fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5.
index_query_options (QueryOptions): Index query option.
hybrid_search_config (HybridSearchConfig): Hybrid search configuration. Defaults to None.

Returns:
AlloyDBVectorStore
Expand All @@ -112,6 +115,7 @@ async def create(
fetch_k=fetch_k,
lambda_mult=lambda_mult,
index_query_options=index_query_options,
hybrid_search_config=hybrid_search_config,
)
vs = await engine._run_as_async(coro)
return cls(cls.__create_key, engine, vs)
Expand All @@ -134,6 +138,7 @@ def create_sync(
fetch_k: int = 20,
lambda_mult: float = 0.5,
index_query_options: Optional[QueryOptions] = None,
hybrid_search_config: Optional[HybridSearchConfig] = None,
) -> AlloyDBVectorStore:
"""Create an AlloyDBVectorStore instance.

Expand All @@ -154,6 +159,7 @@ def create_sync(
fetch_k (int, optional): Number of Documents to fetch to pass to MMR algorithm. Defaults to 20.
lambda_mult (float, optional): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5.
index_query_options (Optional[QueryOptions], optional): Index query option. Defaults to None.
hybrid_search_config (HybridSearchConfig): Hybrid search configuration. Defaults to None.

Returns:
AlloyDBVectorStore
Expand All @@ -174,6 +180,7 @@ def create_sync(
fetch_k=fetch_k,
lambda_mult=lambda_mult,
index_query_options=index_query_options,
hybrid_search_config=hybrid_search_config,
)
vs = engine._run_as_sync(coro)
return cls(cls.__create_key, engine, vs)
Expand Down
59 changes: 59 additions & 0 deletions tests/test_vectorstore_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from sqlalchemy import RowMapping, Sequence, text

from langchain_google_alloydb_pg import AlloyDBEngine, AlloyDBVectorStore, Column
from langchain_google_alloydb_pg.hybrid_search_config import HybridSearchConfig
from langchain_google_alloydb_pg.indexes import DistanceStrategy, HNSWQueryOptions

DEFAULT_TABLE = "test_table" + str(uuid.uuid4()).replace("-", "_")
Expand Down Expand Up @@ -352,6 +353,35 @@ async def test_aget_by_ids_custom_vs(self, vs_custom):

assert results[0] == Document(page_content="foo", id=ids[0])

async def test_asimilarity_hybrid_search(self, vs):
results = await vs.asimilarity_search(
"foo", k=1, hybrid_search_config=HybridSearchConfig()
)
assert len(results) == 1
assert results == [Document(page_content="foo", id=ids[0])]

results = await vs.asimilarity_search(
"foo",
k=1,
filter="content = 'bar'",
hybrid_search_config=HybridSearchConfig(),
)
assert results == [Document(page_content="bar", id=ids[1])]

results = await vs.asimilarity_search(
"foo",
k=1,
filter="content = 'baz'",
hybrid_search_config=HybridSearchConfig(
fusion_function_parameters={
"primary_results_weight": 0.1,
"secondary_results_weight": 0.9,
"fetch_top_k": 10,
},
),
)
assert results == [Document(page_content="baz", id=ids[2])]

@pytest.mark.parametrize("test_filter, expected_ids", FILTERING_TEST_CASES)
async def test_vectorstore_with_metadata_filters(
self,
Expand Down Expand Up @@ -549,6 +579,35 @@ def test_max_marginal_relevance_search_vector_score(self, vs_custom):
)
assert results[0][0] == Document(page_content="bar", id=ids[1])

def test_similarity_hybrid_search(self, vs_custom):
results = vs_custom.similarity_search(
"foo", k=1, hybrid_search_config=HybridSearchConfig()
)
assert len(results) == 1
assert results == [Document(page_content="foo", id=ids[0])]

results = vs_custom.similarity_search(
"foo",
k=1,
filter="mycontent = 'bar'",
hybrid_search_config=HybridSearchConfig(),
)
assert results == [Document(page_content="bar", id=ids[1])]

results = vs_custom.similarity_search(
"foo",
k=1,
filter="mycontent = 'baz'",
hybrid_search_config=HybridSearchConfig(
fusion_function_parameters={
"primary_results_weight": 0.1,
"secondary_results_weight": 0.9,
"fetch_top_k": 10,
},
),
)
assert results == [Document(page_content="baz", id=ids[2])]

def test_get_by_ids_custom_vs(self, vs_custom):
test_ids = [ids[0]]
results = vs_custom.get_by_ids(ids=test_ids)
Expand Down