Skip to content

feat: Add the PGVectorStore class #168

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
a21a8e2
feat: Add the PGVectorStore class
dishaprakash Mar 19, 2025
926f417
Linter and format fix
dishaprakash Mar 19, 2025
a34ddbe
update poetry lock
dishaprakash Mar 19, 2025
bdd2bf6
minor variable name change
dishaprakash Mar 19, 2025
239b1c3
Fix import test
dishaprakash Mar 19, 2025
544cade
enabled socket in one test file
dishaprakash Mar 19, 2025
03dcac1
enabled socket in all test files
dishaprakash Mar 19, 2025
7b4fa7f
Debug tests being skipped
dishaprakash Mar 19, 2025
1d42314
Debug tests being skipped
dishaprakash Mar 19, 2025
dc9a5b8
Debug tests being skipped
dishaprakash Mar 19, 2025
4c3f93f
Debug tests being failed
dishaprakash Mar 19, 2025
b3a12b7
revert debug lines
dishaprakash Mar 19, 2025
8b30833
Remove IVFIndex
dishaprakash Mar 19, 2025
1496033
Minor change
dishaprakash Mar 19, 2025
cbd0889
Review changes
dishaprakash Apr 1, 2025
b436df3
Refactor vectorstore packaging in import
dishaprakash Apr 1, 2025
eb6954d
Change test table names
dishaprakash Apr 1, 2025
3e52c56
Linter fix
dishaprakash Apr 1, 2025
a24fe73
Minor fix
dishaprakash Apr 1, 2025
c74858e
Fix test
dishaprakash Apr 1, 2025
e52e609
Fix tests
dishaprakash Apr 1, 2025
1f6a70e
Remove chat message history format
dishaprakash Apr 1, 2025
8029731
Fix test
dishaprakash Apr 1, 2025
cf58c2a
Fix indexing tests
dishaprakash Apr 1, 2025
b9526c6
Make escape sql string function private
dishaprakash Apr 1, 2025
1d6563a
Rename namespaces
dishaprakash Apr 2, 2025
c9ad8f3
Enable support for TypedDict along with Column
dishaprakash Apr 2, 2025
a913b5a
Fix import test
dishaprakash Apr 2, 2025
9e539e0
Linter fix
dishaprakash Apr 2, 2025
1daac17
Linter fix
dishaprakash Apr 2, 2025
5062185
Add validation and quotes for indexes
dishaprakash Apr 3, 2025
fe62c35
Merge branch 'pg-vectorstore' into upstream-langchain
averikitsch Apr 4, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion langchain_postgres/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from langchain_postgres.chat_message_histories import PostgresChatMessageHistory
from langchain_postgres.engine import Column, PGEngine
from langchain_postgres.translator import PGVectorTranslator
from langchain_postgres.vectorstore import PGVectorStore
from langchain_postgres.vectorstore.v2 import PGVectorStore
from langchain_postgres.vectorstores import PGVector

try:
Expand Down
1 change: 0 additions & 1 deletion langchain_postgres/chat_message_histories.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

This client provides support for both sync and async via psycopg 3.
"""

from __future__ import annotations

import json
Expand Down
117 changes: 49 additions & 68 deletions langchain_postgres/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

import asyncio
from dataclasses import dataclass
import re
from threading import Thread
from typing import TYPE_CHECKING, Any, Awaitable, Optional, TypeVar, Union

from sqlalchemy import MetaData, Table, text
from sqlalchemy import text
from sqlalchemy.engine import URL
from sqlalchemy.exc import InvalidRequestError
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine

if TYPE_CHECKING:
Expand Down Expand Up @@ -54,7 +54,7 @@ def __init__(

Args:
key (object): Prevent direct constructor usage.
engine (AsyncEngine): Async engine connection pool.
pool (AsyncEngine): Async engine connection pool.
loop (Optional[asyncio.AbstractEventLoop]): Async event loop used to create the engine.
thread (Optional[Thread]): Thread used to create the engine async.

Expand Down Expand Up @@ -136,14 +136,18 @@ async def close(self) -> None:
"""Dispose of connection pool"""
await self._run_as_async(self._pool.dispose())

def escape_postgres_identifier(self, name: str) -> str:
return name.replace('"', '""')

async def _ainit_vectorstore_table(
self,
table_name: str,
vector_size: int,
*,
schema_name: str = "public",
content_column: str = "content",
embedding_column: str = "embedding",
metadata_columns: list[Column] = [],
metadata_columns: Optional[list[Column]] = None,
metadata_json_column: str = "langchain_metadata",
id_column: Union[str, Column] = "langchain_id",
overwrite_existing: bool = False,
Expand All @@ -161,8 +165,8 @@ async def _ainit_vectorstore_table(
Default: "page_content".
embedding_column (str) : Name of the column to store vector embeddings.
Default: "embedding".
metadata_columns (list[Column]): A list of Columns to create for custom
metadata. Default: []. Optional.
metadata_columns (Optional[list[Column]]): A list of Columns to create for custom
metadata. Default: None. Optional.
metadata_json_column (str): The column to store extra metadata in JSON format.
Default: "langchain_metadata". Optional.
id_column (Union[str, Column]) : Column to store ids.
Expand All @@ -175,6 +179,21 @@ async def _ainit_vectorstore_table(
:class:`DuplicateTableError <asyncpg.exceptions.DuplicateTableError>`: if table already exists.
:class:`UndefinedObjectError <asyncpg.exceptions.UndefinedObjectError>`: if the data type of the id column is not a postgreSQL data type.
"""

