Skip to content

Commit

Permalink
Resolved conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
Ansh5461 committed Sep 7, 2023
1 parent c4d440f commit 7f6aeb4
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 4 deletions.
10 changes: 9 additions & 1 deletion querent/ingestors/ingestor_manager.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
"""
Ingestor manager, for managing all the factories with backend
"""
from typing import Optional
from querent.config.ingestor_config import IngestorBackend
from querent.ingestors.base_ingestor import BaseIngestor
Expand All @@ -10,10 +13,12 @@


class IngestorFactoryManager:
"""Factory manager"""

def __init__(self):
self.ingestor_factories = {
IngestorBackend.PDF.value: PdfIngestorFactory(),
IngestorBackend.TEXT.value: TextIngestorFactory()
IngestorBackend.TEXT.value: TextIngestorFactory(),
IngestorBackend.MP3.value: AudioIngestorFactory(),
IngestorBackend.WAV.value: AudioIngestorFactory(),
IngestorBackend.JSON.value: JsonIngestorFactory(),
Expand All @@ -24,14 +29,17 @@ def __init__(self):
}

async def get_factory(self, file_extension: str) -> IngestorFactory:
"""get_factory to match factory based on file extension"""
return self.ingestor_factories.get(
file_extension.lower(), UnsupportedIngestor("Unsupported file extension")
)

async def get_ingestor(self, file_extension: str) -> Optional[BaseIngestor]:
"""get_ingestor to get factory for that extension"""
factory = self.get_factory(file_extension)
return factory.create(file_extension)

async def supports(self, file_extension: str) -> bool:
"""check if extension supports factory"""
factory = self.get_factory(file_extension)
return factory.supports(file_extension)
8 changes: 5 additions & 3 deletions tests/test_audio_ingestor.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import asyncio
"""Test cases for audio ingestors"""
from pathlib import Path
import pytest
import asyncio

from querent.collectors.fs.fs_collector import FSCollectorFactory
from querent.config.collector_config import FSCollectorConfig
from querent.common.uri import Uri
from querent.ingestors.ingestor_manager import IngestorFactoryManager
import pytest


@pytest.mark.asyncio
Expand All @@ -30,7 +32,7 @@ async def poll_and_print():
if len(ingested) == 0:
counter += 1

assert counter == 0
assert counter == 0

await poll_and_print()

Expand Down

0 comments on commit 7f6aeb4

Please sign in to comment.