Skip to content

cache_enabled=False in autocast causes OOM regression for iterative decoding workloads #21611

@docbeaker

Description

@docbeaker

Bug description

Lightning 2.6.0 introduced cache_enabled=False in MixedPrecision.autocast_context_manager (compared to 2.5.x which used the default
cache_enabled=True):

2.5.x

  def autocast_context_manager(self) -> torch.autocast:
      return torch.autocast(self.device, dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.half))           

2.6.x

  def autocast_context_manager(self) -> torch.autocast:                                                                                            
      dtype = torch.bfloat16 if self.precision == "bf16-mixed" else torch.half
      return torch.autocast(self.device, dtype=dtype, cache_enabled=False)

This causes severe memory regression (OOM) for training workloads that call the same model parameters repeatedly within a single training_step,
such as autoregressive/streaming decoding loops used in reinforcement learning.

What version are you seeing the problem on?

master

Error messages and logs

Uhhh

[rank0]: torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 158.00 MiB. GPU 0 has a total capacity of 79.11 GiB of which 38.88 MiB is free. Including non-PyTorch memory, this process has 79.06 GiB memory in use. Of the allocated memory 77.10 GiB is allocated by PyTorch, and 761.39 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://docs.pytorch.org/docs/stable/notes/cuda.html#optimizing-memory-usage-with-pytorch-cuda-alloc-conf)

Environment

Environment:

  • Lightning 2.6.1
  • PyTorch 2.11
  • DDP with find_unused_parameters=True
  • precision="bf16-mixed"

More info

Why this happens: The autocast cache reuses bf16-casted weight tensors when the same parameter is used in multiple operations within the autocast
region. With cache_enabled=False, every forward call through the decoder creates fresh bf16 copies of all weight tensors. In an iterative
decoding loop (hundreds of forward calls per training step), this means hundreds of redundant bf16 weight copies are created. Since these are
part of the autograd graph, they cannot be freed until backward, and memory grows linearly with the number of decode steps.

Reproduction: Any model that calls forward on the same submodule multiple times within training_step under precision="bf16-mixed" or
precision="16-mixed". For example:

  def training_step(self, batch):
      for step in range(num_steps):  # e.g. autoregressive decoding
          logits = self.decoder(...)  # same weights, recast every call                                                                            
          log_prob = F.log_softmax(logits, dim=-1)                                                                                                 
          ...                                                                                                                                      
      loss = ...                                                                                                                                   
      return loss 

This works in Lightning 2.5.x but OOMs in 2.6.x with no other changes.

Workaround: Manually wrapping the loop with torch.autocast("cuda", dtype=torch.bfloat16, cache_enabled=True) restores the 2.5 behavior and
resolves the OOM.

cc @ethanwharris @justusschock @lantiga

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions