diff --git a/slime/backends/megatron_utils/actor.py b/slime/backends/megatron_utils/actor.py index ad43608246..2c61febda6 100644 --- a/slime/backends/megatron_utils/actor.py +++ b/slime/backends/megatron_utils/actor.py @@ -199,12 +199,13 @@ def _get_rollout_data(self, rollout_data_ref: Box) -> RolloutBatch: ] if "multimodal_train_inputs" in rollout_data: # Move multimodal training tensors to GPU in advance + def _to_device(tensor): + if torch.is_tensor(tensor): + return tensor.to(device=torch.cuda.current_device()) + return torch.as_tensor(tensor, device=torch.cuda.current_device()) + rollout_data["multimodal_train_inputs"] = [ - ( - {key: tensor.to(device=torch.cuda.current_device()) for key, tensor in mm_dict.items()} - if mm_dict is not None - else None - ) + ({key: _to_device(tensor) for key, tensor in mm_dict.items()} if mm_dict is not None else None) for mm_dict in rollout_data["multimodal_train_inputs"] ] diff --git a/slime/backends/megatron_utils/data.py b/slime/backends/megatron_utils/data.py index 64ccbef499..38c463cc8b 100644 --- a/slime/backends/megatron_utils/data.py +++ b/slime/backends/megatron_utils/data.py @@ -27,6 +27,7 @@ def get_batch( keys: Sequence[str], pad_multiplier: int = 128, qkv_format: str = "thd", + seq_length: int | None = None, ) -> dict[str, torch.Tensor | PackedSeqParams | list[torch.Tensor] | None]: """ Generate a CP-ready micro-batch with packed sequence parameters. @@ -79,13 +80,19 @@ def get_batch( tokens = torch.cat(tokens) - # Always pad to reduce memory fragmentation and maybe make the computation faster - pad = (pad_size - tokens.size(0) % pad_size) % pad_size + if seq_length is not None: + assert tokens.size(0) <= seq_length, f"packed tokens length {tokens.size(0)} > seq_length {seq_length}" + assert ( + seq_length % pad_size == 0 + ), f"seq_length {seq_length} must be divisible by pad_size {pad_size} for THD padding" + pad = seq_length - tokens.size(0) + else: + pad = (pad_size - tokens.size(0) % pad_size) % pad_size + if pad != 0: tokens = F.pad(tokens, (0, pad), value=pad_token_id) cu_seqlens.append(cu_seqlens[-1] + pad) - # thd requires the cu_seqlens to be of the origin length cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int).cuda() * cp_size max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() diff --git a/slime/backends/megatron_utils/megatron_to_hf/__init__.py b/slime/backends/megatron_utils/megatron_to_hf/__init__.py index 28af98ca48..a224cc5ea9 100644 --- a/slime/backends/megatron_utils/megatron_to_hf/__init__.py +++ b/slime/backends/megatron_utils/megatron_to_hf/__init__.py @@ -1,6 +1,7 @@ from .deepseekv3 import convert_deepseekv3_to_hf from .glm4 import convert_glm4_to_hf from .glm4moe import convert_glm4moe_to_hf +from .kimi_vl import convert_kimi_k25_to_hf, convert_kimivl_to_hf from .llama import convert_llama_to_hf from .mimo import convert_mimo_to_hf from .processors import quantize_params, remove_padding @@ -50,6 +51,10 @@ def _convert_to_hf_core(args, model_name, name, param): converted_named_tensors = convert_llama_to_hf(args, name, param) elif "mimo" in model_name: converted_named_tensors = convert_mimo_to_hf(args, name, param) + elif "kimivl" in model_name: + converted_named_tensors = convert_kimivl_to_hf(args, name, param) + elif "kimi_k25" in model_name: + converted_named_tensors = convert_kimi_k25_to_hf(args, name, param) else: raise ValueError(f"Unsupported model: {model_name}") @@ -80,4 +85,4 @@ def _convert_to_hf_core(args, model_name, name, param): _cached_tensors[converted_name] = converted_param else: converted_named_tensors.append((converted_name, converted_param)) - return converted_named_tensors + return converted_named_tensors \ No newline at end of file diff --git a/slime/backends/megatron_utils/megatron_to_hf/kimi_vl.py b/slime/backends/megatron_utils/megatron_to_hf/kimi_vl.py new file mode 100644 index 0000000000..81329d45f6 --- /dev/null +++ b/slime/backends/megatron_utils/megatron_to_hf/kimi_vl.py @@ -0,0 +1,138 @@ +import re + +import torch + + +def convert_kimivl_to_hf(args, name, param): + if name.startswith("module.module.vision_model."): + hf_name = "vision_tower." + name[len("module.module.vision_model.") :] + return [(hf_name, param)] + + if name.startswith("module.module.multi_modal_projector."): + hf_name = "multi_modal_projector." + name[len("module.module.multi_modal_projector.") :] + return [(hf_name, param)] + + return convert_language_model_to_hf(args, name, param) + + +def convert_kimi_k25_to_hf(args, name, param): + if name.startswith("module.module.vision_tower."): + hf_name = "vision_tower." + name[len("module.module.vision_tower.") :] + return [(hf_name, param)] + + if name.startswith("module.module.mm_projector."): + hf_name = "mm_projector." + name[len("module.module.mm_projector.") :] + return [(hf_name, param)] + + return convert_language_model_to_hf(args, name, param) + + +def convert_language_model_to_hf(args, name, param): + if name == "module.module.language_model.embedding.word_embeddings.weight": + return [("language_model.model.embed_tokens.weight", param)] + if name == "module.module.language_model.output_layer.weight": + return [("language_model.lm_head.weight", param)] + if name == "module.module.language_model.decoder.final_layernorm.weight": + return [("language_model.model.norm.weight", param)] + + try: + head_dim = args.kv_channels if args.kv_channels is not None else args.hidden_size // args.num_attention_heads + except AttributeError: + head_dim = args.hidden_size // args.num_attention_heads + value_num_per_group = args.num_attention_heads // args.num_query_groups + + decoder_layers_pattern = r"module\.module\.language_model\.decoder\.layers\.(\d+)\.(.+)" + match = re.match(decoder_layers_pattern, name) + if match: + layer_idx, rest = match.groups() + + # experts + expert_pattern = r"mlp.experts\.(.+)\.weight(\d+)" + match = re.match(expert_pattern, rest) + if match: + rest, expert_idx = match.groups() + if rest == "linear_fc1": + gate_weight, up_weight = param.chunk(2, dim=0) + outputs = [ + ( + f"language_model.model.layers.{layer_idx}.mlp.experts.{expert_idx}.gate_proj.weight", + gate_weight, + ), + (f"language_model.model.layers.{layer_idx}.mlp.experts.{expert_idx}.up_proj.weight", up_weight), + ] + return outputs + elif rest == "linear_fc2": + outputs = [ + (f"language_model.model.layers.{layer_idx}.mlp.experts.{expert_idx}.down_proj.weight", param), + ] + return outputs + else: + raise ValueError(f"Unknown expert parameter name: {name}") + + # shared expert + shared_expert_pattern = r"mlp.shared_experts\.(.+)" + match = re.match(shared_expert_pattern, rest) + if match: + rest = match.groups()[0] + if rest == "linear_fc1.weight": + gate_weight, up_weight = param.chunk(2, dim=0) + return [ + (f"language_model.model.layers.{layer_idx}.mlp.shared_experts.gate_proj.weight", gate_weight), + (f"language_model.model.layers.{layer_idx}.mlp.shared_experts.up_proj.weight", up_weight), + ] + elif rest == "linear_fc2.weight": + return [(f"language_model.model.layers.{layer_idx}.mlp.shared_experts.down_proj.weight", param)] + else: + raise ValueError(f"Unknown shared expert parameter name: {name}") + + if rest == "self_attention.linear_proj.weight": + return [(f"language_model.model.layers.{layer_idx}.self_attn.o_proj.weight", param)] + elif rest == "self_attention.linear_q_proj.weight": + return [(f"language_model.model.layers.{layer_idx}.self_attn.q_proj.weight", param)] + elif rest == "self_attention.linear_q_down_proj.weight": + return [(f"language_model.model.layers.{layer_idx}.self_attn.q_a_proj.weight", param)] + elif rest == "self_attention.linear_q_up_proj.layer_norm_weight": + return [(f"language_model.model.layers.{layer_idx}.self_attn.q_a_layernorm.weight", param)] + elif rest == "self_attention.linear_q_up_proj.weight": + return [(f"language_model.model.layers.{layer_idx}.self_attn.q_b_proj.weight", param)] + elif rest == "self_attention.linear_qkv.bias": + param = param.view(args.num_query_groups, -1) + q_bias, k_bias, v_bias = torch.split( + param, + split_size_or_sections=[value_num_per_group * head_dim, head_dim, head_dim], + dim=1, + ) + q_bias = q_bias.contiguous().flatten() + k_bias = k_bias.contiguous().flatten() + v_bias = v_bias.contiguous().flatten() + return [ + (f"language_model.model.layers.{layer_idx}.self_attn.q_proj.bias", q_bias), + (f"language_model.model.layers.{layer_idx}.self_attn.k_proj.bias", k_bias), + (f"language_model.model.layers.{layer_idx}.self_attn.v_proj.bias", v_bias), + ] + elif rest == "mlp.linear_fc1.weight": + gate_weight, up_weight = param.chunk(2, dim=0) + return [ + (f"language_model.model.layers.{layer_idx}.mlp.gate_proj.weight", gate_weight), + (f"language_model.model.layers.{layer_idx}.mlp.up_proj.weight", up_weight), + ] + elif rest == "mlp.linear_fc2.weight": + return [(f"language_model.model.layers.{layer_idx}.mlp.down_proj.weight", param)] + elif rest == "self_attention.linear_qkv.layer_norm_weight" or rest == "input_layernorm.weight": + return [(f"language_model.model.layers.{layer_idx}.input_layernorm.weight", param)] + elif rest == "mlp.linear_fc1.layer_norm_weight": + return [(f"language_model.model.layers.{layer_idx}.post_attention_layernorm.weight", param)] + elif rest == "self_attention.linear_kv_down_proj.weight": + return [(f"language_model.model.layers.{layer_idx}.self_attn.kv_a_proj_with_mqa.weight", param)] + elif rest == "self_attention.linear_kv_up_proj.layer_norm_weight": + return [(f"language_model.model.layers.{layer_idx}.self_attn.kv_a_layernorm.weight", param)] + elif rest == "self_attention.linear_kv_up_proj.weight": + return [(f"language_model.model.layers.{layer_idx}.self_attn.kv_b_proj.weight", param)] + elif rest == "pre_mlp_layernorm.weight": + return [(f"language_model.model.layers.{layer_idx}.post_attention_layernorm.weight", param)] + elif rest == "mlp.router.weight": + return [(f"language_model.model.layers.{layer_idx}.mlp.gate.weight", param)] + elif rest == "mlp.router.expert_bias": + return [(f"language_model.model.layers.{layer_idx}.mlp.gate.e_score_correction_bias", param)] + + raise ValueError(f"Unknown parameter name: {name}") \ No newline at end of file diff --git a/slime/backends/megatron_utils/model.py b/slime/backends/megatron_utils/model.py index fc497046eb..56b782012a 100644 --- a/slime/backends/megatron_utils/model.py +++ b/slime/backends/megatron_utils/model.py @@ -205,6 +205,7 @@ def forward_step( assert not return_schedule_plan, "forward_only step should never return schedule plan" # Get the batch. + pad_to_seq_length = args.seq_length if (mpu.get_pipeline_model_parallel_world_size() > 1) else None batch = get_batch( data_iterator, [ @@ -217,6 +218,7 @@ def forward_step( ], args.data_pad_size_multiplier, args.qkv_format, + pad_to_seq_length, ) unconcat_tokens = batch["unconcat_tokens"] tokens = batch["tokens"] @@ -354,6 +356,7 @@ def forward_step(data_iterator: DataIterator, model: GPTModel, return_schedule_p """ # Get the batch. + pad_to_seq_length = args.seq_length if (mpu.get_pipeline_model_parallel_world_size() > 1) else None batch = get_batch( data_iterator, [ @@ -374,6 +377,7 @@ def forward_step(data_iterator: DataIterator, model: GPTModel, return_schedule_p ], args.data_pad_size_multiplier, args.qkv_format, + pad_to_seq_length, ) if os.environ.get("ENABLE_ROUTING_REPLAY", "0") == "1": diff --git a/slime/backends/megatron_utils/update_weight/hf_weight_iterator_bridge.py b/slime/backends/megatron_utils/update_weight/hf_weight_iterator_bridge.py index 7c6ac64010..e5effc846c 100644 --- a/slime/backends/megatron_utils/update_weight/hf_weight_iterator_bridge.py +++ b/slime/backends/megatron_utils/update_weight/hf_weight_iterator_bridge.py @@ -27,18 +27,25 @@ def get_hf_weight_chunks(self, megatron_local_weights): named_weights = self._bridge.export_hf_weights(self.model, cpu=False, conversion_tasks=conversion_tasks) - named_weights = ( - ( - hf_param_name, - postprocess_hf_param( - args=self.args, - megatron_param_name=megatron_param_name, - hf_param_name=hf_param_name, - param=weight, - ), - ) - for hf_param_name, weight, megatron_param_name in named_weights - ) + base_named_weights = named_weights + + def _iter_named_weights(): + for item in base_named_weights: + if len(item) == 3: + hf_param_name, weight, megatron_param_name = item + weight = postprocess_hf_param( + args=self.args, + megatron_param_name=megatron_param_name, + hf_param_name=hf_param_name, + param=weight, + ) + elif len(item) == 2: + hf_param_name, weight = item + else: + raise ValueError(f"Unexpected weight tuple size: {len(item)} for {item}") + yield (hf_param_name, weight) + + named_weights = _iter_named_weights() yield from chunk_named_params_by_size(named_weights, chunk_size=self.args.update_weight_buffer_size) diff --git a/slime/utils/reloadable_process_group.py b/slime/utils/reloadable_process_group.py index b4d325cf0b..7eb8dd251d 100644 --- a/slime/utils/reloadable_process_group.py +++ b/slime/utils/reloadable_process_group.py @@ -70,7 +70,8 @@ def new_function(*args, **kwargs): dist.all_to_all = get_new_function(dist.all_to_all) dist.all_to_all_single = get_new_function(dist.all_to_all_single) dist.broadcast = get_new_function(dist.broadcast) - dist.reduce = get_new_function(dist.reduce) + + dist.broadcast_object_list = get_new_function(dist.broadcast_object_list) dist.reduce_scatter = get_new_function(dist.reduce_scatter) dist.reduce_scatter_tensor = get_new_function(dist.reduce_scatter_tensor) dist.scatter = get_new_function(dist.scatter) diff --git a/tools/convert_hf_to_int4_direct.py b/tools/convert_hf_to_int4_direct.py index e741b802d3..3e52e1e5dc 100644 --- a/tools/convert_hf_to_int4_direct.py +++ b/tools/convert_hf_to_int4_direct.py @@ -292,6 +292,8 @@ def parse_args(): "re:.*norm.*", "re:.*embed.*", "re:.*self_attn.*", + "re:.*vision_tower.*", + "re:.*multi_modal_projector.*", "re:.*shared_experts.*", "re:.*mlp\\.(gate|up|gate_up|down)_proj.*", "re:.*mlp\\.gate\\.*", diff --git a/tools/convert_hf_to_torch_dist.py b/tools/convert_hf_to_torch_dist.py index 8d2758947e..b45af58933 100644 --- a/tools/convert_hf_to_torch_dist.py +++ b/tools/convert_hf_to_torch_dist.py @@ -17,7 +17,6 @@ from slime.utils.logging_utils import configure_logger from slime.utils.memory_utils import print_memory - def add_convertion_args(parser): """Add conversion arguments to the parser""" parser.add_argument("--hf-checkpoint", type=str, required=True, help="HuggingFace model path") @@ -33,7 +32,6 @@ def add_convertion_args(parser): pass return parser - def get_args(): args = parse_args(add_convertion_args) args = set_default_megatron_args(args) @@ -76,7 +74,6 @@ def ceildiv(a, b): validate_args(args) return args - def main(): if torch.version.hip: import megatron.core.dist_checkpointing.strategies.filesystem_async as filesystem_async_module @@ -113,10 +110,16 @@ def main(): model = get_model(get_model_provider_func(args), ModelType.encoder_or_decoder, wrap_with_ddp=False) - # Load model hf_model_path = args.hf_checkpoint - bridge = AutoBridge.from_pretrained(hf_model_path, trust_remote_code=True) - bridge.load_weights(model, hf_model_path, memory_efficient=True) + + # Load model + if args.megatron_to_hf_mode == "bridge": + from megatron.bridge import AutoBridge + bridge = AutoBridge.from_hf_pretrained(args.hf_checkpoint, trust_remote_code=True) + bridge.load_hf_weights(model, hf_model_path) + else : + bridge = AutoBridge.from_pretrained(hf_model_path, trust_remote_code=True) + bridge.load_weights(model, hf_model_path, memory_efficient=True) print(f"Model loaded: {hf_model_path}") if args.use_cpu_initialization: @@ -140,6 +143,5 @@ def main(): dist.barrier() dist.destroy_process_group() - if __name__ == "__main__": main()