Skip to content

reduce memory requirements for FSDP #9734

@samuele-bortolato-mag

Description

@samuele-bortolato-mag

🚀 Feature

The ability to perform fsdp in bf16 resolution, and not move the whole model to tpu before sharding.

Motivation

When trying to fit a big model inside a small slice (e.g. a single host as in google colab or kaggle) if the model doesn't fit in a single tpu's hbm there is the need to shard the model.

I tried using FSDP, but apparently it only allows float32 for the params, allowing bf16 only for the compute_dtype and buffer_dtype. This substantially increases the required amount of hbm needed to fit the model. In my case the model that on GPU required 40GB of VRAM didn't fit in 4x16GB=64GB of HBM.

I also tried using the new experimental FSDP_v2, which doesn't complain about the dtype, but in this case:

  • (it tries to first move the model to the tpu going out of memory if auto_wrap_policy is not passed, I'm not sure this is a good default behavior)
  • if I pass an auto_wrap_policy the wrapping succeedes, but it crashes at the first compilation with a SIGSEGV

I'm not sure if I made mistakes on my side, in that case I think more documentation on how to lower per device memory usage would be helpful

Thank you in advance!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions