|
| 1 | +import json |
| 2 | +import os |
| 3 | + |
| 4 | +import torch |
| 5 | +from torch.distributed.fsdp import FullStateDictConfig |
| 6 | +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
| 7 | +from torch.distributed.fsdp import StateDictType |
| 8 | + |
| 9 | +from fastvideo.v1.logger import init_logger |
| 10 | + |
| 11 | +logger = init_logger(__name__) |
| 12 | + |
| 13 | + |
| 14 | +def save_checkpoint(transformer, rank, output_dir, step): |
| 15 | + # Configure FSDP to save full state dict |
| 16 | + FSDP.set_state_dict_type( |
| 17 | + transformer, |
| 18 | + state_dict_type=StateDictType.FULL_STATE_DICT, |
| 19 | + state_dict_config=FullStateDictConfig(offload_to_cpu=True, |
| 20 | + rank0_only=True), |
| 21 | + ) |
| 22 | + |
| 23 | + # Now get the state dict |
| 24 | + cpu_state = transformer.state_dict() |
| 25 | + |
| 26 | + # Save it (only on rank 0 since we used rank0_only=True) |
| 27 | + if rank <= 0: |
| 28 | + save_dir = os.path.join(output_dir, f"checkpoint-{step}") |
| 29 | + os.makedirs(save_dir, exist_ok=True) |
| 30 | + weight_path = os.path.join(save_dir, "diffusion_pytorch_model.pt") |
| 31 | + torch.save(cpu_state, weight_path) |
| 32 | + config_dict = transformer.hf_config |
| 33 | + if "dtype" in config_dict: |
| 34 | + del config_dict["dtype"] # TODO |
| 35 | + config_path = os.path.join(save_dir, "config.json") |
| 36 | + # save dict as json |
| 37 | + with open(config_path, "w") as f: |
| 38 | + json.dump(config_dict, f, indent=4) |
| 39 | + logger.info("--> checkpoint saved at step {step} to {weight_path}", |
| 40 | + step=step, |
| 41 | + weight_path=weight_path) |
0 commit comments