Skip to content

Commit 43f89a5

Browse files
committed
feat: Adding multi modal support for PGVectorStore
1 parent 4c86319 commit 43f89a5

File tree

7 files changed

+465
-1
lines changed

7 files changed

+465
-1
lines changed

langchain_postgres/v2/async_vectorstore.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
# TODO: Remove below import when minimum supported Python version is 3.10
22
from __future__ import annotations
33

4+
import base64
45
import copy
56
import json
7+
import re
68
import uuid
79
from typing import Any, Callable, Iterable, Optional, Sequence
810

911
import numpy as np
12+
import requests
13+
from google.cloud import storage # type: ignore
1014
from langchain_core.documents import Document
1115
from langchain_core.embeddings import Embeddings
1216
from langchain_core.vectorstores import VectorStore, utils
@@ -365,6 +369,92 @@ async def aadd_documents(
365369
ids = await self.aadd_texts(texts, metadatas=metadatas, ids=ids, **kwargs)
366370
return ids
367371

372+
def _encode_image(self, uri: str) -> str:
373+
"""Get base64 string from a image URI."""
374+
gcs_uri = re.match("gs://(.*?)/(.*)", uri)
375+
if gcs_uri:
376+
bucket_name, object_name = gcs_uri.groups()
377+
storage_client = storage.Client()
378+
bucket = storage_client.bucket(bucket_name)
379+
blob = bucket.blob(object_name)
380+
return base64.b64encode(blob.download_as_bytes()).decode("utf-8")
381+
382+
web_uri = re.match(r"^(https?://).*", uri)
383+
if web_uri:
384+
response = requests.get(uri, stream=True)
385+
response.raise_for_status()
386+
return base64.b64encode(response.content).decode("utf-8")
387+
388+
with open(uri, "rb") as image_file:
389+
return base64.b64encode(image_file.read()).decode("utf-8")
390+
391+
async def aadd_images(
392+
self,
393+
uris: list[str],
394+
metadatas: Optional[list[dict]] = None,
395+
ids: Optional[list[str]] = None,
396+
**kwargs: Any,
397+
) -> list[str]:
398+
"""Embed images and add to the table.
399+
400+
Args:
401+
uris (list[str]): List of local image URIs to add to the table.
402+
metadatas (Optional[list[dict]]): List of metadatas to add to table records.
403+
ids: (Optional[list[str]]): List of IDs to add to table records.
404+
405+
Returns:
406+
List of record IDs added.
407+
"""
408+
encoded_images = []
409+
if metadatas is None:
410+
metadatas = [{"image_uri": uri} for uri in uris]
411+
412+
for uri in uris:
413+
encoded_image = self._encode_image(uri)
414+
encoded_images.append(encoded_image)
415+
416+
embeddings = self._images_embedding_helper(uris)
417+
ids = await self.aadd_embeddings(
418+
encoded_images, embeddings, metadatas=metadatas, ids=ids, **kwargs
419+
)
420+
return ids
421+
422+
def _images_embedding_helper(self, image_uris: list[str]) -> list[list[float]]:
423+
# check if either `embed_images()` or `embed_image()` API is supported by the embedding service used
424+
if hasattr(self.embedding_service, "embed_images"):
425+
try:
426+
embeddings = self.embedding_service.embed_images(image_uris)
427+
except Exception as e:
428+
raise Exception(
429+
f"Make sure your selected embedding model supports list of image URIs as input. {str(e)}"
430+
)
431+
elif hasattr(self.embedding_service, "embed_image"):
432+
try:
433+
embeddings = self.embedding_service.embed_image(image_uris)
434+
except Exception as e:
435+
raise Exception(
436+
f"Make sure your selected embedding model supports list of image URIs as input. {str(e)}"
437+
)
438+
else:
439+
raise ValueError(
440+
"Please use an embedding model that supports image embedding."
441+
)
442+
return embeddings
443+
444+
async def asimilarity_search_image(
445+
self,
446+
image_uri: str,
447+
k: Optional[int] = None,
448+
filter: Optional[dict] = None,
449+
**kwargs: Any,
450+
) -> list[Document]:
451+
"""Return docs selected by similarity search on query."""
452+
embedding = self._images_embedding_helper([image_uri])[0]
453+
454+
return await self.asimilarity_search_by_vector(
455+
embedding=embedding, k=k, filter=filter, **kwargs
456+
)
457+
368458
async def adelete(
369459
self,
370460
ids: Optional[list] = None,
@@ -1268,3 +1358,25 @@ def max_marginal_relevance_search_with_score_by_vector(
12681358
raise NotImplementedError(
12691359
"Sync methods are not implemented for AsyncPGVectorStore. Use PGVectorStore interface instead."
12701360
)
1361+
1362+
def add_images(
1363+
self,
1364+
uris: list[str],
1365+
metadatas: Optional[list[dict]] = None,
1366+
ids: Optional[list[str]] = None,
1367+
**kwargs: Any,
1368+
) -> list[str]:
1369+
raise NotImplementedError(
1370+
"Sync methods are not implemented for AsyncAlloyDBVectorStore. Use AlloyDBVectorStore interface instead."
1371+
)
1372+
1373+
def similarity_search_image(
1374+
self,
1375+
image_uri: str,
1376+
k: Optional[int] = None,
1377+
filter: Optional[dict] = None,
1378+
**kwargs: Any,
1379+
) -> list[Document]:
1380+
raise NotImplementedError(
1381+
"Sync methods are not implemented for AsyncAlloyDBVectorStore. Use AlloyDBVectorStore interface instead."
1382+
)

langchain_postgres/v2/vectorstores.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -840,3 +840,55 @@ def get_by_ids(self, ids: Sequence[str]) -> list[Document]:
840840

841841
def get_table_name(self) -> str:
842842
return self.__vs.table_name
843+
844+
async def aadd_images(
845+
self,
846+
uris: list[str],
847+
metadatas: Optional[list[dict]] = None,
848+
ids: Optional[list[str]] = None,
849+
**kwargs: Any,
850+
) -> list[str]:
851+
"""Embed images and add to the table."""
852+
return await self._engine._run_as_async(
853+
self._PGVectorStore__vs.aadd_images(uris, metadatas, ids, **kwargs) # type: ignore
854+
)
855+
856+
def add_images(
857+
self,
858+
uris: list[str],
859+
metadatas: Optional[list[dict]] = None,
860+
ids: Optional[list[str]] = None,
861+
**kwargs: Any,
862+
) -> list[str]:
863+
"""Embed images and add to the table."""
864+
return self._engine._run_as_sync(
865+
self._PGVectorStore__vs.aadd_images(uris, metadatas, ids, **kwargs) # type: ignore
866+
)
867+
868+
def similarity_search_image(
869+
self,
870+
image_uri: str,
871+
k: Optional[int] = None,
872+
filter: Optional[dict] = None,
873+
**kwargs: Any,
874+
) -> list[Document]:
875+
"""Return docs selected by similarity search on image."""
876+
return self._engine._run_as_sync(
877+
self._PGVectorStore__vs.asimilarity_search_image(
878+
image_uri, k, filter, **kwargs
879+
) # type: ignore
880+
)
881+
882+
async def asimilarity_search_image(
883+
self,
884+
image_uri: str,
885+
k: Optional[int] = None,
886+
filter: Optional[dict] = None,
887+
**kwargs: Any,
888+
) -> list[Document]:
889+
"""Return docs selected by similarity search on image_uri."""
890+
return await self._engine._run_as_async(
891+
self._PGVectorStore__vs.asimilarity_search_image(
892+
image_uri, k, filter, **kwargs
893+
) # type: ignore
894+
)

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ sqlalchemy = "^2"
1919
pgvector = ">=0.2.5,<0.4"
2020
numpy = "^1.21"
2121
asyncpg = "^0.30.0"
22+
google-cloud-storage = ">=2.18.2, <4.0.0"
2223

2324
[tool.poetry.group.docs.dependencies]
2425

@@ -37,6 +38,7 @@ pytest-socket = "^0.7.0"
3738
pytest-cov = "^5.0.0"
3839
pytest-timeout = "^2.3.1"
3940
langchain-tests = "0.3.7"
41+
pillow = "11.1.0"
4042

4143
[tool.poetry.group.codespell]
4244
optional = true

tests/unit_tests/v2/test_async_pg_vectorstore.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
import os
12
import uuid
23
from typing import AsyncIterator, Sequence
34

45
import pytest
56
import pytest_asyncio
67
from langchain_core.documents import Document
78
from langchain_core.embeddings import DeterministicFakeEmbedding
9+
from PIL import Image
810
from sqlalchemy import text
911
from sqlalchemy.engine.row import RowMapping
1012

@@ -15,6 +17,7 @@
1517
DEFAULT_TABLE = "default" + str(uuid.uuid4())
1618
DEFAULT_TABLE_SYNC = "default_sync" + str(uuid.uuid4())
1719
CUSTOM_TABLE = "custom" + str(uuid.uuid4())
20+
IMAGE_TABLE = "image_table" + str(uuid.uuid4())
1821
VECTOR_SIZE = 768
1922

2023
embeddings_service = DeterministicFakeEmbedding(size=VECTOR_SIZE)
@@ -28,6 +31,14 @@
2831
embeddings = [embeddings_service.embed_query(texts[i]) for i in range(len(texts))]
2932

3033

34+
class FakeImageEmbedding(DeterministicFakeEmbedding):
35+
def embed_image(self, image_paths: list[str]) -> list[list[float]]:
36+
return [self.embed_query(path) for path in image_paths]
37+
38+
39+
image_embedding_service = FakeImageEmbedding(size=VECTOR_SIZE)
40+
41+
3142
async def aexecute(engine: PGEngine, query: str) -> None:
3243
async with engine._pool.connect() as conn:
3344
await conn.execute(text(query))
@@ -52,6 +63,7 @@ async def engine(self) -> AsyncIterator[PGEngine]:
5263
yield engine
5364
await engine.adrop_table(DEFAULT_TABLE)
5465
await engine.adrop_table(CUSTOM_TABLE)
66+
await engine.adrop_table(IMAGE_TABLE)
5567
await engine.close()
5668

5769
@pytest_asyncio.fixture(scope="class")
@@ -87,6 +99,45 @@ async def vs_custom(self, engine: PGEngine) -> AsyncIterator[AsyncPGVectorStore]
8799
)
88100
yield vs
89101

102+
@pytest_asyncio.fixture(scope="class")
103+
async def image_vs(self, engine: PGEngine) -> AsyncIterator[AsyncPGVectorStore]:
104+
await engine._ainit_vectorstore_table(
105+
IMAGE_TABLE,
106+
VECTOR_SIZE,
107+
metadata_columns=[
108+
Column("image_id", "TEXT"),
109+
Column("source", "TEXT"),
110+
],
111+
)
112+
vs = await AsyncPGVectorStore.create(
113+
engine,
114+
embedding_service=image_embedding_service,
115+
table_name=IMAGE_TABLE,
116+
metadata_columns=["image_id", "source"],
117+
metadata_json_column="mymeta",
118+
)
119+
yield vs
120+
121+
@pytest_asyncio.fixture(scope="class")
122+
async def image_uris(self) -> AsyncIterator[list[str]]:
123+
red_uri = str(uuid.uuid4()).replace("-", "_") + "test_image_red.jpg"
124+
green_uri = str(uuid.uuid4()).replace("-", "_") + "test_image_green.jpg"
125+
blue_uri = str(uuid.uuid4()).replace("-", "_") + "test_image_blue.jpg"
126+
gcs_uri = "gs://github-repo/img/vision/google-cloud-next.jpeg"
127+
image = Image.new("RGB", (100, 100), color="red")
128+
image.save(red_uri)
129+
image = Image.new("RGB", (100, 100), color="green")
130+
image.save(green_uri)
131+
image = Image.new("RGB", (100, 100), color="blue")
132+
image.save(blue_uri)
133+
image_uris = [red_uri, green_uri, blue_uri, gcs_uri]
134+
yield image_uris
135+
for uri in image_uris:
136+
try:
137+
os.remove(uri)
138+
except FileNotFoundError:
139+
pass
140+
90141
async def test_init_with_constructor(self, engine: PGEngine) -> None:
91142
with pytest.raises(Exception):
92143
AsyncPGVectorStore(
@@ -165,6 +216,20 @@ async def test_adelete(self, engine: PGEngine, vs: AsyncPGVectorStore) -> None:
165216
assert result == False
166217
await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"')
167218

219+
async def test_aadd_images(
220+
self, engine: PGEngine, image_vs: AsyncPGVectorStore, image_uris: list[str]
221+
) -> None:
222+
ids = [str(uuid.uuid4()) for i in range(len(image_uris))]
223+
metadatas = [
224+
{"image_id": str(i), "source": "postgres"} for i in range(len(image_uris))
225+
]
226+
await image_vs.aadd_images(image_uris, metadatas, ids)
227+
results = await afetch(engine, (f'SELECT * FROM "{IMAGE_TABLE}"'))
228+
assert len(results) == len(image_uris)
229+
assert results[0]["image_id"] == "0"
230+
assert results[0]["source"] == "postgres"
231+
await aexecute(engine, (f'TRUNCATE TABLE "{IMAGE_TABLE}"'))
232+
168233
##### Custom Vector Store #####
169234
async def test_aadd_embeddings(
170235
self, engine: PGEngine, vs_custom: AsyncPGVectorStore

0 commit comments

Comments
 (0)