Skip to content

Commit c587e4c

Browse files
authored
Update session management in vectorstore (#25)
Update session management in the vectorstore
1 parent 8d09a2b commit c587e4c

File tree

3 files changed

+38
-35
lines changed

3 files changed

+38
-35
lines changed

.github/workflows/_test.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ jobs:
2020
strategy:
2121
matrix:
2222
python-version:
23-
# - "3.8"
24-
# - "3.9"
25-
# - "3.10"
23+
- "3.8"
24+
- "3.9"
25+
- "3.10"
2626
- "3.11"
2727
name: Python ${{ matrix.python-version }}
2828
steps:

langchain_postgres/vectorstores.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
from __future__ import annotations
22

3-
import contextlib
43
import enum
54
import logging
65
import uuid
76
from typing import (
87
Any,
98
Callable,
109
Dict,
11-
Generator,
1210
Iterable,
1311
List,
1412
Optional,
@@ -21,7 +19,7 @@
2119
import sqlalchemy
2220
from sqlalchemy import SQLColumnExpression, cast, delete, func
2321
from sqlalchemy.dialects.postgresql import JSON, JSONB, JSONPATH, UUID, insert
24-
from sqlalchemy.orm import Session, relationship
22+
from sqlalchemy.orm import Session, relationship, sessionmaker
2523

2624
try:
2725
from sqlalchemy.orm import declarative_base
@@ -288,15 +286,19 @@ def __init__(
288286
self.override_relevance_score_fn = relevance_score_fn
289287

290288
if isinstance(connection, str):
291-
self._bind = sqlalchemy.create_engine(url=connection, **(engine_args or {}))
289+
self._engine = sqlalchemy.create_engine(
290+
url=connection, **(engine_args or {})
291+
)
292292
elif isinstance(connection, sqlalchemy.engine.Engine):
293-
self._bind = connection
293+
self._engine = connection
294294
else:
295295
raise ValueError(
296296
"connection should be a connection string or an instance of "
297297
"sqlalchemy.engine.Engine"
298298
)
299299

300+
self._session_maker = sessionmaker(bind=self._engine)
301+
300302
self.use_jsonb = use_jsonb
301303
self.create_extension = create_extension
302304

@@ -321,16 +323,16 @@ def __post_init__(
321323
self.create_collection()
322324

323325
def __del__(self) -> None:
324-
if isinstance(self._bind, sqlalchemy.engine.Connection):
325-
self._bind.close()
326+
if isinstance(self._engine, sqlalchemy.engine.Connection):
327+
self._engine.close()
326328

327329
@property
328330
def embeddings(self) -> Embeddings:
329331
return self.embedding_function
330332

331333
def create_vector_extension(self) -> None:
332334
try:
333-
with Session(self._bind) as session: # type: ignore[arg-type]
335+
with self._session_maker() as session: # type: ignore[arg-type]
334336
# The advisor lock fixes issue arising from concurrent
335337
# creation of the vector extension.
336338
# https://github.com/langchain-ai/langchain/issues/12933
@@ -348,36 +350,31 @@ def create_vector_extension(self) -> None:
348350
raise Exception(f"Failed to create vector extension: {e}") from e
349351

350352
def create_tables_if_not_exists(self) -> None:
351-
with Session(self._bind) as session, session.begin(): # type: ignore[arg-type]
353+
with self._session_maker() as session:
352354
Base.metadata.create_all(session.get_bind())
353355

354356
def drop_tables(self) -> None:
355-
with Session(self._bind) as session, session.begin(): # type: ignore[arg-type]
357+
with self._session_maker() as session:
356358
Base.metadata.drop_all(session.get_bind())
357359

358360
def create_collection(self) -> None:
359361
if self.pre_delete_collection:
360362
self.delete_collection()
361-
with Session(self._bind) as session: # type: ignore[arg-type]
363+
with self._session_maker() as session: # type: ignore[arg-type]
362364
self.CollectionStore.get_or_create(
363365
session, self.collection_name, cmetadata=self.collection_metadata
364366
)
365367

366368
def delete_collection(self) -> None:
367369
self.logger.debug("Trying to delete collection")
368-
with Session(self._bind) as session: # type: ignore[arg-type]
370+
with self._session_maker() as session: # type: ignore[arg-type]
369371
collection = self.get_collection(session)
370372
if not collection:
371373
self.logger.warning("Collection not found")
372374
return
373375
session.delete(collection)
374376
session.commit()
375377

376-
@contextlib.contextmanager
377-
def _make_session(self) -> Generator[Session, None, None]:
378-
"""Create a context manager for the session, bind to _conn string."""
379-
yield Session(self._bind) # type: ignore[arg-type]
380-
381378
def delete(
382379
self,
383380
ids: Optional[List[str]] = None,
@@ -390,7 +387,7 @@ def delete(
390387
ids: List of ids to delete.
391388
collection_only: Only delete ids in the collection.
392389
"""
393-
with Session(self._bind) as session: # type: ignore[arg-type]
390+
with self._session_maker() as session:
394391
if ids is not None:
395392
self.logger.debug(
396393
"Trying to delete vectors by ids (represented by the model "
@@ -476,7 +473,7 @@ def add_embeddings(
476473
if not metadatas:
477474
metadatas = [{} for _ in texts]
478475

479-
with Session(self._bind) as session: # type: ignore[arg-type]
476+
with self._session_maker() as session: # type: ignore[arg-type]
480477
collection = self.get_collection(session)
481478
if not collection:
482479
raise ValueError("Collection not found")
@@ -901,7 +898,7 @@ def __query_collection(
901898
filter: Optional[Dict[str, str]] = None,
902899
) -> List[Any]:
903900
"""Query the collection."""
904-
with Session(self._bind) as session: # type: ignore[arg-type]
901+
with self._session_maker() as session: # type: ignore[arg-type]
905902
collection = self.get_collection(session)
906903
if not collection:
907904
raise ValueError("Collection not found")
@@ -1066,6 +1063,7 @@ def from_existing_index(
10661063
embeddings=embedding,
10671064
distance_strategy=distance_strategy,
10681065
pre_delete_collection=pre_delete_collection,
1066+
**kwargs,
10691067
)
10701068

10711069
return store

tests/unit_tests/test_vectorstore.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""Test PGVector functionality."""
2-
2+
import contextlib
33
from typing import Any, Dict, Generator, List
44

55
import pytest
@@ -18,9 +18,8 @@
1818
TYPE_4_FILTERING_TEST_CASES,
1919
TYPE_5_FILTERING_TEST_CASES,
2020
)
21-
from tests.utils import VECTORSTORE_CONNECTION_STRING
21+
from tests.utils import VECTORSTORE_CONNECTION_STRING as CONNECTION_STRING
2222

23-
CONNECTION_STRING = VECTORSTORE_CONNECTION_STRING
2423
ADA_TOKEN_COUNT = 1536
2524

2625

@@ -159,7 +158,7 @@ def test_pgvector_collection_with_metadata() -> None:
159158
connection=CONNECTION_STRING,
160159
pre_delete_collection=True,
161160
)
162-
with pgvector._make_session() as session:
161+
with pgvector._session_maker() as session:
163162
collection = pgvector.get_collection(session)
164163
if collection is None:
165164
assert False, "Expected a CollectionStore object but received None"
@@ -182,14 +181,14 @@ def test_pgvector_delete_docs() -> None:
182181
pre_delete_collection=True,
183182
)
184183
vectorstore.delete(["1", "2"])
185-
with vectorstore._make_session() as session:
184+
with vectorstore._session_maker() as session:
186185
records = list(session.query(vectorstore.EmbeddingStore).all())
187186
# ignoring type error since mypy cannot determine whether
188187
# the list is sortable
189188
assert sorted(record.id for record in records) == ["3"] # type: ignore
190189

191190
vectorstore.delete(["2", "3"]) # Should not raise on missing ids
192-
with vectorstore._make_session() as session:
191+
with vectorstore._session_maker() as session:
193192
records = list(session.query(vectorstore.EmbeddingStore).all())
194193
# ignoring type error since mypy cannot determine whether
195194
# the list is sortable
@@ -229,7 +228,7 @@ def test_pgvector_index_documents() -> None:
229228
connection=CONNECTION_STRING,
230229
pre_delete_collection=True,
231230
)
232-
with vectorstore._make_session() as session:
231+
with vectorstore._session_maker() as session:
233232
records = list(session.query(vectorstore.EmbeddingStore).all())
234233
# ignoring type error since mypy cannot determine whether
235234
# the list is sortable
@@ -251,7 +250,7 @@ def test_pgvector_index_documents() -> None:
251250

252251
vectorstore.add_documents(documents, ids=[doc.metadata["id"] for doc in documents])
253252

254-
with vectorstore._make_session() as session:
253+
with vectorstore._session_maker() as session:
255254
records = list(session.query(vectorstore.EmbeddingStore).all())
256255
ordered_records = sorted(records, key=lambda x: x.id)
257256
# ignoring type error since mypy cannot determine whether
@@ -408,6 +407,13 @@ def test_pgvector_with_custom_engine_args() -> None:
408407
@pytest.fixture
409408
def pgvector() -> Generator[PGVector, None, None]:
410409
"""Create a PGVector instance."""
410+
with get_vectorstore() as vector_store:
411+
yield vector_store
412+
413+
414+
@contextlib.contextmanager
415+
def get_vectorstore() -> Generator[PGVector, None, None]:
416+
"""Get a pre-populated-vectorstore"""
411417
store = PGVector.from_documents(
412418
documents=DOCUMENTS,
413419
collection_name="test_collection",
@@ -419,20 +425,19 @@ def pgvector() -> Generator[PGVector, None, None]:
419425
)
420426
try:
421427
yield store
422-
# Do clean up
423428
finally:
424429
store.drop_tables()
425430

426431

427432
@pytest.mark.parametrize("test_filter, expected_ids", TYPE_1_FILTERING_TEST_CASES)
428433
def test_pgvector_with_with_metadata_filters_1(
429-
pgvector: PGVector,
430434
test_filter: Dict[str, Any],
431435
expected_ids: List[int],
432436
) -> None:
433437
"""Test end to end construction and search."""
434-
docs = pgvector.similarity_search("meow", k=5, filter=test_filter)
435-
assert [doc.metadata["id"] for doc in docs] == expected_ids, test_filter
438+
with get_vectorstore() as pgvector:
439+
docs = pgvector.similarity_search("meow", k=5, filter=test_filter)
440+
assert [doc.metadata["id"] for doc in docs] == expected_ids, test_filter
436441

437442

438443
@pytest.mark.parametrize("test_filter, expected_ids", TYPE_2_FILTERING_TEST_CASES)

0 commit comments

Comments
 (0)