|
4 | 4 | import base64
|
5 | 5 | import copy
|
6 | 6 | import json
|
7 |
| -import re |
8 | 7 | import uuid
|
9 | 8 | from typing import Any, Callable, Iterable, Optional, Sequence
|
| 9 | +from urllib.parse import urlparse |
10 | 10 |
|
11 | 11 | import numpy as np
|
12 | 12 | import requests
|
13 |
| -from google.cloud import storage # type: ignore |
14 | 13 | from langchain_core.documents import Document
|
15 | 14 | from langchain_core.embeddings import Embeddings
|
16 | 15 | from langchain_core.vectorstores import VectorStore, utils
|
@@ -371,16 +370,22 @@ async def aadd_documents(
|
371 | 370 |
|
372 | 371 | def _encode_image(self, uri: str) -> str:
|
373 | 372 | """Get base64 string from a image URI."""
|
374 |
| - gcs_uri = re.match("gs://(.*?)/(.*)", uri) |
375 |
| - if gcs_uri: |
376 |
| - bucket_name, object_name = gcs_uri.groups() |
| 373 | + if uri.startswith("gs://"): |
| 374 | + from google.cloud import storage # type: ignore |
| 375 | + |
| 376 | + path_without_prefix = uri[len("gs://") :] |
| 377 | + parts = path_without_prefix.split("/", 1) |
| 378 | + bucket_name = parts[0] |
| 379 | + object_name = "" # Default for bucket root if no object specified |
| 380 | + if len(parts) == 2: |
| 381 | + object_name = parts[1] |
377 | 382 | storage_client = storage.Client()
|
378 | 383 | bucket = storage_client.bucket(bucket_name)
|
379 | 384 | blob = bucket.blob(object_name)
|
380 | 385 | return base64.b64encode(blob.download_as_bytes()).decode("utf-8")
|
381 | 386 |
|
382 |
| - web_uri = re.match(r"^(https?://).*", uri) |
383 |
| - if web_uri: |
| 387 | + parsed_uri = urlparse(uri) |
| 388 | + if parsed_uri.scheme in ["http", "https"]: |
384 | 389 | response = requests.get(uri, stream=True)
|
385 | 390 | response.raise_for_status()
|
386 | 391 | return base64.b64encode(response.content).decode("utf-8")
|
|
0 commit comments