diff --git a/aioaws/s3.py b/aioaws/s3.py index cba8355..3b0c0e5 100644 --- a/aioaws/s3.py +++ b/aioaws/s3.py @@ -12,7 +12,7 @@ from pydantic import BaseModel, validator from ._utils import ManyTasks, pretty_xml, utcnow -from .core import AwsClient +from .core import AwsClient, RequestError if TYPE_CHECKING: from ._types import S3ConfigProtocol @@ -159,6 +159,19 @@ def signed_download_url(self, path: str, version: Optional[str] = None, max_age: url = url.copy_add_param('v', version) return str(url) + async def download(self, file: Union[str, S3File], version: Optional[str] = None) -> bytes: + if isinstance(file, str): + path = file + else: + path = file.key + + url = self.signed_download_url(path, version=version) + r = await self._aws_client.client.get(url) + if r.status_code == 200: + return r.content + else: + raise RequestError(r) + def signed_upload_url( self, *, diff --git a/tests/conftest.py b/tests/conftest.py index bf9c27c..f127187 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -40,7 +40,7 @@ def _merge_url(self, url): new_url = url.copy_with(scheme=self.scheme, host=self.host, port=self.port) if 's3.' in url.host: - return new_url.copy_with(path='/s3/') + return new_url.copy_with(path=f'/s3{new_url.path}') elif 'email.' in url.host: return new_url.copy_with(path='/ses/') elif url.host.startswith('sns.'): diff --git a/tests/dummy_server.py b/tests/dummy_server.py index 7856c35..ec0c973 100644 --- a/tests/dummy_server.py +++ b/tests/dummy_server.py @@ -1,11 +1,9 @@ import re -from io import BytesIO from typing import List from xml.etree import ElementTree from aiohttp import web from aiohttp.web_response import Response -from PIL import Image, ImageDraw from aioaws.testing import ses_email_data, ses_send_response @@ -70,13 +68,8 @@ async def s3_root(request: web.Request): return Response(body=body, content_type='text/xml') -async def s3_demo_image(request): - width, height = 2000, 1200 - stream = BytesIO() - image = Image.new('RGB', (width, height), (50, 100, 150)) - ImageDraw.Draw(image).line((0, 0) + image.size, fill=128) - image.save(stream, format='JPEG', optimize=True) - return Response(body=stream.getvalue()) +async def s3_file(request: web.Request): + return Response(body='this is demo file content') async def ses_send(request): @@ -130,7 +123,7 @@ async def xml_error(request): routes = [ web.route('*', '/s3/', s3_root), - web.get('/s3_demo_image_url/{image:.*}', s3_demo_image), + web.get('/s3/testing.txt', s3_file), web.post('/ses/', ses_send), web.get('/sns/certs/', aws_certs), web.get('/xml-error/', xml_error), diff --git a/tests/requirements.txt b/tests/requirements.txt index 504eeb1..506b560 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -1,6 +1,5 @@ aiohttp==3.8.3 foxglove-web==0.0.36 -pillow==9.4.0 coverage==7.0.4 dirty-equals==0.5.0 pytest==7.2.0 diff --git a/tests/test_s3.py b/tests/test_s3.py index d6115ce..e383a61 100644 --- a/tests/test_s3.py +++ b/tests/test_s3.py @@ -3,7 +3,7 @@ from datetime import datetime, timezone import pytest -from dirty_equals import IsNow +from dirty_equals import IsNow, IsStr from foxglove.test_server import DummyServer from httpx import AsyncClient @@ -135,6 +135,30 @@ async def test_list_delete_many(client: AsyncClient, aws: DummyServer): ] +@pytest.mark.asyncio +async def test_download_ok(client: AsyncClient, aws: DummyServer): + s3 = S3Client(client, S3Config('testing', 'testing', 'testing', 'testing')) + content = await s3.download('testing.txt') + assert content == b'this is demo file content' + assert aws.log == [IsStr(regex=r'GET /s3/testing\.txt\?.+ > 200')] + + +@pytest.mark.asyncio +async def test_download_ok_file(client: AsyncClient, aws: DummyServer): + s3 = S3Client(client, S3Config('testing', 'testing', 'testing', 'testing')) + content = await s3.download(S3File(Key='testing.txt', LastModified=0, Size=1, ETag='x', StorageClass='x')) + assert content == b'this is demo file content' + assert aws.log == [IsStr(regex=r'GET /s3/testing\.txt\?.+ > 200')] + + +@pytest.mark.asyncio +async def test_download_error(client: AsyncClient, aws: DummyServer): + s3 = S3Client(client, S3Config('testing', 'testing', 'testing', 'testing')) + with pytest.raises(RequestError): + await s3.download('missing.txt') + assert aws.log == [IsStr(regex=r'GET /s3/missing\.txt\?.+ > 404')] + + @pytest.mark.asyncio async def test_list_bad(client: AsyncClient): s3 = S3Client(client, S3Config('testing', 'testing', 'testing', 'testing'))