Skip to content

Commit

Permalink
Wrote text ingestors, still incomplete tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Ansh5461 committed Aug 31, 2023
1 parent 8cd2553 commit bc70e9a
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 11 deletions.
9 changes: 9 additions & 0 deletions querent/ingestors/ingestor_errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from enum import Enum


class IngestorError(Enum):
EOF = "End of File"
ETCD = "ETCD Error"
NETWORK = "Network Error"
TIMEOUT = "Timeout"
UNKNOWN = "Unknown Error"
11 changes: 7 additions & 4 deletions querent/ingestors/ingestor_manager.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,29 @@

from typing import Optional
from querent.config.ingestor_config import IngestorBackend
from querent.ingestors.base_ingestor import BaseIngestor
from querent.ingestors.ingestor_factory import IngestorFactory, UnsupportedIngestor
from querent.ingestors.pdfs.pdf_ingestor_v1 import PdfIngestorFactory
from querent.ingestors.texts.text_ingestor import TextIngestorFactory


class IngestorFactoryManager:
def __init__(self):
self.ingestor_factories = {
IngestorBackend.PDF.value: PdfIngestorFactory(),
#Ingestor.TEXT.value: TextIngestor(),
IngestorBackend.TEXT.value: TextIngestorFactory()
# Ingestor.TEXT.value: TextIngestor(),
# Add more mappings as needed
}

async def get_factory(self, file_extension: str) -> IngestorFactory:
return self.ingestor_factories.get(file_extension.lower(), UnsupportedIngestor("Unsupported file extension"))
return self.ingestor_factories.get(
file_extension.lower(), UnsupportedIngestor("Unsupported file extension")
)

async def get_ingestor(self, file_extension: str) -> Optional[BaseIngestor]:
factory = self.get_factory(file_extension)
return factory.create(file_extension)

async def supports(self, file_extension: str) -> bool:
factory = self.get_factory(file_extension)
return factory.supports(file_extension)
8 changes: 6 additions & 2 deletions querent/ingestors/pdfs/pdf_ingestor_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ class PdfIngestorFactory(IngestorFactory):
async def supports(self, file_extension: str) -> bool:
return file_extension.lower() in self.SUPPORTED_EXTENSIONS

async def create(self, file_extension: str, processors: List[AsyncProcessor]) -> BaseIngestor:
async def create(
self, file_extension: str, processors: List[AsyncProcessor]
) -> BaseIngestor:
if not self.supports(file_extension):
return None
return PdfIngestor(processors)
Expand Down Expand Up @@ -58,7 +60,9 @@ async def ingest(
except Exception as e:
yield []

async def extract_and_process_pdf(self, collected_bytes: CollectedBytes) -> List[str]:
async def extract_and_process_pdf(
self, collected_bytes: CollectedBytes
) -> List[str]:
text = await self.extract_text_from_pdf(collected_bytes)
return await self.process_data(text)

Expand Down
75 changes: 75 additions & 0 deletions querent/ingestors/texts/text_ingestor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from typing import List, AsyncGenerator
from querent.common.types.collected_bytes import CollectedBytes
from querent.ingestors.base_ingestor import BaseIngestor
from querent.ingestors.ingestor_factory import IngestorFactory
from querent.processors.async_processor import AsyncProcessor
from querent.config.ingestor_config import IngestorBackend


class TextIngestorFactory(IngestorFactory):
SUPPORTED_EXTENSIONS = {"txt"}

async def supports(self, file_extension: str) -> bool:
return file_extension.lower() in self.SUPPORTED_EXTENSIONS

async def create(
self, file_extension: str, processors: List[AsyncProcessor]
) -> BaseIngestor:
if not self.supports(file_extension):
return None

return TextIngestor(processors)


class TextIngestor(BaseIngestor):
def __init__(self, processors: List[AsyncProcessor]):
self.processors = processors
super.__init__(IngestorBackend.TEXT)

async def ingest(
self, poll_function: AsyncGenerator[CollectedBytes, None]
) -> AsyncGenerator[List[str], None]:
try:
collected_bytes = b""
current_file = None

async for chunk_bytes in poll_function:
if chunk_bytes.is_error():
continue

if chunk_bytes.file != current_file:
if current_file:
text = await self.extract_and_process_text(
CollectedBytes(file=current_file, data=collected_bytes)
)
yield text

collected_bytes = b""
current_file = chunk_bytes.file

collected_bytes += chunk_bytes.data

if current_file:
text = await self.extract_and_process_text(
CollectedBytes(file=current_file, data=collected_bytes)
)
yield text

except Exception as e:
yield []

async def extract_and_process_text(
self, collected_bytes: CollectedBytes
) -> List[str]:
text = await self.extract_text_from_file(collected_bytes)
return await self.process_data(text=text)

async def extract_text_from_file(collected_bytes: CollectedBytes) -> str:
text = collected_bytes.data.decode("utf-8")
return text

async def process_data(self, text: str) -> List[str]:
processed_data = text
for processor in self.processors:
processed_data = await processor.process(processed_data)
return processed_data
14 changes: 9 additions & 5 deletions tests/test_pdf_ingestor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,34 @@
from querent.ingestors.ingestor_manager import IngestorFactoryManager
import pytest


@pytest.mark.asyncio
async def test_collect_and_ingest_pdf():
# Set up the collector
collector_factory = FSCollectorFactory()
uri = Uri("file://" + str(Path("./tests/data/pdf/").resolve()))
config = FSCollectorConfig(root_path=uri.path)
collector = collector_factory.resolve(uri, config)

# Set up the ingestor
ingestor_factory_manager = IngestorFactoryManager()
ingestor_factory = await ingestor_factory_manager.get_factory("pdf") # Notice the use of await here
ingestor_factory = await ingestor_factory_manager.get_factory(
"pdf"
) # Notice the use of await here
ingestor = await ingestor_factory.create("pdf", [])

# Collect and ingest the PDF
ingested_call = ingestor.ingest(collector.poll())
counter = 0

async def poll_and_print():
counter = 0
async for ingested in ingested_call:
assert ingested is not None
if len(ingested) == 0:
if len(ingested) == 0:
counter += 1
assert counter == 1

await poll_and_print() # Notice the use of await here


Expand Down

0 comments on commit bc70e9a

Please sign in to comment.