From 88e2ec14ee4afa5a7282ffa1293e925a94f91703 Mon Sep 17 00:00:00 2001 From: conver334 Date: Wed, 11 Feb 2026 21:00:46 -0800 Subject: [PATCH] support mfsdp Signed-off-by: conver334 --- .../5.mfsdp_load_and_export_multiple_gpus.py | 194 ++++++++++++++++++ mbridge/core/bridge.py | 30 ++- mbridge/core/util.py | 127 +++++++++++- 3 files changed, 339 insertions(+), 12 deletions(-) create mode 100644 example/5.mfsdp_load_and_export_multiple_gpus.py diff --git a/example/5.mfsdp_load_and_export_multiple_gpus.py b/example/5.mfsdp_load_and_export_multiple_gpus.py new file mode 100644 index 0000000..15b2599 --- /dev/null +++ b/example/5.mfsdp_load_and_export_multiple_gpus.py @@ -0,0 +1,194 @@ +# Example to load/export weights with Megatron FSDP (data parallel sharding). +# Run: torchrun --nproc_per_node=8 5.mfsdp_load_and_export_multiple_gpus.py --model_path /path/to/model + +import argparse +import os + +import torch +from megatron.core import parallel_state as mpu +from megatron.core.distributed import DistributedDataParallelConfig +from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed +from transformers import AutoTokenizer + +from mbridge import AutoBridge +from mbridge.utils.post_creation_callbacks import freeze_moe_router, make_value_model + + +def init_distributed(tp=1, pp=1, cp=1, vpp=1, ep=1, etp=None): + """Initialize distributed environment""" + torch.distributed.init_process_group("nccl") + torch.cuda.set_device(torch.distributed.get_rank()) + if pp <= 1: + vpp = None + mpu.initialize_model_parallel( + tensor_model_parallel_size=tp, + pipeline_model_parallel_size=pp, + virtual_pipeline_model_parallel_size=vpp, + context_parallel_size=cp, + expert_model_parallel_size=ep, + expert_tensor_parallel_size=etp, + ) + model_parallel_cuda_manual_seed(0) + + +def generate_sequence( + prompt, model, hf_model_path, max_new_tokens=100, trust_remote_code=False +): + try: + assert mpu.get_tensor_model_parallel_world_size() == 1 + assert mpu.get_pipeline_model_parallel_world_size() == 1 + assert mpu.get_context_parallel_world_size() == 1 + except Exception as e: + print(e) + print("only EP is supported in example generate, skip") + return + """Generate text sequence""" + tokenizer = AutoTokenizer.from_pretrained( + hf_model_path, trust_remote_code=trust_remote_code + ) + + input_ids = tokenizer.encode(prompt, return_tensors="pt") + input_ids = input_ids.cuda() + position_ids = torch.arange(input_ids.shape[1], device=input_ids.device).unsqueeze( + 0 + ) + attention_mask = torch.ones_like(input_ids).to(input_ids.device) + + generated_tokens = [] + cur_input_ids = input_ids + cur_position_ids = position_ids + cur_attention_mask = attention_mask + from tqdm import trange + + for _ in trange(max_new_tokens): + # Move inputs to GPU + cur_input_ids = cur_input_ids.cuda() + cur_position_ids = cur_position_ids.cuda() + cur_attention_mask = cur_attention_mask.cuda() + + # Forward inference with the model + with torch.no_grad(): + model[0].cuda() + output = model[0].module( + cur_input_ids, cur_position_ids, cur_attention_mask + ) + + # Get the next token + next_token = output.argmax(dim=-1)[:, -1] + generated_tokens.append(next_token.item()) + + # Stop if EOS token is generated + if next_token.item() == tokenizer.eos_token_id: + break + + # Update input sequence + cur_input_ids = torch.cat([cur_input_ids, next_token.unsqueeze(0)], dim=1) + cur_position_ids = torch.arange( + cur_input_ids.shape[1], device=cur_input_ids.device + ).unsqueeze(0) + cur_attention_mask = torch.ones_like(cur_input_ids) + + # Decode the generated token sequence + generated_text = tokenizer.decode(generated_tokens) + if torch.distributed.get_rank() == 0: + print(f"Generated text:\n{generated_text}") + + return generated_text + + +def main(): + # Parse command line arguments + parser = argparse.ArgumentParser(description="Load model and generate text") + parser.add_argument( + "--model_path", type=str, required=True, help="HuggingFace model path" + ) + parser.add_argument("--tp", type=int, default=1, help="Tensor model parallel size") + parser.add_argument( + "--pp", type=int, default=1, help="Pipeline model parallel size" + ) + parser.add_argument("--cp", type=int, default=1, help="Context parallel size") + parser.add_argument( + "--vpp", type=int, default=1, help="Virtual pipeline model parallel size" + ) + parser.add_argument("--ep", type=int, default=1, help="Expert model parallel size") + parser.add_argument( + "--etp", type=int, default=None, help="Expert tensor parallel size" + ) + parser.add_argument( + "--save_path", type=str, default=None, help="Path to save weights" + ) + parser.add_argument( + "--max_tokens", + type=int, + default=10, + help="Maximum number of tokens to generate", + ) + parser.add_argument( + "--trust_remote_code", action="store_true", help="Trust remote code" + ) + args = parser.parse_args() + + # Initialize distributed environment + init_distributed( + tp=args.tp, + pp=args.pp, + cp=args.cp, + vpp=args.vpp, + ep=args.ep, + etp=args.etp, + ) + + # Load model + hf_model_path = args.model_path + print(f"rank{torch.distributed.get_rank()}: start loading model") + bridge = AutoBridge.from_pretrained(hf_model_path) + ddp_config = { + "use_distributed_optimizer": True, + "check_for_nan_in_grad": True, + "use_megatron_fsdp": True, + "data_parallel_sharding_strategy": "optim_grads_params", + } + model = bridge.get_model(wrap_with_ddp=True, use_megatron_fsdp=True, ddp_config=ddp_config,data_parallel_random_init=False, post_model_creation_callbacks=[]) + print( + f"rank{torch.distributed.get_rank()}: start loading weights from {hf_model_path}" + ) + bridge.load_weights(model, hf_model_path, memory_efficient=True) + + prompt = "A bubble sort in python is " + generate_sequence( + prompt, model, args.model_path, args.max_tokens, args.trust_remote_code + ) + + # export weights + keys = bridge.safetensor_io.load_hf_weight_names() + loaded_keys = set() + not_matched_keys = set() + for k, v in bridge.export_weights(model): + if torch.distributed.get_rank() != 0: + continue + gt = bridge.safetensor_io.load_one_hf_weight(k).to(device=v.device, dtype=v.dtype) + if k != "lm_head.weight": + assert v.shape == gt.shape, f"mismatch of {k} {v.shape=} {gt.shape=}" + if not torch.allclose(v.sum(), gt.sum(), atol=1e-5): + not_matched_keys.add(k) + else: + if v.shape[0] == 1: + print(f"this is a value model, {k} {v.shape=} {gt.shape=}") + loaded_keys.add(k) + print(k, "export ok") + if args.save_path: + bridge.save_weights(model, args.save_path, memory_efficient=False) + + missing_keys = set(keys) - loaded_keys + missing_keys = sorted(list(missing_keys)) + if torch.distributed.get_rank() == 0: + print(f"missing keys: {missing_keys}") + print(f"not_matched_keys: {not_matched_keys}") + + # wait for save finish + torch.distributed.barrier() + torch.distributed.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/mbridge/core/bridge.py b/mbridge/core/bridge.py index 7a18420..45ff2df 100644 --- a/mbridge/core/bridge.py +++ b/mbridge/core/bridge.py @@ -7,11 +7,12 @@ import torch from megatron.core import parallel_state as mpu +from megatron.core.distributed.fsdp.mcore_fsdp_adapter import FullyShardedDataParallel from megatron.core.models.gpt.gpt_model import ModelType from transformers import AutoConfig from transformers.utils.hub import cached_file from safetensors import safe_open - +from torch.distributed._tensor import DTensor from .parallel_states import ParallelStates from .safetensor_io import SafeTensorIO from .util import ( @@ -19,6 +20,7 @@ broadcast_str_from_megatron_pp, get_model, unwrap_model, + get_module_and_param_from_name, ) @@ -73,6 +75,7 @@ def get_model( fp16: bool = False, bf16: bool = True, encoder_pipeline_model_parallel_size: int = 0, + use_megatron_fsdp: bool = False, use_torch_fsdp2: bool = False, use_custom_fsdp: bool = False, use_precision_aware_optimizer: bool = False, @@ -131,6 +134,7 @@ def get_model( bf16=bf16, virtual_pipeline_model_parallel_size=self.mpu.vpp_size, encoder_pipeline_model_parallel_size=encoder_pipeline_model_parallel_size, + use_megatron_fsdp=use_megatron_fsdp, use_torch_fsdp2=use_torch_fsdp2, use_custom_fsdp=use_custom_fsdp, use_precision_aware_optimizer=use_precision_aware_optimizer, @@ -198,8 +202,10 @@ def load_weights( ) # import mcore weights + use_megatron_fsdp = isinstance(model, FullyShardedDataParallel) + unwrapped_model = unwrap_model(model) for local_name, hf_names in local_to_hf_map.items(): - param = model.state_dict()[local_name] + param = unwrapped_model.state_dict()[local_name] # hf format to mcore format if set(to_load_from_disk) & set(hf_names): if not memory_efficient: @@ -218,7 +224,7 @@ def load_weights( # skip lm_head.weight when the model is a value model continue - param_to_load = torch.empty_like(param) + param_to_load = torch.empty(param.shape, device=param.device, dtype=param.dtype) if ".mlp.experts.linear_fc" in local_name: # split mcore weights across etp if self.mpu.etp_rank == 0: @@ -258,7 +264,14 @@ def load_weights( group=self.mpu.tp_group, ) # load + if isinstance(param, DTensor): + _, local_weights = get_module_and_param_from_name(unwrapped_model, local_name) + sliced_converted_weights = param_to_load.reshape(-1)[local_weights.megatron_fsdp_slice] + param._local_tensor.reshape(-1).copy_(sliced_converted_weights) + continue param.copy_(param_to_load) + if use_megatron_fsdp: + model.module.install_optimized_model_weights() def _save_weights_fast( self, @@ -527,7 +540,16 @@ def get_model_chunk_generator(): name, param = None, None name = broadcast_str_from_megatron_pp(name) - broad_pp_param = broadcast_from_megatron_pp(param) + broad_pp_param = None + if isinstance(param, DTensor): + from megatron.core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor import ( + gather_uneven_dtensor_to_full_tensor, + ) + _, local_weights = get_module_and_param_from_name(models, iter_name, iter_vpp_rank) + full_tensor = gather_uneven_dtensor_to_full_tensor(local_weights) + broad_pp_param = full_tensor.to_local() + else: + broad_pp_param = broadcast_from_megatron_pp(param) # EP if ".mlp.experts.linear_fc" in name and self.mpu.ep_size >= 1: diff --git a/mbridge/core/util.py b/mbridge/core/util.py index 6f6fa61..188f9d2 100644 --- a/mbridge/core/util.py +++ b/mbridge/core/util.py @@ -7,13 +7,14 @@ from functools import lru_cache import torch +from typing import List, Optional, Tuple from megatron.core import mpu from megatron.core import parallel_state as mpu from megatron.core import tensor_parallel from megatron.core.fp8_utils import correct_amax_history_if_needed from megatron.core.models.gpt.gpt_model import ModelType from megatron.core.packed_seq_params import PackedSeqParams -from megatron.core.transformer.module import Float16Module +from megatron.core.transformer.module import Float16Module, MegatronModule from megatron.core.utils import ( StragglerDetector, check_param_hashes_across_dp_replicas, @@ -30,6 +31,7 @@ def get_model( bf16: bool = True, virtual_pipeline_model_parallel_size: int = None, encoder_pipeline_model_parallel_size: int = 0, + use_megatron_fsdp: bool = False, use_torch_fsdp2: bool = False, use_custom_fsdp: bool = False, use_precision_aware_optimizer: bool = False, @@ -164,9 +166,12 @@ def build_model(): correct_amax_history_if_needed(model) if wrap_with_ddp: - from megatron.core.distributed import DistributedDataParallelConfig - - if use_torch_fsdp2: + from megatron.core.distributed import DistributedDataParallelConfig, FullyShardedDataParallel + if use_megatron_fsdp: + DP = FullyShardedDataParallel + if use_torch_fsdp2: + raise ValueError("Using use_megatron_fsdp and use_torch_fsdp2 at the same time is not supported.") + elif use_torch_fsdp2: try: from megatron.core.distributed import ( TorchFullyShardedDataParallel as torch_FSDP, @@ -193,7 +198,7 @@ def build_model(): # default kwargs = {"grad_reduce_in_fp32": True, "use_distributed_optimizer": True} if ddp_config is not None: - kwargs.update(ddp_config) + kwargs.update(ddp_config) # 这里可能报错,因为这里ddep_config 是个字典 if optimizer_config is not None: import warnings @@ -205,9 +210,9 @@ def build_model(): if use_custom_fsdp and use_precision_aware_optimizer: kwargs["preserve_fp32_weights"] = False - ddp_config = DistributedDataParallelConfig(**kwargs) + ddp_config = DistributedDataParallelConfig(**kwargs) # 它在这里才实例化 - if not use_torch_fsdp2: + if not use_torch_fsdp2 and not use_megatron_fsdp: # In the custom FSDP and DDP use path, we need to initialize the bucket size. # If bucket_size is not provided as an input, use sane default. @@ -273,7 +278,11 @@ def build_model(): ALL_MODULE_WRAPPER_CLASSNAMES = (DDP, torch_FSDP, custom_FSDP, Float16Module) except ImportError: ALL_MODULE_WRAPPER_CLASSNAMES = (DDP, custom_FSDP, Float16Module) - +try: + from megatron.core.distributed.fsdp.src.megatron_fsdp.megatron_fsdp import MegatronFSDP + ALL_MODULE_WRAPPER_CLASSNAMES = ALL_MODULE_WRAPPER_CLASSNAMES + (MegatronFSDP,) +except ImportError: + pass def unwrap_model(model, module_instances=ALL_MODULE_WRAPPER_CLASSNAMES): return_list = True @@ -782,3 +791,105 @@ def postprocess_packed_seqs( output_new[i, attention_mask[i]] = tmp[:s_len] return output_new + + +def get_module_and_param_from_name( + models: MegatronModule | List[MegatronModule], + param_name: str, + vp_stage: Optional[int] = None, +) -> Tuple[torch.nn.Module, torch.Tensor] | Tuple[torch.nn.Module, torch.Tensor, Tuple]: + """ + Get parameter from specific VP stage, ensuring that parameter + attributes are preserved. Supports both absolute and relative parameter names. + + Args: + models: List of Megatron model instances or a submodule + param_name: Dot-separated parameter name (can be absolute or relative to models) + vp_stage: Virtual pipeline stage index (None for single stage) + + Returns: + Tuple of (module, parameter) where module owns the parameter + + Raises: + ValueError: If vp_stage is out of range or parameter doesn't exist + + Examples: + Basic usage with full model: + >>> module, param = get_module_and_param_from_name( + ... models=full_model, + ... param_name="transformer.layers.0.attention.query.weight" + ... ) + + Usage with model list and VP stage: + >>> module, param = get_module_and_param_from_name( + ... models=[model1, model2, model3], + ... param_name="layers.0.mlp.dense.bias", + ... vp_stage=1 + ... ) + + Usage with submodule and relative path: + >>> linear_module = model.transformer.layers[0].mlp.dense + >>> module, param = get_module_and_param_from_name( + ... models=linear_module, + ... param_name="weight" + ... ) + + Usage with submodule and absolute path (automatic suffix matching): + >>> linear_module = model.transformer.layers[0].mlp.dense + >>> module, param = get_module_and_param_from_name( + ... models=linear_module, + ... param_name="transformer.layers.0.mlp.dense.weight" + ... ) + # Automatically matches "weight" suffix and returns the parameter + + Edge case with partial path matching: + >>> attention_module = model.transformer.layers[0].attention + >>> module, param = get_module_and_param_from_name( + ... models=attention_module, + ... param_name="layers.0.attention.query.weight" + ... ) + # Matches "query.weight" suffix within the attention module + """ + + if isinstance(models, list): + if vp_stage is None: + model = models[0] + else: + if vp_stage >= len(models): + raise ValueError(f"VP stage {vp_stage} out of range (max: {len(models) - 1})") + model = models[vp_stage] + else: + model = models + + module = unwrap_model(model) + splitted_name = param_name.split(".") + + # Try to find the parameter using the given parts + def try_get_param(parts): + param = module + temp_module = module + + for i, part in enumerate(parts): + if not hasattr(param, part): + return None + param = getattr(param, part) + if i < len(parts) - 1: + temp_module = getattr(temp_module, part) + + return temp_module, param + + # First try the full parameter name (current behavior) + result = try_get_param(splitted_name) + if result is not None: + return result + + # If full name doesn't work, try suffixes of the parameter name + # This handles cases where models is a submodule but param_name is absolute + for start_idx in range(1, len(splitted_name)): + suffix_parts = splitted_name[start_idx:] + result = try_get_param(suffix_parts) + if result is not None: + return result + + # If no approach works, raise an error + raise ValueError(f"Parameter '{param_name}' not found in model at VP stage {vp_stage}")