Skip to content

Commit 28d21e7

Browse files
feat: adds hybrid search for async VS interface
1 parent bb352ee commit 28d21e7

File tree

2 files changed

+477
-12
lines changed

2 files changed

+477
-12
lines changed

src/langchain_google_alloydb_pg/async_vectorstore.py

Lines changed: 122 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from sqlalchemy.ext.asyncio import AsyncEngine
3333

3434
from .engine import AlloyDBEngine
35+
from .hybrid_search_config import HybridSearchConfig
3536
from .indexes import (
3637
DEFAULT_DISTANCE_STRATEGY,
3738
DEFAULT_INDEX_NAME_SUFFIX,
@@ -95,6 +96,8 @@ def __init__(
9596
fetch_k: int = 20,
9697
lambda_mult: float = 0.5,
9798
index_query_options: Optional[QueryOptions] = None,
99+
hybrid_search_config: Optional[HybridSearchConfig] = None,
100+
hybrid_search_column_exists: bool = False,
98101
):
99102
"""AsyncAlloyDBVectorStore constructor.
100103
Args:
@@ -113,6 +116,8 @@ def __init__(
113116
fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
114117
lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5.
115118
index_query_options (QueryOptions): Index query option.
119+
hybrid_search_config (HybridSearchConfig): Hybrid search configuration. Defaults to None.
120+
hybrid_search_column_exists (bool): Defines whether the existing table has the hybrid search column.
116121
117122
118123
Raises:
@@ -137,6 +142,8 @@ def __init__(
137142
self.fetch_k = fetch_k
138143
self.lambda_mult = lambda_mult
139144
self.index_query_options = index_query_options
145+
self.hybrid_search_config = hybrid_search_config
146+
self.hybrid_search_column_exists = hybrid_search_column_exists
140147

141148
@classmethod
142149
async def create(
@@ -156,6 +163,7 @@ async def create(
156163
fetch_k: int = 20,
157164
lambda_mult: float = 0.5,
158165
index_query_options: Optional[QueryOptions] = None,
166+
hybrid_search_config: Optional[HybridSearchConfig] = None,
159167
) -> AsyncAlloyDBVectorStore:
160168
"""Create an AsyncAlloyDBVectorStore instance.
161169
@@ -175,6 +183,8 @@ async def create(
175183
fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
176184
lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5.
177185
index_query_options (QueryOptions): Index query option.
186+
hybrid_search_config (HybridSearchConfig): Hybrid search configuration. Defaults to None.
187+
hybrid_search_column_exists (bool): Defines whether the existing table has the hybrid search column.
178188
179189
Returns:
180190
AsyncAlloyDBVectorStore
@@ -203,6 +213,15 @@ async def create(
203213
raise ValueError(
204214
f"Content column, {content_column}, is type, {content_type}. It must be a type of character string."
205215
)
216+
hybrid_search_column_exists = False
217+
if hybrid_search_config:
218+
hybrid_search_config.tsv_column = (
219+
hybrid_search_config.tsv_column
220+
if hybrid_search_config.tsv_column
221+
else content_column + "_tsv"
222+
)
223+
hybrid_search_column_exists = hybrid_search_config.tsv_column in columns
224+
206225
if embedding_column not in columns:
207226
raise ValueError(f"Embedding column, {embedding_column}, does not exist.")
208227
if columns[embedding_column] != "USER-DEFINED":
@@ -246,6 +265,8 @@ async def create(
246265
fetch_k=fetch_k,
247266
lambda_mult=lambda_mult,
248267
index_query_options=index_query_options,
268+
hybrid_search_config=hybrid_search_config,
269+
hybrid_search_column_exists=hybrid_search_column_exists,
249270
)
250271

251272
@property
@@ -279,7 +300,12 @@ async def aadd_embeddings(
279300
if len(self.metadata_columns) > 0
280301
else ""
281302
)
282-
insert_stmt = f'INSERT INTO "{self.schema_name}"."{self.table_name}"("{self.id_column}", "{self.content_column}", "{self.embedding_column}"{metadata_col_names}'
303+
hybrid_search_column = (
304+
f', "{self.hybrid_search_config.tsv_column}"'
305+
if self.hybrid_search_config and self.hybrid_search_column_exists
306+
else ""
307+
)
308+
insert_stmt = f'INSERT INTO "{self.schema_name}"."{self.table_name}"("{self.id_column}", "{self.content_column}", "{self.embedding_column}"{hybrid_search_column}{metadata_col_names}'
283309
values = {
284310
"id": id,
285311
"content": content,
@@ -292,6 +318,10 @@ async def aadd_embeddings(
292318
if not embedding and callable(inline_embed_func):
293319
values_stmt = f"VALUES (:id, :content, {self.embedding_service.embed_query_inline(content)}" # type: ignore
294320

321+
if self.hybrid_search_config and self.hybrid_search_column_exists:
322+
values_stmt += f", to_tsvector('{self.hybrid_search_config.tsv_lang}', :tsv_content)"
323+
values["tsv_content"] = content
324+
295325
# Add metadata
296326
extra = copy.deepcopy(metadata)
297327
for metadata_column in self.metadata_columns:
@@ -316,6 +346,9 @@ async def aadd_embeddings(
316346

317347
upsert_stmt = f' ON CONFLICT ("{self.id_column}") DO UPDATE SET "{self.content_column}" = EXCLUDED."{self.content_column}", "{self.embedding_column}" = EXCLUDED."{self.embedding_column}"'
318348

349+
if self.hybrid_search_config and self.hybrid_search_column_exists:
350+
upsert_stmt += f', "{self.hybrid_search_config.tsv_column}" = EXCLUDED."{self.hybrid_search_config.tsv_column}"'
351+
319352
if self.metadata_json_column:
320353
upsert_stmt += f', "{self.metadata_json_column}" = EXCLUDED."{self.metadata_json_column}"'
321354

@@ -464,6 +497,7 @@ async def afrom_texts( # type: ignore[override]
464497
fetch_k: int = 20,
465498
lambda_mult: float = 0.5,
466499
index_query_options: Optional[QueryOptions] = None,
500+
hybrid_search_config: Optional[HybridSearchConfig] = None,
467501
**kwargs: Any,
468502
) -> AsyncAlloyDBVectorStore:
469503
"""Create an AsyncAlloyDBVectorStore instance from texts.
@@ -486,6 +520,7 @@ async def afrom_texts( # type: ignore[override]
486520
fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
487521
lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5.
488522
index_query_options (QueryOptions): Index query option.
523+
hybrid_search_column_exists (bool): Defines whether the existing table has the hybrid search column.
489524
490525
Raises:
491526
:class:`InvalidTextRepresentationError <asyncpg.exceptions.InvalidTextRepresentationError>`: if the `ids` data type does not match that of the `id_column`.
@@ -509,6 +544,7 @@ async def afrom_texts( # type: ignore[override]
509544
fetch_k=fetch_k,
510545
lambda_mult=lambda_mult,
511546
index_query_options=index_query_options,
547+
hybrid_search_config=hybrid_search_config,
512548
)
513549
await vs.aadd_texts(texts, metadatas=metadatas, ids=ids, **kwargs)
514550
return vs
@@ -533,6 +569,7 @@ async def afrom_documents( # type: ignore[override]
533569
fetch_k: int = 20,
534570
lambda_mult: float = 0.5,
535571
index_query_options: Optional[QueryOptions] = None,
572+
hybrid_search_config: Optional[HybridSearchConfig] = None,
536573
**kwargs: Any,
537574
) -> AsyncAlloyDBVectorStore:
538575
"""Create an AsyncAlloyDBVectorStore instance from documents.
@@ -555,6 +592,7 @@ async def afrom_documents( # type: ignore[override]
555592
fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
556593
lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5.
557594
index_query_options (QueryOptions): Index query option.
595+
hybrid_search_column_exists (bool): Defines whether the existing table has the hybrid search column.
558596
559597
Raises:
560598
:class:`InvalidTextRepresentationError <asyncpg.exceptions.InvalidTextRepresentationError>`: if the `ids` data type does not match that of the `id_column`.
@@ -579,6 +617,7 @@ async def afrom_documents( # type: ignore[override]
579617
fetch_k=fetch_k,
580618
lambda_mult=lambda_mult,
581619
index_query_options=index_query_options,
620+
hybrid_search_config=hybrid_search_config,
582621
)
583622
texts = [doc.page_content for doc in documents]
584623
metadatas = [doc.metadata for doc in documents]
@@ -592,45 +631,93 @@ async def __query_collection(
592631
filter: Optional[dict] | Optional[str] = None,
593632
**kwargs: Any,
594633
) -> Sequence[RowMapping]:
595-
"""Perform similarity search query on database."""
596-
k = k if k else self.k
634+
"""
635+
Perform similarity search (and hybrid search if provided) query on database.
636+
If the hybrid search column does not exist, then the queries might be very slow.
637+
Consider creating the TSV column and adding GIN index, if hybrid search is required.
638+
"""
639+
if not k:
640+
k = (
641+
max(
642+
self.k,
643+
self.hybrid_search_config.primary_top_k,
644+
self.hybrid_search_config.secondary_top_k,
645+
)
646+
if self.hybrid_search_config
647+
else self.k
648+
)
597649
operator = self.distance_strategy.operator
598650
search_function = self.distance_strategy.search_function
599651

600-
columns = self.metadata_columns + [
652+
columns = [
601653
self.id_column,
602654
self.content_column,
603655
self.embedding_column,
604-
]
656+
] + self.metadata_columns
605657
if self.metadata_json_column:
606658
columns.append(self.metadata_json_column)
607659

608660
column_names = ", ".join(f'"{col}"' for col in columns)
609661

610662
if filter and isinstance(filter, dict):
611663
filter = self._create_filter_clause(filter)
612-
filter = f"WHERE {filter}" if filter else ""
664+
where_filters = f"WHERE {filter}" if filter else ""
665+
and_filters = f"AND ({filter})" if filter else ""
666+
667+
# dense query
613668
inline_embed_func = getattr(self.embedding_service, "embed_query_inline", None)
614669
if not embedding and callable(inline_embed_func) and "query" in kwargs:
615670
query_embedding = self.embedding_service.embed_query_inline(kwargs["query"]) # type: ignore
616671
else:
617672
query_embedding = f"'{[float(dimension) for dimension in embedding]}'"
618-
stmt = f'SELECT {column_names}, {search_function}({self.embedding_column}, {query_embedding}) as distance FROM "{self.schema_name}"."{self.table_name}" {filter} ORDER BY {self.embedding_column} {operator} {query_embedding} LIMIT {k};'
673+
dense_query_stmt = f'SELECT {column_names}, {search_function}({self.embedding_column}, {query_embedding}) as distance FROM "{self.schema_name}"."{self.table_name}" {where_filters} ORDER BY {self.embedding_column} {operator} {query_embedding} LIMIT {k};'
619674
if self.index_query_options:
620675
async with self.engine.connect() as conn:
621676
# Set each query option individually
622677
for query_option in self.index_query_options.to_parameter():
623678
query_options_stmt = f"SET LOCAL {query_option};"
624679
await conn.execute(text(query_options_stmt))
625-
result = await conn.execute(text(stmt))
680+
result = await conn.execute(text(dense_query_stmt))
626681
result_map = result.mappings()
627-
results = result_map.fetchall()
682+
dense_results = result_map.fetchall()
628683
else:
629684
async with self.engine.connect() as conn:
630-
result = await conn.execute(text(stmt))
685+
result = await conn.execute(text(dense_query_stmt))
631686
result_map = result.mappings()
632-
results = result_map.fetchall()
633-
return results
687+
dense_results = result_map.fetchall()
688+
689+
hybrid_search_config = kwargs.get(
690+
"hybrid_search_config", self.hybrid_search_config
691+
)
692+
fts_query = hybrid_search_config.fts_query if hybrid_search_config else ""
693+
if hybrid_search_config and fts_query:
694+
hybrid_search_config.fusion_function_parameters["fetch_top_k"] = k
695+
# do the sparse query
696+
lang = (
697+
f"'{hybrid_search_config.tsv_lang}',"
698+
if hybrid_search_config.tsv_lang
699+
else ""
700+
)
701+
query_tsv = f"plainto_tsquery({lang} :fts_query)"
702+
values = {"fts_query": fts_query}
703+
if self.hybrid_search_column_exists:
704+
content_tsv = f'"{hybrid_search_config.tsv_column}"'
705+
else:
706+
content_tsv = f'to_tsvector({lang} "{self.content_column}")'
707+
sparse_query_stmt = f'SELECT {column_names}, ts_rank_cd({content_tsv}, {query_tsv}) as distance FROM "{self.schema_name}"."{self.table_name}" WHERE {content_tsv} @@ {query_tsv} {and_filters} ORDER BY distance desc LIMIT {hybrid_search_config.secondary_top_k};'
708+
async with self.engine.connect() as conn:
709+
result = await conn.execute(text(sparse_query_stmt), values)
710+
result_map = result.mappings()
711+
sparse_results = result_map.fetchall()
712+
713+
combined_results = hybrid_search_config.fusion_function(
714+
dense_results,
715+
sparse_results,
716+
**hybrid_search_config.fusion_function_parameters,
717+
)
718+
return combined_results
719+
720+
return dense_results
634721

635722
async def asimilarity_search(
636723
self,
@@ -648,6 +735,14 @@ async def asimilarity_search(
648735
)
649736
kwargs["query"] = query
650737

738+
# add fts_query to hybrid_search_config
739+
hybrid_search_config = kwargs.get(
740+
"hybrid_search_config", self.hybrid_search_config
741+
)
742+
if hybrid_search_config and not hybrid_search_config.fts_query:
743+
hybrid_search_config.fts_query = query
744+
kwargs["hybrid_search_config"] = hybrid_search_config
745+
651746
return await self.asimilarity_search_by_vector(
652747
embedding=embedding, k=k, filter=filter, **kwargs
653748
)
@@ -715,6 +810,14 @@ async def asimilarity_search_with_score(
715810
)
716811
kwargs["query"] = query
717812

813+
# add fts_query to hybrid_search_config
814+
hybrid_search_config = kwargs.get(
815+
"hybrid_search_config", self.hybrid_search_config
816+
)
817+
if hybrid_search_config and not hybrid_search_config.fts_query:
818+
hybrid_search_config.fts_query = query
819+
kwargs["hybrid_search_config"] = hybrid_search_config
820+
718821
docs = await self.asimilarity_search_with_score_by_vector(
719822
embedding=embedding, k=k, filter=filter, **kwargs
720823
)
@@ -898,13 +1001,20 @@ async def aapply_vector_index(
8981001
index.name = self.table_name + DEFAULT_INDEX_NAME_SUFFIX
8991002
name = index.name
9001003
stmt = f"CREATE INDEX {'CONCURRENTLY' if concurrently else ''} {name} ON \"{self.schema_name}\".\"{self.table_name}\" USING {index.index_type} ({self.embedding_column} {function}) {params} {filter};"
1004+
tsv_index_query = (
1005+
f'CREATE INDEX IF NOT EXISTS {self.hybrid_search_config.index_name} ON table_name USING {self.hybrid_search_config.index_type}("{self.content_column}");'
1006+
if self.hybrid_search_config
1007+
else ""
1008+
)
9011009
if concurrently:
9021010
async with self.engine.connect() as conn:
9031011
await conn.execute(text("COMMIT"))
9041012
await conn.execute(text(stmt))
1013+
await conn.execute(text(tsv_index_query))
9051014
else:
9061015
async with self.engine.connect() as conn:
9071016
await conn.execute(text(stmt))
1017+
await conn.execute(text(tsv_index_query))
9081018
await conn.commit()
9091019

9101020
async def areindex(self, index_name: Optional[str] = None) -> None:

0 commit comments

Comments
 (0)