Skip to content

Commit

Permalink
add download method to S3 (#38)
Browse files Browse the repository at this point in the history
* add download method to s3

* more tests
  • Loading branch information
samuelcolvin authored Jan 11, 2023
1 parent 1e2455f commit 402c3f9
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 14 deletions.
15 changes: 14 additions & 1 deletion aioaws/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pydantic import BaseModel, validator

from ._utils import ManyTasks, pretty_xml, utcnow
from .core import AwsClient
from .core import AwsClient, RequestError

if TYPE_CHECKING:
from ._types import S3ConfigProtocol
Expand Down Expand Up @@ -159,6 +159,19 @@ def signed_download_url(self, path: str, version: Optional[str] = None, max_age:
url = url.copy_add_param('v', version)
return str(url)

async def download(self, file: Union[str, S3File], version: Optional[str] = None) -> bytes:
if isinstance(file, str):
path = file
else:
path = file.key

url = self.signed_download_url(path, version=version)
r = await self._aws_client.client.get(url)
if r.status_code == 200:
return r.content
else:
raise RequestError(r)

def signed_upload_url(
self,
*,
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def _merge_url(self, url):

new_url = url.copy_with(scheme=self.scheme, host=self.host, port=self.port)
if 's3.' in url.host:
return new_url.copy_with(path='/s3/')
return new_url.copy_with(path=f'/s3{new_url.path}')
elif 'email.' in url.host:
return new_url.copy_with(path='/ses/')
elif url.host.startswith('sns.'):
Expand Down
13 changes: 3 additions & 10 deletions tests/dummy_server.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import re
from io import BytesIO
from typing import List
from xml.etree import ElementTree

from aiohttp import web
from aiohttp.web_response import Response
from PIL import Image, ImageDraw

from aioaws.testing import ses_email_data, ses_send_response

Expand Down Expand Up @@ -70,13 +68,8 @@ async def s3_root(request: web.Request):
return Response(body=body, content_type='text/xml')


async def s3_demo_image(request):
width, height = 2000, 1200
stream = BytesIO()
image = Image.new('RGB', (width, height), (50, 100, 150))
ImageDraw.Draw(image).line((0, 0) + image.size, fill=128)
image.save(stream, format='JPEG', optimize=True)
return Response(body=stream.getvalue())
async def s3_file(request: web.Request):
return Response(body='this is demo file content')


async def ses_send(request):
Expand Down Expand Up @@ -130,7 +123,7 @@ async def xml_error(request):

routes = [
web.route('*', '/s3/', s3_root),
web.get('/s3_demo_image_url/{image:.*}', s3_demo_image),
web.get('/s3/testing.txt', s3_file),
web.post('/ses/', ses_send),
web.get('/sns/certs/', aws_certs),
web.get('/xml-error/', xml_error),
Expand Down
1 change: 0 additions & 1 deletion tests/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
aiohttp==3.8.3
foxglove-web==0.0.36
pillow==9.4.0
coverage==7.0.4
dirty-equals==0.5.0
pytest==7.2.0
Expand Down
26 changes: 25 additions & 1 deletion tests/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from datetime import datetime, timezone

import pytest
from dirty_equals import IsNow
from dirty_equals import IsNow, IsStr
from foxglove.test_server import DummyServer
from httpx import AsyncClient

Expand Down Expand Up @@ -135,6 +135,30 @@ async def test_list_delete_many(client: AsyncClient, aws: DummyServer):
]


@pytest.mark.asyncio
async def test_download_ok(client: AsyncClient, aws: DummyServer):
s3 = S3Client(client, S3Config('testing', 'testing', 'testing', 'testing'))
content = await s3.download('testing.txt')
assert content == b'this is demo file content'
assert aws.log == [IsStr(regex=r'GET /s3/testing\.txt\?.+ > 200')]


@pytest.mark.asyncio
async def test_download_ok_file(client: AsyncClient, aws: DummyServer):
s3 = S3Client(client, S3Config('testing', 'testing', 'testing', 'testing'))
content = await s3.download(S3File(Key='testing.txt', LastModified=0, Size=1, ETag='x', StorageClass='x'))
assert content == b'this is demo file content'
assert aws.log == [IsStr(regex=r'GET /s3/testing\.txt\?.+ > 200')]


@pytest.mark.asyncio
async def test_download_error(client: AsyncClient, aws: DummyServer):
s3 = S3Client(client, S3Config('testing', 'testing', 'testing', 'testing'))
with pytest.raises(RequestError):
await s3.download('missing.txt')
assert aws.log == [IsStr(regex=r'GET /s3/missing\.txt\?.+ > 404')]


@pytest.mark.asyncio
async def test_list_bad(client: AsyncClient):
s3 = S3Client(client, S3Config('testing', 'testing', 'testing', 'testing'))
Expand Down

0 comments on commit 402c3f9

Please sign in to comment.