You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
hi have there been any tests with fa-3 and low bit optimizers from torchao like FP8adam for 8bit adam? i see divergence in training when resuming a FA-2 checkpoint with FA-3 or when using 8BITADAMW
The text was updated successfully, but these errors were encountered:
baseline is FA-2 checkpoint with adamw
2.switching to FA-3 directly for inference (single gpu and multi-gpu TP based) in the model trained on FA-2 leads to broken results. however finetuning from scratch with FA-3 seems to work and give around 30% speedup depending on parallel config
with adam 8-bit the loss seems to diverge after some iterations, tried with various block_ sizes and am using the torchao implementation. any suggestions to help solve it? can it be a error due to TP/DP config??
The text was updated successfully, but these errors were encountered: