Skip to content

Conversation

@yanfeich
Copy link
Contributor

@yanfeich yanfeich commented Jan 4, 2026

Motivation

enable MoE EP for hpu with loader_v1

💡 If this PR is a Cherry Pick, the PR title needs to follow the format by adding the [Cherry-Pick] label at the very beginning and appending the original PR ID at the end. For example, [Cherry-Pick][CI] Add check trigger and logic(#5191)

💡 如若此PR是Cherry Pick,PR标题需遵循格式,在最开始加上[Cherry-Pick]标签,以及最后面加上原PR ID,例如[Cherry-Pick][CI] Add check trigger and logic(#5191)

Modifications

  • fastdeploy/model_executor/layers/moe/moe.py
    HPU calls forward_normal no matter EP or TP, and won't fall into forward_split_allgather nor forward_chunked_moe

  • fused_moe_hpu_backend.py
    change down_proj_in_scale from list to tensor.

  • hpu_model_runner.py
    list to tensor, add padding dim for 0x80 alignment request.

  • fastdeploy/model_executor/load_weight_utils.py
    needs up_gate_proj.activation_scale for EP in loader v0

  • fastdeploy/model_executor/models/ernie4_5_moe.py‎
    add Attention related activation_scale name conversions

      - self_attn.
    
loaded_weight_name checkpoint_to_fd_key_fn all_param_mapping
qkv_proj.activation_scale qkv_proj.in_scale qkv_proj.act_scale
o_proj.activation_scale o_proj.in_scale o_proj.act_scale
cachek_matmul.activation_scale cachek_matmul.in_scale attn.cache_k_scale
cachev_matmul.activation_scale cachev_matmul.in_scale attn.cache_v_scale
q_matmul.activation_scale q_matmul.in_scale attn.q_scale
s_matmul.activation_scale s_matmul.in_scale attn.s_scale
    - mlp. & mlp.shared_experts.
loaded_weight_name checkpoint_to_fd_key_fn all_param_mapping
down_proj.activation_scale down_proj.in_scale down_proj.act_scale
up_gate_proj.activation_scale up_gate_proj.in_scale up_gate_proj.act_scale
    - mlp.experts. (all experts share same activation_scale)
loaded_weight_name checkpoint_to_fd_key_fn all_param_mapping
up_gate_proj.activation_scale up_gate_proj.in_scale up_gate_proj_in_scale
    - mlp.experts.{exp_id}.
loaded_weight_name checkpoint_to_fd_key_fn all_param_mapping
experts.{exp_id}.down_proj.activation_scale experts.{exp_id}.down_proj.in_scale experts.down_proj_in_scale

Usage or Command

set enable_expert_parallel=True, and disable_sequence_parallel_moe=True, to enable HPU MoE EP.

Accuracy Tests

Checklist

  • [ Done ] Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • [ Done ] Format your code, run pre-commit before commit.
  • [ Done ] Add unit tests. Please write the reason in this PR if no unit tests.
    conducted by local tests
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

Copilot AI review requested due to automatic review settings January 4, 2026 06:34
@paddle-bot
Copy link

paddle-bot bot commented Jan 4, 2026

Thanks for your contribution!

@paddle-bot paddle-bot bot added the contributor External developers label Jan 4, 2026
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR enables MoE (Mixture of Experts) Expert Parallelism (EP) for Intel HPU by modifying the execution path and weight handling to accommodate HPU-specific requirements.

Key changes:

  • Modified MoE forward logic to route HPU through forward_normal regardless of EP/TP configuration
  • Converted down_proj_in_scale from list to tensor and added padding alignment for HPU's 0x80 byte alignment requirement
  • Added up_gate_proj.activation_scale weight loading support for EP mode

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
fastdeploy/model_executor/layers/moe/moe.py Routes HPU platform to use forward_normal path for both EP and TP modes
fastdeploy/model_executor/layers/backends/intel_hpu/moe/fused_moe_hpu_backend.py Changes down_proj_in_scale handling from list to tensor and renames apply_tp to apply
fastdeploy/worker/hpu_model_runner.py Adds alignment padding function for scales and implements early return for EP mode
fastdeploy/model_executor/load_weight_utils.py Adds up_gate_proj_in_scale_key to weight loading for EP support
examples/intel_hpu/offline_demo.py Enables EP configuration in demo script

@codecov-commenter
Copy link

codecov-commenter commented Jan 4, 2026

Codecov Report

❌ Patch coverage is 18.18182% with 9 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (develop@1aa7e82). Learn more about missing BASE report.

Files with missing lines Patch % Lines
fastdeploy/model_executor/layers/moe/moe.py 16.66% 4 Missing and 1 partial ⚠️
fastdeploy/model_executor/load_weight_utils.py 0.00% 2 Missing ⚠️
fastdeploy/model_executor/utils.py 33.33% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             develop    #5855   +/-   ##
==========================================
  Coverage           ?   67.32%           
==========================================
  Files              ?      347           
  Lines              ?    44642           
  Branches           ?     6879           
==========================================
  Hits               ?    30055           
  Misses             ?    12368           
  Partials           ?     2219           
Flag Coverage Δ
GPU 67.32% <18.18%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@yanfeich
Copy link
Contributor Author

yanfeich commented Jan 8, 2026

add @LeoZhao-Intel @fmiao2372
Please help review this patch, thanks!

Copy link

@LeoZhao-Intel LeoZhao-Intel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Collaborator

@bukejiyu bukejiyu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the quantization naming convention currently used in the open-source ERNIE 4.5 series models. We are not certain whether it fully covers your use case. If it does, we recommend aligning the naming of the quantized weights with this convention.

  1. Quantized weights: Append the suffix quant_weight to the layer name (i.e., replace weight with quant_weight), for example:
    ernie.layers.1.mlp.down_proj.quant_weight

  2. Scale for quantized weights: Append the suffix .weight_scale, for example:
    ernie.layers.1.mlp.down_proj.weight_scale

  3. Activation scale after quantization: Append the suffix .activation_scale, for example:
    ernie.layers.1.mlp.down_proj.activation_scale

  4. Smooth scale (not applicable to ERNIE 4.5T at the moment): Append the suffix smooth_scale, for example:
    ernie.layers.1.mlp.down_proj.smooth_scale

  5. Shift bias (not applicable to ERNIE 4.5T at the moment): Append the suffix shift_bias, for example:
    ernie.layers.1.mlp.down_proj.shift_bias

  6. Cache KV scale: Append the suffixes .cachek_matmul.activation_scale and .cachev_matmul.activation_scale to K and V, respectively, for example:
    ernie.layers.0.self_attn.cachek_matmul.activation_scale
    ernie.layers.0.self_attn.cachev_matmul.activation_scale

  7. Cache KV zero point: Based on item 6, replace scale with zero_point, for example:
    ernie.layers.0.self_attn.cachek_matmul.activation_zero_point

@yanfeich
Copy link
Contributor Author

yanfeich commented Jan 12, 2026 via email

("up_gate_proj", "up_proj", None, "up"),
("attn.cache_k_scale", "cachek_matmul.activation_scale", None, None),
("attn.cache_v_scale", "cachev_matmul.activation_scale", None, None),
("attn.cache_k_scale", "cachek_matmul.in_scale", None, None),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这样修改,会导致开源的ernie模型无法加载了,这块建议不修改

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

前一版本的修改,会在checkpoint_to_fd_key_fn里先统一把activation_scale改成in_scale,然后在做loaded_weight_name.replace的时候是可以按照新的cachek_matmul.in_scale识别的。
为了不修改当前实现,已经撤回这部分改动。目前版本把 checkpoint_to_fd_key_fn 里面的替换去掉了,还是使用 cachek_matmul.activation_scale。

("attn.cache_v_scale", "cachev_matmul.in_scale", None, None),
("attn.cache_k_zp", "cachek_matmul.activation_zero_point", None, None),
("attn.cache_v_zp", "cachev_matmul.activation_zero_point", None, None),
("act_scale", "in_scale", None, None),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

act_scale/attn.q_scale/attn.s_scale/up_gate_proj_in_scale这些分别代表什么意义呢,目前fd都以weight_scale/activation_scale 加layername去命名

Copy link
Contributor Author

@yanfeich yanfeich Jan 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

act_scale对应mlp. & mlp.shared_experts.:
down_proj.activation_scale --> down_proj.act_scale
up_gate_proj.activation_scale --> up_gate_proj.act_scale

attn.q_scale/attn.s_scale 类似 attn.cache_k_scale / attn.cache_v_scale

up_gate_proj_in_scale 对应 mlp.experts..:
experts.{exp_id}.up_gate_proj.activation_scale --> experts.up_gate_proj_in_scale
最后这个所有experts共用一个activation_scale,所以没有放在 make_expert_params_mapping 里。

("attn.cache_k_zp", "cachek_matmul.activation_zero_point", None, None),
("attn.cache_v_zp", "cachev_matmul.activation_zero_point", None, None),
("act_scale", "in_scale", None, None),
("attn.q_scale", "q_matmul.in_scale", None, None),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

act_scale/attn.q_scale/attn.s_scale/up_gate_proj_in_scale这些分别代表什么意义呢,目前fd都以weight_scale/activation_scale 加layername去命名,需要讨论下规范格式

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

attention 里面的 SDPA 和 MLP / MoE 里面的 up/gate/down proj 这几个部分matmul都是用的 tensor_wise_fp8,所以他们都需要各自的activation_scale。

目前的FD只提供了 K 和 V 的activation_scale,给KV_cache用。我们SDPA在做 QKT 和 SV 两部分矩阵乘的时候,Q, K, V, S这4个都是需要的,但是Q和S又不能叫cache_{q/s}_scale,所以就只保留了attn.q_scale/attn.s_scale.

up/gate/down部分,普通的MLP和share_experts部分,FD只把activation_scale 改成了 act_scale

MoE 的 expert部分,down_proj.activation_scale 去掉exper_id后,连带着下划线一起改成了down_proj_in_scale, 与FD目前的命名规则一致。

我们的MoE up_gate部分,所有的expert共用一个activation_scale,所以把up_gate_proj.activation_scale单独放在了上面,作为up_gate_proj_in_scale

MoE部分的命名规则与 fused_moe_backend_base.py 及其他厂家一致,没有使用新的名称。只是这部分重命名规则在V1里面缺失。

Copy link
Contributor

@xiaoluomi xiaoluomi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

act_scale/attn.q_scale/attn.s_scale/up_gate_proj_in_scale这些分别代表什么意义呢,目前fd都以weight_scale/activation_scale 加layername去命名,需要讨论下规范格式

@xiaoluomi
Copy link
Contributor

xiaoluomi commented Jan 14, 2026

act_scale/attn.q_scale/attn.s_scale/up_gate_proj_in_scale这些分别代表什么意义呢,目前fd都以weight_scale/activation_scale 加layername去命名,需要讨论下规范格式

attention 里面的 SDPA 和 MLP / MoE 里面的 up/gate/down proj 这几个部分matmul都是用的 tensor_wise_fp8,所以他们都需要各自的activation_scale。

目前的FD只提供了 K 和 V 的activation_scale,给KV_cache用。我们SDPA在做 QKT 和 SV 两部分矩阵乘的时候,Q, K, V, S这4个都是需要的,但是Q和S又不能叫cache_{q/s}_scale,所以就只保留了attn.q_scale/attn.s_scale.

up/gate/down部分,普通的MLP和share_experts部分,FD只把activation_scale 改成了 act_scale

MoE 的 expert部分,down_proj.activation_scale 去掉exper_id后,连带着下划线一起改成了down_proj_in_scale, 与FD目前的命名规则一致。

我们的MoE up_gate部分,所有的expert共用一个activation_scale,所以把up_gate_proj.activation_scale单独放在了上面,作为up_gate_proj_in_scale

MoE部分的命名规则与 fused_moe_backend_base.py 及其他厂家一致,没有使用新的名称。只是这部分重命名规则在V1里面缺失。

我这边注意到很多quant attn里的描述是把你这里的S attn weights 描述为P,然后做PV计算。
那我注明下这个PR里新增的scale描述:
attn.q_scale 是量化q_proj输出的Query的(若中间包含rope则为rope后的Q)量化scale/attn.s_scale是量化attn weights的量化scale(服务于attn里的量化计算)

Copy link
Collaborator

@EmmonsCurse EmmonsCurse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for HPU

@zoooo0820 zoooo0820 merged commit fbcccaa into PaddlePaddle:develop Jan 15, 2026
27 of 31 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants