We're encountering an out-of-memory (OOM) error on a 2TB machine when attempting to load Qwen3-235B-A22B.
The issue seems to be that torchtune loads the full, unsharded checkpoint with self._checkpoint_client.load_base_checkpoint(). For a bf16 model of this size, the complete weights (~3.7TB) exceed the available memory. Loading the model in shards would likely resolve this.
Has anyone else encountered this?
Originally posted by @leng-yue in #2867 (comment)