Skip to content

Commit

Permalink
allow a custom host/endpoint for S3 (#36)
Browse files Browse the repository at this point in the history
* allow a custom host/endpoint for S3

* make endpoint a property

* linting

* coverage and mypy

* more tests

* linting
  • Loading branch information
samuelcolvin authored Jan 10, 2023
1 parent 6a38c84 commit 64f8cfe
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 23 deletions.
2 changes: 2 additions & 0 deletions aioaws/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ class BaseConfigProtocol(Protocol):
aws_access_key: str
aws_secret_key: str
aws_region: str
# aws_host is optional and will be inferred if omitted
# aws_host: str


class S3ConfigProtocol(Protocol):
Expand Down
18 changes: 13 additions & 5 deletions aioaws/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
if TYPE_CHECKING:
from ._types import BaseConfigProtocol

__all__ = 'get_config_attr', 'to_unix_s', 'utcnow', 'ManyTasks', 'pretty_response'
__all__ = 'get_config_attr', 'to_unix_s', 'utcnow', 'ManyTasks', 'pretty_xml', 'pretty_response'

EPOCH = datetime(1970, 1, 1)
EPOCH_TZ = EPOCH.replace(tzinfo=timezone.utc)
Expand Down Expand Up @@ -55,16 +55,24 @@ async def finish(self) -> Iterable[Any]:
return await asyncio.gather(*self._tasks)


def pretty_response(r: Response) -> None: # pragma: no cover
from xml.etree import ElementTree
def pretty_xml(response_xml: bytes) -> str:
import xml.dom.minidom

try:
pretty = xml.dom.minidom.parseString(response_xml).toprettyxml(indent=' ')
except Exception: # pragma: no cover
return response_xml.decode()
else:
return f'{pretty} (XML formatted by aioaws)'


def pretty_response(r: Response) -> None: # pragma: no cover
from devtools import debug

xml_root = ElementTree.fromstring(r.content)
debug(
status=r.status_code,
url=str(r.url),
headers=dict(r.request.headers),
history=r.history,
xml={el.tag: el.text for el in xml_root},
xml=pretty_xml(r.content),
)
38 changes: 24 additions & 14 deletions aioaws/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from httpx import URL, AsyncClient, Response

from ._utils import get_config_attr, utcnow
from ._utils import get_config_attr, pretty_xml, utcnow

if TYPE_CHECKING:
from ._types import BaseConfigProtocol
Expand All @@ -34,19 +34,25 @@ def __init__(self, client: AsyncClient, config: 'BaseConfigProtocol', service: L
self.aws_secret_key = get_config_attr(config, 'aws_secret_key')
self.service = service
self.region = get_config_attr(config, 'aws_region')
if self.service == 'ses':
self.host = f'email.{self.region}.amazonaws.com'
else:
assert self.service == 's3', self.service
bucket = get_config_attr(config, 'aws_s3_bucket')
if '.' in bucket:
# assumes the bucket is a domain and is already as a CNAME record for S3
self.host = bucket
try:
self.host = get_config_attr(config, 'aws_host')
except TypeError:
if self.service == 'ses':
self.host = f'email.{self.region}.amazonaws.com'
else:
# see https://docs.aws.amazon.com/AmazonS3/latest/userguide/access-bucket-intro.html
self.host = f'{bucket}.s3.{self.region}.amazonaws.com'

self.endpoint = f'https://{self.host}'
assert self.service == 's3', self.service
bucket = get_config_attr(config, 'aws_s3_bucket')
if '.' in bucket:
# assumes the bucket is a domain and is already as a CNAME record for S3
self.host = bucket
else:
# see https://docs.aws.amazon.com/AmazonS3/latest/userguide/access-bucket-intro.html
self.host = f'{bucket}.s3.{self.region}.amazonaws.com'
self.schema = 'https'

@property
def endpoint(self) -> str:
return f'{self.schema}://{self.host}'

async def get(self, path: str = '', *, params: Optional[Dict[str, Any]] = None) -> Response:
return await self.request('GET', path=path, params=params)
Expand Down Expand Up @@ -222,4 +228,8 @@ def __init__(self, r: Response):
self.status = r.status_code

def __str__(self) -> str:
return f'{self.args[0]}, response:\n{self.response.text}'
if self.response.headers.get('content-type') == 'application/xml':
text = pretty_xml(self.response.content)
else:
text = self.response.text
return f'{self.args[0]}, response:\n{text}'
6 changes: 4 additions & 2 deletions aioaws/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from httpx import URL, AsyncClient
from pydantic import BaseModel, validator

from ._utils import ManyTasks, utcnow
from ._utils import ManyTasks, pretty_xml, utcnow
from .core import AwsClient

if TYPE_CHECKING:
Expand All @@ -32,6 +32,8 @@ class S3Config:
aws_secret_key: str
aws_region: str
aws_s3_bucket: str
# custom host to connect with
aws_host: Optional[str] = None


class S3File(BaseModel):
Expand Down Expand Up @@ -80,7 +82,7 @@ async def list(self, prefix: Optional[str] = None) -> AsyncIterable[S3File]:
if (t := xml_root.find('NextContinuationToken')) is not None:
continuation_token = t.text
else:
raise RuntimeError(f'unexpected response from S3: {r.text!r}')
raise RuntimeError(f'unexpected response from S3:\n{pretty_xml(r.content)}')

async def delete(self, *files: Union[str, S3File]) -> List[str]:
"""
Expand Down
5 changes: 5 additions & 0 deletions tests/dummy_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,14 @@ async def aws_certs(request):
return Response(body=aws_certs_body, content_type='content/unknown')


async def xml_error(request):
return Response(body=s3_list_response_template, content_type='application/xml', status=456)


routes = [
web.route('*', '/s3/', s3_root),
web.get('/s3_demo_image_url/{image:.*}', s3_demo_image),
web.post('/ses/', ses_send),
web.get('/sns/certs/', aws_certs),
web.get('/xml-error/', xml_error),
]
3 changes: 2 additions & 1 deletion tests/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
def test_upload_url_after_overriding_aws_client_endpoint(mocker):
mocker.patch('aioaws.s3.utcnow', return_value=datetime(2032, 1, 1))
s3 = S3Client('-', S3Config('testing', 'testing', 'testing', 'testing.com'))
s3._aws_client.endpoint = 'http://localhost:4766'
s3._aws_client.host = 'localhost:4766'
s3._aws_client.schema = 'http'
d = s3.signed_upload_url(
path='testing/', filename='test.png', content_type='image/png', size=123, expires=datetime(2032, 1, 1)
)
Expand Down
22 changes: 21 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
from httpx import AsyncClient

from aioaws import _types, _utils
from aioaws import _types, _utils, core


def test_get_config_attr():
Expand All @@ -20,3 +21,22 @@ class Foo:
def test_types():
assert hasattr(_types, 'BaseConfigProtocol')
assert hasattr(_types, 'S3ConfigProtocol')


@pytest.mark.asyncio
async def test_response_error_xml(client: AsyncClient):
response = await client.get(f'http://localhost:{client.port}/xml-error/')
assert response.status_code == 456
e = core.RequestError(response)
assert str(e).endswith('(XML formatted by aioaws)')


@pytest.mark.asyncio
async def test_response_error_not_xml(client: AsyncClient):
response = await client.get(f'http://localhost:{client.port}/status/400/')
assert response.status_code == 400
e = core.RequestError(response)
assert str(e) == (
f'unexpected response from GET "http://localhost:{client.port}/status/400/": 400, response:\n'
'test response with status 400'
)

0 comments on commit 64f8cfe

Please sign in to comment.