1
1
from __future__ import annotations
2
2
3
- import contextlib
4
3
import enum
5
4
import logging
6
5
import uuid
7
6
from typing import (
8
7
Any ,
9
8
Callable ,
10
9
Dict ,
11
- Generator ,
12
10
Iterable ,
13
11
List ,
14
12
Optional ,
21
19
import sqlalchemy
22
20
from sqlalchemy import SQLColumnExpression , cast , delete , func
23
21
from sqlalchemy .dialects .postgresql import JSON , JSONB , JSONPATH , UUID , insert
24
- from sqlalchemy .orm import Session , relationship
22
+ from sqlalchemy .orm import Session , relationship , sessionmaker
25
23
26
24
try :
27
25
from sqlalchemy .orm import declarative_base
@@ -288,15 +286,19 @@ def __init__(
288
286
self .override_relevance_score_fn = relevance_score_fn
289
287
290
288
if isinstance (connection , str ):
291
- self ._bind = sqlalchemy .create_engine (url = connection , ** (engine_args or {}))
289
+ self ._engine = sqlalchemy .create_engine (
290
+ url = connection , ** (engine_args or {})
291
+ )
292
292
elif isinstance (connection , sqlalchemy .engine .Engine ):
293
- self ._bind = connection
293
+ self ._engine = connection
294
294
else :
295
295
raise ValueError (
296
296
"connection should be a connection string or an instance of "
297
297
"sqlalchemy.engine.Engine"
298
298
)
299
299
300
+ self ._session_maker = sessionmaker (bind = self ._engine )
301
+
300
302
self .use_jsonb = use_jsonb
301
303
self .create_extension = create_extension
302
304
@@ -321,16 +323,16 @@ def __post_init__(
321
323
self .create_collection ()
322
324
323
325
def __del__ (self ) -> None :
324
- if isinstance (self ._bind , sqlalchemy .engine .Connection ):
325
- self ._bind .close ()
326
+ if isinstance (self ._engine , sqlalchemy .engine .Connection ):
327
+ self ._engine .close ()
326
328
327
329
@property
328
330
def embeddings (self ) -> Embeddings :
329
331
return self .embedding_function
330
332
331
333
def create_vector_extension (self ) -> None :
332
334
try :
333
- with Session ( self ._bind ) as session : # type: ignore[arg-type]
335
+ with self ._session_maker ( ) as session : # type: ignore[arg-type]
334
336
# The advisor lock fixes issue arising from concurrent
335
337
# creation of the vector extension.
336
338
# https://github.com/langchain-ai/langchain/issues/12933
@@ -348,36 +350,31 @@ def create_vector_extension(self) -> None:
348
350
raise Exception (f"Failed to create vector extension: { e } " ) from e
349
351
350
352
def create_tables_if_not_exists (self ) -> None :
351
- with Session ( self ._bind ) as session , session . begin (): # type: ignore[arg-type]
353
+ with self ._session_maker ( ) as session :
352
354
Base .metadata .create_all (session .get_bind ())
353
355
354
356
def drop_tables (self ) -> None :
355
- with Session ( self ._bind ) as session , session . begin (): # type: ignore[arg-type]
357
+ with self ._session_maker ( ) as session :
356
358
Base .metadata .drop_all (session .get_bind ())
357
359
358
360
def create_collection (self ) -> None :
359
361
if self .pre_delete_collection :
360
362
self .delete_collection ()
361
- with Session ( self ._bind ) as session : # type: ignore[arg-type]
363
+ with self ._session_maker ( ) as session : # type: ignore[arg-type]
362
364
self .CollectionStore .get_or_create (
363
365
session , self .collection_name , cmetadata = self .collection_metadata
364
366
)
365
367
366
368
def delete_collection (self ) -> None :
367
369
self .logger .debug ("Trying to delete collection" )
368
- with Session ( self ._bind ) as session : # type: ignore[arg-type]
370
+ with self ._session_maker ( ) as session : # type: ignore[arg-type]
369
371
collection = self .get_collection (session )
370
372
if not collection :
371
373
self .logger .warning ("Collection not found" )
372
374
return
373
375
session .delete (collection )
374
376
session .commit ()
375
377
376
- @contextlib .contextmanager
377
- def _make_session (self ) -> Generator [Session , None , None ]:
378
- """Create a context manager for the session, bind to _conn string."""
379
- yield Session (self ._bind ) # type: ignore[arg-type]
380
-
381
378
def delete (
382
379
self ,
383
380
ids : Optional [List [str ]] = None ,
@@ -390,7 +387,7 @@ def delete(
390
387
ids: List of ids to delete.
391
388
collection_only: Only delete ids in the collection.
392
389
"""
393
- with Session ( self ._bind ) as session : # type: ignore[arg-type]
390
+ with self ._session_maker ( ) as session :
394
391
if ids is not None :
395
392
self .logger .debug (
396
393
"Trying to delete vectors by ids (represented by the model "
@@ -476,7 +473,7 @@ def add_embeddings(
476
473
if not metadatas :
477
474
metadatas = [{} for _ in texts ]
478
475
479
- with Session ( self ._bind ) as session : # type: ignore[arg-type]
476
+ with self ._session_maker ( ) as session : # type: ignore[arg-type]
480
477
collection = self .get_collection (session )
481
478
if not collection :
482
479
raise ValueError ("Collection not found" )
@@ -901,7 +898,7 @@ def __query_collection(
901
898
filter : Optional [Dict [str , str ]] = None ,
902
899
) -> List [Any ]:
903
900
"""Query the collection."""
904
- with Session ( self ._bind ) as session : # type: ignore[arg-type]
901
+ with self ._session_maker ( ) as session : # type: ignore[arg-type]
905
902
collection = self .get_collection (session )
906
903
if not collection :
907
904
raise ValueError ("Collection not found" )
@@ -1066,6 +1063,7 @@ def from_existing_index(
1066
1063
embeddings = embedding ,
1067
1064
distance_strategy = distance_strategy ,
1068
1065
pre_delete_collection = pre_delete_collection ,
1066
+ ** kwargs ,
1069
1067
)
1070
1068
1071
1069
return store
0 commit comments