Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add overall timeout support #936

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
## Unreleased

- Handle `SSLError` exception. (#918)
- Add overall timeout support. (#936)

## 1.0.5 (March 27th, 2024)

Expand Down
14 changes: 13 additions & 1 deletion docs/extensions.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ print(r.extensions["http_version"])

A dictionary of `str: Optional[float]` timeout values.

May include values for `'connect'`, `'read'`, `'write'`, or `'pool'`.
May include values for `'connect'`, `'read'`, `'write'`, `'pool'` or `'total'`.
karpetrosyan marked this conversation as resolved.
Show resolved Hide resolved

For example:

Expand All @@ -86,6 +86,18 @@ r = httpcore.request(
)
```

or

```python
# Timeout if we are blocked waiting for the connection read for more
# than a second, or if the total time of the request exceeds 10 seconds.
r = httpcore.request(
"GET",
"https://www.example.com",
extensions={"timeout": {"read": 1.0, "total": 10.0}}
)
```

### `"trace"`

The trace extension allows a callback handler to be installed to monitor the internal
Expand Down
9 changes: 7 additions & 2 deletions httpcore/_async/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from types import TracebackType
from typing import Iterable, Iterator, Optional, Type

from httpcore._utils import OverallTimeoutHandler

from .._backends.auto import AutoBackend
from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream
from .._exceptions import ConnectError, ConnectTimeout
Expand Down Expand Up @@ -105,6 +107,8 @@ async def _connect(self, request: Request) -> AsyncNetworkStream:
sni_hostname = request.extensions.get("sni_hostname", None)
timeout = timeouts.get("connect", None)

overall_timeout = OverallTimeoutHandler(timeouts)

retries_left = self._retries
delays = exponential_backoff(factor=RETRIES_BACKOFF_FACTOR)

Expand All @@ -115,11 +119,12 @@ async def _connect(self, request: Request) -> AsyncNetworkStream:
"host": self._origin.host.decode("ascii"),
"port": self._origin.port,
"local_address": self._local_address,
"timeout": timeout,
"timeout": overall_timeout.get_minimum_timeout(timeout),
"socket_options": self._socket_options,
}
async with Trace("connect_tcp", logger, request, kwargs) as trace:
stream = await self._network_backend.connect_tcp(**kwargs)
with overall_timeout:
stream = await self._network_backend.connect_tcp(**kwargs)
trace.return_value = stream
else:
kwargs = {
Expand Down
10 changes: 8 additions & 2 deletions httpcore/_async/connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from types import TracebackType
from typing import AsyncIterable, AsyncIterator, Iterable, List, Optional, Type

from httpcore._utils import OverallTimeoutHandler

from .._backends.auto import AutoBackend
from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend
from .._exceptions import ConnectionNotAvailable, UnsupportedProtocol
Expand Down Expand Up @@ -174,6 +176,7 @@ async def handle_async_request(self, request: Request) -> Response:

timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("pool", None)
overall_timeout = OverallTimeoutHandler(timeouts)

with self._optional_thread_lock:
# Add the incoming request to our request queue.
Expand All @@ -188,8 +191,11 @@ async def handle_async_request(self, request: Request) -> Response:
closing = self._assign_requests_to_connections()
await self._close_connections(closing)

# Wait until this request has an assigned connection.
connection = await pool_request.wait_for_connection(timeout=timeout)
with overall_timeout:
# Wait until this request has an assigned connection.
connection = await pool_request.wait_for_connection(
timeout=overall_timeout.get_minimum_timeout(timeout)
)

try:
# Send the request on the assigned connection.
Expand Down
31 changes: 26 additions & 5 deletions httpcore/_async/http11.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

import h11

from httpcore._utils import OverallTimeoutHandler

from .._backends.base import AsyncNetworkStream
from .._exceptions import (
ConnectionNotAvailable,
Expand Down Expand Up @@ -147,25 +149,37 @@ async def handle_async_request(self, request: Request) -> Response:
async def _send_request_headers(self, request: Request) -> None:
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("write", None)
overall_timeout = OverallTimeoutHandler(timeouts)

with map_exceptions({h11.LocalProtocolError: LocalProtocolError}):
event = h11.Request(
method=request.method,
target=request.url.target,
headers=request.headers,
)
await self._send_event(event, timeout=timeout)
with overall_timeout:
await self._send_event(
event, timeout=overall_timeout.get_minimum_timeout(timeout)
)

async def _send_request_body(self, request: Request) -> None:
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("write", None)
overall_timeout = OverallTimeoutHandler(timeouts)

assert isinstance(request.stream, AsyncIterable)
async for chunk in request.stream:
event = h11.Data(data=chunk)
await self._send_event(event, timeout=timeout)

await self._send_event(h11.EndOfMessage(), timeout=timeout)
with overall_timeout:
await self._send_event(
event, timeout=overall_timeout.get_minimum_timeout(timeout)
)

with overall_timeout:
await self._send_event(
h11.EndOfMessage(), timeout=overall_timeout.get_minimum_timeout(timeout)
)

async def _send_event(
self, event: h11.Event, timeout: Optional[float] = None
Expand All @@ -181,9 +195,13 @@ async def _receive_response_headers(
) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]], bytes]:
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("read", None)
overall_timeout = OverallTimeoutHandler(timeouts)

while True:
event = await self._receive_event(timeout=timeout)
with overall_timeout:
event = await self._receive_event(
timeout=overall_timeout.get_minimum_timeout(timeout)
)
if isinstance(event, h11.Response):
break
if (
Expand All @@ -205,9 +223,12 @@ async def _receive_response_headers(
async def _receive_response_body(self, request: Request) -> AsyncIterator[bytes]:
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("read", None)
overall_timeout = OverallTimeoutHandler(timeouts)

while True:
event = await self._receive_event(timeout=timeout)
event = await self._receive_event(
timeout=overall_timeout.get_minimum_timeout(timeout)
)
if isinstance(event, h11.Data):
yield bytes(event.data)
elif isinstance(event, (h11.EndOfMessage, h11.PAUSED)):
Expand Down
12 changes: 10 additions & 2 deletions httpcore/_async/http2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import h2.exceptions
import h2.settings

from httpcore._utils import OverallTimeoutHandler

from .._backends.base import AsyncNetworkStream
from .._exceptions import (
ConnectionNotAvailable,
Expand Down Expand Up @@ -430,12 +432,16 @@ async def _read_incoming_data(
) -> typing.List[h2.events.Event]:
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("read", None)
overall_timeout = OverallTimeoutHandler(timeouts)

if self._read_exception is not None:
raise self._read_exception # pragma: nocover

try:
data = await self._network_stream.read(self.READ_NUM_BYTES, timeout)
with overall_timeout:
data = await self._network_stream.read(
self.READ_NUM_BYTES, overall_timeout.get_minimum_timeout(timeout)
)
if data == b"":
raise RemoteProtocolError("Server disconnected")
except Exception as exc:
Expand All @@ -458,6 +464,7 @@ async def _read_incoming_data(
async def _write_outgoing_data(self, request: Request) -> None:
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("write", None)
overall_timeout = OverallTimeoutHandler(timeouts)

async with self._write_lock:
data_to_send = self._h2_state.data_to_send()
Expand All @@ -466,7 +473,8 @@ async def _write_outgoing_data(self, request: Request) -> None:
raise self._write_exception # pragma: nocover

try:
await self._network_stream.write(data_to_send, timeout)
with overall_timeout:
await self._network_stream.write(data_to_send, timeout)
except Exception as exc: # pragma: nocover
# If we get a network error we should:
#
Expand Down
8 changes: 6 additions & 2 deletions httpcore/_async/http_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from base64 import b64encode
from typing import Iterable, List, Mapping, Optional, Sequence, Tuple, Union

from httpcore._utils import OverallTimeoutHandler

from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend
from .._exceptions import ProxyError
from .._models import (
Expand Down Expand Up @@ -266,6 +268,7 @@ def __init__(
async def handle_async_request(self, request: Request) -> Response:
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("connect", None)
overall_timeout = OverallTimeoutHandler(timeouts)

async with self._connect_lock:
if not self._connected:
Expand Down Expand Up @@ -311,10 +314,11 @@ async def handle_async_request(self, request: Request) -> Response:
kwargs = {
"ssl_context": ssl_context,
"server_hostname": self._remote_origin.host.decode("ascii"),
"timeout": timeout,
"timeout": overall_timeout.get_minimum_timeout(timeout),
}
async with Trace("start_tls", logger, request, kwargs) as trace:
stream = await stream.start_tls(**kwargs)
with overall_timeout:
stream = await stream.start_tls(**kwargs)
trace.return_value = stream

# Determine if we should be using HTTP/1.1 or HTTP/2
Expand Down
8 changes: 6 additions & 2 deletions httpcore/_async/socks_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from socksio import socks5

from httpcore._utils import OverallTimeoutHandler

from .._backends.auto import AutoBackend
from .._backends.base import AsyncNetworkBackend, AsyncNetworkStream
from .._exceptions import ConnectionNotAvailable, ProxyError
Expand Down Expand Up @@ -218,6 +220,7 @@ async def handle_async_request(self, request: Request) -> Response:
timeouts = request.extensions.get("timeout", {})
sni_hostname = request.extensions.get("sni_hostname", None)
timeout = timeouts.get("connect", None)
overall_timeout = OverallTimeoutHandler(timeouts)

async with self._connect_lock:
if self._connection is None:
Expand All @@ -226,10 +229,11 @@ async def handle_async_request(self, request: Request) -> Response:
kwargs = {
"host": self._proxy_origin.host.decode("ascii"),
"port": self._proxy_origin.port,
"timeout": timeout,
"timeout": overall_timeout.get_minimum_timeout(timeout),
}
async with Trace("connect_tcp", logger, request, kwargs) as trace:
stream = await self._network_backend.connect_tcp(**kwargs)
with overall_timeout:
stream = await self._network_backend.connect_tcp(**kwargs)
trace.return_value = stream

# Connect to the remote host using socks5
Expand Down
9 changes: 7 additions & 2 deletions httpcore/_sync/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from types import TracebackType
from typing import Iterable, Iterator, Optional, Type

from httpcore._utils import OverallTimeoutHandler

from .._backends.sync import SyncBackend
from .._backends.base import SOCKET_OPTION, NetworkBackend, NetworkStream
from .._exceptions import ConnectError, ConnectTimeout
Expand Down Expand Up @@ -105,6 +107,8 @@ def _connect(self, request: Request) -> NetworkStream:
sni_hostname = request.extensions.get("sni_hostname", None)
timeout = timeouts.get("connect", None)

overall_timeout = OverallTimeoutHandler(timeouts)

retries_left = self._retries
delays = exponential_backoff(factor=RETRIES_BACKOFF_FACTOR)

Expand All @@ -115,11 +119,12 @@ def _connect(self, request: Request) -> NetworkStream:
"host": self._origin.host.decode("ascii"),
"port": self._origin.port,
"local_address": self._local_address,
"timeout": timeout,
"timeout": overall_timeout.get_minimum_timeout(timeout),
"socket_options": self._socket_options,
}
with Trace("connect_tcp", logger, request, kwargs) as trace:
stream = self._network_backend.connect_tcp(**kwargs)
with overall_timeout:
stream = self._network_backend.connect_tcp(**kwargs)
trace.return_value = stream
else:
kwargs = {
Expand Down
10 changes: 8 additions & 2 deletions httpcore/_sync/connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from types import TracebackType
from typing import Iterable, Iterator, Iterable, List, Optional, Type

from httpcore._utils import OverallTimeoutHandler

from .._backends.sync import SyncBackend
from .._backends.base import SOCKET_OPTION, NetworkBackend
from .._exceptions import ConnectionNotAvailable, UnsupportedProtocol
Expand Down Expand Up @@ -174,6 +176,7 @@ def handle_request(self, request: Request) -> Response:

timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("pool", None)
overall_timeout = OverallTimeoutHandler(timeouts)

with self._optional_thread_lock:
# Add the incoming request to our request queue.
Expand All @@ -188,8 +191,11 @@ def handle_request(self, request: Request) -> Response:
closing = self._assign_requests_to_connections()
self._close_connections(closing)

# Wait until this request has an assigned connection.
connection = pool_request.wait_for_connection(timeout=timeout)
with overall_timeout:
# Wait until this request has an assigned connection.
connection = pool_request.wait_for_connection(
timeout=overall_timeout.get_minimum_timeout(timeout)
)

try:
# Send the request on the assigned connection.
Expand Down
Loading
Loading