|
15 | 15 | import re
|
16 | 16 | from typing import Any
|
17 | 17 |
|
18 |
| -from torchtitan.protocols.state_dict_adapter import StateDictAdapter |
| 18 | +from torch.distributed.tensor import DTensor |
| 19 | +from torchtitan.models.utils import MoEStateDictAdapter |
19 | 20 |
|
20 | 21 | from .args import Qwen3ModelArgs
|
21 | 22 |
|
22 | 23 |
|
23 |
| -class Qwen3StateDictAdapter(StateDictAdapter): |
| 24 | +class Qwen3StateDictAdapter(MoEStateDictAdapter): |
24 | 25 | def __init__(self, model_args: Qwen3ModelArgs, hf_assets_path: str | None):
|
25 | 26 | super().__init__(model_args, hf_assets_path)
|
26 |
| - |
27 |
| - self.model_args = model_args |
28 |
| - self.hf_assets_path = hf_assets_path |
29 |
| - |
30 | 27 | self.from_hf_map = {
|
31 | 28 | "model.embed_tokens.weight": "tok_embeddings.weight",
|
| 29 | + # Attention module |
32 | 30 | "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
|
33 | 31 | "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
|
34 | 32 | "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
|
35 | 33 | "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
|
36 | 34 | "model.layers.{}.self_attn.q_norm.weight": "layers.{}.attention.q_norm.weight",
|
37 | 35 | "model.layers.{}.self_attn.k_norm.weight": "layers.{}.attention.k_norm.weight",
|
38 | 36 | "model.layers.{}.self_attn.rotary_emb.inv_freq": None,
|
| 37 | + # MLP module for non-MoE |
39 | 38 | "model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight",
|
40 | 39 | "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
|
41 | 40 | "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
|
| 41 | + # Transformer layer |
42 | 42 | "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
|
43 | 43 | "model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
|
| 44 | + # MoE |
| 45 | + "model.layers.{}.mlp.experts.{}.gate_proj.weight": "layers.{}.moe.experts.w1", |
| 46 | + "model.layers.{}.mlp.experts.{}.up_proj.weight": "layers.{}.moe.experts.w3", |
| 47 | + "model.layers.{}.mlp.experts.{}.down_proj.weight": "layers.{}.moe.experts.w2", |
| 48 | + "model.layers.{}.mlp.gate.weight": "layers.{}.moe.router.gate.weight", |
44 | 49 | "model.norm.weight": "norm.weight",
|
45 | 50 | "lm_head.weight": "output.weight",
|
46 | 51 | }
|
47 | 52 |
|
48 | 53 | def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]:
|
49 |
| - |
| 54 | + """ |
| 55 | + 1. Convert between the HF shape and the torchtitan shape. |
| 56 | + 2. Split the GroupedExperts' weight into separate expert's wegiht. |
| 57 | + """ |
50 | 58 | to_hf_map = {v: k for k, v in self.from_hf_map.items()}
|
51 | 59 | hf_state_dict = {}
|
52 | 60 |
|
53 | 61 | for key, value in state_dict.items():
|
54 |
| - if "layers" in key: |
| 62 | + if "moe.experts" in key: |
55 | 63 | abstract_key = re.sub(r"(\d+)", "{}", key, count=1)
|
| 64 | + if abstract_key not in to_hf_map: |
| 65 | + continue |
56 | 66 | layer_num = re.search(r"\d+", key).group(0)
|
57 |
| - new_key = to_hf_map[abstract_key] |
58 |
| - |
59 |
| - if new_key is None: |
| 67 | + new_abstract_key = to_hf_map[abstract_key] |
| 68 | + |
| 69 | + # Store the GroupedExperts Weight metadata for from_hf() |
| 70 | + if isinstance(value, DTensor): |
| 71 | + self.grouped_expert_weight_placements[ |
| 72 | + abstract_key |
| 73 | + ] = value.placements |
| 74 | + self.grouped_expert_weight_shape[abstract_key] = value.shape |
| 75 | + |
| 76 | + # Split GroupedExperts weight to local individual expert weights |
| 77 | + local_expert_fqn = self._get_local_experts_weights( |
| 78 | + new_abstract_key, |
| 79 | + abstract_key, |
| 80 | + layer_num, |
| 81 | + value, |
| 82 | + ) |
| 83 | + hf_state_dict.update(local_expert_fqn) |
| 84 | + |
| 85 | + else: |
| 86 | + # keep this path for offline conversion |
| 87 | + split_values = self._split_experts_weights( |
| 88 | + value, self.model_args.moe_args.num_experts |
| 89 | + ) |
| 90 | + |
| 91 | + for expert_num in range(self.model_args.moe_args.num_experts): |
| 92 | + new_key = new_abstract_key.format(layer_num, expert_num) |
| 93 | + hf_state_dict[new_key] = split_values[expert_num].squeeze() |
| 94 | + |
| 95 | + elif "layers" in key: |
| 96 | + abstract_key = re.sub(r"(\d+)", "{}", key, count=1) |
| 97 | + if abstract_key not in to_hf_map: |
60 | 98 | continue
|
| 99 | + layer_num = re.search(r"\d+", key).group(0) |
| 100 | + new_key = to_hf_map[abstract_key] |
61 | 101 | new_key = new_key.format(layer_num)
|
| 102 | + hf_state_dict[new_key] = value |
| 103 | + |
62 | 104 | else:
|
| 105 | + if key not in to_hf_map: |
| 106 | + continue |
63 | 107 | new_key = to_hf_map[key]
|
64 |
| - |
65 |
| - hf_state_dict[new_key] = value |
| 108 | + hf_state_dict[new_key] = value |
66 | 109 |
|
67 | 110 | return hf_state_dict
|
68 | 111 |
|
69 | 112 | def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]:
|
| 113 | + """ |
| 114 | + 1. Convert between the HF shape and the torchtitan shape. |
| 115 | + 2. Concate separate expert's wegiht into GroupedExperts' weight. |
| 116 | + """ |
70 | 117 |
|
71 | 118 | state_dict = {}
|
| 119 | + expert_weights_by_layer = {} # {layer: {abstract_key: {expert_id: tensor}}} |
72 | 120 |
|
73 | 121 | for key, value in hf_state_dict.items():
|
74 |
| - if "layers" in key: |
| 122 | + if "mlp.experts" in key: |
| 123 | + abstract_key = re.sub(r"(\d+)", "{}", key, count=2) |
| 124 | + layer_num, expert_num = re.findall(r"\d+", key) |
| 125 | + titan_abstract_key = self.from_hf_map[abstract_key] |
| 126 | + new_key = titan_abstract_key.format(layer_num) |
| 127 | + |
| 128 | + # Store the expert's weight in expert_weights_by_layer for concatenating later. |
| 129 | + if layer_num not in expert_weights_by_layer: |
| 130 | + expert_weights_by_layer[layer_num] = {} |
| 131 | + if titan_abstract_key not in expert_weights_by_layer[layer_num]: |
| 132 | + expert_weights_by_layer[layer_num][titan_abstract_key] = {} |
| 133 | + expert_weights_by_layer[layer_num][titan_abstract_key][ |
| 134 | + expert_num |
| 135 | + ] = value |
| 136 | + |
| 137 | + if isinstance(value, DTensor): |
| 138 | + stacked_value = self._concatenate_expert_weights_dtensor( |
| 139 | + expert_weights_by_layer, |
| 140 | + titan_abstract_key, |
| 141 | + layer_num, |
| 142 | + value.device_mesh, |
| 143 | + ) |
| 144 | + else: # keep this path to be compatibile with offline conversion |
| 145 | + stacked_value = self._concatenate_expert_weights( |
| 146 | + expert_weights_by_layer, |
| 147 | + titan_abstract_key, |
| 148 | + layer_num, |
| 149 | + self.model_args.moe_args.num_experts, |
| 150 | + ) |
| 151 | + |
| 152 | + if stacked_value is not None: |
| 153 | + state_dict[new_key] = stacked_value |
| 154 | + |
| 155 | + elif "layers" in key: |
75 | 156 | abstract_key = re.sub(r"(\d+)", "{}", key, count=1)
|
76 | 157 | layer_num = re.search(r"\d+", key).group(0)
|
77 | 158 | new_key = self.from_hf_map[abstract_key]
|
78 |
| - |
79 |
| - if new_key is None: |
80 |
| - continue |
81 | 159 | new_key = new_key.format(layer_num)
|
| 160 | + state_dict[new_key] = value |
| 161 | + |
82 | 162 | else:
|
83 | 163 | new_key = self.from_hf_map[key]
|
| 164 | + state_dict[new_key] = value |
84 | 165 |
|
85 |
| - state_dict[new_key] = value |
86 | 166 | return state_dict
|
0 commit comments