Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 54 additions & 4 deletions src/lightning/pytorch/plugins/precision/amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# limitations under the License.
from collections.abc import Generator
from contextlib import contextmanager
from typing import Any, Callable, Literal, Optional, Union
from typing import Any, Callable, Literal, Optional, Union, cast

import torch
from torch import Tensor
Expand All @@ -27,6 +27,26 @@
from lightning.pytorch.utilities.exceptions import MisconfigurationException


class _AutocastClearCacheOnExit:
"""Proxy a grad-disabling context manager and clear the autocast cache when it exits."""

def __init__(self, context_manager: Any, *, clear_cache: bool) -> None:
self._context_manager = context_manager
self._clear_cache = clear_cache

def __enter__(self) -> Any:
return self._context_manager.__enter__()

def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> Any:
out = self._context_manager.__exit__(exc_type, exc, tb)
if self._clear_cache:
torch.clear_autocast_cache()
return out

def __call__(self, func: Callable[..., Any]) -> Callable[..., Any]:
return self._context_manager(func)


class MixedPrecision(Precision):
"""Plugin for Automatic Mixed Precision (AMP) training with ``torch.autocast``.

Expand Down Expand Up @@ -118,9 +138,39 @@ def autocast_context_manager(self) -> torch.autocast:
@override
@contextmanager
def forward_context(self) -> Generator[None, None, None]:
"""Enable autocast context."""
with self.autocast_context_manager():
yield
"""Enable autocast and clear cached casts after nested grad-disabling contexts exit."""
original_no_grad = torch.no_grad
original_inference_mode = torch.inference_mode

def _clear_cache_on_exit(
context_factory: Callable[..., Any], *, clear_cache: Callable[..., bool]
) -> Callable[..., Any]:
def wrapper(*args: Any, **kwargs: Any) -> _AutocastClearCacheOnExit:
return _AutocastClearCacheOnExit(
context_factory(*args, **kwargs),
clear_cache=clear_cache(*args, **kwargs),
)

return wrapper

try:
# Lightning wraps the whole step in a persistent autocast context. If a nested `no_grad` or
# `inference_mode` block creates cached casts there, later grad-enabled forwards in the same step can
# incorrectly reuse them. Clear the autocast cache when such nested contexts exit, while keeping the
# default cached path for normal training.
torch_module = cast(Any, torch)
torch_module.no_grad = _clear_cache_on_exit(original_no_grad, clear_cache=lambda *args, **kwargs: True)
torch_module.inference_mode = _clear_cache_on_exit(
original_inference_mode,
clear_cache=lambda *args, **kwargs: bool(args[0] if args else kwargs.get("mode", True)),
)
dtype = torch.bfloat16 if self.precision == "bf16-mixed" else torch.half
with torch.autocast(self.device, dtype=dtype):
yield
finally:
torch_module = cast(Any, torch)
torch_module.no_grad = original_no_grad
torch_module.inference_mode = original_inference_mode

@override
def state_dict(self) -> dict[str, Any]:
Expand Down
104 changes: 103 additions & 1 deletion tests/tests_pytorch/plugins/precision/test_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from lightning.pytorch.plugins import MixedPrecision
from lightning.pytorch.utilities import GradClipAlgorithmType
from tests_pytorch.helpers.runif import RunIf


def test_clip_gradients():
Expand Down Expand Up @@ -62,10 +63,111 @@ def test_amp_with_no_grad():
x = torch.randn(1, 2)
amp = MixedPrecision(precision="bf16-mixed", device="cpu")

with amp.autocast_context_manager():
with amp.forward_context():
with torch.no_grad():
_ = layer(x)

loss = layer(x).mean()
loss.backward()
assert loss.grad_fn is not None


def test_amp_with_inference_mode():
"""Test that nested `inference_mode` also clears the autocast cache on exit."""
layer = nn.Linear(2, 1)
x = torch.randn(1, 2)
amp = MixedPrecision(precision="bf16-mixed", device="cpu")

with amp.forward_context():
with torch.inference_mode():
_ = layer(x)

loss = layer(x).mean()
loss.backward()
assert loss.grad_fn is not None


def test_amp_forward_context_restores_grad_mode_context_managers():
amp = MixedPrecision(precision="bf16-mixed", device="cpu")
original_no_grad = torch.no_grad
original_inference_mode = torch.inference_mode

with amp.forward_context():
assert torch.no_grad is not original_no_grad
assert torch.inference_mode is not original_inference_mode

assert torch.no_grad is original_no_grad
assert torch.inference_mode is original_inference_mode


@pytest.mark.parametrize(("cache_enabled", "expect_grad"), [(True, False), (False, True)])
def test_torch_autocast_cache_behavior_with_no_grad(cache_enabled, expect_grad):
"""Document the underlying PyTorch autocast behavior that this plugin needs to handle."""
layer = nn.Linear(2, 1)
x = torch.randn(1, 2)

with torch.autocast("cpu", dtype=torch.bfloat16, cache_enabled=cache_enabled):
with torch.no_grad():
_ = layer(x)

loss = layer(x).mean()
if expect_grad:
loss.backward()
assert loss.grad_fn is not None
else:
assert loss.grad_fn is None
with pytest.raises(RuntimeError, match="does not require grad"):
loss.backward()


@RunIf(min_cuda_gpus=1)
@pytest.mark.parametrize(("cache_enabled", "expect_grad"), [(True, False), (False, True)])
def test_torch_autocast_cache_behavior_with_no_grad_cuda(cache_enabled, expect_grad):
"""Document the same autocast cache behavior on CUDA, where the reported regression happens."""
layer = nn.Linear(2, 1, device="cuda")
x = torch.randn(1, 2, device="cuda")

with torch.autocast("cuda", dtype=torch.float16, cache_enabled=cache_enabled):
with torch.no_grad():
_ = layer(x)

loss = layer(x).mean()
if expect_grad:
loss.backward()
assert loss.grad_fn is not None
else:
assert loss.grad_fn is None
with pytest.raises(RuntimeError, match="does not require grad"):
loss.backward()


@RunIf(min_cuda_gpus=1)
def test_amp_with_no_grad_cuda():
"""Test the Lightning workaround on the CUDA path used by the reported regression."""
layer = nn.Linear(2, 1, device="cuda")
x = torch.randn(1, 2, device="cuda")
amp = MixedPrecision(precision="16-mixed", device="cuda")

with amp.forward_context():
with torch.no_grad():
_ = layer(x)

loss = layer(x).mean()
loss.backward()
assert loss.grad_fn is not None


def test_amp_autocast_context_manager_disables_cache():
"""Test that the public autocast context manager preserves the existing no-cache workaround."""
amp = MixedPrecision(precision="bf16-mixed", device="cpu")

with amp.autocast_context_manager():
assert not torch.is_autocast_cache_enabled()


def test_amp_forward_context_keeps_cache_enabled():
"""Test that Lightning's internal step context keeps the cached autocast path enabled."""
amp = MixedPrecision(precision="bf16-mixed", device="cpu")

with amp.forward_context():
assert torch.is_autocast_cache_enabled()
Loading