Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions slime/backends/megatron_utils/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,12 +199,13 @@ def _get_rollout_data(self, rollout_data_ref: Box) -> RolloutBatch:
]
if "multimodal_train_inputs" in rollout_data:
# Move multimodal training tensors to GPU in advance
def _to_device(tensor):
if torch.is_tensor(tensor):
return tensor.to(device=torch.cuda.current_device())
return torch.as_tensor(tensor, device=torch.cuda.current_device())

rollout_data["multimodal_train_inputs"] = [
(
{key: tensor.to(device=torch.cuda.current_device()) for key, tensor in mm_dict.items()}
if mm_dict is not None
else None
)
({key: _to_device(tensor) for key, tensor in mm_dict.items()} if mm_dict is not None else None)
for mm_dict in rollout_data["multimodal_train_inputs"]
]

Expand Down
13 changes: 10 additions & 3 deletions slime/backends/megatron_utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def get_batch(
keys: Sequence[str],
pad_multiplier: int = 128,
qkv_format: str = "thd",
seq_length: int | None = None,
) -> dict[str, torch.Tensor | PackedSeqParams | list[torch.Tensor] | None]:
"""
Generate a CP-ready micro-batch with packed sequence parameters.
Expand Down Expand Up @@ -79,13 +80,19 @@ def get_batch(

tokens = torch.cat(tokens)

# Always pad to reduce memory fragmentation and maybe make the computation faster
pad = (pad_size - tokens.size(0) % pad_size) % pad_size
if seq_length is not None:
assert tokens.size(0) <= seq_length, f"packed tokens length {tokens.size(0)} > seq_length {seq_length}"
assert (
seq_length % pad_size == 0
), f"seq_length {seq_length} must be divisible by pad_size {pad_size} for THD padding"
pad = seq_length - tokens.size(0)
else:
pad = (pad_size - tokens.size(0) % pad_size) % pad_size

if pad != 0:
tokens = F.pad(tokens, (0, pad), value=pad_token_id)
cu_seqlens.append(cu_seqlens[-1] + pad)

# thd requires the cu_seqlens to be of the origin length
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int).cuda() * cp_size
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()

Expand Down
7 changes: 6 additions & 1 deletion slime/backends/megatron_utils/megatron_to_hf/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .deepseekv3 import convert_deepseekv3_to_hf
from .glm4 import convert_glm4_to_hf
from .glm4moe import convert_glm4moe_to_hf
from .kimi_vl import convert_kimi_k25_to_hf, convert_kimivl_to_hf
from .llama import convert_llama_to_hf
from .mimo import convert_mimo_to_hf
from .processors import quantize_params, remove_padding
Expand Down Expand Up @@ -50,6 +51,10 @@ def _convert_to_hf_core(args, model_name, name, param):
converted_named_tensors = convert_llama_to_hf(args, name, param)
elif "mimo" in model_name:
converted_named_tensors = convert_mimo_to_hf(args, name, param)
elif "kimivl" in model_name:
converted_named_tensors = convert_kimivl_to_hf(args, name, param)
elif "kimi_k25" in model_name:
converted_named_tensors = convert_kimi_k25_to_hf(args, name, param)
else:
raise ValueError(f"Unsupported model: {model_name}")

Expand Down Expand Up @@ -80,4 +85,4 @@ def _convert_to_hf_core(args, model_name, name, param):
_cached_tensors[converted_name] = converted_param
else:
converted_named_tensors.append((converted_name, converted_param))
return converted_named_tensors
return converted_named_tensors
138 changes: 138 additions & 0 deletions slime/backends/megatron_utils/megatron_to_hf/kimi_vl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import re

import torch


def convert_kimivl_to_hf(args, name, param):
if name.startswith("module.module.vision_model."):
hf_name = "vision_tower." + name[len("module.module.vision_model.") :]
return [(hf_name, param)]

if name.startswith("module.module.multi_modal_projector."):
hf_name = "multi_modal_projector." + name[len("module.module.multi_modal_projector.") :]
return [(hf_name, param)]

return convert_language_model_to_hf(args, name, param)


def convert_kimi_k25_to_hf(args, name, param):
if name.startswith("module.module.vision_tower."):
hf_name = "vision_tower." + name[len("module.module.vision_tower.") :]
return [(hf_name, param)]

if name.startswith("module.module.mm_projector."):
hf_name = "mm_projector." + name[len("module.module.mm_projector.") :]
return [(hf_name, param)]

return convert_language_model_to_hf(args, name, param)


def convert_language_model_to_hf(args, name, param):
if name == "module.module.language_model.embedding.word_embeddings.weight":
return [("language_model.model.embed_tokens.weight", param)]
if name == "module.module.language_model.output_layer.weight":
return [("language_model.lm_head.weight", param)]
if name == "module.module.language_model.decoder.final_layernorm.weight":
return [("language_model.model.norm.weight", param)]

try:
head_dim = args.kv_channels if args.kv_channels is not None else args.hidden_size // args.num_attention_heads
except AttributeError:
head_dim = args.hidden_size // args.num_attention_heads
value_num_per_group = args.num_attention_heads // args.num_query_groups

decoder_layers_pattern = r"module\.module\.language_model\.decoder\.layers\.(\d+)\.(.+)"
match = re.match(decoder_layers_pattern, name)
if match:
layer_idx, rest = match.groups()

# experts
expert_pattern = r"mlp.experts\.(.+)\.weight(\d+)"
match = re.match(expert_pattern, rest)
if match:
rest, expert_idx = match.groups()
if rest == "linear_fc1":
gate_weight, up_weight = param.chunk(2, dim=0)
outputs = [
(
f"language_model.model.layers.{layer_idx}.mlp.experts.{expert_idx}.gate_proj.weight",
gate_weight,
),
(f"language_model.model.layers.{layer_idx}.mlp.experts.{expert_idx}.up_proj.weight", up_weight),
]
return outputs
elif rest == "linear_fc2":
outputs = [
(f"language_model.model.layers.{layer_idx}.mlp.experts.{expert_idx}.down_proj.weight", param),
]
return outputs
else:
raise ValueError(f"Unknown expert parameter name: {name}")

# shared expert
shared_expert_pattern = r"mlp.shared_experts\.(.+)"
match = re.match(shared_expert_pattern, rest)
if match:
rest = match.groups()[0]
if rest == "linear_fc1.weight":
gate_weight, up_weight = param.chunk(2, dim=0)
return [
(f"language_model.model.layers.{layer_idx}.mlp.shared_experts.gate_proj.weight", gate_weight),
(f"language_model.model.layers.{layer_idx}.mlp.shared_experts.up_proj.weight", up_weight),
]
elif rest == "linear_fc2.weight":
return [(f"language_model.model.layers.{layer_idx}.mlp.shared_experts.down_proj.weight", param)]
else:
raise ValueError(f"Unknown shared expert parameter name: {name}")

if rest == "self_attention.linear_proj.weight":
return [(f"language_model.model.layers.{layer_idx}.self_attn.o_proj.weight", param)]
elif rest == "self_attention.linear_q_proj.weight":
return [(f"language_model.model.layers.{layer_idx}.self_attn.q_proj.weight", param)]
elif rest == "self_attention.linear_q_down_proj.weight":
return [(f"language_model.model.layers.{layer_idx}.self_attn.q_a_proj.weight", param)]
elif rest == "self_attention.linear_q_up_proj.layer_norm_weight":
return [(f"language_model.model.layers.{layer_idx}.self_attn.q_a_layernorm.weight", param)]
elif rest == "self_attention.linear_q_up_proj.weight":
return [(f"language_model.model.layers.{layer_idx}.self_attn.q_b_proj.weight", param)]
elif rest == "self_attention.linear_qkv.bias":
param = param.view(args.num_query_groups, -1)
q_bias, k_bias, v_bias = torch.split(
param,
split_size_or_sections=[value_num_per_group * head_dim, head_dim, head_dim],
dim=1,
)
q_bias = q_bias.contiguous().flatten()
k_bias = k_bias.contiguous().flatten()
v_bias = v_bias.contiguous().flatten()
return [
(f"language_model.model.layers.{layer_idx}.self_attn.q_proj.bias", q_bias),
(f"language_model.model.layers.{layer_idx}.self_attn.k_proj.bias", k_bias),
(f"language_model.model.layers.{layer_idx}.self_attn.v_proj.bias", v_bias),
]
elif rest == "mlp.linear_fc1.weight":
gate_weight, up_weight = param.chunk(2, dim=0)
return [
(f"language_model.model.layers.{layer_idx}.mlp.gate_proj.weight", gate_weight),
(f"language_model.model.layers.{layer_idx}.mlp.up_proj.weight", up_weight),
]
elif rest == "mlp.linear_fc2.weight":
return [(f"language_model.model.layers.{layer_idx}.mlp.down_proj.weight", param)]
elif rest == "self_attention.linear_qkv.layer_norm_weight" or rest == "input_layernorm.weight":
return [(f"language_model.model.layers.{layer_idx}.input_layernorm.weight", param)]
elif rest == "mlp.linear_fc1.layer_norm_weight":
return [(f"language_model.model.layers.{layer_idx}.post_attention_layernorm.weight", param)]
elif rest == "self_attention.linear_kv_down_proj.weight":
return [(f"language_model.model.layers.{layer_idx}.self_attn.kv_a_proj_with_mqa.weight", param)]
elif rest == "self_attention.linear_kv_up_proj.layer_norm_weight":
return [(f"language_model.model.layers.{layer_idx}.self_attn.kv_a_layernorm.weight", param)]
elif rest == "self_attention.linear_kv_up_proj.weight":
return [(f"language_model.model.layers.{layer_idx}.self_attn.kv_b_proj.weight", param)]
elif rest == "pre_mlp_layernorm.weight":
return [(f"language_model.model.layers.{layer_idx}.post_attention_layernorm.weight", param)]
elif rest == "mlp.router.weight":
return [(f"language_model.model.layers.{layer_idx}.mlp.gate.weight", param)]
elif rest == "mlp.router.expert_bias":
return [(f"language_model.model.layers.{layer_idx}.mlp.gate.e_score_correction_bias", param)]

raise ValueError(f"Unknown parameter name: {name}")
4 changes: 4 additions & 0 deletions slime/backends/megatron_utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ def forward_step(
assert not return_schedule_plan, "forward_only step should never return schedule plan"

# Get the batch.
pad_to_seq_length = args.seq_length if (mpu.get_pipeline_model_parallel_world_size() > 1) else None
batch = get_batch(
data_iterator,
[
Expand All @@ -217,6 +218,7 @@ def forward_step(
],
args.data_pad_size_multiplier,
args.qkv_format,
pad_to_seq_length,
)
unconcat_tokens = batch["unconcat_tokens"]
tokens = batch["tokens"]
Expand Down Expand Up @@ -354,6 +356,7 @@ def forward_step(data_iterator: DataIterator, model: GPTModel, return_schedule_p
"""

