Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/Querent-ai/querent-ai into …
Browse files Browse the repository at this point in the history
…ayush
  • Loading branch information
the-non-expert committed Aug 29, 2023
2 parents 66328e2 + a935c81 commit 3f0c2da
Show file tree
Hide file tree
Showing 16 changed files with 291 additions and 68 deletions.
30 changes: 30 additions & 0 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
name: Run Pytest on Branches

on:
push:
branches:
- '*'
paths-ignore:
- 'README.md' # Add any paths you want to exclude

jobs:
pytest:
if: ${{ github.ref != 'refs/heads/main' }}
runs-on: ubuntu-latest

steps:
- name: Checkout code
uses: actions/checkout@v2

- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: 3.8

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
- name: Run Pytest
run: pytest --disable-warnings .
5 changes: 3 additions & 2 deletions querent/collectors/collector_base.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from abc import ABC, abstractmethod
from typing import AsyncGenerator

from querent.collectors.collector_result import CollectorResult
from querent.common.types.collected_bytes import CollectedBytes


class Collector(ABC):
@abstractmethod
async def connect(self):
pass

@abstractmethod
async def poll(self) -> AsyncGenerator[CollectorResult, None]:
async def poll(self) -> AsyncGenerator[CollectedBytes, None]:
pass

@abstractmethod
Expand Down
3 changes: 3 additions & 0 deletions querent/collectors/collector_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
from pathlib import Path
from enum import Enum


class CollectorErrorKind(Enum):
NotFound = "not_found"
Unauthorized = "unauthorized"
incompatible = "incompatible"
NotSupported = "not_supported"


class CollectorResolverError(Exception):
def __init__(self, kind: CollectorErrorKind, message: str):
super().__init__(message)
Expand All @@ -24,6 +26,7 @@ class NotFoundError(CollectorError):
def __init__(self, message: str):
super().__init__(CollectorErrorKind.NotFound, message)


class UnauthorizedError(CollectorError):
def __init__(self, message: str):
super().__init__(CollectorErrorKind.unauthorized, message)
Expand Down
9 changes: 7 additions & 2 deletions querent/collectors/collector_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
from querent.collectors.webscaper.web_scraper_collector import WebScraperFactory
from querent.config.collector_config import CollectConfig, CollectorBackend
from querent.collectors.collector_base import Collector
from querent.collectors.collector_errors import CollectorResolverError, CollectorErrorKind
from querent.collectors.collector_errors import (
CollectorResolverError,
CollectorErrorKind,
)
from querent.common.uri import Protocol, Uri


Expand Down Expand Up @@ -39,7 +42,9 @@ def _determine_backend(self, protocol: Protocol) -> CollectorBackend:
return CollectorBackend.WebScraper
elif protocol.is_grpc():
return CollectorBackend.Gcs
if protocol.is_webscrapper():
return CollectorBackend.WebScraper
else:
raise CollectorResolverError(
CollectorErrorKind.NotSupported, "Unknown backend", "Unknown backend"
CollectorErrorKind.NotSupported, "Unknown backend"
)
8 changes: 5 additions & 3 deletions querent/collectors/fs/fs_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import AsyncGenerator
from querent.collectors.collector_base import Collector
from querent.collectors.collector_factory import CollectorFactory
from querent.collectors.collector_result import CollectorResult
from querent.common.types.collected_bytes import CollectedBytes
from querent.common.uri import Uri
from querent.config.collector_config import CollectorBackend, FSCollectorConfig
import aiofiles
Expand All @@ -22,11 +22,11 @@ async def disconnect(self):
# Add your cleanup logic here if needed
pass

async def poll(self) -> AsyncGenerator[CollectorResult, None]:
async def poll(self) -> AsyncGenerator[CollectedBytes, None]:
async for file_path in self.walk_files(self.root_dir):
async with aiofiles.open(file_path, "rb") as file:
async for chunk in self.read_chunks(file):
yield CollectorResult({"file_path": file_path, "chunk": chunk})
yield CollectedBytes(file=file_path, data=chunk, error=None)

async def read_chunks(self, file):
while True:
Expand All @@ -43,6 +43,8 @@ async def walk_files(self, root: Path) -> AsyncGenerator[Path, None]:
elif item.is_dir():
async for file_path in self.walk_files(item):
yield file_path


