Skip to content

Commit

Permalink
add basic SQS consumer functionality (#33)
Browse files Browse the repository at this point in the history
* add SQS receive support

* lint

* pr feedback

* import httpx -> from httpx import ...

* _AWSv4Auth -> AWSv4Auth

* minimize diff

* minimize diff

* minimize diff

* lint

* try to minimize further

* fixes & more tests

* move AWSV4AuthFlow up

* aws_access_key_id -> aws_access_key

Co-authored-by: Samuel Colvin <s@muelcolvin.com>
  • Loading branch information
adriangb and samuelcolvin authored Jan 10, 2023
1 parent 64f8cfe commit 1e2455f
Show file tree
Hide file tree
Showing 3 changed files with 543 additions and 16 deletions.
79 changes: 63 additions & 16 deletions aioaws/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from binascii import hexlify
from datetime import datetime
from functools import reduce
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Optional, Tuple
from urllib.parse import quote as url_quote

from httpx import URL, AsyncClient, Response
from httpx import URL, AsyncClient, Auth, Request, Response

from ._utils import get_config_attr, pretty_xml, utcnow

Expand Down Expand Up @@ -50,6 +50,13 @@ def __init__(self, client: AsyncClient, config: 'BaseConfigProtocol', service: L
self.host = f'{bucket}.s3.{self.region}.amazonaws.com'
self.schema = 'https'

self._auth = AWSv4Auth(
aws_secret_key=self.aws_secret_key,
aws_access_key=self.aws_access_key,
region=self.region,
service=self.service,
)

@property
def endpoint(self) -> str:
return f'{self.schema}://{self.host}'
Expand Down Expand Up @@ -98,7 +105,7 @@ async def request(
method,
url,
content=data,
headers=self._auth_headers(method, url, data=data, content_type=content_type),
headers=self._auth.auth_headers(method, url, data=data, content_type=content_type),
)
if r.status_code != 200:
# from ._utils import pretty_response
Expand All @@ -113,31 +120,45 @@ def add_signed_download_params(self, method: Literal['GET', 'POST'], url: URL, e
url = url.copy_merge_params(
{
'X-Amz-Algorithm': _AUTH_ALGORITHM,
'X-Amz-Credential': self._aws4_credential(now),
'X-Amz-Credential': self._auth.aws4_credential(now),
'X-Amz-Date': _aws4_x_amz_date(now),
'X-Amz-Expires': str(expires),
'X-Amz-SignedHeaders': 'host',
}
)
_, signature = self._aws4_signature(now, method, url, {'host': self.host}, 'UNSIGNED-PAYLOAD')
_, signature = self._auth.aws4_signature(now, method, url, {'host': self.host}, 'UNSIGNED-PAYLOAD')
return url.copy_add_param('X-Amz-Signature', signature)

def upload_extra_conditions(self, dt: datetime) -> List[Dict[str, str]]:
return [
{'x-amz-credential': self._aws4_credential(dt)},
{'x-amz-credential': self._auth.aws4_credential(dt)},
{'x-amz-algorithm': _AUTH_ALGORITHM},
{'x-amz-date': _aws4_x_amz_date(dt)},
]

def signed_upload_fields(self, dt: datetime, string_to_sign: str) -> Dict[str, str]:
return {
'X-Amz-Algorithm': _AUTH_ALGORITHM,
'X-Amz-Credential': self._aws4_credential(dt),
'X-Amz-Credential': self._auth.aws4_credential(dt),
'X-Amz-Date': _aws4_x_amz_date(dt),
'X-Amz-Signature': self._aws4_sign_string(string_to_sign, dt),
'X-Amz-Signature': self._auth.aws4_sign_string(string_to_sign, dt),
}

def _auth_headers(

class AWSv4Auth:
def __init__(
self,
aws_secret_key: str,
aws_access_key: str,
region: str,
service: str,
) -> None:
self.aws_secret_key = aws_secret_key
self.aws_access_key = aws_access_key
self.region = region
self.service = service

def auth_headers(
self,
method: Literal['GET', 'POST'],
url: URL,
Expand All @@ -153,20 +174,20 @@ def _auth_headers(
headers = {
'content-md5': base64.b64encode(hashlib.md5(data).digest()).decode(),
'content-type': content_type,
'host': self.host,
'host': url.host,
'x-amz-date': _aws4_x_amz_date(now),
}

payload_sha256_hash = hashlib.sha256(data).hexdigest()
signed_headers, signature = self._aws4_signature(now, method, url, headers, payload_sha256_hash)
credential = self._aws4_credential(now)
signed_headers, signature = self.aws4_signature(now, method, url, headers, payload_sha256_hash)
credential = self.aws4_credential(now)
authorization_header = (
f'{_AUTH_ALGORITHM} Credential={credential},SignedHeaders={signed_headers},Signature={signature}'
)
headers.update({'authorization': authorization_header, 'x-amz-content-sha256': payload_sha256_hash})
return headers

def _aws4_signature(
def aws4_signature(
self, dt: datetime, method: Literal['GET', 'POST'], url: URL, headers: Dict[str, str], payload_hash: str
) -> Tuple[str, str]:
header_keys = sorted(headers)
Expand All @@ -187,9 +208,9 @@ def _aws4_signature(
hashlib.sha256(canonical_request.encode()).hexdigest(),
)
string_to_sign = '\n'.join(string_to_sign_parts)
return signed_headers, self._aws4_sign_string(string_to_sign, dt)
return signed_headers, self.aws4_sign_string(string_to_sign, dt)

def _aws4_sign_string(self, string_to_sign: str, dt: datetime) -> str:
def aws4_sign_string(self, string_to_sign: str, dt: datetime) -> str:
key_parts = (
b'AWS4' + self.aws_secret_key.encode(),
_aws4_date_stamp(dt),
Expand All @@ -204,10 +225,36 @@ def _aws4_sign_string(self, string_to_sign: str, dt: datetime) -> str:
def _aws4_scope(self, dt: datetime) -> str:
return f'{_aws4_date_stamp(dt)}/{self.region}/{self.service}/{_AWS_AUTH_REQUEST}'

def _aws4_credential(self, dt: datetime) -> str:
def aws4_credential(self, dt: datetime) -> str:
return f'{self.aws_access_key}/{self._aws4_scope(dt)}'


class AWSV4AuthFlow(Auth):
def __init__(
self,
aws_secret_key: str,
aws_access_key: str,
region: str,
service: str,
) -> None:
self._authorizer = AWSv4Auth(
aws_secret_key=aws_secret_key,
aws_access_key=aws_access_key,
region=region,
service=service,
)

def auth_flow(self, request: Request) -> Generator[Request, Response, None]:
auth_headers = self._authorizer.auth_headers(
method=request.method.upper(), # type: ignore
url=request.url,
data=request.content,
content_type=request.headers.get('Content-Type'),
)
request.headers.update(auth_headers)
yield request


def _aws4_date_stamp(dt: datetime) -> str:
return dt.strftime('%Y%m%d')

Expand Down
180 changes: 180 additions & 0 deletions aioaws/sqs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
from contextlib import AsyncExitStack, asynccontextmanager
from dataclasses import dataclass
from typing import Any, AsyncIterator, Iterable, Mapping, Optional, Union

from httpx import AsyncClient, Timeout
from pydantic import BaseModel, Field

from .core import AWSV4AuthFlow


class AWSAuthConfig(BaseModel):
aws_access_key: str
aws_secret_key: str
aws_region: str


class SQSMessage(BaseModel):
message_id: str
receipt_handle: str
md5_of_body: str
body: str
attributes: Mapping[str, Any]


class PollConfig(BaseModel):
wait_time: int = Field(10, gt=0)
max_messages: int = Field(1, ge=1, le=10)


@dataclass
class _QueueName:
name: str


@dataclass
class _QueueURL:
url: str


MAX_VISIBILITY_TIMEOUT = 12 * 60 * 60 # 12 hours in seconds


class SQSClient:
def __init__(
self,
queue_name_or_url: str,
auth: AWSAuthConfig,
*,
client: AsyncClient,
) -> None:
self._queue_name_or_url: Union[_QueueName, _QueueURL]
if queue_name_or_url[:4] == 'http':
self._queue_name_or_url = _QueueURL(queue_name_or_url)
else:
self._queue_name_or_url = _QueueName(queue_name_or_url)
self._client = client
self._auth = AWSV4AuthFlow(
aws_access_key=auth.aws_access_key,
aws_secret_key=auth.aws_secret_key,
region=auth.aws_region,
service='sqs',
)
self._service_url = f'https://sqs.{auth.aws_region}.amazonaws.com'

async def _get_queue_url_from_name_and_region(
self,
queue_name: str,
client: AsyncClient,
auth: AWSV4AuthFlow,
) -> str:
resp = await client.get(
url=self._service_url,
params={
'Action': 'GetQueueUrl',
'QueueName': queue_name,
},
auth=auth,
headers={'Accept': 'application/json'},
)
resp.raise_for_status()
return resp.json()['GetQueueUrlResponse']['GetQueueUrlResult']['QueueUrl']

async def _get_queue_url(self) -> str:
if isinstance(self._queue_name_or_url, _QueueName):
self._queue_name_or_url = _QueueURL(
await self._get_queue_url_from_name_and_region(
self._queue_name_or_url.name,
self._client,
auth=self._auth,
)
)
return self._queue_name_or_url.url

async def poll(
self,
*,
config: Optional[PollConfig] = None,
) -> AsyncIterator[Iterable[SQSMessage]]:
config = config or PollConfig()
queue_url = await self._get_queue_url()
while True:
resp = await self._client.get(
url=queue_url,
params={
'Action': 'ReceiveMessage',
'MaxNumberOfMessages': config.max_messages,
'WaitTimeSeconds': config.wait_time,
},
headers={
'Accept': 'application/json',
},
timeout=Timeout(
5, # htppx's default timeout
# arbitrary selection of 1.5x wait time
# to avoid http timeouts while long polling
read=config.wait_time * 1.5,
),
auth=self._auth,
)
resp.raise_for_status()
yield [
SQSMessage.construct(
message_id=message_data['MessageId'],
receipt_handle=message_data['ReceiptHandle'],
md5_of_body=message_data['MD5OfBody'],
body=message_data['Body'],
attributes=message_data['Attributes'],
)
for message_data in resp.json()['ReceiveMessageResponse']['ReceiveMessageResult']['messages'] or ()
]

async def change_visibility(self, message: SQSMessage, timeout: int) -> None:
queue_url = await self._get_queue_url()
if timeout >= MAX_VISIBILITY_TIMEOUT:
raise ValueError(f'timeout value range is 0 to {MAX_VISIBILITY_TIMEOUT}, got {timeout}')
await self._client.post(
url=queue_url,
params={
'Action': 'ChangeMessageVisibility',
'VisibilityTimeout': timeout,
'ReceiptHandle': message.receipt_handle,
},
auth=self._auth,
headers={
'Accept': 'application/json',
},
)

async def delete_message(self, message: SQSMessage) -> None:
queue_url = await self._get_queue_url()
resp = await self._client.post(
url=queue_url,
params={
'Action': 'DeleteMessage',
'ReceiptHandle': message.receipt_handle,
},
auth=self._auth,
headers={
'Accept': 'application/json',
},
)
resp.raise_for_status()


@asynccontextmanager
async def create_sqs_client(
queue: str,
auth: AWSAuthConfig,
*,
client: Optional[AsyncClient] = None,
) -> AsyncIterator[SQSClient]:
async with AsyncExitStack() as stack:
if client is None:
client = await stack.enter_async_context(AsyncClient())
assert client is not None # for mypy
yield SQSClient(
queue_name_or_url=queue,
auth=auth,
client=client,
)
Loading

0 comments on commit 1e2455f

Please sign in to comment.