Skip to content

Commit 4409c13

Browse files
authored
[Qwen3] StateDictAdapter support for MoE model (#1766)
Reuse StateDictAdapter support for DeepSeek V3 model to implement Qwen3 StateDictAdapter. Move the common features of MoE state dict adapter to the subclass MoEStateDictAdapter in torchtitan/models/utils.py.
1 parent 5217163 commit 4409c13

File tree

3 files changed

+451
-349
lines changed

3 files changed

+451
-349
lines changed

torchtitan/experiments/qwen3/model/state_dict_adapter.py

Lines changed: 98 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,72 +15,152 @@
1515
import re
1616
from typing import Any
1717

18-
from torchtitan.protocols.state_dict_adapter import StateDictAdapter
18+
from torch.distributed.tensor import DTensor
19+
from torchtitan.models.utils import MoEStateDictAdapter
1920

2021
from .args import Qwen3ModelArgs
2122

2223

23-
class Qwen3StateDictAdapter(StateDictAdapter):
24+
class Qwen3StateDictAdapter(MoEStateDictAdapter):
2425
def __init__(self, model_args: Qwen3ModelArgs, hf_assets_path: str | None):
2526
super().__init__(model_args, hf_assets_path)
26-
27-
self.model_args = model_args
28-
self.hf_assets_path = hf_assets_path
29-
3027
self.from_hf_map = {
3128
"model.embed_tokens.weight": "tok_embeddings.weight",
29+
# Attention module
3230
"model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
3331
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
3432
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
3533
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
3634
"model.layers.{}.self_attn.q_norm.weight": "layers.{}.attention.q_norm.weight",
3735
"model.layers.{}.self_attn.k_norm.weight": "layers.{}.attention.k_norm.weight",
3836
"model.layers.{}.self_attn.rotary_emb.inv_freq": None,
37+
# MLP module for non-MoE
3938
"model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight",
4039
"model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
4140
"model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
41+
# Transformer layer
4242
"model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
4343
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
44+
# MoE
45+
"model.layers.{}.mlp.experts.{}.gate_proj.weight": "layers.{}.moe.experts.w1",
46+
"model.layers.{}.mlp.experts.{}.up_proj.weight": "layers.{}.moe.experts.w3",
47+
"model.layers.{}.mlp.experts.{}.down_proj.weight": "layers.{}.moe.experts.w2",
48+
"model.layers.{}.mlp.gate.weight": "layers.{}.moe.router.gate.weight",
4449
"model.norm.weight": "norm.weight",
4550
"lm_head.weight": "output.weight",
4651
}
4752

4853
def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]:
49-
54+
"""
55+
1. Convert between the HF shape and the torchtitan shape.
56+
2. Split the GroupedExperts' weight into separate expert's wegiht.
57+
"""
5058
to_hf_map = {v: k for k, v in self.from_hf_map.items()}
5159
hf_state_dict = {}
5260

5361
for key, value in state_dict.items():
54-
if "layers" in key:
62+
if "moe.experts" in key:
5563
abstract_key = re.sub(r"(\d+)", "{}", key, count=1)
64+
if abstract_key not in to_hf_map:
65+
continue
5666
layer_num = re.search(r"\d+", key).group(0)
57-
new_key = to_hf_map[abstract_key]
58-
59-
if new_key is None:
67+
new_abstract_key = to_hf_map[abstract_key]
68+
69+
# Store the GroupedExperts Weight metadata for from_hf()
70+
if isinstance(value, DTensor):
71+
self.grouped_expert_weight_placements[
72+
abstract_key
73+
] = value.placements
74+
self.grouped_expert_weight_shape[abstract_key] = value.shape
75+
76+
# Split GroupedExperts weight to local individual expert weights
77+
local_expert_fqn = self._get_local_experts_weights(
78+
new_abstract_key,
79+
abstract_key,
80+
layer_num,
81+
value,
82+
)
83+
hf_state_dict.update(local_expert_fqn)
84+
85+
else:
86+
# keep this path for offline conversion
87+
split_values = self._split_experts_weights(
88+
value, self.model_args.moe_args.num_experts
89+
)
90+
91+
for expert_num in range(self.model_args.moe_args.num_experts):
92+
new_key = new_abstract_key.format(layer_num, expert_num)
93+
hf_state_dict[new_key] = split_values[expert_num].squeeze()
94+
95+
elif "layers" in key:
96+
abstract_key = re.sub(r"(\d+)", "{}", key, count=1)
97+
if abstract_key not in to_hf_map:
6098
continue
99+
layer_num = re.search(r"\d+", key).group(0)
100+
new_key = to_hf_map[abstract_key]
61101
new_key = new_key.format(layer_num)
102+
hf_state_dict[new_key] = value
103+
62104
else:
105+
if key not in to_hf_map:
106+
continue
63107
new_key = to_hf_map[key]
64-
65-
hf_state_dict[new_key] = value
108+
hf_state_dict[new_key] = value
66109

67110
return hf_state_dict
68111

69112
def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]:
113+
"""
114+
1. Convert between the HF shape and the torchtitan shape.
115+
2. Concate separate expert's wegiht into GroupedExperts' weight.
116+
"""
70117

71118
state_dict = {}
119+
expert_weights_by_layer = {} # {layer: {abstract_key: {expert_id: tensor}}}
72120

73121
for key, value in hf_state_dict.items():
74-
if "layers" in key:
122+
if "mlp.experts" in key:
123+
abstract_key = re.sub(r"(\d+)", "{}", key, count=2)
124+
layer_num, expert_num = re.findall(r"\d+", key)
125+
titan_abstract_key = self.from_hf_map[abstract_key]
126+
new_key = titan_abstract_key.format(layer_num)
127+
128+
# Store the expert's weight in expert_weights_by_layer for concatenating later.
129+
if layer_num not in expert_weights_by_layer:
130+
expert_weights_by_layer[layer_num] = {}
131+
if titan_abstract_key not in expert_weights_by_layer[layer_num]:
132+
expert_weights_by_layer[layer_num][titan_abstract_key] = {}
133+
expert_weights_by_layer[layer_num][titan_abstract_key][
134+
expert_num
135+
] = value
136+
137+
if isinstance(value, DTensor):
138+
stacked_value = self._concatenate_expert_weights_dtensor(
139+
expert_weights_by_layer,
140+
titan_abstract_key,
141+
layer_num,
142+
value.device_mesh,
143+
)
144+
else: # keep this path to be compatibile with offline conversion
145+
stacked_value = self._concatenate_expert_weights(
146+
expert_weights_by_layer,
147+
titan_abstract_key,
148+
layer_num,
149+
self.model_args.moe_args.num_experts,
150+
)
151+
152+
if stacked_value is not None:
153+
state_dict[new_key] = stacked_value
154+
155+
elif "layers" in key:
75156
abstract_key = re.sub(r"(\d+)", "{}", key, count=1)
76157
layer_num = re.search(r"\d+", key).group(0)
77158
new_key = self.from_hf_map[abstract_key]
78-
79-
if new_key is None:
80-
continue
81159
new_key = new_key.format(layer_num)
160+
state_dict[new_key] = value
161+
82162
else:
83163
new_key = self.from_hf_map[key]
164+
state_dict[new_key] = value
84165

85-
state_dict[new_key] = value
86166
return state_dict

0 commit comments

Comments
 (0)