Skip to content

fix: remove string filters and parameterize filters #185

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 62 additions & 42 deletions langchain_postgres/v2/async_vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import copy
import json
import re
import uuid
from typing import Any, Callable, Iterable, Optional, Sequence

Expand Down Expand Up @@ -175,7 +174,8 @@ async def create(
stmt = "SELECT column_name, data_type FROM information_schema.columns WHERE table_name = :table_name AND table_schema = :schema_name"
async with engine._pool.connect() as conn:
result = await conn.execute(
text(stmt), {"table_name": table_name, "schema_name": schema_name}
text(stmt),
{"table_name": table_name, "schema_name": schema_name},
)
result_map = result.mappings()
results = result_map.fetchall()
Expand Down Expand Up @@ -535,7 +535,7 @@ async def __query_collection(
embedding: list[float],
*,
k: Optional[int] = None,
filter: Optional[dict] | Optional[str] = None,
filter: Optional[dict] = None,
**kwargs: Any,
) -> Sequence[RowMapping]:
"""Perform similarity search query on database."""
Expand All @@ -553,16 +553,22 @@ async def __query_collection(

column_names = ", ".join(f'"{col}"' for col in columns)

safe_filter = None
filter_dict = None
if filter and isinstance(filter, dict):
filter = self._create_filter_clause(filter)
filter = f"WHERE {filter}" if filter else ""
safe_filter, filter_dict = self._create_filter_clause(filter)
param_filter = f"WHERE {safe_filter}" if safe_filter else ""
inline_embed_func = getattr(self.embedding_service, "embed_query_inline", None)
if not embedding and callable(inline_embed_func) and "query" in kwargs:
query_embedding = self.embedding_service.embed_query_inline(kwargs["query"]) # type: ignore
else:
query_embedding = f"{[float(dimension) for dimension in embedding]}"
stmt = f'SELECT {column_names}, {search_function}("{self.embedding_column}", :query_embedding) as distance FROM "{self.schema_name}"."{self.table_name}" {filter} ORDER BY "{self.embedding_column}" {operator} :query_embedding LIMIT :k;'
stmt = f"""SELECT {column_names}, {search_function}("{self.embedding_column}", :query_embedding) as distance
FROM "{self.schema_name}"."{self.table_name}" {param_filter} ORDER BY "{self.embedding_column}" {operator} :query_embedding LIMIT :k;
"""
param_dict = {"query_embedding": query_embedding, "k": k}
if filter_dict:
param_dict.update(filter_dict)
if self.index_query_options:
async with self.engine.connect() as conn:
# Set each query option individually
Expand All @@ -583,7 +589,7 @@ async def asimilarity_search(
self,
query: str,
k: Optional[int] = None,
filter: Optional[dict] | Optional[str] = None,
filter: Optional[dict] = None,
**kwargs: Any,
) -> list[Document]:
"""Return docs selected by similarity search on query."""
Expand Down Expand Up @@ -614,7 +620,7 @@ async def asimilarity_search_with_score(
self,
query: str,
k: Optional[int] = None,
filter: Optional[dict] | Optional[str] = None,
filter: Optional[dict] = None,
**kwargs: Any,
) -> list[tuple[Document, float]]:
"""Return docs and distance scores selected by similarity search on query."""
Expand All @@ -635,7 +641,7 @@ async def asimilarity_search_by_vector(
self,
embedding: list[float],
k: Optional[int] = None,
filter: Optional[dict] | Optional[str] = None,
filter: Optional[dict] = None,
**kwargs: Any,
) -> list[Document]:
"""Return docs selected by vector similarity search."""
Expand All @@ -649,7 +655,7 @@ async def asimilarity_search_with_score_by_vector(
self,
embedding: list[float],
k: Optional[int] = None,
filter: Optional[dict] | Optional[str] = None,
filter: Optional[dict] = None,
**kwargs: Any,
) -> list[tuple[Document, float]]:
"""Return docs and distance scores selected by vector similarity search."""
Expand Down Expand Up @@ -685,7 +691,7 @@ async def amax_marginal_relevance_search(
k: Optional[int] = None,
fetch_k: Optional[int] = None,
lambda_mult: Optional[float] = None,
filter: Optional[dict] | Optional[str] = None,
filter: Optional[dict] = None,
**kwargs: Any,
) -> list[Document]:
"""Return docs selected using the maximal marginal relevance."""
Expand All @@ -706,7 +712,7 @@ async def amax_marginal_relevance_search_by_vector(
k: Optional[int] = None,
fetch_k: Optional[int] = None,
lambda_mult: Optional[float] = None,
filter: Optional[dict] | Optional[str] = None,
filter: Optional[dict] = None,
**kwargs: Any,
) -> list[Document]:
"""Return docs selected using the maximal marginal relevance."""
Expand All @@ -729,7 +735,7 @@ async def amax_marginal_relevance_search_with_score_by_vector(
k: Optional[int] = None,
fetch_k: Optional[int] = None,
lambda_mult: Optional[float] = None,
filter: Optional[dict] | Optional[str] = None,
filter: Optional[dict] = None,
**kwargs: Any,
) -> list[tuple[Document, float]]:
"""Return docs and distance scores selected using the maximal marginal relevance."""
Expand Down Expand Up @@ -834,7 +840,7 @@ async def is_valid_index(
) -> bool:
"""Check if index exists in the table."""
index_name = index_name or self.table_name + DEFAULT_INDEX_NAME_SUFFIX
query = f"""
query = """
SELECT tablename, indexname
FROM pg_indexes
WHERE tablename = :table_name AND schemaname = :schema_name AND indexname = :index_name;
Expand Down Expand Up @@ -898,7 +904,7 @@ def _handle_field_filter(
*,
field: str,
value: Any,
) -> str:
) -> tuple[str, dict]:
"""Create a filter for a specific field.

Args:
Expand Down Expand Up @@ -951,15 +957,17 @@ def _handle_field_filter(
if operator in COMPARISONS_TO_NATIVE:
# Then we implement an equality filter
# native is trusted input
if isinstance(filter_value, str):
filter_value = f"'{filter_value}'"
native = COMPARISONS_TO_NATIVE[operator]
return f"({field} {native} {filter_value})"
id = str(uuid.uuid4()).split("-")[0]
return f"{field} {native} :{field}_{id}", {f"{field}_{id}": filter_value}
elif operator == "$between":
# Use AND with two comparisons
low, high = filter_value

return f"({field} BETWEEN {low} AND {high})"
return f"({field} BETWEEN :{field}_low AND :{field}_high)", {
f"{field}_low": low,
f"{field}_high": high,
}
elif operator in {"$in", "$nin", "$like", "$ilike"}:
# We'll do force coercion to text
if operator in {"$in", "$nin"}:
Expand All @@ -975,15 +983,15 @@ def _handle_field_filter(
)

if operator in {"$in"}:
values = str(tuple(val for val in filter_value))
return f"({field} IN {values})"
return f"{field} = ANY(:{field}_in)", {f"{field}_in": filter_value}
elif operator in {"$nin"}:
values = str(tuple(val for val in filter_value))
return f"({field} NOT IN {values})"
return f"{field} <> ALL (:{field}_nin)", {f"{field}_nin": filter_value}
elif operator in {"$like"}:
return f"({field} LIKE '{filter_value}')"
return f"({field} LIKE :{field}_like)", {f"{field}_like": filter_value}
elif operator in {"$ilike"}:
return f"({field} ILIKE '{filter_value}')"
return f"({field} ILIKE :{field}_ilike)", {
f"{field}_ilike": filter_value
}
else:
raise NotImplementedError()
elif operator == "$exists":
Expand All @@ -994,13 +1002,13 @@ def _handle_field_filter(
)
else:
if filter_value:
return f"({field} IS NOT NULL)"
return f"({field} IS NOT NULL)", {}
else:
return f"({field} IS NULL)"
return f"({field} IS NULL)", {}
else:
raise NotImplementedError()

def _create_filter_clause(self, filters: Any) -> str:
def _create_filter_clause(self, filters: Any) -> tuple[str, dict]:
"""Create LangChain filter representation to matching SQL where clauses

Args:
Expand Down Expand Up @@ -1037,7 +1045,11 @@ def _create_filter_clause(self, filters: Any) -> str:
op = key[1:].upper() # Extract the operator
filter_clause = [self._create_filter_clause(el) for el in value]
if len(filter_clause) > 1:
return f"({f' {op} '.join(filter_clause)})"
all_clauses = [clause[0] for clause in filter_clause]
params = {}
for clause in filter_clause:
params.update(clause[1])
return f"({f' {op} '.join(all_clauses)})", params
elif len(filter_clause) == 1:
return filter_clause[0]
else:
Expand All @@ -1050,11 +1062,15 @@ def _create_filter_clause(self, filters: Any) -> str:
not_conditions = [
self._create_filter_clause(item) for item in value
]
not_stmts = [f"NOT {condition}" for condition in not_conditions]
return f"({' AND '.join(not_stmts)})"
all_clauses = [clause[0] for clause in not_conditions]
params = {}
for clause in not_conditions:
params.update(clause[1])
not_stmts = [f"NOT {condition}" for condition in all_clauses]
return f"({' AND '.join(not_stmts)})", params
elif isinstance(value, dict):
not_ = self._create_filter_clause(value)
return f"(NOT {not_})"
not_, params = self._create_filter_clause(value)
return f"(NOT {not_})", params
else:
raise ValueError(
f"Invalid filter condition. Expected a dictionary "
Expand All @@ -1077,7 +1093,11 @@ def _create_filter_clause(self, filters: Any) -> str:
self._handle_field_filter(field=k, value=v) for k, v in filters.items()
]
if len(and_) > 1:
return f"({' AND '.join(and_)})"
all_clauses = [clause[0] for clause in and_]
params = {}
for clause in and_:
params.update(clause[1])
return f"({' AND '.join(all_clauses)})", params
elif len(and_) == 1:
return and_[0]
else:
Expand All @@ -1086,7 +1106,7 @@ def _create_filter_clause(self, filters: Any) -> str:
"but got an empty dictionary"
)
else:
return ""
return "", {}

def get_by_ids(self, ids: Sequence[str]) -> list[Document]:
raise NotImplementedError(
Expand Down Expand Up @@ -1168,7 +1188,7 @@ def similarity_search(
self,
query: str,
k: Optional[int] = None,
filter: Optional[dict] | Optional[str] = None,
filter: Optional[dict] = None,
**kwargs: Any,
) -> list[Document]:
raise NotImplementedError(
Expand All @@ -1179,7 +1199,7 @@ def similarity_search_with_score(
self,
query: str,
k: Optional[int] = None,
filter: Optional[dict] | Optional[str] = None,
filter: Optional[dict] = None,
**kwargs: Any,
) -> list[tuple[Document, float]]:
raise NotImplementedError(
Expand All @@ -1190,7 +1210,7 @@ def similarity_search_by_vector(
self,
embedding: list[float],
k: Optional[int] = None,
filter: Optional[dict] | Optional[str] = None,
filter: Optional[dict] = None,
**kwargs: Any,
) -> list[Document]:
raise NotImplementedError(
Expand All @@ -1201,7 +1221,7 @@ def similarity_search_with_score_by_vector(
self,
embedding: list[float],
k: Optional[int] = None,
filter: Optional[dict] | Optional[str] = None,
filter: Optional[dict] = None,
**kwargs: Any,
) -> list[tuple[Document, float]]:
raise NotImplementedError(
Expand All @@ -1214,7 +1234,7 @@ def max_marginal_relevance_search(
k: Optional[int] = None,
fetch_k: Optional[int] = None,
lambda_mult: Optional[float] = None,
filter: Optional[dict] | Optional[str] = None,
filter: Optional[dict] = None,
**kwargs: Any,
) -> list[Document]:
raise NotImplementedError(
Expand All @@ -1227,7 +1247,7 @@ def max_marginal_relevance_search_by_vector(
k: Optional[int] = None,
fetch_k: Optional[int] = None,
lambda_mult: Optional[float] = None,
filter: Optional[dict] | Optional[str] = None,
filter: Optional[dict] = None,
**kwargs: Any,
) -> list[Document]:
raise NotImplementedError(
Expand All @@ -1240,7 +1260,7 @@ def max_marginal_relevance_search_with_score_by_vector(
k: Optional[int] = None,
fetch_k: Optional[int] = None,
lambda_mult: Optional[float] = None,
filter: Optional[dict] | Optional[str] = None,
filter: Optional[dict] = None,
**kwargs: Any,
) -> list[tuple[Document, float]]:
raise NotImplementedError(
Expand Down
Loading