@@ -111,9 +111,7 @@ def __init__(
111
111
self .schema_name = schema_name
112
112
self .content_column = content_column
113
113
self .embedding_column = embedding_column
114
- self .metadata_columns = (
115
- metadata_columns if metadata_columns is not None else []
116
- )
114
+ self .metadata_columns = metadata_columns if metadata_columns is not None else []
117
115
self .id_column = id_column
118
116
self .metadata_json_column = metadata_json_column
119
117
self .distance_strategy = distance_strategy
@@ -189,27 +187,21 @@ async def create(
189
187
if id_column not in columns :
190
188
raise ValueError (f"Id column, { id_column } , does not exist." )
191
189
if content_column not in columns :
192
- raise ValueError (
193
- f"Content column, { content_column } , does not exist."
194
- )
190
+ raise ValueError (f"Content column, { content_column } , does not exist." )
195
191
content_type = columns [content_column ]
196
192
if content_type != "text" and "char" not in content_type :
197
193
raise ValueError (
198
194
f"Content column, { content_column } , is type, { content_type } . It must be a type of character string."
199
195
)
200
196
if embedding_column not in columns :
201
- raise ValueError (
202
- f"Embedding column, { embedding_column } , does not exist."
203
- )
197
+ raise ValueError (f"Embedding column, { embedding_column } , does not exist." )
204
198
if columns [embedding_column ] != "USER-DEFINED" :
205
199
raise ValueError (
206
200
f"Embedding column, { embedding_column } , is not type Vector."
207
201
)
208
202
209
203
metadata_json_column = (
210
- None
211
- if metadata_json_column not in columns
212
- else metadata_json_column
204
+ None if metadata_json_column not in columns else metadata_json_column
213
205
)
214
206
215
207
# If using metadata_columns check to make sure column exists
@@ -272,14 +264,10 @@ async def aadd_embeddings(
272
264
metadatas = [{} for _ in texts ]
273
265
274
266
# Check for inline embedding capability
275
- inline_embed_func = getattr (
276
- self .embedding_service , "embed_query_inline" , None
277
- )
267
+ inline_embed_func = getattr (self .embedding_service , "embed_query_inline" , None )
278
268
can_inline_embed = callable (inline_embed_func )
279
269
# Insert embeddings
280
- for id , content , embedding , metadata in zip (
281
- ids , texts , embeddings , metadatas
282
- ):
270
+ for id , content , embedding , metadata in zip (ids , texts , embeddings , metadatas ):
283
271
metadata_col_names = (
284
272
", " + ", " .join (f'"{ col } "' for col in self .metadata_columns )
285
273
if len (self .metadata_columns ) > 0
@@ -348,15 +336,11 @@ async def aadd_texts(
348
336
:class:`InvalidTextRepresentationError <asyncpg.exceptions.InvalidTextRepresentationError>`: if the `ids` data type does not match that of the `id_column`.
349
337
"""
350
338
# Check for inline embedding query
351
- inline_embed_func = getattr (
352
- self .embedding_service , "embed_query_inline" , None
353
- )
339
+ inline_embed_func = getattr (self .embedding_service , "embed_query_inline" , None )
354
340
if callable (inline_embed_func ):
355
341
embeddings : list [list [float ]] = [[] for _ in list (texts )]
356
342
else :
357
- embeddings = await self .embedding_service .aembed_documents (
358
- list (texts )
359
- )
343
+ embeddings = await self .embedding_service .aembed_documents (list (texts ))
360
344
361
345
ids = await self .aadd_embeddings (
362
346
texts , embeddings , metadatas = metadatas , ids = ids , ** kwargs
@@ -378,9 +362,7 @@ async def aadd_documents(
378
362
metadatas = [doc .metadata for doc in documents ]
379
363
if not ids :
380
364
ids = [doc .id for doc in documents ]
381
- ids = await self .aadd_texts (
382
- texts , metadatas = metadatas , ids = ids , ** kwargs
383
- )
365
+ ids = await self .aadd_texts (texts , metadatas = metadatas , ids = ids , ** kwargs )
384
366
return ids
385
367
386
368
async def adelete (
@@ -576,9 +558,7 @@ async def __query_collection(
576
558
if filter and isinstance (filter , dict ):
577
559
safe_filter , filter_dict = self ._create_filter_clause (filter )
578
560
filter = f"WHERE { safe_filter } " if safe_filter else ""
579
- inline_embed_func = getattr (
580
- self .embedding_service , "embed_query_inline" , None
581
- )
561
+ inline_embed_func = getattr (self .embedding_service , "embed_query_inline" , None )
582
562
if not embedding and callable (inline_embed_func ) and "query" in kwargs :
583
563
query_embedding = self .embedding_service .embed_query_inline (kwargs ["query" ]) # type: ignore
584
564
else :
@@ -613,9 +593,7 @@ async def asimilarity_search(
613
593
** kwargs : Any ,
614
594
) -> list [Document ]:
615
595
"""Return docs selected by similarity search on query."""
616
- inline_embed_func = getattr (
617
- self .embedding_service , "embed_query_inline" , None
618
- )
596
+ inline_embed_func = getattr (self .embedding_service , "embed_query_inline" , None )
619
597
embedding = (
620
598
[]
621
599
if callable (inline_embed_func )
@@ -646,9 +624,7 @@ async def asimilarity_search_with_score(
646
624
** kwargs : Any ,
647
625
) -> list [tuple [Document , float ]]:
648
626
"""Return docs and distance scores selected by similarity search on query."""
649
- inline_embed_func = getattr (
650
- self .embedding_service , "embed_query_inline" , None
651
- )
627
+ inline_embed_func = getattr (self .embedding_service , "embed_query_inline" , None )
652
628
embedding = (
653
629
[]
654
630
if callable (inline_embed_func )
@@ -770,9 +746,7 @@ async def amax_marginal_relevance_search_with_score_by_vector(
770
746
k = k if k else self .k
771
747
fetch_k = fetch_k if fetch_k else self .fetch_k
772
748
lambda_mult = lambda_mult if lambda_mult else self .lambda_mult
773
- embedding_list = [
774
- json .loads (row [self .embedding_column ]) for row in results
775
- ]
749
+ embedding_list = [json .loads (row [self .embedding_column ]) for row in results ]
776
750
mmr_selected = utils .maximal_marginal_relevance (
777
751
np .array (embedding , dtype = np .float32 ),
778
752
embedding_list ,
@@ -800,9 +774,7 @@ async def amax_marginal_relevance_search_with_score_by_vector(
800
774
)
801
775
)
802
776
803
- return [
804
- r for i , r in enumerate (documents_with_scores ) if i in mmr_selected
805
- ]
777
+ return [r for i , r in enumerate (documents_with_scores ) if i in mmr_selected ]
806
778
807
779
async def aapply_vector_index (
808
780
self ,
@@ -820,16 +792,12 @@ async def aapply_vector_index(
820
792
if index .extension_name :
821
793
async with self .engine .connect () as conn :
822
794
await conn .execute (
823
- text (
824
- f"CREATE EXTENSION IF NOT EXISTS { index .extension_name } "
825
- )
795
+ text (f"CREATE EXTENSION IF NOT EXISTS { index .extension_name } " )
826
796
)
827
797
await conn .commit ()
828
798
function = index .get_index_function ()
829
799
830
- filter = (
831
- f"WHERE ({ index .partial_indexes } )" if index .partial_indexes else ""
832
- )
800
+ filter = f"WHERE ({ index .partial_indexes } )" if index .partial_indexes else ""
833
801
params = "WITH " + index .index_options ()
834
802
if name is None :
835
803
if index .name == None :
@@ -993,9 +961,7 @@ def _handle_field_filter(
993
961
# filter_value = f"'{filter_value}'"
994
962
native = COMPARISONS_TO_NATIVE [operator ]
995
963
id = str (uuid .uuid4 ()).split ("-" )[0 ]
996
- return f"{ field } { native } :{ field } _{ id } " , {
997
- f"{ field } _{ id } " : filter_value
998
- }
964
+ return f"{ field } { native } :{ field } _{ id } " , {f"{ field } _{ id } " : filter_value }
999
965
elif operator == "$between" :
1000
966
# Use AND with two comparisons
1001
967
low , high = filter_value
@@ -1019,17 +985,11 @@ def _handle_field_filter(
1019
985
)
1020
986
1021
987
if operator in {"$in" }:
1022
- return f"{ field } = ANY(:{ field } _in)" , {
1023
- f"{ field } _in" : filter_value
1024
- }
988
+ return f"{ field } = ANY(:{ field } _in)" , {f"{ field } _in" : filter_value }
1025
989
elif operator in {"$nin" }:
1026
- return f"{ field } <> ALL (:{ field } _nin)" , {
1027
- f"{ field } _nin" : filter_value
1028
- }
990
+ return f"{ field } <> ALL (:{ field } _nin)" , {f"{ field } _nin" : filter_value }
1029
991
elif operator in {"$like" }:
1030
- return f"({ field } LIKE :{ field } _like)" , {
1031
- f"{ field } _like" : filter_value
1032
- }
992
+ return f"({ field } LIKE :{ field } _like)" , {f"{ field } _like" : filter_value }
1033
993
elif operator in {"$ilike" }:
1034
994
return f"({ field } ILIKE :{ field } _ilike)" , {
1035
995
f"{ field } _ilike" : filter_value
@@ -1108,9 +1068,7 @@ def _create_filter_clause(self, filters: Any) -> tuple[str, dict]:
1108
1068
params = {}
1109
1069
for clause in not_conditions :
1110
1070
params .update (clause [1 ])
1111
- not_stmts = [
1112
- f"NOT { condition } " for condition in all_clauses
1113
- ]
1071
+ not_stmts = [f"NOT { condition } " for condition in all_clauses ]
1114
1072
return f"({ ' AND ' .join (not_stmts )} )" , params
1115
1073
elif isinstance (value , dict ):
1116
1074
not_ , params = self ._create_filter_clause (value )
@@ -1134,8 +1092,7 @@ def _create_filter_clause(self, filters: Any) -> tuple[str, dict]:
1134
1092
)
1135
1093
# These should all be fields and combined using an $and operator
1136
1094
and_ = [
1137
- self ._handle_field_filter (field = k , value = v )
1138
- for k , v in filters .items ()
1095
+ self ._handle_field_filter (field = k , value = v ) for k , v in filters .items ()
1139
1096
]
1140
1097
if len (and_ ) > 1 :
1141
1098
all_clauses = [clause [0 ] for clause in and_ ]
0 commit comments