Skip to content

Commit

Permalink
Add support for google drive input (#61)
Browse files Browse the repository at this point in the history
  • Loading branch information
homanp authored Feb 23, 2024
1 parent 8163aed commit 8264094
Show file tree
Hide file tree
Showing 8 changed files with 660 additions and 38 deletions.
12 changes: 7 additions & 5 deletions api/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from models.ingest import RequestPayload
from service.embedding import EmbeddingService, get_encoder
from service.ingest import handle_urls, handle_google_drive
from utils.summarise import SUMMARY_SUFFIX

router = APIRouter()
Expand All @@ -15,15 +16,16 @@
async def ingest(payload: RequestPayload) -> Dict:
encoder = get_encoder(encoder_config=payload.encoder)
embedding_service = EmbeddingService(
files=payload.files,
index_name=payload.index_name,
vector_credentials=payload.vector_database,
dimensions=payload.encoder.dimensions,
)
chunks = await embedding_service.generate_chunks()
summary_documents = await embedding_service.generate_summary_documents(
documents=chunks
)
if payload.files:
chunks, summary_documents = await handle_urls(embedding_service, payload.files)
elif payload.google_drive:
chunks, summary_documents = await handle_google_drive(
embedding_service, payload.google_drive
)

await asyncio.gather(
embedding_service.generate_and_upsert_embeddings(
Expand Down
8 changes: 8 additions & 0 deletions models/google_drive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from pydantic import BaseModel, Field


class GoogleDrive(BaseModel):
service_account_key: dict = Field(
..., description="The service account key for Google Drive API"
)
drive_id: str = Field(..., description="The ID of a File or Folder")
4 changes: 3 additions & 1 deletion models/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from models.file import File
from models.vector_database import VectorDatabase
from models.google_drive import GoogleDrive


class EncoderEnum(str, Enum):
Expand All @@ -19,7 +20,8 @@ class Encoder(BaseModel):


class RequestPayload(BaseModel):
files: List[File]
files: Optional[List[File]] = None
google_drive: Optional[GoogleDrive] = None
encoder: Encoder
vector_database: VectorDatabase
index_name: str
Expand Down
607 changes: 597 additions & 10 deletions poetry.lock

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ readme = "README.md"
packages = [{include = "main.py"}]

[tool.poetry.dependencies]
python = ">=3.9,<3.13"
python = ">=3.9,<3.12"
fastapi = "^0.109.2"
uvicorn = "^0.27.1"
weaviate-client = "^4.1.2"
Expand All @@ -31,6 +31,7 @@ python-dotenv = "^1.0.1"
e2b = "^0.14.4"
gunicorn = "^21.2.0"
unstructured-client = "^0.18.0"
unstructured = {extras = ["google-drive"], version = "^0.12.4"}

[tool.poetry.group.dev.dependencies]
termcolor = "^2.4.0"
Expand Down
30 changes: 9 additions & 21 deletions service/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,25 @@

from models.document import BaseDocument, BaseDocumentChunk
from models.file import File
from models.google_drive import GoogleDrive
from models.ingest import Encoder, EncoderEnum
from utils.logger import logger
from utils.summarise import completion
from utils.file import get_file_extension_from_url
from vectordbs import get_vector_service


class EmbeddingService:
def __init__(
self,
files: List[File],
index_name: str,
vector_credentials: dict,
dimensions: Optional[int],
files: Optional[List[File]] = None,
google_drive: Optional[GoogleDrive] = None,
):
self.files = files
self.google_drive = google_drive
self.index_name = index_name
self.vector_credentials = vector_credentials
self.dimensions = dimensions
Expand All @@ -42,20 +46,6 @@ def __init__(
server_url=config("UNSTRUCTURED_IO_SERVER_URL"),
)

def _get_datasource_suffix(self, type: str) -> dict:
suffixes = {
"TXT": ".txt",
"PDF": ".pdf",
"MARKDOWN": ".md",
"DOCX": ".docx",
"CSV": ".csv",
"XLSX": ".xlsx",
}
try:
return suffixes[type]
except KeyError:
raise ValueError("Unsupported datasource type")

def _get_strategy(self, type: str) -> dict:
strategies = {
"PDF": "auto",
Expand All @@ -66,7 +56,7 @@ def _get_strategy(self, type: str) -> dict:
return None

async def _download_and_extract_elements(
self, file, strategy: Optional[str] = "hi_res"
self, file: File, strategy: Optional[str] = "hi_res"
) -> List[Any]:
"""
Downloads the file and extracts elements using the partition function.
Expand All @@ -76,7 +66,7 @@ async def _download_and_extract_elements(
f"Downloading and extracting elements from {file.url},"
f"using `{strategy}` strategy"
)
suffix = self._get_datasource_suffix(file.type.value)
suffix = get_file_extension_from_url(url=file.url)
strategy = self._get_strategy(type=file.type.value)
with NamedTemporaryFile(suffix=suffix, delete=True) as temp_file:
with requests.get(url=file.url) as response:
Expand Down Expand Up @@ -115,7 +105,7 @@ async def generate_document(
doc_metadata = {
"source": file.url,
"source_type": "document",
"document_type": self._get_datasource_suffix(file.type.value),
"document_type": get_file_extension_from_url(url=file.url),
}
return BaseDocument(
id=f"doc_{uuid.uuid4()}",
Expand Down Expand Up @@ -159,9 +149,7 @@ async def generate_chunks(
"document_id": document.id,
"source": file.url,
"source_type": "document",
"document_type": self._get_datasource_suffix(
file.type.value
),
"document_type": get_file_extension_from_url(file.url),
"content": chunk_text,
**sanitized_metadata,
},
Expand Down
23 changes: 23 additions & 0 deletions service/ingest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from typing import List

from models.file import File
from models.google_drive import GoogleDrive
from service.embedding import EmbeddingService


async def handle_urls(
embedding_service: EmbeddingService,
files: List[File],
):
embedding_service.files = files
chunks = await embedding_service.generate_chunks()
summary_documents = await embedding_service.generate_summary_documents(
documents=chunks
)
return chunks, summary_documents


async def handle_google_drive(
_embedding_service: EmbeddingService, _google_drive: GoogleDrive
):
pass
11 changes: 11 additions & 0 deletions utils/file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from urllib.parse import urlparse
import os


def get_file_extension_from_url(url: str) -> str:
"""
Extracts the file extension from a given URL.
"""
path = urlparse(url).path
ext = os.path.splitext(path)[1]
return ext

0 comments on commit 8264094

Please sign in to comment.