diff --git a/querent/collectors/aws/__init__.py b/querent/collectors/aws/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/querent/collectors/aws/aws_collector.py b/querent/collectors/aws/aws_collector.py new file mode 100644 index 00000000..669e5041 --- /dev/null +++ b/querent/collectors/aws/aws_collector.py @@ -0,0 +1,70 @@ +import asyncio +from typing import AsyncGenerator +import io + +import aiofiles +from querent.config.collector_config import CollectorBackend, S3CollectConfig +from querent.collectors.collector_base import Collector +from querent.collectors.collector_factory import CollectorFactory +from querent.collectors.collector_result import CollectorResult +from querent.common.uri import Uri +import boto3 + + +class AWSCollector(Collector): + def __init__(self, config: S3CollectConfig, prefix: str): + self.bucket_name = config["bucket"] + self.region = config["region"] + self.access_key = config["access_key"] + self.secret_key = config["secret_key"] + self.chunk_size = 1024 + self.s3_client = boto3.client( + 's3', + aws_access_key_id=self.access_key, + aws_secret_access_key=self.secret_key, + region_name=self.region + ) + self.prefix = str(prefix) + + async def connect(self): + pass # No asynchronous connection needed for boto3 + + async def disconnect(self): + pass # No asynchronous disconnect needed for boto3 + + async def poll(self) -> AsyncGenerator[CollectorResult, None]: + response = self.s3_client.list_objects_v2( + Bucket=self.bucket_name) + + for obj in response.get('Contents', []): + file = self.download_object_as_byte_stream(obj['Key']) + async for chunk in self.read_chunks(file): + yield CollectorResult({"object_key": obj['Key'], "chunk": chunk}) + + async def read_chunks(self, file): + while True: + chunk = file.read(self.chunk_size) + if not chunk: + break + yield chunk + + # def download_object(self, object_key): + # file_path = object_key # Set your desired file path + # self.s3_client.download_file( + # self.bucket_name, object_key, file_path) + # return open(file_path, 'rb') + + def download_object_as_byte_stream(self, object_key): + byte_stream = io.BytesIO() + self.s3_client.download_fileobj( + self.bucket_name, object_key, byte_stream) + byte_stream.seek(0) # Rewind the stream to the beginning + return byte_stream + + +class AWSCollectorFactory(CollectorFactory): + def backend(self) -> CollectorBackend: + return CollectorBackend.S3 + + def resolve(self, uri: Uri, config: S3CollectConfig) -> Collector: + return AWSCollector(config, uri) diff --git a/querent/collectors/collector_resolver.py b/querent/collectors/collector_resolver.py index 6d1be8d0..98fc28cd 100644 --- a/querent/collectors/collector_resolver.py +++ b/querent/collectors/collector_resolver.py @@ -1,4 +1,6 @@ from typing import Optional +from querent.collectors.gcs.gcs_collector import GCSCollectorFactory +from querent.collectors.aws.aws_collector import AWSCollectorFactory from querent.collectors.fs.fs_collector import FSCollectorFactory from querent.collectors.webscaper.web_scraper_collector import WebScraperFactory from querent.config.collector_config import CollectConfig, CollectorBackend @@ -14,7 +16,9 @@ class CollectorResolver: def __init__(self): self.collector_factories = { CollectorBackend.LocalFile: FSCollectorFactory(), + CollectorBackend.S3: AWSCollectorFactory(), CollectorBackend.WebScraper: WebScraperFactory(), + CollectorBackend.Gcs: GCSCollectorFactory() # Add other collector factories as needed } @@ -32,7 +36,13 @@ def resolve(self, uri: Uri, config: CollectConfig) -> Optional[Collector]: def _determine_backend(self, protocol: Protocol) -> CollectorBackend: if protocol.is_file_storage(): return CollectorBackend.LocalFile - if protocol.is_webscrapper(): + elif protocol.is_s3(): + return CollectorBackend.S3 + elif protocol.is_webscraper(): + return CollectorBackend.WebScraper + elif protocol.is_gcs(): + return CollectorBackend.Gcs + elif protocol.is_webscraper(): return CollectorBackend.WebScraper else: raise CollectorResolverError( diff --git a/querent/collectors/gcs/__init__.py b/querent/collectors/gcs/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/querent/collectors/gcs/gcs_collector.py b/querent/collectors/gcs/gcs_collector.py new file mode 100644 index 00000000..d5614483 --- /dev/null +++ b/querent/collectors/gcs/gcs_collector.py @@ -0,0 +1,79 @@ +# + +import asyncio +from typing import AsyncGenerator + +import aiofiles +from querent.config.collector_config import GcsCollectConfig +from querent.config.collector_config import CollectorBackend +from querent.collectors.collector_base import Collector +from querent.collectors.collector_factory import CollectorFactory +from querent.collectors.collector_result import CollectorResult +from querent.common.uri import Uri +import aiohttp +from google.cloud import storage +import os +from dotenv import load_dotenv + +load_dotenv() + +credentials_info = os.getenv('GOOGLE_APPLICATION_CREDENTIALS') +bucket_name = os.getenv("GOOGLE_BUCKET_NAME") + + +class GCSCollector(Collector): + def __init__(self, config: GcsCollectConfig): + self.bucket_name = config.bucket + self.credentials = config.credentials + self.client = None + + async def connect(self): + if not self.client: + self.client = storage.Client.from_service_account_json( + self.credentials) + + async def disconnect(self): + if self.client is not None: + self.client.close() + self.client = None + + async def poll(self) -> AsyncGenerator[CollectorResult, None]: + # Make sure to connect the client before using it + await self.connect() + + try: + bucket = self.client.get_bucket(self.bucket_name) + blobs = bucket.list_blobs() + for blob in blobs: + async with self.download_blob(blob) as file: + async for chunk in self.read_chunks(file): + yield CollectorResult({"object_key": blob.name, "chunk": chunk}) + finally: + # Disconnect the client when done + await self.disconnect() + + async def read_chunks(self, file): + while True: + chunk = await file.read(self.chunk_size) + if not chunk: + break + yield chunk + + async def download_blob(self, blob): + file = aiofiles.open(blob.name, 'wb') + await file.__aenter__() + with blob.open("rb") as blob_file: + await file.write(await blob_file.read()) + return file + + +class GCSCollectorFactory(CollectorFactory): + def backend(self) -> CollectorBackend: + return CollectorBackend.Gcs + + def resolve(self, uri: Uri, config: GcsCollectConfig) -> Collector: + config = GcsCollectConfig( + bucket=bucket_name, + credentials=credentials_info + ) + return GCSCollector(config) diff --git a/querent/collectors/webscaper/web_scraper_collector.py b/querent/collectors/webscaper/web_scraper_collector.py index 3df46ec4..c0ef8b31 100644 --- a/querent/collectors/webscaper/web_scraper_collector.py +++ b/querent/collectors/webscaper/web_scraper_collector.py @@ -22,8 +22,8 @@ async def poll(self): async def scrape_website(self, website_url: str): content = WebpageExtractor().extract_with_bs4(website_url) - max_length = len(" ".join(content.split(" ")[:600])) - return CollectedBytes(data=content[:max_length], file=None, error=None) + max_length = len(' '.join(content.split(" ")[:600])) + return CollectedBytes(file=None, data=content[:max_length], error=None) class WebScraperFactory(CollectorFactory): diff --git a/querent/common/uri.py b/querent/common/uri.py index 8065e9ea..d0e7636a 100644 --- a/querent/common/uri.py +++ b/querent/common/uri.py @@ -9,7 +9,7 @@ class Protocol(enum.Enum): Azure = "azure" File = "file" - Grpc = "grpc" + Gcs = "gcs" PostgreSQL = "postgresql" Ram = "ram" S3 = "s3" @@ -21,8 +21,8 @@ def is_azure(self) -> bool: def is_file(self) -> bool: return self == Protocol.File - def is_grpc(self) -> bool: - return self == Protocol.Grpc + def is_gcs(self) -> bool: + return self == Protocol.Gcs def is_postgresql(self) -> bool: return self == Protocol.PostgreSQL @@ -42,7 +42,7 @@ def is_object_storage(self) -> bool: def is_database(self) -> bool: return self == Protocol.PostgreSQL - def is_webscrapper(self) -> bool: + def is_webscraper(self) -> bool: return self == Protocol.Webscraper @@ -70,7 +70,7 @@ def extension(self) -> Optional[str]: @property def path(self) -> str: - return self.uri[self.protocol_idx + len(self.PROTOCOL_SEPARATOR) :] + return self.uri[self.protocol_idx + len(self.PROTOCOL_SEPARATOR):] def as_str(self) -> str: return self.uri diff --git a/querent/config/collector_config.py b/querent/config/collector_config.py index 0040870b..84d2dc41 100644 --- a/querent/config/collector_config.py +++ b/querent/config/collector_config.py @@ -27,19 +27,21 @@ class S3CollectConfig(BaseModel): region: str access_key: str secret_key: str + chunk: int = 1024 class GcsCollectConfig(BaseModel): bucket: str - region: str - access_key: str - secret_key: str + credentials: str + # chunk: int = 1024 + class WebScraperConfig(BaseModel): website_url: str = Field( ..., description="The URL of the website to scrape." ) + class CollectConfigWrapper(BaseModel): backend: CollectorBackend config: Optional[BaseModel] = None @@ -59,4 +61,5 @@ def from_collect_config(cls, collect_config: CollectConfig): backend=CollectorBackend.WebScraper, config=WebScraperConfig() ) else: - raise ValueError(f"Unsupported collector backend: {collect_config.backend}") + raise ValueError( + f"Unsupported collector backend: {collect_config.backend}") diff --git a/querent/connectors/aws_connector.py b/querent/connectors/aws_connector.py new file mode 100644 index 00000000..22b29edb --- /dev/null +++ b/querent/connectors/aws_connector.py @@ -0,0 +1,70 @@ +import boto3 +from botocore.exceptions import NoCredentialsError, PartialCredentialsError, BotoCoreError + +def initialize_s3_resource(): + try: + s3_resource = boto3.resource('s3', + aws_access_key_id='YOUR_ACCESS_KEY', + aws_secret_access_key='YOUR_SECRET_KEY', + region_name='YOUR_REGION') + return s3_resource + except (PartialCredentialsError, BotoCoreError) as e: + print(f"Error initializing S3 resource: {e}") + return None + +def list_objects(bucket_name): + try: + s3_resource = initialize_s3_resource() + if s3_resource is None: + return [] + + bucket = s3_resource.Bucket(bucket_name) + objects = list(bucket.objects.all()) + return objects + except BotoCoreError as e: + print(f"Error listing objects: {e}") + return [] + +def download_file(bucket_name, s3_key, local_path): + try: + s3_resource = initialize_s3_resource() + if s3_resource is None: + return False + + s3_resource.Bucket(bucket_name).download_file(s3_key, local_path) + return True + except (BotoCoreError, NoCredentialsError) as e: + print(f"Error downloading file: {e}") + return False + +def upload_file(bucket_name, s3_key, local_path): + try: + s3_resource = initialize_s3_resource() + if s3_resource is None: + return False + + s3_resource.Bucket(bucket_name).upload_file(local_path, s3_key) + return True + except (BotoCoreError, NoCredentialsError) as e: + print(f"Error uploading file: {e}") + return False + +def delete_object(bucket_name, s3_key): + try: + s3_resource = initialize_s3_resource() + if s3_resource is None: + return False + + s3_resource.Object(bucket_name, s3_key).delete() + return True + except (BotoCoreError, NoCredentialsError) as e: + print(f"Error deleting object: {e}") + return False + +# Usage +bucket_name = 'your-s3-bucket-name' +object_list = list_objects(bucket_name) +for obj in object_list: + print(obj.key) + +# Remember to replace 'YOUR_ACCESS_KEY', 'YOUR_SECRET_KEY', and 'YOUR_REGION' with your actual credentials. diff --git a/querent/tools/web_page_extractor.py b/querent/tools/web_page_extractor.py index fbc01fa6..6e52405d 100644 --- a/querent/tools/web_page_extractor.py +++ b/querent/tools/web_page_extractor.py @@ -76,7 +76,8 @@ def extract_with_3k(self, url): article = Article(url, config=config) article.set_html(html_content) article.parse() - content = article.text.replace("\t", " ").replace("\n", " ").strip() + content = article.text.replace( + '\t', ' ').replace('\n', ' ').strip() return content[:1500] @@ -131,7 +132,8 @@ def extract_with_bs4(self, url): ["main", "article", "section", "div"] ) if main_content_areas: - main_content = max(main_content_areas, key=lambda x: len(x.text)) + main_content = max(main_content_areas, + key=lambda x: len(x.text)) content_tags = ["p", "h1", "h2", "h3", "h4", "h5", "h6"] content = " ".join( [ @@ -224,18 +226,21 @@ def extract_with_lxml(self, url): tree = html.fromstring(html_content) paragraphs = tree.cssselect("p, h1, h2, h3, h4, h5, h6") content = " ".join( - [para.text_content() for para in paragraphs if para.text_content()] + [para.text_content() + for para in paragraphs if para.text_content()] ) content = content.replace("\t", " ").replace("\n", " ").strip() return content except ArticleException as ae: - logger.error("Error while extracting text from HTML (lxml): {str(ae)}") + logger.error( + "Error while extracting text from HTML (lxml): {str(ae)}") return "" except RequestException as re: - logger.error(f"Error while making the request to the URL (lxml): {str(re)}") + logger.error( + f"Error while making the request to the URL (lxml): {str(re)}") return "" except Exception as e: diff --git a/requirements.txt b/requirements.txt index 1c725de2..50611946 100644 --- a/requirements.txt +++ b/requirements.txt @@ -155,4 +155,7 @@ duckduckgo-search==3.8.3 asyncio==3.4.3 aiofiles pytest-asyncio +google-cloud-storage +google-cloud +boto3 pymupdf diff --git a/tests/HP6 - Harry Potter and the Half-Blood Prince.pdf b/tests/HP6 - Harry Potter and the Half-Blood Prince.pdf new file mode 100644 index 00000000..4a2ebe35 Binary files /dev/null and b/tests/HP6 - Harry Potter and the Half-Blood Prince.pdf differ diff --git a/tests/test_local_collector.py b/tests/collector_tests/test_local_collector.py similarity index 100% rename from tests/test_local_collector.py rename to tests/collector_tests/test_local_collector.py diff --git a/tests/test_aws_collector.py b/tests/test_aws_collector.py new file mode 100644 index 00000000..80b3fee1 --- /dev/null +++ b/tests/test_aws_collector.py @@ -0,0 +1,69 @@ +import asyncio + +from querent.config.collector_config import S3CollectConfig +from querent.collectors.collector_resolver import CollectorResolver +from querent.collectors.aws.aws_collector import AWSCollectorFactory +from querent.common.uri import Uri +from querent.config.collector_config import CollectorBackend +import pytest +import os +from dotenv import load_dotenv + +load_dotenv() + + +aws_access_key_id = os.getenv('AWS_ACCESS_KEY_ID') +aws_secret_access_key = os.getenv('AWS_SECRET_ACCESS_KEY') +aws_region = os.getenv('AWS_REGION') +aws_bucket_name = os.getenv('AWS_BUCKET_NAME') + + +@pytest.fixture +def aws_config(): + return { + "bucket": "pstreamsbucket1", + "region": "ap-south-1", + "access_key": aws_access_key_id, + "secret_key": aws_secret_access_key, + } + + +def test_aws_collector_factory(): + factory = AWSCollectorFactory() + assert factory.backend() == CollectorBackend.S3 + +# Modify this function to test the AWS collector + + +@pytest.mark.asyncio +async def test_aws_collector(aws_config): + config = aws_config + uri = Uri("s3://" + config["bucket"] + "/prefix/") + resolver = CollectorResolver() + collector = resolver.resolve(uri, config) + assert collector is not None + + await collector.connect() + + async def poll_and_print(): + async for result in collector.poll(): + assert not result.is_error() + chunk = result.unwrap() + assert chunk is not None + + await poll_and_print() + + # Modify this function to add files to S3 bucket + # async def add_files(): + # # Add files to your S3 bucket here + # pass + +# async def main(): +# await asyncio.gather(poll_and_print()) + +# # asyncio.run(main()) + + +# if __name__ == "__main__": +# # Modify this line to call the appropriate test function +# pytest.main(["-k", "test_aws_collector"]) diff --git a/tests/test_gcs_collector.py b/tests/test_gcs_collector.py new file mode 100644 index 00000000..777b61da --- /dev/null +++ b/tests/test_gcs_collector.py @@ -0,0 +1,63 @@ +import asyncio +from querent.collectors.collector_resolver import CollectorResolver +from querent.collectors.gcs.gcs_collector import GCSCollectorFactory +from querent.common.uri import Uri +from querent.config.collector_config import CollectorBackend +import pytest +import os +from dotenv import load_dotenv + +load_dotenv() + + +@pytest.fixture +def gcs_config(): + credentials_info = os.getenv('GOOGLE_APPLICATION_CREDENTIALS') + bucket_name = os.getenv("GOOGLE_BUCKET_NAME") + return { + "bucket": "www.billbo.ai", + "credentials_path": credentials_info + } + + +def test_gcs_collector_factory(): + factory = GCSCollectorFactory() + assert factory.backend() == CollectorBackend.Gcs + +# Modify this function to test the GCS collector + +# To do: uncomment the following code when you have the bucket name and the credentials.json file for testing. + + +@pytest.mark.asyncio +async def test_gcs_collector(gcs_config): + config = gcs_config + uri = Uri("gcs://" + config["bucket"] + "/prefix/") + resolver = CollectorResolver() + collector = resolver.resolve(uri, config) + assert collector is not None + + await collector.connect() + + async def poll_and_print(): + async for result in collector.poll(): + assert not result.is_error() + chunk = result.unwrap() + assert chunk is not None + + await poll_and_print() + # Modify this function to add files to your GCS bucket + + async def add_files(): + # Add files to your GCS bucket here + pass + + async def main(): + await asyncio.gather(add_files(), poll_and_print()) + + # asyncio.run(main()) + + +if __name__ == "__main__": + # Modify this line to call the appropriate test function + pytest.main(["-k", "test_gcs_collector"]) diff --git a/tests/test_webscrapper.py b/tests/test_webscrapper.py index 6415bb9b..52eeb264 100644 --- a/tests/test_webscrapper.py +++ b/tests/test_webscrapper.py @@ -29,8 +29,12 @@ def test_scrapping_data(): collector = resolver.resolve(uri, webscrapperConfig) assert collector is not None + print("REached here") + async def poll_and_print(): + print("Part 2") async for result in collector.poll(): + print("Hola...") assert not result.is_error() print(result.unwrap())