class FSCollectorFactory(CollectorFactory):
def __init__(self):
pass
Expand Down
15 changes: 9 additions & 6 deletions querent/collectors/webscaper/web_scraper_collector.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from querent.collectors.collector_base import Collector
from querent.collectors.collector_factory import CollectorFactory
from querent.collectors.collector_result import CollectorResult
from querent.common.types.collected_bytes import CollectedBytes
from querent.config.collector_config import CollectorBackend, WebScraperConfig
from querent.tools.web_page_extractor import WebpageExtractor
from querent.common.uri import Uri


class WebScraperCollector(Collector):
def __init__(self, config: WebScraperConfig):
Expand All @@ -16,19 +18,20 @@ async def disconnect(self):

async def poll(self):
content = await self.scrape_website(self.website_url)
yield CollectorResult(content)
yield CollectedBytes(file=None, data=content.data, error=None)

async def scrape_website(self, website_url: str):
content = WebpageExtractor().extract_with_bs4(website_url)
max_length = len(' '.join(content.split(" ")[:600]))
return content[:max_length]
max_length = len(" ".join(content.split(" ")[:600]))
return CollectedBytes(data=content[:max_length], file=None, error=None)


class WebScraperFactory(CollectorFactory):
def __init__(self):
pass

def backend(self) -> CollectorBackend:
return CollectorBackend.WebScraper

def resolve(self, config: WebScraperConfig) -> Collector:
def resolve(self, uri: Uri, config: WebScraperConfig) -> Collector:
return WebScraperCollector(config)
Empty file.
53 changes: 53 additions & 0 deletions querent/common/types/collected_bytes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from typing import Union


class CollectedBytes:
def __init__(self, file: str, data: bytes, error: str = None) -> None:
self.data = data
self.error = error
self.file = file
if self.file:
file = str(file)
self.extension = file.split(".")[-1]
self.file_id = file.split("/")[-1].split(".")[0]

def __str__(self):
if self.error:
return f"Error: {self.error}"
return f"Data: {self.data}"

def is_error(self) -> bool:
return self.error is not None

def get_file_path(self) -> str:
return self.file

def get_extension(self) -> str:
return self.extension

def get_file_id(self) -> str:
return self.file_id

@classmethod
def success(cls, data: bytes) -> "CollectedBytes":
return cls(data)

@classmethod
def error(cls, error: str) -> "CollectedBytes":
return cls(None, error)

def unwrap(self) -> bytes:
if self.error:
raise ValueError(self.error)
return self.data

def unwrap_or(self, default: bytes) -> bytes:
return self.data if not self.error else default

def __eq__(self, other: Union[bytes, "CollectedBytes"]) -> bool:
if isinstance(other, CollectedBytes):
return self.data == other.data and self.error == other.error
return self.data == other

def __hash__(self) -> int:
return hash((self.data, self.error))
7 changes: 4 additions & 3 deletions querent/common/uri.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,20 @@ def is_object_storage(self) -> bool:
def is_database(self) -> bool:
return self == Protocol.PostgreSQL

def is_webscraper(self) -> bool:
def is_webscrapper(self) -> bool:
return self == Protocol.Webscraper


class Uri:
PROTOCOL_SEPARATOR = "://"
DATABASE_URI_PATTERN = re.compile(
r"(?P<before>^.*://.*)(?P<password>:.*@)(?P<after>.*)")
r"(?P<before>^.*://.*)(?P<password>:.*@)(?P<after>.*)"
)

def __init__(self, uri: str):
self.uri = uri
self.protocol_idx = uri.find(self.PROTOCOL_SEPARATOR)
self.protocol = Protocol(uri[:self.protocol_idx])
self.protocol = Protocol(uri[: self.protocol_idx])

@classmethod
def from_well_formed(cls, uri: str) -> "Uri":
Expand Down
3 changes: 2 additions & 1 deletion querent/ingestors/pdf_ingestor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import PyPDF2
import pypdf


class PDFConnector:
Expand All @@ -10,7 +11,7 @@ def __init__(self, file_path):
def open_pdf(self):
"""Open the PDF file."""
self.pdf_file = open(self.file_path, 'rb')
self.pdf_reader = PyPDF2.PdfReader(self.pdf_file)
self.pdf_reader = pypdf.pdf_reader(self.pdf_file)

