Skip to content

Commit

Permalink
setting up the env variables for credentials for aws and gcs
Browse files Browse the repository at this point in the history
  • Loading branch information
the-non-expert committed Sep 2, 2023
1 parent 452f0da commit ef308a0
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 135 deletions.
94 changes: 15 additions & 79 deletions querent/collectors/aws/aws_collector.py
Original file line number Diff line number Diff line change
@@ -1,77 +1,3 @@
# import asyncio
# from typing import AsyncGenerator

# import aiofiles
# from querent.config.collector_config import CollectorBackend, S3CollectConfig
# from querent.collectors.collector_base import Collector
# from querent.collectors.collector_factory import CollectorFactory
# from querent.collectors.collector_result import CollectorResult
# from querent.common.uri import Uri
# import aiohttp
# import aiobotocore
# from aiobotocore.session import get_session


# class AWSCollector(Collector):
# def __init__(self, config: S3CollectConfig):
# self.bucket_name = config.bucket
# self.region = config.region
# self.access_key = config.access_key
# self.secret_key = config.secret_key
# self.chunk_size = config.chunk

# # async def connect(self):
# # # session = aiobotocore.get_session()
# # # self.s3_client = session.create_client(
# # # 's3', region_name=self.region, aws_access_key_id=self.access_key, aws_secret_access_key=self.secret_key)
# # session = aiohttp.ClientSession()
# # s3_client = aiobotocore.get_session().create_client(
# # 's3', region_name='self.region')
# # self.s3_client = s3_client

# async def connect(self):
# # session = aiohttp.ClientSession()
# session = get_session()
# async with session:
# async with session.create_client(
# 's3', region_name=self.region,
# aws_secret_access_key=self.secret_key,
# aws_access_key_id=self.access_key) as s3_client:
# self.s3_client = s3_client

# async def disconnect(self):
# await self.session.close()
# await self.s3_client.close()

# async def poll(self) -> AsyncGenerator[CollectorResult, None]:
# async with self.s3_client.list_objects_v2(Bucket=self.bucket_name, Prefix=self.prefix) as response:
# for obj in response.get('Contents', []):
# async with self.download_object(obj['Key']) as file:
# async for chunk in self.read_chunks(file):
# yield CollectorResult({"object_key": obj['Key'], "chunk": chunk})

# async def read_chunks(self, file):
# while True:
# chunk = await file.read(self.chunk_size)
# if not chunk:
# break
# yield chunk

# async def download_object(self, object_key):
# async with aiofiles.open(object_key, 'wb') as file:
# await self.s3_client.download_fileobj(self.bucket_name, object_key, file)


# class AWSCollectorFactory(CollectorFactory):
# def backend(self) -> CollectorBackend:
# return CollectorBackend.S3

# def resolve(self, uri: Uri, config: S3CollectConfig) -> Collector:
# config = S3CollectConfig(bucket='your_bucket_name', region='your_aws_region',
# access_key='your_access_key', secret_key='your_secret_key')
# return AWSCollector(config)


import asyncio
from typing import AsyncGenerator

Expand All @@ -82,10 +8,19 @@
from querent.collectors.collector_result import CollectorResult
from querent.common.uri import Uri
import boto3
import os
from dotenv import load_dotenv

load_dotenv()

aws_access_key_id = os.getenv('AWS_ACCESS_KEY_ID')
aws_secret_access_key = os.getenv('AWS_SECRET_ACCESS_KEY')
aws_region = os.getenv('AWS_REGION')
aws_bucket_name = os.getenv('AWS_BUCKET_NAME')


class AWSCollector(Collector):
def __init__(self, config: S3CollectConfig):
def __init__(self, config: S3CollectConfig, prefix: str):
self.bucket_name = config.bucket
self.region = config.region
self.access_key = config.access_key
Expand All @@ -97,6 +32,7 @@ def __init__(self, config: S3CollectConfig):
aws_secret_access_key=self.secret_key,
region_name=self.region
)
self.prefix = prefix

