Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 194 additions & 0 deletions example/5.mfsdp_load_and_export_multiple_gpus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
# Example to load/export weights with Megatron FSDP (data parallel sharding).
# Run: torchrun --nproc_per_node=8 5.mfsdp_load_and_export_multiple_gpus.py --model_path /path/to/model

import argparse
import os

import torch
from megatron.core import parallel_state as mpu
from megatron.core.distributed import DistributedDataParallelConfig
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
from transformers import AutoTokenizer

from mbridge import AutoBridge
from mbridge.utils.post_creation_callbacks import freeze_moe_router, make_value_model


def init_distributed(tp=1, pp=1, cp=1, vpp=1, ep=1, etp=None):
"""Initialize distributed environment"""
torch.distributed.init_process_group("nccl")
torch.cuda.set_device(torch.distributed.get_rank())
if pp <= 1:
vpp = None
mpu.initialize_model_parallel(
tensor_model_parallel_size=tp,
pipeline_model_parallel_size=pp,
virtual_pipeline_model_parallel_size=vpp,
context_parallel_size=cp,
expert_model_parallel_size=ep,
expert_tensor_parallel_size=etp,
)
model_parallel_cuda_manual_seed(0)


def generate_sequence(
prompt, model, hf_model_path, max_new_tokens=100, trust_remote_code=False
):
try:
assert mpu.get_tensor_model_parallel_world_size() == 1
assert mpu.get_pipeline_model_parallel_world_size() == 1
assert mpu.get_context_parallel_world_size() == 1
except Exception as e:
print(e)
print("only EP is supported in example generate, skip")
return
"""Generate text sequence"""
tokenizer = AutoTokenizer.from_pretrained(
hf_model_path, trust_remote_code=trust_remote_code
)

input_ids = tokenizer.encode(prompt, return_tensors="pt")
input_ids = input_ids.cuda()
position_ids = torch.arange(input_ids.shape[1], device=input_ids.device).unsqueeze(
0
)
attention_mask = torch.ones_like(input_ids).to(input_ids.device)

generated_tokens = []
cur_input_ids = input_ids
cur_position_ids = position_ids
cur_attention_mask = attention_mask
from tqdm import trange

for _ in trange(max_new_tokens):
# Move inputs to GPU
cur_input_ids = cur_input_ids.cuda()
cur_position_ids = cur_position_ids.cuda()
cur_attention_mask = cur_attention_mask.cuda()

# Forward inference with the model
with torch.no_grad():
model[0].cuda()
output = model[0].module(
cur_input_ids, cur_position_ids, cur_attention_mask
)

# Get the next token
next_token = output.argmax(dim=-1)[:, -1]
generated_tokens.append(next_token.item())

# Stop if EOS token is generated
if next_token.item() == tokenizer.eos_token_id:
break

# Update input sequence
cur_input_ids = torch.cat([cur_input_ids, next_token.unsqueeze(0)], dim=1)
cur_position_ids = torch.arange(
cur_input_ids.shape[1], device=cur_input_ids.device
).unsqueeze(0)
cur_attention_mask = torch.ones_like(cur_input_ids)

# Decode the generated token sequence
generated_text = tokenizer.decode(generated_tokens)
if torch.distributed.get_rank() == 0:
print(f"Generated text:\n{generated_text}")

return generated_text


def main():
# Parse command line arguments
parser = argparse.ArgumentParser(description="Load model and generate text")
parser.add_argument(
"--model_path", type=str, required=True, help="HuggingFace model path"
)
parser.add_argument("--tp", type=int, default=1, help="Tensor model parallel size")
parser.add_argument(
"--pp", type=int, default=1, help="Pipeline model parallel size"
)
parser.add_argument("--cp", type=int, default=1, help="Context parallel size")
parser.add_argument(
"--vpp", type=int, default=1, help="Virtual pipeline model parallel size"
)
parser.add_argument("--ep", type=int, default=1, help="Expert model parallel size")
parser.add_argument(
"--etp", type=int, default=None, help="Expert tensor parallel size"
)
parser.add_argument(
"--save_path", type=str, default=None, help="Path to save weights"
)
parser.add_argument(
"--max_tokens",
type=int,
default=10,
help="Maximum number of tokens to generate",
)
parser.add_argument(
"--trust_remote_code", action="store_true", help="Trust remote code"
)
args = parser.parse_args()

# Initialize distributed environment
init_distributed(
tp=args.tp,
pp=args.pp,
cp=args.cp,
vpp=args.vpp,
ep=args.ep,
etp=args.etp,
)

# Load model
hf_model_path = args.model_path
print(f"rank{torch.distributed.get_rank()}: start loading model")
bridge = AutoBridge.from_pretrained(hf_model_path)
ddp_config = {
"use_distributed_optimizer": True,
"check_for_nan_in_grad": True,
"use_megatron_fsdp": True,
"data_parallel_sharding_strategy": "optim_grads_params",
}
model = bridge.get_model(wrap_with_ddp=True, use_megatron_fsdp=True, ddp_config=ddp_config,data_parallel_random_init=False, post_model_creation_callbacks=[])
print(
f"rank{torch.distributed.get_rank()}: start loading weights from {hf_model_path}"
)
bridge.load_weights(model, hf_model_path, memory_efficient=True)

