-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
135 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
"""CSV Ingestor""" | ||
from typing import List, AsyncGenerator | ||
import csv | ||
import io | ||
|
||
from querent.processors.async_processor import AsyncProcessor | ||
from querent.ingestors.ingestor_factory import IngestorFactory | ||
from querent.ingestors.base_ingestor import BaseIngestor | ||
from querent.config.ingestor_config import IngestorBackend | ||
from querent.common.types.collected_bytes import CollectedBytes | ||
|
||
|
||
class CsvIngestorFactory(IngestorFactory): | ||
"""Ingestor factory for CSV""" | ||
|
||
SUPPORTED_EXTENSIONS = {"csv"} | ||
|
||
async def supports(self, file_extension: str) -> bool: | ||
return file_extension.lower() in self.SUPPORTED_EXTENSIONS | ||
|
||
async def create( | ||
self, file_extension: str, processors: List[AsyncProcessor] | ||
) -> BaseIngestor: | ||
if not self.supports(file_extension): | ||
return None | ||
return CsvIngestor(processors) | ||
|
||
|
||
class CsvIngestor(BaseIngestor): | ||
"""Ingestor for CSV""" | ||
|
||
def __init__(self, processors: List[AsyncProcessor]): | ||
super().__init__(IngestorBackend.CSV) | ||
self.processors = processors | ||
|
||
async def ingest( | ||
self, poll_function: AsyncGenerator[CollectedBytes, None] | ||
) -> AsyncGenerator[str, None]: | ||
current_file = None | ||
collected_bytes = b"" | ||
try: | ||
async for chunk_bytes in poll_function: | ||
if chunk_bytes.is_error(): | ||
# TODO handle error | ||
continue | ||
if current_file is None: | ||
current_file = chunk_bytes.file | ||
elif current_file != chunk_bytes.file: | ||
# we have a new file, process the old one | ||
async for text in self.extract_and_process_csv( | ||
CollectedBytes(file=current_file, data=collected_bytes) | ||
): | ||
yield text | ||
collected_bytes = b"" | ||
current_file = chunk_bytes.file | ||
collected_bytes += chunk_bytes.data | ||
except Exception as e: | ||
# TODO handle exception | ||
print(e) | ||
yield "" | ||
finally: | ||
# process the last file | ||
async for text in self.extract_and_process_csv( | ||
CollectedBytes(file=current_file, data=collected_bytes) | ||
): | ||
yield text | ||
|
||
async def extract_and_process_csv( | ||
self, collected_bytes: CollectedBytes | ||
) -> AsyncGenerator[str, None]: | ||
text = await self.extract_text_from_csv(collected_bytes) | ||
# print(text) | ||
processed_text = await self.process_data(text) | ||
yield processed_text | ||
|
||
async def extract_text_from_csv( | ||
self, collected_bytes: CollectedBytes | ||
) -> csv.reader: | ||
text_data = collected_bytes.data.decode("utf-8") | ||
print(text_data) | ||
text = csv.reader(io.StringIO(text_data)) | ||
return text | ||
|
||
async def process_data(self, text: str) -> List[str]: | ||
processed_data = text | ||
for processor in self.processors: | ||
processed_data = await processor.process(processed_data) | ||
return processed_data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
Name,State,City | ||
Ansh,Punjab,Anandpur | ||
Ayush,Odisha,Cuttack |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
FirstName,LastName,Country | ||
John,Doe,Usa | ||
Leo,Messi,Argentina | ||
Cristiano,Ronaldo,Portugal |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import asyncio | ||
import pytest | ||
from pathlib import Path | ||
|
||
from querent.collectors.fs.fs_collector import FSCollectorFactory | ||
from querent.config.collector_config import FSCollectorConfig | ||
from querent.common.uri import Uri | ||
from querent.ingestors.ingestor_manager import IngestorFactoryManager | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_collect_and_ingest_csv_data(): | ||
collector_factory = FSCollectorFactory() | ||
uri = Uri("file://" + str(Path("./tests/data/csv/").resolve())) | ||
config = FSCollectorConfig(root_path=uri.path) | ||
collector = collector_factory.resolve(uri, config) | ||
|
||
ingestor_factory_manager = IngestorFactoryManager() | ||
ingestor_factory = await ingestor_factory_manager.get_factory("csv") | ||
ingestor = await ingestor_factory.create("csv", []) | ||
|
||
ingested_call = ingestor.ingest(collector.poll()) | ||
counter = 0 | ||
|
||
async def poll_and_print(): | ||
counter = 0 | ||
async for ingested in ingested_call: | ||
assert ingested is not None | ||
# we have an iterable in ingested | ||
for row in ingested: | ||
counter += 1 | ||
assert counter == 7 | ||
|
||
await poll_and_print() | ||
|
||
|
||
if __name__ == "__main__": | ||
asyncio.run(test_collect_and_ingest_csv_data()) |