Skip to content

Commit 7532fdf

Browse files
committed
review changes
1 parent 39dc8f1 commit 7532fdf

File tree

1 file changed

+12
-7
lines changed

1 file changed

+12
-7
lines changed

langchain_postgres/v2/async_vectorstore.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,12 @@
44
import base64
55
import copy
66
import json
7-
import re
87
import uuid
98
from typing import Any, Callable, Iterable, Optional, Sequence
9+
from urllib.parse import urlparse
1010

1111
import numpy as np
1212
import requests
13-
from google.cloud import storage # type: ignore
1413
from langchain_core.documents import Document
1514
from langchain_core.embeddings import Embeddings
1615
from langchain_core.vectorstores import VectorStore, utils
@@ -371,16 +370,22 @@ async def aadd_documents(
371370

372371
def _encode_image(self, uri: str) -> str:
373372
"""Get base64 string from a image URI."""
374-
gcs_uri = re.match("gs://(.*?)/(.*)", uri)
375-
if gcs_uri:
376-
bucket_name, object_name = gcs_uri.groups()
373+
if uri.startswith("gs://"):
374+
from google.cloud import storage # type: ignore
375+
376+
path_without_prefix = uri[len("gs://") :]
377+
parts = path_without_prefix.split("/", 1)
378+
bucket_name = parts[0]
379+
object_name = "" # Default for bucket root if no object specified
380+
if len(parts) == 2:
381+
object_name = parts[1]
377382
storage_client = storage.Client()
378383
bucket = storage_client.bucket(bucket_name)
379384
blob = bucket.blob(object_name)
380385
return base64.b64encode(blob.download_as_bytes()).decode("utf-8")
381386

382-
web_uri = re.match(r"^(https?://).*", uri)
383-
if web_uri:
387+
parsed_uri = urlparse(uri)
388+
if parsed_uri.scheme in ["http", "https"]:
384389
response = requests.get(uri, stream=True)
385390
response.raise_for_status()
386391
return base64.b64encode(response.content).decode("utf-8")

0 commit comments

Comments
 (0)