prompt = "A bubble sort in python is "
generate_sequence(
prompt, model, args.model_path, args.max_tokens, args.trust_remote_code
)

# export weights
keys = bridge.safetensor_io.load_hf_weight_names()
loaded_keys = set()
not_matched_keys = set()
for k, v in bridge.export_weights(model):
if torch.distributed.get_rank() != 0:
continue
gt = bridge.safetensor_io.load_one_hf_weight(k).to(device=v.device, dtype=v.dtype)
if k != "lm_head.weight":
assert v.shape == gt.shape, f"mismatch of {k} {v.shape=} {gt.shape=}"
if not torch.allclose(v.sum(), gt.sum(), atol=1e-5):
not_matched_keys.add(k)
else:
if v.shape[0] == 1:
print(f"this is a value model, {k} {v.shape=} {gt.shape=}")
loaded_keys.add(k)
print(k, "export ok")
if args.save_path:
bridge.save_weights(model, args.save_path, memory_efficient=False)

missing_keys = set(keys) - loaded_keys
missing_keys = sorted(list(missing_keys))
if torch.distributed.get_rank() == 0:
print(f"missing keys: {missing_keys}")
print(f"not_matched_keys: {not_matched_keys}")

# wait for save finish
torch.distributed.barrier()
torch.distributed.destroy_process_group()


if __name__ == "__main__":
main()
30 changes: 26 additions & 4 deletions mbridge/core/bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,20 @@

import torch
from megatron.core import parallel_state as mpu
from megatron.core.distributed.fsdp.mcore_fsdp_adapter import FullyShardedDataParallel
from megatron.core.models.gpt.gpt_model import ModelType
from transformers import AutoConfig
from transformers.utils.hub import cached_file
from safetensors import safe_open

from torch.distributed._tensor import DTensor
from .parallel_states import ParallelStates
from .safetensor_io import SafeTensorIO
from .util import (
broadcast_from_megatron_pp,
broadcast_str_from_megatron_pp,
get_model,
unwrap_model,
get_module_and_param_from_name,
)


Expand Down Expand Up @@ -73,6 +75,7 @@ def get_model(
fp16: bool = False,
bf16: bool = True,
encoder_pipeline_model_parallel_size: int = 0,
use_megatron_fsdp: bool = False,
use_torch_fsdp2: bool = False,
use_custom_fsdp: bool = False,
use_precision_aware_optimizer: bool = False,
Expand Down Expand Up @@ -131,6 +134,7 @@ def get_model(
bf16=bf16,
virtual_pipeline_model_parallel_size=self.mpu.vpp_size,
encoder_pipeline_model_parallel_size=encoder_pipeline_model_parallel_size,
use_megatron_fsdp=use_megatron_fsdp,
use_torch_fsdp2=use_torch_fsdp2,
use_custom_fsdp=use_custom_fsdp,
use_precision_aware_optimizer=use_precision_aware_optimizer,
Expand Down Expand Up @@ -198,8 +202,10 @@ def load_weights(
)

# import mcore weights
use_megatron_fsdp = isinstance(model, FullyShardedDataParallel)
unwrapped_model = unwrap_model(model)
for local_name, hf_names in local_to_hf_map.items():
param = model.state_dict()[local_name]
param = unwrapped_model.state_dict()[local_name]
# hf format to mcore format
if set(to_load_from_disk) & set(hf_names):
if not memory_efficient:
Expand All @@ -218,7 +224,7 @@ def load_weights(
# skip lm_head.weight when the model is a value model
continue

param_to_load = torch.empty_like(param)
param_to_load = torch.empty(param.shape, device=param.device, dtype=param.dtype)
if ".mlp.experts.linear_fc" in local_name:
# split mcore weights across etp
if self.mpu.etp_rank == 0:
Expand Down Expand Up @@ -258,7 +264,14 @@ def load_weights(
group=self.mpu.tp_group,
)
# load
if isinstance(param, DTensor):
_, local_weights = get_module_and_param_from_name(unwrapped_model, local_name)
sliced_converted_weights = param_to_load.reshape(-1)[local_weights.megatron_fsdp_slice]
param._local_tensor.reshape(-1).copy_(sliced_converted_weights)
continue
param.copy_(param_to_load)
if use_megatron_fsdp:
model.module.install_optimized_model_weights()

def _save_weights_fast(
self,
Expand Down Expand Up @@ -527,7 +540,16 @@ def get_model_chunk_generator():
name, param = None, None

name = broadcast_str_from_megatron_pp(name)
broad_pp_param = broadcast_from_megatron_pp(param)
broad_pp_param = None
if isinstance(param, DTensor):
from megatron.core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor import (
gather_uneven_dtensor_to_full_tensor,
)
_, local_weights = get_module_and_param_from_name(models, iter_name, iter_vpp_rank)
full_tensor = gather_uneven_dtensor_to_full_tensor(local_weights)
broad_pp_param = full_tensor.to_local()
else:
broad_pp_param = broadcast_from_megatron_pp(param)

# EP
if ".mlp.experts.linear_fc" in name and self.mpu.ep_size >= 1:
Expand Down
Loading