def authenticate(self, password):
"""Authenticate the connection if the PDF is encrypted."""
Expand Down
24 changes: 16 additions & 8 deletions querent/storage/local/local_file_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from querent.storage.storage_factory import StorageFactory
from querent.storage.storage_result import StorageResult


class AsyncDebouncer:
def __init__(self):
self.cache = {}
Expand Down Expand Up @@ -43,10 +44,12 @@ async def get_or_create(self, key, build_a_future):
del self.cache[key]
return result


class DebouncerEntry:
def __init__(self, future):
self.future = future


class DebouncedStorage:
def __init__(self, underlying):
self.underlying = underlying
Expand Down Expand Up @@ -84,7 +87,7 @@ async def bulk_delete(self, paths):
await self.underlying.bulk_delete(paths)

async def get_all(self, path):
key = (path, 0, float('inf'))
key = (path, 0, float("inf"))
cached_result = await self.get_slice_cache(key)
if cached_result is None:
result = await self.underlying.get_all(path)
Expand All @@ -98,6 +101,7 @@ async def file_num_bytes(self, path):
def get_uri(self):
return self.underlying.get_uri()


class LocalFileStorage(Storage):
def __init__(self, uri: Uri, root=None):
self.uri = uri
Expand All @@ -106,6 +110,9 @@ def __init__(self, uri: Uri, root=None):
self.root = root
self.cache_lock = Lock()

async def initialize_lock(self):
self.cache_lock = Lock()

async def full_path(self, relative_path):
await self.ensure_valid_relative_path(relative_path)
return self.root / relative_path
Expand All @@ -129,7 +136,7 @@ async def check_connectivity(self):
f"Failed to create directories at {self.root}: {e}",
)

async def put(self, path: Path, payload: PutPayload)-> StorageResult:
async def put(self, path: Path, payload: PutPayload) -> StorageResult:
full_path = await self.full_path(path)
parent_dir = full_path.parent
try:
Expand All @@ -153,18 +160,18 @@ async def copy_to(self, path, output) -> StorageResult:
await asyncio.to_thread(shutil.copyfileobj, file, output)
return StorageResult.success(None)

async def get_slice(self, path, start, end)-> StorageResult:
async def get_slice(self, path, start, end) -> StorageResult:
full_path = await self.full_path(path)
with open(full_path, "rb") as file:
file.seek(start)
return StorageResult.success(file.read(end - start))

async def get_all(self, path)-> StorageResult:
async def get_all(self, path) -> StorageResult:
full_path = await self.full_path(path)
with open(full_path, "rb") as file:
return StorageResult.success(file.read())

async def delete(self, path)-> StorageResult:
async def delete(self, path) -> StorageResult:
full_path = await self.full_path(path)
try:
full_path.unlink()
Expand All @@ -181,11 +188,11 @@ async def bulk_delete(self, paths):
for path in paths:
await self.delete(path)

async def exists(self, path)-> StorageResult:
async def exists(self, path) -> StorageResult:
full_path = await self.full_path(path)
return StorageResult.success(full_path.exists())

async def file_num_bytes(self, path)-> StorageResult:
async def file_num_bytes(self, path) -> StorageResult:
full_path = await self.full_path(path)
try:
return StorageResult.success(full_path.stat().st_size)
Expand All @@ -196,9 +203,10 @@ async def file_num_bytes(self, path)-> StorageResult:
)

@property
def get_uri(self)-> Uri:
def get_uri(self) -> Uri:
return self.uri


class LocalStorageFactory(StorageFactory):
def backend(self) -> StorageBackend:
return StorageBackend.LocalFile
Expand Down
4 changes: 3 additions & 1 deletion querent/storage/storage_base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from abc import ABC, abstractmethod
from pathlib import Path
from typing import IO
from typing import List

from querent.common.uri import Uri
from querent.storage.payload import PutPayload
from querent.storage.storage_result import StorageResult


class Storage(ABC):
@abstractmethod
async def check_connectivity(self) -> None:
Expand Down Expand Up @@ -36,7 +38,7 @@ async def delete(self, path: Path) -> StorageResult:
pass

@abstractmethod
async def bulk_delete(self, paths: list[Path]) -> None:
async def bulk_delete(self, paths: List[Path]) -> None:
pass

@abstractmethod
Expand Down
Loading

0 comments on commit 3f0c2da

Please sign in to comment.