Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for google drive input #61

Merged
merged 1 commit into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions api/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from models.ingest import RequestPayload
from service.embedding import EmbeddingService, get_encoder
from service.ingest import handle_urls, handle_google_drive
from utils.summarise import SUMMARY_SUFFIX

router = APIRouter()
Expand All @@ -15,15 +16,16 @@
async def ingest(payload: RequestPayload) -> Dict:
encoder = get_encoder(encoder_config=payload.encoder)
embedding_service = EmbeddingService(
files=payload.files,
index_name=payload.index_name,
vector_credentials=payload.vector_database,
dimensions=payload.encoder.dimensions,
)
chunks = await embedding_service.generate_chunks()
summary_documents = await embedding_service.generate_summary_documents(
documents=chunks
)
if payload.files:
chunks, summary_documents = await handle_urls(embedding_service, payload.files)
elif payload.google_drive:
chunks, summary_documents = await handle_google_drive(
embedding_service, payload.google_drive
)

await asyncio.gather(
embedding_service.generate_and_upsert_embeddings(
Expand Down
8 changes: 8 additions & 0 deletions models/google_drive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from pydantic import BaseModel, Field


class GoogleDrive(BaseModel):
service_account_key: dict = Field(
..., description="The service account key for Google Drive API"
)
drive_id: str = Field(..., description="The ID of a File or Folder")
4 changes: 3 additions & 1 deletion models/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from models.file import File
from models.vector_database import VectorDatabase
from models.google_drive import GoogleDrive


class EncoderEnum(str, Enum):
Expand All @@ -19,7 +20,8 @@ class Encoder(BaseModel):


class RequestPayload(BaseModel):
files: List[File]
files: Optional[List[File]] = None
google_drive: Optional[GoogleDrive] = None
encoder: Encoder
vector_database: VectorDatabase
index_name: str
Expand Down
607 changes: 597 additions & 10 deletions poetry.lock

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ readme = "README.md"
packages = [{include = "main.py"}]

[tool.poetry.dependencies]
python = ">=3.9,<3.13"
python = ">=3.9,<3.12"
fastapi = "^0.109.2"
uvicorn = "^0.27.1"
weaviate-client = "^4.1.2"
Expand All @@ -31,6 +31,7 @@ python-dotenv = "^1.0.1"
e2b = "^0.14.4"
gunicorn = "^21.2.0"
unstructured-client = "^0.18.0"
unstructured = {extras = ["google-drive"], version = "^0.12.4"}

[tool.poetry.group.dev.dependencies]
termcolor = "^2.4.0"
Expand Down
30 changes: 9 additions & 21 deletions service/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,25 @@

from models.document import BaseDocument, BaseDocumentChunk
from models.file import File
from models.google_drive import GoogleDrive
from models.ingest import Encoder, EncoderEnum
from utils.logger import logger
from utils.summarise import completion
from utils.file import get_file_extension_from_url
from vectordbs import get_vector_service


class EmbeddingService:
def __init__(
self,
files: List[File],
index_name: str,
vector_credentials: dict,
dimensions: Optional[int],
files: Optional[List[File]] = None,
google_drive: Optional[GoogleDrive] = None,
):
self.files = files
self.google_drive = google_drive
self.index_name = index_name
self.vector_credentials = vector_credentials
self.dimensions = dimensions
Expand All @@ -42,20 +46,6 @@ def __init__(
server_url=config("UNSTRUCTURED_IO_SERVER_URL"),
)

def _get_datasource_suffix(self, type: str) -> dict:
suffixes = {
"TXT": ".txt",
"PDF": ".pdf",
"MARKDOWN": ".md",
"DOCX": ".docx",
"CSV": ".csv",
"XLSX": ".xlsx",
}
try:
return suffixes[type]
except KeyError:
raise ValueError("Unsupported datasource type")

def _get_strategy(self, type: str) -> dict:
strategies = {
"PDF": "auto",
Expand All @@ -66,7 +56,7 @@ def _get_strategy(self, type: str) -> dict:
return None

async def _download_and_extract_elements(
self, file, strategy: Optional[str] = "hi_res"
self, file: File, strategy: Optional[str] = "hi_res"
) -> List[Any]:
"""
Downloads the file and extracts elements using the partition function.
Expand All @@ -76,7 +66,7 @@ async def _download_and_extract_elements(
f"Downloading and extracting elements from {file.url},"
f"using `{strategy}` strategy"
)
suffix = self._get_datasource_suffix(file.type.value)
suffix = get_file_extension_from_url(url=file.url)
strategy = self._get_strategy(type=file.type.value)
with NamedTemporaryFile(suffix=suffix, delete=True) as temp_file:
with requests.get(url=file.url) as response:
Expand Down Expand Up @@ -115,7 +105,7 @@ async def generate_document(
doc_metadata = {
"source": file.url,
"source_type": "document",
"document_type": self._get_datasource_suffix(file.type.value),
"document_type": get_file_extension_from_url(url=file.url),
}
return BaseDocument(
id=f"doc_{uuid.uuid4()}",
Expand Down Expand Up @@ -159,9 +149,7 @@ async def generate_chunks(
"document_id": document.id,
"source": file.url,
"source_type": "document",
"document_type": self._get_datasource_suffix(
file.type.value
),
"document_type": get_file_extension_from_url(file.url),
"content": chunk_text,
**sanitized_metadata,
},
Expand Down
23 changes: 23 additions & 0 deletions service/ingest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from typing import List

from models.file import File
from models.google_drive import GoogleDrive
from service.embedding import EmbeddingService


async def handle_urls(
embedding_service: EmbeddingService,
files: List[File],
):
embedding_service.files = files
chunks = await embedding_service.generate_chunks()
summary_documents = await embedding_service.generate_summary_documents(
documents=chunks
)
return chunks, summary_documents


async def handle_google_drive(
_embedding_service: EmbeddingService, _google_drive: GoogleDrive
):
pass
11 changes: 11 additions & 0 deletions utils/file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from urllib.parse import urlparse
import os


def get_file_extension_from_url(url: str) -> str:
"""
Extracts the file extension from a given URL.
"""
path = urlparse(url).path
ext = os.path.splitext(path)[1]
return ext