Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions examples/intel_hpu/offline_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
max_out_tokens = 128
server_max_bs = 128
TP = 1
EP = True

# num_gpu_blocks_override = ceil((input_seq + max_out_tokens) / 128) * server_max_bs
num_gpu_blocks_override = 2000
Expand All @@ -34,12 +35,14 @@
llm = LLM(
model=model_name_or_path,
tensor_parallel_size=TP,
enable_expert_parallel=EP,
engine_worker_queue_port=8602,
num_gpu_blocks_override=num_gpu_blocks_override,
block_size=128,
max_model_len=32768,
max_num_seqs=server_max_bs,
graph_optimization_config=graph_optimization_config,
disable_sequence_parallel_moe=True,
)

if input_seq is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
UnquantizedFusedMoEMethod,
)
from fastdeploy.model_executor.layers.utils import get_tensor
from fastdeploy.model_executor.utils import set_weight_attrs


class HpuMoEMethod(UnquantizedFusedMoEMethod):
Expand Down Expand Up @@ -53,6 +54,24 @@ def process_loaded_weights(self, layer: nn.Layer, state_dict):
)
self.down_proj_expert_act_scale_key = down_proj_expert_weight_key.replace("weight", "activation_scale")

def init_ep(self, layer: nn.Layer) -> None:
"""
Initialize EP (Expert Parallel) related modules.
"""
return

