-
Notifications
You must be signed in to change notification settings - Fork 564
Description
🚀 Feature
The ability to perform fsdp in bf16 resolution, and not move the whole model to tpu before sharding.
Motivation
When trying to fit a big model inside a small slice (e.g. a single host as in google colab or kaggle) if the model doesn't fit in a single tpu's hbm there is the need to shard the model.
I tried using FSDP, but apparently it only allows float32 for the params, allowing bf16 only for the compute_dtype and buffer_dtype. This substantially increases the required amount of hbm needed to fit the model. In my case the model that on GPU required 40GB of VRAM didn't fit in 4x16GB=64GB of HBM.
I also tried using the new experimental FSDP_v2, which doesn't complain about the dtype, but in this case:
- (it tries to first move the model to the tpu going out of memory if auto_wrap_policy is not passed, I'm not sure this is a good default behavior)
- if I pass an auto_wrap_policy the wrapping succeedes, but it crashes at the first compilation with a SIGSEGV
I'm not sure if I made mistakes on my side, in that case I think more documentation on how to lower per device memory usage would be helpful
Thank you in advance!