Skip to content

Commit 637d80e

Browse files
committed
lint
1 parent 9be659b commit 637d80e

File tree

3 files changed

+51
-144
lines changed

3 files changed

+51
-144
lines changed

langchain_postgres/v2/async_vectorstore.py

Lines changed: 22 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,7 @@ def __init__(
111111
self.schema_name = schema_name
112112
self.content_column = content_column
113113
self.embedding_column = embedding_column
114-
self.metadata_columns = (
115-
metadata_columns if metadata_columns is not None else []
116-
)
114+
self.metadata_columns = metadata_columns if metadata_columns is not None else []
117115
self.id_column = id_column
118116
self.metadata_json_column = metadata_json_column
119117
self.distance_strategy = distance_strategy
@@ -189,27 +187,21 @@ async def create(
189187
if id_column not in columns:
190188
raise ValueError(f"Id column, {id_column}, does not exist.")
191189
if content_column not in columns:
192-
raise ValueError(
193-
f"Content column, {content_column}, does not exist."
194-
)
190+
raise ValueError(f"Content column, {content_column}, does not exist.")
195191
content_type = columns[content_column]
196192
if content_type != "text" and "char" not in content_type:
197193
raise ValueError(
198194
f"Content column, {content_column}, is type, {content_type}. It must be a type of character string."
199195
)
200196
if embedding_column not in columns:
201-
raise ValueError(
202-
f"Embedding column, {embedding_column}, does not exist."
203-
)
197+
raise ValueError(f"Embedding column, {embedding_column}, does not exist.")
204198
if columns[embedding_column] != "USER-DEFINED":
205199
raise ValueError(
206200
f"Embedding column, {embedding_column}, is not type Vector."
207201
)
208202

209203
metadata_json_column = (
210-
None
211-
if metadata_json_column not in columns
212-
else metadata_json_column
204+
None if metadata_json_column not in columns else metadata_json_column
213205
)
214206

215207
# If using metadata_columns check to make sure column exists
@@ -272,14 +264,10 @@ async def aadd_embeddings(
272264
metadatas = [{} for _ in texts]
273265

274266
# Check for inline embedding capability
275-
inline_embed_func = getattr(
276-
self.embedding_service, "embed_query_inline", None
277-
)
267+
inline_embed_func = getattr(self.embedding_service, "embed_query_inline", None)
278268
can_inline_embed = callable(inline_embed_func)
279269
# Insert embeddings
280-
for id, content, embedding, metadata in zip(
281-
ids, texts, embeddings, metadatas
282-
):
270+
for id, content, embedding, metadata in zip(ids, texts, embeddings, metadatas):
283271
metadata_col_names = (
284272
", " + ", ".join(f'"{col}"' for col in self.metadata_columns)
285273
if len(self.metadata_columns) > 0
@@ -348,15 +336,11 @@ async def aadd_texts(
348336
:class:`InvalidTextRepresentationError <asyncpg.exceptions.InvalidTextRepresentationError>`: if the `ids` data type does not match that of the `id_column`.
349337
"""
350338
# Check for inline embedding query
351-
inline_embed_func = getattr(
352-
self.embedding_service, "embed_query_inline", None
353-
)
339+
inline_embed_func = getattr(self.embedding_service, "embed_query_inline", None)
354340
if callable(inline_embed_func):
355341
embeddings: list[list[float]] = [[] for _ in list(texts)]
356342
else:
357-
embeddings = await self.embedding_service.aembed_documents(
358-
list(texts)
359-
)
343+
embeddings = await self.embedding_service.aembed_documents(list(texts))
360344

361345
ids = await self.aadd_embeddings(
362346
texts, embeddings, metadatas=metadatas, ids=ids, **kwargs
@@ -378,9 +362,7 @@ async def aadd_documents(
378362
metadatas = [doc.metadata for doc in documents]
379363
if not ids:
380364
ids = [doc.id for doc in documents]
381-
ids = await self.aadd_texts(
382-
texts, metadatas=metadatas, ids=ids, **kwargs
383-
)
365+
ids = await self.aadd_texts(texts, metadatas=metadatas, ids=ids, **kwargs)
384366
return ids
385367

386368
async def adelete(
@@ -576,9 +558,7 @@ async def __query_collection(
576558
if filter and isinstance(filter, dict):
577559
safe_filter, filter_dict = self._create_filter_clause(filter)
578560
filter = f"WHERE {safe_filter}" if safe_filter else ""
579-
inline_embed_func = getattr(
580-
self.embedding_service, "embed_query_inline", None
581-
)
561+
inline_embed_func = getattr(self.embedding_service, "embed_query_inline", None)
582562
if not embedding and callable(inline_embed_func) and "query" in kwargs:
583563
query_embedding = self.embedding_service.embed_query_inline(kwargs["query"]) # type: ignore
584564
else:
@@ -613,9 +593,7 @@ async def asimilarity_search(
613593
**kwargs: Any,
614594
) -> list[Document]:
615595
"""Return docs selected by similarity search on query."""
616-
inline_embed_func = getattr(
617-
self.embedding_service, "embed_query_inline", None
618-
)
596+
inline_embed_func = getattr(self.embedding_service, "embed_query_inline", None)
619597
embedding = (
620598
[]
621599
if callable(inline_embed_func)
@@ -646,9 +624,7 @@ async def asimilarity_search_with_score(
646624
**kwargs: Any,
647625
) -> list[tuple[Document, float]]:
648626
"""Return docs and distance scores selected by similarity search on query."""
649-
inline_embed_func = getattr(
650-
self.embedding_service, "embed_query_inline", None
651-
)
627+
inline_embed_func = getattr(self.embedding_service, "embed_query_inline", None)
652628
embedding = (
653629
[]
654630
if callable(inline_embed_func)
@@ -770,9 +746,7 @@ async def amax_marginal_relevance_search_with_score_by_vector(
770746
k = k if k else self.k
771747
fetch_k = fetch_k if fetch_k else self.fetch_k
772748
lambda_mult = lambda_mult if lambda_mult else self.lambda_mult
773-
embedding_list = [
774-
json.loads(row[self.embedding_column]) for row in results
775-
]
749+
embedding_list = [json.loads(row[self.embedding_column]) for row in results]
776750
mmr_selected = utils.maximal_marginal_relevance(
777751
np.array(embedding, dtype=np.float32),
778752
embedding_list,
@@ -800,9 +774,7 @@ async def amax_marginal_relevance_search_with_score_by_vector(
800774
)
801775
)
802776

803-
return [
804-
r for i, r in enumerate(documents_with_scores) if i in mmr_selected
805-
]
777+
return [r for i, r in enumerate(documents_with_scores) if i in mmr_selected]
806778

807779
async def aapply_vector_index(
808780
self,
@@ -820,16 +792,12 @@ async def aapply_vector_index(
820792
if index.extension_name:
821793
async with self.engine.connect() as conn:
822794
await conn.execute(
823-
text(
824-
f"CREATE EXTENSION IF NOT EXISTS {index.extension_name}"
825-
)
795+
text(f"CREATE EXTENSION IF NOT EXISTS {index.extension_name}")
826796
)
827797
await conn.commit()
828798
function = index.get_index_function()
829799

830-
filter = (
831-
f"WHERE ({index.partial_indexes})" if index.partial_indexes else ""
832-
)
800+
filter = f"WHERE ({index.partial_indexes})" if index.partial_indexes else ""
833801
params = "WITH " + index.index_options()
834802
if name is None:
835803
if index.name == None:
@@ -993,9 +961,7 @@ def _handle_field_filter(
993961
# filter_value = f"'{filter_value}'"
994962
native = COMPARISONS_TO_NATIVE[operator]
995963
id = str(uuid.uuid4()).split("-")[0]
996-
return f"{field} {native} :{field}_{id}", {
997-
f"{field}_{id}": filter_value
998-
}
964+
return f"{field} {native} :{field}_{id}", {f"{field}_{id}": filter_value}
999965
elif operator == "$between":
1000966
# Use AND with two comparisons
1001967
low, high = filter_value
@@ -1019,17 +985,11 @@ def _handle_field_filter(
1019985
)
1020986

1021987
if operator in {"$in"}:
1022-
return f"{field} = ANY(:{field}_in)", {
1023-
f"{field}_in": filter_value
1024-
}
988+
return f"{field} = ANY(:{field}_in)", {f"{field}_in": filter_value}
1025989
elif operator in {"$nin"}:
1026-
return f"{field} <> ALL (:{field}_nin)", {
1027-
f"{field}_nin": filter_value
1028-
}
990+
return f"{field} <> ALL (:{field}_nin)", {f"{field}_nin": filter_value}
1029991
elif operator in {"$like"}:
1030-
return f"({field} LIKE :{field}_like)", {
1031-
f"{field}_like": filter_value
1032-
}
992+
return f"({field} LIKE :{field}_like)", {f"{field}_like": filter_value}
1033993
elif operator in {"$ilike"}:
1034994
return f"({field} ILIKE :{field}_ilike)", {
1035995
f"{field}_ilike": filter_value
@@ -1108,9 +1068,7 @@ def _create_filter_clause(self, filters: Any) -> tuple[str, dict]:
11081068
params = {}
11091069
for clause in not_conditions:
11101070
params.update(clause[1])
1111-
not_stmts = [
1112-
f"NOT {condition}" for condition in all_clauses
1113-
]
1071+
not_stmts = [f"NOT {condition}" for condition in all_clauses]
11141072
return f"({' AND '.join(not_stmts)})", params
11151073
elif isinstance(value, dict):
11161074
not_, params = self._create_filter_clause(value)
@@ -1134,8 +1092,7 @@ def _create_filter_clause(self, filters: Any) -> tuple[str, dict]:
11341092
)
11351093
# These should all be fields and combined using an $and operator
11361094
and_ = [
1137-
self._handle_field_filter(field=k, value=v)
1138-
for k, v in filters.items()
1095+
self._handle_field_filter(field=k, value=v) for k, v in filters.items()
11391096
]
11401097
if len(and_) > 1:
11411098
all_clauses = [clause[0] for clause in and_]

tests/unit_tests/v2/test_async_pg_vectorstore_search.py

Lines changed: 14 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,13 @@
3333
ids = [str(uuid.uuid4()) for i in range(len(texts))]
3434
metadatas = [{"page": str(i), "source": "postgres"} for i in range(len(texts))]
3535
docs = [
36-
Document(page_content=texts[i], metadata=metadatas[i])
37-
for i in range(len(texts))
36+
Document(page_content=texts[i], metadata=metadatas[i]) for i in range(len(texts))
3837
]
3938

4039
embeddings = [embeddings_service.embed_query("foo") for i in range(len(texts))]
4140

4241
filter_docs = [
43-
Document(page_content=texts[i], metadata=METADATAS[i])
44-
for i in range(len(texts))
42+
Document(page_content=texts[i], metadata=METADATAS[i]) for i in range(len(texts))
4543
]
4644

4745

@@ -87,9 +85,7 @@ async def vs(self, engine: PGEngine) -> AsyncIterator[AsyncPGVectorStore]:
8785
yield vs
8886

8987
@pytest_asyncio.fixture(scope="class")
90-
async def vs_custom(
91-
self, engine: PGEngine
92-
) -> AsyncIterator[AsyncPGVectorStore]:
88+
async def vs_custom(self, engine: PGEngine) -> AsyncIterator[AsyncPGVectorStore]:
9389
await engine._ainit_vectorstore_table(
9490
CUSTOM_TABLE,
9591
VECTOR_SIZE,
@@ -153,24 +149,18 @@ async def vs_custom_filter(
153149
await vs_custom_filter.aadd_documents(filter_docs, ids=ids)
154150
yield vs_custom_filter
155151

156-
async def test_asimilarity_search_score(
157-
self, vs: AsyncPGVectorStore
158-
) -> None:
152+
async def test_asimilarity_search_score(self, vs: AsyncPGVectorStore) -> None:
159153
results = await vs.asimilarity_search_with_score("foo")
160154
assert len(results) == 4
161155
assert results[0][0] == Document(page_content="foo", id=ids[0])
162156
assert results[0][1] == 0
163157

164-
async def test_asimilarity_search_by_vector(
165-
self, vs: AsyncPGVectorStore
166-
) -> None:
158+
async def test_asimilarity_search_by_vector(self, vs: AsyncPGVectorStore) -> None:
167159
embedding = embeddings_service.embed_query("foo")
168160
results = await vs.asimilarity_search_by_vector(embedding)
169161
assert len(results) == 4
170162
assert results[0] == Document(page_content="foo", id=ids[0])
171-
result = await vs.asimilarity_search_with_score_by_vector(
172-
embedding=embedding
173-
)
163+
result = await vs.asimilarity_search_with_score_by_vector(embedding=embedding)
174164
assert result[0][0] == Document(page_content="foo", id=ids[0])
175165
assert result[0][1] == 0
176166

@@ -244,9 +234,7 @@ async def test_amax_marginal_relevance_search_vector_score(
244234
)
245235
assert results[0][0] == Document(page_content="bar", id=ids[1])
246236

247-
async def test_similarity_search_score(
248-
self, vs_custom: AsyncPGVectorStore
249-
) -> None:
237+
async def test_similarity_search_score(self, vs_custom: AsyncPGVectorStore) -> None:
250238
results = await vs_custom.asimilarity_search_with_score("foo")
251239
assert len(results) == 4
252240
assert results[0][0] == Document(page_content="foo", id=ids[0])
@@ -269,26 +257,20 @@ async def test_max_marginal_relevance_search_vector(
269257
self, vs_custom: AsyncPGVectorStore
270258
) -> None:
271259
embedding = embeddings_service.embed_query("bar")
272-
results = await vs_custom.amax_marginal_relevance_search_by_vector(
273-
embedding
274-
)
260+
results = await vs_custom.amax_marginal_relevance_search_by_vector(embedding)
275261
assert results[0] == Document(page_content="bar", id=ids[1])
276262

277263
async def test_max_marginal_relevance_search_vector_score(
278264
self, vs_custom: AsyncPGVectorStore
279265
) -> None:
280266
embedding = embeddings_service.embed_query("bar")
281-
results = (
282-
await vs_custom.amax_marginal_relevance_search_with_score_by_vector(
283-
embedding
284-
)
267+
results = await vs_custom.amax_marginal_relevance_search_with_score_by_vector(
268+
embedding
285269
)
286270
assert results[0][0] == Document(page_content="bar", id=ids[1])
287271

288-
results = (
289-
await vs_custom.amax_marginal_relevance_search_with_score_by_vector(
290-
embedding, lambda_mult=0.75, fetch_k=10
291-
)
272+
results = await vs_custom.amax_marginal_relevance_search_with_score_by_vector(
273+
embedding, lambda_mult=0.75, fetch_k=10
292274
)
293275
assert results[0][0] == Document(page_content="bar", id=ids[1])
294276

@@ -298,9 +280,7 @@ async def test_aget_by_ids(self, vs: AsyncPGVectorStore) -> None:
298280

299281
assert results[0] == Document(page_content="foo", id=ids[0])
300282

301-
async def test_aget_by_ids_custom_vs(
302-
self, vs_custom: AsyncPGVectorStore
303-
) -> None:
283+
async def test_aget_by_ids_custom_vs(self, vs_custom: AsyncPGVectorStore) -> None:
304284
test_ids = [ids[0]]
305285
results = await vs_custom.aget_by_ids(ids=test_ids)
306286

@@ -322,6 +302,4 @@ async def test_vectorstore_with_metadata_filters(
322302
docs = await vs_custom_filter.asimilarity_search(
323303
"meow", k=5, filter=test_filter
324304
)
325-
assert [
326-
doc.metadata["code"] for doc in docs
327-
] == expected_ids, test_filter
305+
assert [doc.metadata["code"] for doc in docs] == expected_ids, test_filter

0 commit comments

Comments
 (0)