Skip to content

Commit 8264094

Browse files
authored
Add support for google drive input (#61)
1 parent 8163aed commit 8264094

File tree

8 files changed

+660
-38
lines changed

8 files changed

+660
-38
lines changed

api/ingest.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from models.ingest import RequestPayload
88
from service.embedding import EmbeddingService, get_encoder
9+
from service.ingest import handle_urls, handle_google_drive
910
from utils.summarise import SUMMARY_SUFFIX
1011

1112
router = APIRouter()
@@ -15,15 +16,16 @@
1516
async def ingest(payload: RequestPayload) -> Dict:
1617
encoder = get_encoder(encoder_config=payload.encoder)
1718
embedding_service = EmbeddingService(
18-
files=payload.files,
1919
index_name=payload.index_name,
2020
vector_credentials=payload.vector_database,
2121
dimensions=payload.encoder.dimensions,
2222
)
23-
chunks = await embedding_service.generate_chunks()
24-
summary_documents = await embedding_service.generate_summary_documents(
25-
documents=chunks
26-
)
23+
if payload.files:
24+
chunks, summary_documents = await handle_urls(embedding_service, payload.files)
25+
elif payload.google_drive:
26+
chunks, summary_documents = await handle_google_drive(
27+
embedding_service, payload.google_drive
28+
)
2729

2830
await asyncio.gather(
2931
embedding_service.generate_and_upsert_embeddings(

models/google_drive.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from pydantic import BaseModel, Field
2+
3+
4+
class GoogleDrive(BaseModel):
5+
service_account_key: dict = Field(
6+
..., description="The service account key for Google Drive API"
7+
)
8+
drive_id: str = Field(..., description="The ID of a File or Folder")

models/ingest.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from models.file import File
77
from models.vector_database import VectorDatabase
8+
from models.google_drive import GoogleDrive
89

910

1011
class EncoderEnum(str, Enum):
@@ -19,7 +20,8 @@ class Encoder(BaseModel):
1920

2021

2122
class RequestPayload(BaseModel):
22-
files: List[File]
23+
files: Optional[List[File]] = None
24+
google_drive: Optional[GoogleDrive] = None
2325
encoder: Encoder
2426
vector_database: VectorDatabase
2527
index_name: str

poetry.lock

Lines changed: 597 additions & 10 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ readme = "README.md"
77
packages = [{include = "main.py"}]
88

99
[tool.poetry.dependencies]
10-
python = ">=3.9,<3.13"
10+
python = ">=3.9,<3.12"
1111
fastapi = "^0.109.2"
1212
uvicorn = "^0.27.1"
1313
weaviate-client = "^4.1.2"
@@ -31,6 +31,7 @@ python-dotenv = "^1.0.1"
3131
e2b = "^0.14.4"
3232
gunicorn = "^21.2.0"
3333
unstructured-client = "^0.18.0"
34+
unstructured = {extras = ["google-drive"], version = "^0.12.4"}
3435

3536
[tool.poetry.group.dev.dependencies]
3637
termcolor = "^2.4.0"

service/embedding.py

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,21 +19,25 @@
1919

2020
from models.document import BaseDocument, BaseDocumentChunk
2121
from models.file import File
22+
from models.google_drive import GoogleDrive
2223
from models.ingest import Encoder, EncoderEnum
2324
from utils.logger import logger
2425
from utils.summarise import completion
26+
from utils.file import get_file_extension_from_url
2527
from vectordbs import get_vector_service
2628

2729

2830
class EmbeddingService:
2931
def __init__(
3032
self,
31-
files: List[File],
3233
index_name: str,
3334
vector_credentials: dict,
3435
dimensions: Optional[int],
36+
files: Optional[List[File]] = None,
37+
google_drive: Optional[GoogleDrive] = None,
3538
):
3639
self.files = files
40+
self.google_drive = google_drive
3741
self.index_name = index_name
3842
self.vector_credentials = vector_credentials
3943
self.dimensions = dimensions
@@ -42,20 +46,6 @@ def __init__(
4246
server_url=config("UNSTRUCTURED_IO_SERVER_URL"),
4347
)
4448

45-
def _get_datasource_suffix(self, type: str) -> dict:
46-
suffixes = {
47-
"TXT": ".txt",
48-
"PDF": ".pdf",
49-
"MARKDOWN": ".md",
50-
"DOCX": ".docx",
51-
"CSV": ".csv",
52-
"XLSX": ".xlsx",
53-
}
54-
try:
55-
return suffixes[type]
56-
except KeyError:
57-
raise ValueError("Unsupported datasource type")
58-
5949
def _get_strategy(self, type: str) -> dict:
6050
strategies = {
6151
"PDF": "auto",
@@ -66,7 +56,7 @@ def _get_strategy(self, type: str) -> dict:
6656
return None
6757

6858
async def _download_and_extract_elements(
69-
self, file, strategy: Optional[str] = "hi_res"
59+
self, file: File, strategy: Optional[str] = "hi_res"
7060
) -> List[Any]:
7161
"""
7262
Downloads the file and extracts elements using the partition function.
@@ -76,7 +66,7 @@ async def _download_and_extract_elements(
7666
f"Downloading and extracting elements from {file.url},"
7767
f"using `{strategy}` strategy"
7868
)
79-
suffix = self._get_datasource_suffix(file.type.value)
69+
suffix = get_file_extension_from_url(url=file.url)
8070
strategy = self._get_strategy(type=file.type.value)
8171
with NamedTemporaryFile(suffix=suffix, delete=True) as temp_file:
8272
with requests.get(url=file.url) as response:
@@ -115,7 +105,7 @@ async def generate_document(
115105
doc_metadata = {
116106
"source": file.url,
117107
"source_type": "document",
118-
"document_type": self._get_datasource_suffix(file.type.value),
108+
"document_type": get_file_extension_from_url(url=file.url),
119109
}
120110
return BaseDocument(
121111
id=f"doc_{uuid.uuid4()}",
@@ -159,9 +149,7 @@ async def generate_chunks(
159149
"document_id": document.id,
160150
"source": file.url,
161151
"source_type": "document",
162-
"document_type": self._get_datasource_suffix(
163-
file.type.value
164-
),
152+
"document_type": get_file_extension_from_url(file.url),
165153
"content": chunk_text,
166154
**sanitized_metadata,
167155
},

service/ingest.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from typing import List
2+
3+
from models.file import File
4+
from models.google_drive import GoogleDrive
5+
from service.embedding import EmbeddingService
6+
7+
8+
async def handle_urls(
9+
embedding_service: EmbeddingService,
10+
files: List[File],
11+
):
12+
embedding_service.files = files
13+
chunks = await embedding_service.generate_chunks()
14+
summary_documents = await embedding_service.generate_summary_documents(
15+
documents=chunks
16+
)
17+
return chunks, summary_documents
18+
19+
20+
async def handle_google_drive(
21+
_embedding_service: EmbeddingService, _google_drive: GoogleDrive
22+
):
23+
pass

utils/file.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from urllib.parse import urlparse
2+
import os
3+
4+
5+
def get_file_extension_from_url(url: str) -> str:
6+
"""
7+
Extracts the file extension from a given URL.
8+
"""
9+
path = urlparse(url).path
10+
ext = os.path.splitext(path)[1]
11+
return ext

0 commit comments

Comments
 (0)