def apply_tp(
self,
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Apply the TP prefill method.
"""
raise NotImplementedError

def apply_ep_prefill(
self,
layer: nn.Layer,
Expand All @@ -77,7 +96,7 @@ def apply_ep_decode(
"""
raise NotImplementedError

def apply_tp(
def apply(
self,
layer: nn.Layer,
x: paddle.Tensor,
Expand Down Expand Up @@ -190,17 +209,18 @@ def _extract_descale_tensor(key_template, logical_expert_ids):
down_proj_weight = paddle.stack(down_proj_weights, axis=0)
up_gate_proj_weight_scale = paddle.stack(up_gate_proj_weight_scale, axis=0)
down_proj_weight_scale = paddle.stack(down_proj_weight_scale, axis=0)
down_proj_in_scale = paddle.stack(down_proj_in_scale, axis=0)

name_tensor_map = {
"up_gate_proj_weight": up_gate_proj_weight,
"down_proj_weight": down_proj_weight,
"up_gate_proj_weight_scale": up_gate_proj_weight_scale,
"down_proj_weight_scale": down_proj_weight_scale,
"up_gate_proj_in_scale": up_gate_proj_in_scale,
"down_proj_in_scale": down_proj_in_scale,
}
for name, tensor in name_tensor_map.items():
getattr(layer, name).set_value(tensor)
setattr(layer, "down_proj_in_scale", down_proj_in_scale)

def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
"""
Expand Down Expand Up @@ -247,6 +267,15 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
default_initializer=paddle.nn.initializer.Constant(0),
),
)
setattr(
layer,
"down_proj_in_scale",
layer.create_parameter(
shape=[layer.num_local_experts, 1],
dtype=self.default_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)

# weight_scales
setattr(
Expand All @@ -267,6 +296,19 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
default_initializer=paddle.nn.initializer.Constant(0),
),
)
extra_weight_attrs = {
**(extra_weight_attrs or {}),
"SHARD_ID_TO_SHARDED_DIM": {"gate": 1, "down": 0, "up": 1},
}
set_weight_attrs(layer.up_gate_proj_weight, extra_weight_attrs)
set_weight_attrs(layer.down_proj_weight, extra_weight_attrs)
extra_scale_attrs = {
**(extra_weight_attrs or {}),
"SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "up": 0, "down": None},
}
set_weight_attrs(layer.down_proj_in_scale, extra_scale_attrs)
set_weight_attrs(layer.up_gate_proj_weight_scale, extra_scale_attrs)
set_weight_attrs(layer.down_proj_weight_scale, extra_scale_attrs)

def process_loaded_weights(self, layer: nn.Layer, state_dict):
"""
Expand Down Expand Up @@ -296,6 +338,27 @@ def process_loaded_weights(self, layer: nn.Layer, state_dict):
setattr(layer, weights_name, weights_list)
setattr(layer, scales_name, scales_list)

def process_weights_after_loading(self, layer):
return

def init_ep(self, layer: nn.Layer) -> None:
"""
Initialize EP (Expert Parallel) related modules.
"""
return

def apply_tp(
self,
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Apply the TP decoder method.
"""
raise NotImplementedError

def apply_ep_prefill(
self,
layer: nn.Layer,
Expand All @@ -320,7 +383,7 @@ def apply_ep_decode(
"""
raise NotImplementedError

def apply_tp(
def apply(
self,
layer: nn.Layer,
x: paddle.Tensor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,11 @@ def load_scale(self, layer: nn.Layer, state_dict):
layer.s_scale.set_value(s_scale)
layer.s_out_scale.set_value(s_out_scale)

def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
loaded_weight = get_tensor(loaded_weight).cast("float32")
loaded_weight = self.cache_quant_config.max_bound / loaded_weight
param.copy_(loaded_weight, False)

def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
"""
create_weights
Expand Down Expand Up @@ -158,12 +163,14 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
layer.cache_k_scale,
{
**extra_weight_attrs,
"weight_loader": self.weight_loader,
},
)
set_weight_attrs(
layer.cache_v_scale,
{
**extra_weight_attrs,
"weight_loader": self.weight_loader,
},
)
layer.cache_k_out_scale = layer.create_parameter(
Expand All @@ -182,6 +189,13 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
dtype="float32",
default_initializer=paddle.nn.initializer.Constant(0),
)
set_weight_attrs(
layer.q_scale,
{
**extra_weight_attrs,
"weight_loader": self.weight_loader,
},
)
layer.q_out_scale = layer.create_parameter(
shape=scale_shape,
dtype="float32",
Expand All @@ -202,6 +216,13 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
dtype="float32",
default_initializer=paddle.nn.initializer.Constant(0),
)
set_weight_attrs(
layer.s_scale,
{
**extra_weight_attrs,
"weight_loader": self.weight_loader,
},
)
layer.s_out_scale = layer.create_parameter(
shape=scale_shape,
dtype="float32",
Expand All @@ -226,9 +247,16 @@ def process_weights_after_loading(self, layer: nn.Layer):
"""
# cache_k_out_scale is the reciprocal of cache_k_scale
if layer.cache_k_scale._is_initialized():
layer.cache_k_out_scale.set_value(1 / layer.cache_k_scale) # cache_k_out_scale
layer.cache_k_out_scale.set_value(1.0 / layer.cache_k_scale)
if layer.cache_v_scale._is_initialized():
layer.cache_v_out_scale.set_value(1 / layer.cache_v_scale)
layer.cache_v_out_scale.set_value(1.0 / layer.cache_v_scale)
if layer.q_scale._is_initialized():
scaling_factor = layer.head_dim**-0.5
layer.q_scaling_scale.set_value(layer.q_scale / scaling_factor)
layer.q_scaling_out_scale.set_value(scaling_factor / layer.q_scale)
layer.q_out_scale.set_value(1.0 / layer.q_scale)
if layer.s_scale._is_initialized():
layer.s_out_scale.set_value(1.0 / layer.s_scale)

def apply(self, layer):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
)
from fastdeploy.model_executor.layers.utils import get_tensor
from fastdeploy.model_executor.ops.intel_hpu import fused_quant
from fastdeploy.model_executor.utils import set_weight_attrs


class HpuTensorWiseFP8LinearMethod(TensorWiseFP8LinearMethod):
Expand Down Expand Up @@ -92,10 +93,35 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs) -> None:
is_bias=False,
)

self.model_format = extra_weight_attrs.get("model_format")
if self.model_format == "torch" and "output_dim" in extra_weight_attrs:
extra_weight_attrs["output_dim"] = not extra_weight_attrs["output_dim"]
set_weight_attrs(
layer.weight,
extra_weight_attrs,
)

def process_loaded_weights(self, layer: nn.Layer, weight: paddle.Tensor) -> None:
"""
loaded_weights using HPU specific quantization
"""
quanted_weight_tensor, weight_scale_tensor = fused_quant(weight)
layer.weight.set_value(quanted_weight_tensor)
layer.weight_scale.set_value(weight_scale_tensor)

def process_weights_after_loading(self, layer: nn.Layer):
"""
use for loader v1
"""
# these activation_scale will fall in, but only quant for self_attn
# mlp.shared_experts.up_gate_proj / down_proj
# self_attn.qkv_proj / o_proj
if layer.act_scale._is_initialized():
if "self_attn" in layer.act_scale_key:
act_scale_inv = layer.act_scale / self.max_bound
act_scale = self.max_bound / layer.act_scale
else:
act_scale_inv = layer.act_scale
act_scale = 1.0 / layer.act_scale
layer.act_scale.set_value(act_scale.astype(paddle.get_default_dtype()))
layer.act_scale_inv.set_value(act_scale_inv.astype(paddle.get_default_dtype()))
11 changes: 7 additions & 4 deletions fastdeploy/model_executor/layers/moe/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,12 @@ def forward(self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta =
tp_group=self.fd_config.parallel_config.tp_group,
)

if current_platform.is_intel_hpu():
out = self.forward_normal(x, gate, forward_meta, topk_ids_hookfunc=topk_ids_hookfunc)
if self.reduce_results and (self.ep_size > 1 or self.tp_size > 1):
tensor_model_parallel_all_reduce_custom(out)
return out

token_num = x.shape[0]
if (
self.ep_size > 1
Expand All @@ -676,10 +682,7 @@ def forward(self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta =
out = self.forward_normal(x, gate, forward_meta, topk_ids_hookfunc=topk_ids_hookfunc)

if self.reduce_results and self.tp_size > 1:
if current_platform.is_intel_hpu():
tensor_model_parallel_all_reduce_custom(out)
else:
out = tensor_model_parallel_all_reduce(out, self.tp_group)
out = tensor_model_parallel_all_reduce(out, self.tp_group)
return out

def forward_chunked_moe(
Expand Down
3 changes: 3 additions & 0 deletions fastdeploy/model_executor/load_weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,13 +269,16 @@ def get_expert_ranges(fd_config):
down_proj_scale_key = f"ernie.{prefix_layer_name}.{i}.mlp.experts.{j}.down_proj.weight_scale"

down_proj_in_scale_key = f"ernie.{prefix_layer_name}.{i}.mlp.experts.{j}.down_proj.activation_scale"
# single up_gate_proj.activation_scale for all mlp.experts
up_gate_proj_in_scale_key = f"ernie.layers.{i}.mlp.experts.up_gate_proj.activation_scale"
num_local_ffn_keys.append(up_gate_proj_key)
num_local_ffn_keys.append(down_proj_key)
num_local_ffn_keys.append(up_gate_proj_quant_key)
num_local_ffn_keys.append(down_proj_quant_key)
num_local_ffn_keys.append(up_gate_proj_scale_key)
num_local_ffn_keys.append(down_proj_scale_key)
num_local_ffn_keys.append(down_proj_in_scale_key)
num_local_ffn_keys.append(up_gate_proj_in_scale_key)

# for EP w4a8, we need all expert's activation_scale for up_gate_proj
num_experts = fd_config.model_config.moe_num_experts
Expand Down
11 changes: 10 additions & 1 deletion fastdeploy/model_executor/models/ernie4_5_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,12 @@ def load_weights(self, weights_iterator) -> None:
("attn.cache_v_scale", "cachev_matmul.activation_scale", None, None),
("attn.cache_k_zp", "cachek_matmul.activation_zero_point", None, None),
("attn.cache_v_zp", "cachev_matmul.activation_zero_point", None, None),
("act_scale", "in_scale", None, None),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

act_scale/attn.q_scale/attn.s_scale/up_gate_proj_in_scale这些分别代表什么意义呢,目前fd都以weight_scale/activation_scale 加layername去命名

Copy link
Contributor Author

@yanfeich yanfeich Jan 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

act_scale对应mlp. & mlp.shared_experts.:
down_proj.activation_scale --> down_proj.act_scale
up_gate_proj.activation_scale --> up_gate_proj.act_scale

attn.q_scale/attn.s_scale 类似 attn.cache_k_scale / attn.cache_v_scale

up_gate_proj_in_scale 对应 mlp.experts..:
experts.{exp_id}.up_gate_proj.activation_scale --> experts.up_gate_proj_in_scale
最后这个所有experts共用一个activation_scale,所以没有放在 make_expert_params_mapping 里。

("attn.q_scale", "q_matmul.in_scale", None, None),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

act_scale/attn.q_scale/attn.s_scale/up_gate_proj_in_scale这些分别代表什么意义呢,目前fd都以weight_scale/activation_scale 加layername去命名,需要讨论下规范格式

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

attention 里面的 SDPA 和 MLP / MoE 里面的 up/gate/down proj 这几个部分matmul都是用的 tensor_wise_fp8,所以他们都需要各自的activation_scale。

目前的FD只提供了 K 和 V 的activation_scale,给KV_cache用。我们SDPA在做 QKT 和 SV 两部分矩阵乘的时候,Q, K, V, S这4个都是需要的,但是Q和S又不能叫cache_{q/s}_scale,所以就只保留了attn.q_scale/attn.s_scale.

up/gate/down部分,普通的MLP和share_experts部分,FD只把activation_scale 改成了 act_scale

MoE 的 expert部分,down_proj.activation_scale 去掉exper_id后,连带着下划线一起改成了down_proj_in_scale, 与FD目前的命名规则一致。

我们的MoE up_gate部分,所有的expert共用一个activation_scale,所以把up_gate_proj.activation_scale单独放在了上面,作为up_gate_proj_in_scale

MoE部分的命名规则与 fused_moe_backend_base.py 及其他厂家一致,没有使用新的名称。只是这部分重命名规则在V1里面缺失。

("attn.s_scale", "s_matmul.in_scale", None, None),
("attn.cache_k_scale", "cachek_matmul.in_scale", None, None),
("attn.cache_v_scale", "cachev_matmul.in_scale", None, None),
("up_gate_proj_in_scale", "up_gate_proj.in_scale", None, None),
]

expert_params_mapping = []
Expand All @@ -590,7 +596,10 @@ def load_weights(self, weights_iterator) -> None:
(param, weight, exp, shard, False) for param, weight, exp, shard in general_params_mapping
] + [(param, weight, exp, shard, True) for param, weight, exp, shard in expert_params_mapping]
checkpoint_to_fd_key_fn = rename_offline_ckpt_suffix_to_fd_suffix(
fd_config=self.fd_config, ckpt_weight_suffix="quant_weight", ckpt_scale_suffix="weight_scale"
fd_config=self.fd_config,
ckpt_weight_suffix="quant_weight",
ckpt_scale_suffix="weight_scale",
ckpt_act_suffix="activation_scale",
)
params_dict = dict(self.named_parameters())

Expand Down
11 changes: 10 additions & 1 deletion fastdeploy/model_executor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,10 @@ def multi_switch_config_context(*changes):


def rename_offline_ckpt_suffix_to_fd_suffix(
fd_config, ckpt_weight_suffix: str = "quant_weight", ckpt_scale_suffix="weight_scale"
fd_config,
ckpt_weight_suffix: str = "quant_weight",
ckpt_scale_suffix="weight_scale",
ckpt_act_suffix="activation_scale",
):
"""
Create a function to rename checkpoint key suffixes for FastDeploy.
Expand All @@ -489,6 +492,10 @@ def rename_offline_ckpt_suffix_to_fd_suffix(
ckpt_weight_suffix: "weight",
ckpt_scale_suffix: "weight_scale_inv",
}
tensor_wise_fp8_suffix_map = {
ckpt_weight_suffix: "weight",
ckpt_act_suffix: "in_scale",
}
moe_quant_type = ""
dense_quant_type = ""
if fd_config.quant_config is not None:
Expand All @@ -505,6 +512,8 @@ def fn(loaded_weight_name, is_moe):
# Can be extended to other offline quantization suffixes if needed.
if (is_moe and moe_quant_type == "block_wise_fp8") or (not is_moe and dense_quant_type == "block_wise_fp8"):
fd_suffix_map = fp8_suffix_map
if (is_moe and moe_quant_type == "tensor_wise_fp8") or (not is_moe and dense_quant_type == "tensor_wise_fp8"):
fd_suffix_map = tensor_wise_fp8_suffix_map
for ckpt_suffix, fd_suffix in fd_suffix_map.items():
if re.search(rf"{ckpt_suffix}$", loaded_weight_name):
loaded_weight_name = loaded_weight_name.replace(ckpt_suffix, fd_suffix)
Expand Down
Loading
Loading