# Get the batch.
pad_to_seq_length = args.seq_length if (mpu.get_pipeline_model_parallel_world_size() > 1) else None
batch = get_batch(
data_iterator,
[
Expand All @@ -374,6 +377,7 @@ def forward_step(data_iterator: DataIterator, model: GPTModel, return_schedule_p
],
args.data_pad_size_multiplier,
args.qkv_format,
pad_to_seq_length,
)

if os.environ.get("ENABLE_ROUTING_REPLAY", "0") == "1":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,25 @@ def get_hf_weight_chunks(self, megatron_local_weights):

named_weights = self._bridge.export_hf_weights(self.model, cpu=False, conversion_tasks=conversion_tasks)

named_weights = (
(
hf_param_name,
postprocess_hf_param(
args=self.args,
megatron_param_name=megatron_param_name,
hf_param_name=hf_param_name,
param=weight,
),
)
for hf_param_name, weight, megatron_param_name in named_weights
)
base_named_weights = named_weights

def _iter_named_weights():
for item in base_named_weights:
if len(item) == 3:
hf_param_name, weight, megatron_param_name = item
weight = postprocess_hf_param(
args=self.args,
megatron_param_name=megatron_param_name,
hf_param_name=hf_param_name,
param=weight,
)
elif len(item) == 2:
hf_param_name, weight = item
else:
raise ValueError(f"Unexpected weight tuple size: {len(item)} for {item}")
yield (hf_param_name, weight)

named_weights = _iter_named_weights()

yield from chunk_named_params_by_size(named_weights, chunk_size=self.args.update_weight_buffer_size)

Expand Down
3 changes: 2 additions & 1 deletion slime/utils/reloadable_process_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ def new_function(*args, **kwargs):
dist.all_to_all = get_new_function(dist.all_to_all)
dist.all_to_all_single = get_new_function(dist.all_to_all_single)
dist.broadcast = get_new_function(dist.broadcast)
dist.reduce = get_new_function(dist.reduce)

dist.broadcast_object_list = get_new_function(dist.broadcast_object_list)
dist.reduce_scatter = get_new_function(dist.reduce_scatter)
dist.reduce_scatter_tensor = get_new_function(dist.reduce_scatter_tensor)
dist.scatter = get_new_function(dist.scatter)
Expand Down
2 changes: 2 additions & 0 deletions tools/convert_hf_to_int4_direct.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,8 @@ def parse_args():
"re:.*norm.*",
"re:.*embed.*",
"re:.*self_attn.*",
"re:.*vision_tower.*",
"re:.*multi_modal_projector.*",
"re:.*shared_experts.*",
"re:.*mlp\\.(gate|up|gate_up|down)_proj.*",
"re:.*mlp\\.gate\\.*",
Expand Down
16 changes: 9 additions & 7 deletions tools/convert_hf_to_torch_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from slime.utils.logging_utils import configure_logger
from slime.utils.memory_utils import print_memory


def add_convertion_args(parser):
"""Add conversion arguments to the parser"""
parser.add_argument("--hf-checkpoint", type=str, required=True, help="HuggingFace model path")
Expand All @@ -33,7 +32,6 @@ def add_convertion_args(parser):
pass
return parser


def get_args():
args = parse_args(add_convertion_args)
args = set_default_megatron_args(args)
Expand Down Expand Up @@ -76,7 +74,6 @@ def ceildiv(a, b):
validate_args(args)
return args


def main():
if torch.version.hip:
import megatron.core.dist_checkpointing.strategies.filesystem_async as filesystem_async_module
Expand Down Expand Up @@ -113,10 +110,16 @@ def main():

model = get_model(get_model_provider_func(args), ModelType.encoder_or_decoder, wrap_with_ddp=False)

# Load model
hf_model_path = args.hf_checkpoint
bridge = AutoBridge.from_pretrained(hf_model_path, trust_remote_code=True)
bridge.load_weights(model, hf_model_path, memory_efficient=True)

# Load model
if args.megatron_to_hf_mode == "bridge":
from megatron.bridge import AutoBridge
bridge = AutoBridge.from_hf_pretrained(args.hf_checkpoint, trust_remote_code=True)
bridge.load_hf_weights(model, hf_model_path)
else :
bridge = AutoBridge.from_pretrained(hf_model_path, trust_remote_code=True)
bridge.load_weights(model, hf_model_path, memory_efficient=True)
print(f"Model loaded: {hf_model_path}")

if args.use_cpu_initialization:
Expand All @@ -140,6 +143,5 @@ def main():
dist.barrier()
dist.destroy_process_group()


if __name__ == "__main__":
main()