async def connect(self):
pass # No asynchronous connection needed for boto3
Expand All @@ -115,7 +51,7 @@ async def poll(self) -> AsyncGenerator[CollectorResult, None]:

async def read_chunks(self, file):
while True:
chunk = await file.read(self.chunk_size)
chunk = file.read(self.chunk_size)
if not chunk:
break
yield chunk
Expand All @@ -132,6 +68,6 @@ def backend(self) -> CollectorBackend:
return CollectorBackend.S3

def resolve(self, uri: Uri, config: S3CollectConfig) -> Collector:
config = S3CollectConfig(bucket='pstreamsbucket1', region='ap-south-1',
access_key='AKIA5ZFZH6CA6LDWIPV5', secret_key='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')
return AWSCollector(config)
config = S3CollectConfig(bucket=aws_bucket_name, region=aws_region,
access_key=aws_access_key_id, secret_key=aws_secret_access_key)
return AWSCollector(config, "")
43 changes: 31 additions & 12 deletions querent/collectors/gcs/gcs_collector.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#

import asyncio
from typing import AsyncGenerator

Expand All @@ -10,29 +12,45 @@
from querent.common.uri import Uri
import aiohttp
from google.cloud import storage
import os
from dotenv import load_dotenv

load_dotenv()

credentials_info = os.getenv('GOOGLE_APPLICATION_CREDENTIALS')
bucket_name = os.getenv("GOOGLE_BUCKET_NAME")


class GCSCollector(Collector):
def __init__(self, config: GcsCollectConfig):
self.bucket_name = config.bucket
self.credentials = config.credentials_path
self.chunk_size = config.chunk
self.credentials = config.credentials
self.client = None

async def connect(self):
self.client = storage.Client.from_service_account_json(
self.credentials)
self.bucket = self.client.get_bucket(self.bucket_name)
if not self.client:
self.client = storage.Client.from_service_account_json(
self.credentials)

async def disconnect(self):
if self.client is not None:
self.client.close()
self.client = None

async def poll(self) -> AsyncGenerator[CollectorResult, None]:
blobs = self.bucket.list_blobs()
for blob in blobs:
async with self.download_blob(blob) as file:
async for chunk in self.read_chunks(file):
yield CollectorResult({"object_key": blob.name, "chunk": chunk})
# Make sure to connect the client before using it
await self.connect()

try:
bucket = self.client.get_bucket(self.bucket_name)
blobs = bucket.list_blobs()
for blob in blobs:
async with self.download_blob(blob) as file:
async for chunk in self.read_chunks(file):
yield CollectorResult({"object_key": blob.name, "chunk": chunk})
finally:
# Disconnect the client when done
await self.disconnect()

async def read_chunks(self, file):
while True:
Expand All @@ -43,7 +61,6 @@ async def read_chunks(self, file):

async def download_blob(self, blob):
file = aiofiles.open(blob.name, 'wb')
# Manually enter the context manager since aiofiles doesn't natively support async context management
await file.__aenter__()
with blob.open("rb") as blob_file:
await file.write(await blob_file.read())
Expand All @@ -56,5 +73,7 @@ def backend(self) -> CollectorBackend:

def resolve(self, uri: Uri, config: GcsCollectConfig) -> Collector:
config = GcsCollectConfig(
bucket='your_bucket_name', credentials_path='path_to_your_credentials.json')
bucket=bucket_name,
credentials=credentials_info
)
return GCSCollector(config)
7 changes: 4 additions & 3 deletions querent/config/collector_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,16 @@ class S3CollectConfig(BaseModel):

class GcsCollectConfig(BaseModel):
bucket: str
region: str
access_key: str
secret_key: str
credentials: str
# chunk: int = 1024


class WebScraperConfig(BaseModel):
website_url: str = Field(
..., description="The URL of the website to scrape."
)


