Skip to content

Commit

Permalink
add HttpxStreamingBody to reduce test changes
Browse files Browse the repository at this point in the history
  • Loading branch information
jakkdl committed Feb 28, 2025
1 parent d1c0d12 commit 9462503
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 69 deletions.
4 changes: 2 additions & 2 deletions aiobotocore/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from aiobotocore.httpchecksum import handle_checksum_body
from aiobotocore.httpsession import AIOHTTPSession
from aiobotocore.response import StreamingBody
from aiobotocore.response import HttpxStreamingBody, StreamingBody

try:
import httpx
Expand Down Expand Up @@ -55,7 +55,7 @@ async def convert_to_response_dict(http_response, operation_model):
response_dict['body'] = http_response.raw
elif operation_model.has_streaming_output:
if httpx and isinstance(http_response.raw, httpx.Response):
response_dict['body'] = http_response.raw
response_dict['body'] = HttpxStreamingBody(http_response.raw)
else:
length = response_dict['headers'].get('content-length')
response_dict['body'] = StreamingBody(http_response.raw, length)
Expand Down
26 changes: 26 additions & 0 deletions aiobotocore/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,32 @@ def _verify_content_length(self):
def tell(self):
return self._self_amount_read

async def aclose(self):
return self.__wrapped__.close()


class HttpxStreamingBody(wrapt.ObjectProxy):
def __init__(self, raw_stream: aiohttp.StreamReader):
super().__init__(raw_stream)

async def read(self, amt=None):
if amt is not None:
# We could do a fancy thing here and start doing calls to
# aiter_bytes()/aiter_raw() and keep state
raise ValueError(
"httpx.Response.aread does not support reading a specific number of bytes"
)
return await self.__wrapped__.aread()

async def __aenter__(self):
# use AsyncClient.stream somehow?
# See "manual mode" at https://www.python-httpx.org/async/#streaming-responses
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
# TODO: I'm pretty sure this eats exceptions
await self.__wrapped__.aclose()


async def get_response(operation_model, http_response):
protocol = operation_model.metadata['protocol']
Expand Down
6 changes: 1 addition & 5 deletions tests/python3.8/boto_tests/unit/test_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
def parametrize(cases):
return pytest.mark.parametrize(
"test_case",
cases,
deepcopy(cases),
ids=[c["documentation"] for c in cases],
)

Expand Down Expand Up @@ -287,10 +287,6 @@ async def test_sso_token_provider_refresh(test_case):
cache_key = "d033e22ae348aeb5660fc2140aec35850c4da997"
token_cache = {}

# deepcopy the test case so the test can be parametrized against the same
# test case w/ aiohttp & httpx
test_case = deepcopy(test_case)

