From 64f8cfe551b1fa915ba029dcf2b522d9b496f4e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=AC=A2=20Samuel=20Colvin?= Date: Tue, 10 Jan 2023 15:46:01 +0000 Subject: [PATCH] allow a custom host/endpoint for S3 (#36) * allow a custom host/endpoint for S3 * make endpoint a property * linting * coverage and mypy * more tests * linting --- aioaws/_types.py | 2 ++ aioaws/_utils.py | 18 +++++++++++++----- aioaws/core.py | 38 ++++++++++++++++++++++++-------------- aioaws/s3.py | 6 ++++-- tests/dummy_server.py | 5 +++++ tests/test_s3.py | 3 ++- tests/test_utils.py | 22 +++++++++++++++++++++- 7 files changed, 71 insertions(+), 23 deletions(-) diff --git a/aioaws/_types.py b/aioaws/_types.py index 1c19e27..cefdcfa 100644 --- a/aioaws/_types.py +++ b/aioaws/_types.py @@ -5,6 +5,8 @@ class BaseConfigProtocol(Protocol): aws_access_key: str aws_secret_key: str aws_region: str + # aws_host is optional and will be inferred if omitted + # aws_host: str class S3ConfigProtocol(Protocol): diff --git a/aioaws/_utils.py b/aioaws/_utils.py index 7a266b5..691fda5 100644 --- a/aioaws/_utils.py +++ b/aioaws/_utils.py @@ -7,7 +7,7 @@ if TYPE_CHECKING: from ._types import BaseConfigProtocol -__all__ = 'get_config_attr', 'to_unix_s', 'utcnow', 'ManyTasks', 'pretty_response' +__all__ = 'get_config_attr', 'to_unix_s', 'utcnow', 'ManyTasks', 'pretty_xml', 'pretty_response' EPOCH = datetime(1970, 1, 1) EPOCH_TZ = EPOCH.replace(tzinfo=timezone.utc) @@ -55,16 +55,24 @@ async def finish(self) -> Iterable[Any]: return await asyncio.gather(*self._tasks) -def pretty_response(r: Response) -> None: # pragma: no cover - from xml.etree import ElementTree +def pretty_xml(response_xml: bytes) -> str: + import xml.dom.minidom + + try: + pretty = xml.dom.minidom.parseString(response_xml).toprettyxml(indent=' ') + except Exception: # pragma: no cover + return response_xml.decode() + else: + return f'{pretty} (XML formatted by aioaws)' + +def pretty_response(r: Response) -> None: # pragma: no cover from devtools import debug - xml_root = ElementTree.fromstring(r.content) debug( status=r.status_code, url=str(r.url), headers=dict(r.request.headers), history=r.history, - xml={el.tag: el.text for el in xml_root}, + xml=pretty_xml(r.content), ) diff --git a/aioaws/core.py b/aioaws/core.py index c03dbb1..2dd736b 100644 --- a/aioaws/core.py +++ b/aioaws/core.py @@ -10,7 +10,7 @@ from httpx import URL, AsyncClient, Response -from ._utils import get_config_attr, utcnow +from ._utils import get_config_attr, pretty_xml, utcnow if TYPE_CHECKING: from ._types import BaseConfigProtocol @@ -34,19 +34,25 @@ def __init__(self, client: AsyncClient, config: 'BaseConfigProtocol', service: L self.aws_secret_key = get_config_attr(config, 'aws_secret_key') self.service = service self.region = get_config_attr(config, 'aws_region') - if self.service == 'ses': - self.host = f'email.{self.region}.amazonaws.com' - else: - assert self.service == 's3', self.service - bucket = get_config_attr(config, 'aws_s3_bucket') - if '.' in bucket: - # assumes the bucket is a domain and is already as a CNAME record for S3 - self.host = bucket + try: + self.host = get_config_attr(config, 'aws_host') + except TypeError: + if self.service == 'ses': + self.host = f'email.{self.region}.amazonaws.com' else: - # see https://docs.aws.amazon.com/AmazonS3/latest/userguide/access-bucket-intro.html - self.host = f'{bucket}.s3.{self.region}.amazonaws.com' - - self.endpoint = f'https://{self.host}' + assert self.service == 's3', self.service + bucket = get_config_attr(config, 'aws_s3_bucket') + if '.' in bucket: + # assumes the bucket is a domain and is already as a CNAME record for S3 + self.host = bucket + else: + # see https://docs.aws.amazon.com/AmazonS3/latest/userguide/access-bucket-intro.html + self.host = f'{bucket}.s3.{self.region}.amazonaws.com' + self.schema = 'https' + + @property + def endpoint(self) -> str: + return f'{self.schema}://{self.host}' async def get(self, path: str = '', *, params: Optional[Dict[str, Any]] = None) -> Response: return await self.request('GET', path=path, params=params) @@ -222,4 +228,8 @@ def __init__(self, r: Response): self.status = r.status_code def __str__(self) -> str: - return f'{self.args[0]}, response:\n{self.response.text}' + if self.response.headers.get('content-type') == 'application/xml': + text = pretty_xml(self.response.content) + else: + text = self.response.text + return f'{self.args[0]}, response:\n{text}' diff --git a/aioaws/s3.py b/aioaws/s3.py index e4a7ff4..cba8355 100644 --- a/aioaws/s3.py +++ b/aioaws/s3.py @@ -11,7 +11,7 @@ from httpx import URL, AsyncClient from pydantic import BaseModel, validator -from ._utils import ManyTasks, utcnow +from ._utils import ManyTasks, pretty_xml, utcnow from .core import AwsClient if TYPE_CHECKING: @@ -32,6 +32,8 @@ class S3Config: aws_secret_key: str aws_region: str aws_s3_bucket: str + # custom host to connect with + aws_host: Optional[str] = None class S3File(BaseModel): @@ -80,7 +82,7 @@ async def list(self, prefix: Optional[str] = None) -> AsyncIterable[S3File]: if (t := xml_root.find('NextContinuationToken')) is not None: continuation_token = t.text else: - raise RuntimeError(f'unexpected response from S3: {r.text!r}') + raise RuntimeError(f'unexpected response from S3:\n{pretty_xml(r.content)}') async def delete(self, *files: Union[str, S3File]) -> List[str]: """ diff --git a/tests/dummy_server.py b/tests/dummy_server.py index 660c6ae..7856c35 100644 --- a/tests/dummy_server.py +++ b/tests/dummy_server.py @@ -124,9 +124,14 @@ async def aws_certs(request): return Response(body=aws_certs_body, content_type='content/unknown') +async def xml_error(request): + return Response(body=s3_list_response_template, content_type='application/xml', status=456) + + routes = [ web.route('*', '/s3/', s3_root), web.get('/s3_demo_image_url/{image:.*}', s3_demo_image), web.post('/ses/', ses_send), web.get('/sns/certs/', aws_certs), + web.get('/xml-error/', xml_error), ] diff --git a/tests/test_s3.py b/tests/test_s3.py index 2086201..d6115ce 100644 --- a/tests/test_s3.py +++ b/tests/test_s3.py @@ -18,7 +18,8 @@ def test_upload_url_after_overriding_aws_client_endpoint(mocker): mocker.patch('aioaws.s3.utcnow', return_value=datetime(2032, 1, 1)) s3 = S3Client('-', S3Config('testing', 'testing', 'testing', 'testing.com')) - s3._aws_client.endpoint = 'http://localhost:4766' + s3._aws_client.host = 'localhost:4766' + s3._aws_client.schema = 'http' d = s3.signed_upload_url( path='testing/', filename='test.png', content_type='image/png', size=123, expires=datetime(2032, 1, 1) ) diff --git a/tests/test_utils.py b/tests/test_utils.py index 85a04bf..00732c7 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,6 +1,7 @@ import pytest +from httpx import AsyncClient -from aioaws import _types, _utils +from aioaws import _types, _utils, core def test_get_config_attr(): @@ -20,3 +21,22 @@ class Foo: def test_types(): assert hasattr(_types, 'BaseConfigProtocol') assert hasattr(_types, 'S3ConfigProtocol') + + +@pytest.mark.asyncio +async def test_response_error_xml(client: AsyncClient): + response = await client.get(f'http://localhost:{client.port}/xml-error/') + assert response.status_code == 456 + e = core.RequestError(response) + assert str(e).endswith('(XML formatted by aioaws)') + + +@pytest.mark.asyncio +async def test_response_error_not_xml(client: AsyncClient): + response = await client.get(f'http://localhost:{client.port}/status/400/') + assert response.status_code == 400 + e = core.RequestError(response) + assert str(e) == ( + f'unexpected response from GET "http://localhost:{client.port}/status/400/": 400, response:\n' + 'test response with status 400' + )