class CollectConfigWrapper(BaseModel):
backend: CollectorBackend
config: Optional[BaseModel] = None
Expand Down
13 changes: 0 additions & 13 deletions querent/config/protean-tooling-368008-8b160be0bb98.json

This file was deleted.

Binary file not shown.
18 changes: 14 additions & 4 deletions tests/collector_tests/test_aws_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,23 @@
from querent.common.uri import Uri
from querent.config.collector_config import CollectorBackend
import pytest
import os
from dotenv import load_dotenv

load_dotenv()


@pytest.fixture
def aws_config():
aws_access_key_id = os.getenv('AWS_ACCESS_KEY_ID')
aws_secret_access_key = os.getenv('AWS_SECRET_ACCESS_KEY')
aws_region = os.getenv('AWS_REGION')
aws_bucket_name = os.getenv('AWS_BUCKET_NAME')
return {
"bucket": "pstreamsbucket1",
"region": "ap-south-1",
"access_key": "AKIA5ZFZH6CA6LDWIPV5",
"secret_key": "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",
"bucket": aws_bucket_name,
"region": aws_region,
"access_key": aws_access_key_id,
"secret_key": aws_secret_access_key,
}


Expand All @@ -38,6 +46,8 @@ async def poll_and_print():
chunk = result.unwrap()
assert chunk is not None

await poll_and_print()

# Modify this function to add files to S3 bucket
async def add_files():
# Add files to your S3 bucket here
Expand Down
53 changes: 29 additions & 24 deletions tests/collector_tests/test_gcs_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
from querent.config.collector_config import CollectorBackend
import pytest
import os
from dotenv import load_dotenv

credentials_info = os.environ.get('GOOGLE_APPLICATION_CREDENTIALS')
bucket_name = os.environ.get("GOOGLE_BUCKET_NAME")
load_dotenv()


@pytest.fixture
def gcs_config():
credentials_info = os.getenv('GOOGLE_APPLICATION_CREDENTIALS')
bucket_name = os.getenv("GOOGLE_BUCKET_NAME")
return {
"bucket": bucket_name,
"credentials_path": credentials_info
Expand All @@ -27,32 +29,35 @@ def test_gcs_collector_factory():
# To do: uncomment the following code when you have the bucket name and the credentials.json file for testing.


# @pytest.mark.asyncio
# async def test_gcs_collector(gcs_config):
# uri = Uri("gcs://" + gcs_config["bucket"] + "/prefix/")
# resolver = CollectorResolver()
# collector = resolver.resolve(uri, gcs_config)
# assert collector is not None
@pytest.mark.asyncio
async def test_gcs_collector(gcs_config):
config = gcs_config
uri = Uri("gcs://" + config["bucket"] + "/prefix/")
resolver = CollectorResolver()
collector = resolver.resolve(uri, config)
assert collector is not None

# await collector.connect()
await collector.connect()

# async def poll_and_print():
# async for result in collector.poll():
# assert not result.is_error()
# chunk = result.unwrap()
# assert chunk is not None
async def poll_and_print():
async for result in collector.poll():
assert not result.is_error()
chunk = result.unwrap()
assert chunk is not None

# # Modify this function to add files to your GCS bucket
# async def add_files():
# # Add files to your GCS bucket here
# pass
await poll_and_print()
# Modify this function to add files to your GCS bucket

# async def main():
# await asyncio.gather(add_files(), poll_and_print())
async def add_files():
# Add files to your GCS bucket here
pass

# # asyncio.run(main())
async def main():
await asyncio.gather(add_files(), poll_and_print())

# asyncio.run(main())

# if __name__ == "__main__":
# # Modify this line to call the appropriate test function
# pytest.main(["-k", "test_gcs_collector"])

if __name__ == "__main__":
# Modify this line to call the appropriate test function
pytest.main(["-k", "test_gcs_collector"])

0 comments on commit ef308a0

Please sign in to comment.