# Prepopulate the token cache
cached_token = test_case.pop("cachedToken", None)
if cached_token:
Expand Down
78 changes: 21 additions & 57 deletions tests/test_basic_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,15 +197,9 @@ async def test_can_get_and_put_object(
await create_object(key_name, body='body contents')

resp = await s3_client.get_object(Bucket=bucket_name, Key=key_name)
if httpx and isinstance(resp['Body'], httpx.Response):
data = await resp['Body'].aread()
# note that calling `aclose()` is redundant, httpx will auto-close when the
# data is fully read
await resp['Body'].aclose()
else:
data = await resp['Body'].read()
# TODO: think about better api and make behavior like in aiohttp
resp['Body'].close()
data = await resp['Body'].read()
# TODO: think about better api and make behavior like in aiohttp
await resp['Body'].aclose()
assert data == b'body contents'

# now test checksum'd file
Expand All @@ -214,10 +208,7 @@ async def test_can_get_and_put_object(
resp = await s3_client.get_object(
Bucket=bucket_name, Key=key_name, ChecksumMode="ENABLED"
)
if httpx and isinstance(resp['Body'], httpx.Response):
data = await resp['Body'].aread()
else:
data = await resp['Body'].read()
data = await resp['Body'].read()
assert data == b'abcd'


Expand Down Expand Up @@ -280,31 +271,27 @@ async def test_get_object_stream_wrapper(
response = await s3_client.get_object(Bucket=bucket_name, Key='foobarbaz')
body = response['Body']
if httpx and isinstance(body, httpx.Response):
# httpx does not support `.aread(1)`
byte_iterator = body.aiter_raw(1)
chunk1 = await byte_iterator.__anext__()
chunk2 = b""
async for b in byte_iterator:
chunk2 += b
await body.aclose()
else:
chunk1 = await body.read(1)
chunk2 = await body.read()
body.close()
assert chunk1 == b'b'
assert chunk2 == b'ody contents'
await body.aclose()


async def test_get_object_stream_context(
s3_client, create_object, bucket_name
):
await create_object('foobarbaz', body='body contents')
response = await s3_client.get_object(Bucket=bucket_name, Key='foobarbaz')
# httpx does not support context manager
if httpx and isinstance(response['Body'], httpx.Response):
data = await response['Body'].aread()
else:
async with response['Body'] as stream:
data = await stream.read()
async with response['Body'] as stream:
data = await stream.read()
assert data == b'body contents'


Expand Down Expand Up @@ -399,12 +386,8 @@ async def test_unicode_key_put_list(s3_client, bucket_name, create_object):
assert len(parsed['Contents']) == 1
assert parsed['Contents'][0]['Key'] == key_name
parsed = await s3_client.get_object(Bucket=bucket_name, Key=key_name)
if httpx and isinstance(parsed['Body'], httpx.Response):
data = await parsed['Body'].aread()
await parsed['Body'].aclose()
else:
data = await parsed['Body'].read()
parsed['Body'].close()
data = await parsed['Body'].read()
await parsed['Body'].aclose()
assert data == b'foo'


Expand Down Expand Up @@ -456,12 +439,8 @@ async def test_copy_with_quoted_char(s3_client, create_object, bucket_name):

# Now verify we can retrieve the copied object.
resp = await s3_client.get_object(Bucket=bucket_name, Key=key_name2)
if httpx and isinstance(resp['Body'], httpx.Response):
data = await resp['Body'].aread()
await resp['Body'].aclose()
else:
data = await resp['Body'].read()
resp['Body'].close()
data = await resp['Body'].read()
await resp['Body'].aclose()
assert data == b'foo'


Expand All @@ -478,12 +457,8 @@ async def test_copy_with_query_string(s3_client, create_object, bucket_name):

# Now verify we can retrieve the copied object.
resp = await s3_client.get_object(Bucket=bucket_name, Key=key_name2)
if httpx and isinstance(resp['Body'], httpx.Response):
data = await resp['Body'].aread()
await resp['Body'].aclose()
else:
data = await resp['Body'].read()
resp['Body'].close()
data = await resp['Body'].read()
await resp['Body'].aclose()
assert data == b'foo'


Expand All @@ -500,12 +475,8 @@ async def test_can_copy_with_dict_form(s3_client, create_object, bucket_name):

# Now verify we can retrieve the copied object.
resp = await s3_client.get_object(Bucket=bucket_name, Key=key_name2)
if httpx and isinstance(resp['Body'], httpx.Response):
data = await resp['Body'].aread()
await resp['Body'].aclose()
else:
data = await resp['Body'].read()
resp['Body'].close()
data = await resp['Body'].read()
await resp['Body'].aclose()
assert data == b'foo'


Expand All @@ -527,12 +498,8 @@ async def test_can_copy_with_dict_form_with_version(

# Now verify we can retrieve the copied object.
resp = await s3_client.get_object(Bucket=bucket_name, Key=key_name2)
if httpx and isinstance(resp['Body'], httpx.Response):
data = await resp['Body'].aread()
await resp['Body'].aclose()
else:
data = await resp['Body'].read()
resp['Body'].close()
data = await resp['Body'].read()
await resp['Body'].aclose()
assert data == b'foo'


Expand Down Expand Up @@ -570,6 +537,7 @@ async def test_presign_with_existing_query_string_values(
'get_object', Params=params
)
# Try to retrieve the object using the presigned url.
# TODO: compatibility layer between httpx.AsyncClient and aiohttp.ClientSession?
if httpx and isinstance(aio_session, httpx.AsyncClient):
async with aio_session.stream("GET", presigned_url) as resp:
data = await resp.aread()
Expand Down Expand Up @@ -625,12 +593,8 @@ async def test_can_follow_signed_url_redirect(
resp = await alternative_s3_client.get_object(
Bucket=bucket_name, Key='foobarbaz'
)
if httpx and isinstance(resp['Body'], httpx.Response):
data = await resp['Body'].aread()
await resp['Body'].aclose()
else:
data = await resp['Body'].read()
resp['Body'].close()
data = await resp['Body'].read()
await resp['Body'].aclose()
assert data == b'foo'


Expand Down
7 changes: 2 additions & 5 deletions tests/test_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,8 @@ async def test_run_lambda(
Payload=json.dumps({"hello": "world"}),
)

if httpx and isinstance(invoke_response['Payload'], httpx.Response):
data = await invoke_response['Payload'].aread()
else:
async with invoke_response['Payload'] as stream:
data = await stream.read()
async with invoke_response['Payload'] as stream:
data = await stream.read()

log_result = base64.b64decode(invoke_response["LogResult"])

Expand Down

0 comments on commit 9462503

Please sign in to comment.