schema_name = self.escape_postgres_identifier(schema_name)
table_name = self.escape_postgres_identifier(table_name)
content_column = self.escape_postgres_identifier(content_column)
embedding_column = self.escape_postgres_identifier(embedding_column)
if metadata_columns is None:
metadata_columns = []
else:
for col in metadata_columns:
col.name = self.escape_postgres_identifier(col.name)
if isinstance(id_column, str):
id_column = self.escape_postgres_identifier(id_column)
else:
id_column.name = self.escape_postgres_identifier(id_column.name)

async with self._pool.connect() as conn:
await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
await conn.commit()
Expand Down Expand Up @@ -208,10 +227,11 @@ async def ainit_vectorstore_table(
self,
table_name: str,
vector_size: int,
*,
schema_name: str = "public",
content_column: str = "content",
embedding_column: str = "embedding",
metadata_columns: list[Column] = [],
metadata_columns: Optional[list[Column]] = None,
metadata_json_column: str = "langchain_metadata",
id_column: Union[str, Column] = "langchain_id",
overwrite_existing: bool = False,
Expand All @@ -229,8 +249,8 @@ async def ainit_vectorstore_table(
Default: "page_content".
embedding_column (str) : Name of the column to store vector embeddings.
Default: "embedding".
metadata_columns (list[Column]): A list of Columns to create for custom
metadata. Default: []. Optional.
metadata_columns (Optional[list[Column]]): A list of Columns to create for custom
metadata. Default: None. Optional.
metadata_json_column (str): The column to store extra metadata in JSON format.
Default: "langchain_metadata". Optional.
id_column (Union[str, Column]) : Column to store ids.
Expand All @@ -243,25 +263,26 @@ async def ainit_vectorstore_table(
self._ainit_vectorstore_table(
table_name,
vector_size,
schema_name,
content_column,
embedding_column,
metadata_columns,
metadata_json_column,
id_column,
overwrite_existing,
store_metadata,
schema_name=schema_name,
content_column=content_column,
embedding_column=embedding_column,
metadata_columns=metadata_columns,
metadata_json_column=metadata_json_column,
id_column=id_column,
overwrite_existing=overwrite_existing,
store_metadata=store_metadata,
)
)

def init_vectorstore_table(
self,
table_name: str,
vector_size: int,
*,
schema_name: str = "public",
content_column: str = "content",
embedding_column: str = "embedding",
metadata_columns: list[Column] = [],
metadata_columns: Optional[list[Column]] = None,
metadata_json_column: str = "langchain_metadata",
id_column: Union[str, Column] = "langchain_id",
overwrite_existing: bool = False,
Expand All @@ -279,8 +300,8 @@ def init_vectorstore_table(
Default: "page_content".
embedding_column (str) : Name of the column to store vector embeddings.
Default: "embedding".
metadata_columns (list[Column]): A list of Columns to create for custom
metadata. Default: []. Optional.
metadata_columns (Optional[list[Column]]): A list of Columns to create for custom
metadata. Default: None. Optional.
metadata_json_column (str): The column to store extra metadata in JSON format.
Default: "langchain_metadata". Optional.
id_column (Union[str, Column]) : Column to store ids.
Expand All @@ -293,53 +314,13 @@ def init_vectorstore_table(
self._ainit_vectorstore_table(
table_name,
vector_size,
schema_name,
content_column,
embedding_column,
metadata_columns,
metadata_json_column,
id_column,
overwrite_existing,
store_metadata,
schema_name=schema_name,
content_column=content_column,
embedding_column=embedding_column,
metadata_columns=metadata_columns,
metadata_json_column=metadata_json_column,
id_column=id_column,
overwrite_existing=overwrite_existing,
store_metadata=store_metadata,
)
)

async def _aload_table_schema(
self, table_name: str, schema_name: str = "public"
) -> Table:
"""
Load table schema from an existing table in a PgSQL database, potentially from a specific database schema.

Args:
table_name: The name of the table to load the table schema from.
schema_name: The name of the database schema where the table resides.
Default: "public".

Returns:
(sqlalchemy.Table): The loaded table, including its table schema information.
"""
metadata = MetaData()
async with self._pool.connect() as conn:
try:
await conn.run_sync(
metadata.reflect, schema=schema_name, only=[table_name]
)
except InvalidRequestError as e:
raise ValueError(
f"Table, '{schema_name}'.'{table_name}', does not exist: " + str(e)
)

table = Table(table_name, metadata, schema=schema_name)
# Extract the schema information
schema = []
for column in table.columns:
schema.append(
{
"name": column.name,
"type": column.type.python_type,
"max_length": getattr(column.type, "length", None),
"nullable": not column.nullable,
}
)

return metadata.tables[f"{schema_name}.{table_name}"]
32 changes: 32 additions & 0 deletions langchain_postgres/indexes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
"""Index class to add vector indexes on the PGVectorStore.

Learn more about vector indexes at https://github.com/pgvector/pgvector?tab=readme-ov-file#indexing
"""

import enum
import re
import warnings
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
Expand Down Expand Up @@ -26,6 +32,18 @@ class DistanceStrategy(StrategyMixin, enum.Enum):

@dataclass
class BaseIndex(ABC):
"""
Abstract base class for defining vector indexes.

Attributes:
name (Optional[str]): A human-readable name for the index. Defaults to None.
index_type (str): A string identifying the type of index. Defaults to "base".
distance_strategy (DistanceStrategy): The strategy used to calculate distances
between vectors in the index. Defaults to DistanceStrategy.COSINE_DISTANCE.
partial_indexes (Optional[list[str]]): A list of names of partial indexes. Defaults to None.
extension_name (Optional[str]): The name of the extension to be created for the index, if any. Defaults to None.
"""

name: Optional[str] = None
index_type: str = "base"
distance_strategy: DistanceStrategy = field(
Expand All @@ -44,6 +62,20 @@ def index_options(self) -> str:
def get_index_function(self) -> str:
return self.distance_strategy.index_function

def __post_init__(self) -> None:
"""Check if initialization parameters are valid.

Raises:
ValueError: extension_name is a valid postgreSQL identifier
"""
if (
self.extension_name
and re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", self.extension_name) is None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: If we are doing validation at the application layer, this should probably be in a standalone function and used in other places as well (e.g., any of the index classes has the same injection issue)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is added as a post init to the BaseIndex, which is extended by all Index classes, so all the indexes run this check after init.
Is there a different way you would like this to be implemented?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check is only for the extension_name if I'm understanding correctly. The other index types seem to also generate SQL (e.g., https://github.com/langchain-ai/langchain-postgres/pull/168/files/cf58c2ab9bedf43df0989a7cc707d70e7e5a66a0#diff-5dbed1276479f8f27084b783f123c32e6acc95a6662c9838f3e126f603925809R106)

We can also handle this in a follow up PR if easier?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right! I missed that. I've seperated that function and now it validates for both extension_name and index_type.

I've also wrapped the index_name in double quotes to allow the same flexibility as tables.

):
raise ValueError(
f"Invalid identifier: {self.extension_name}. Identifiers must start with a letter or underscore, and subsequent characters can be letters, digits, or underscores."
)


@dataclass
class ExactNearestNeighbor(BaseIndex):
Expand Down
2 changes: 1 addition & 1 deletion langchain_postgres/utils/pgvector_migrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from sqlalchemy.exc import ProgrammingError, SQLAlchemyError

from ..engine import PGEngine
from ..vectorstore import PGVectorStore
from ..vectorstore.v2 import PGVectorStore

COLLECTIONS_TABLE = "langchain_pg_collection"
EMBEDDINGS_TABLE = "langchain_pg_embedding"
Expand Down
Loading
Loading