diff --git a/custom_ops/xpu_ops/test/test_moe_expert_ffn.py b/custom_ops/xpu_ops/test/test_moe_expert_ffn.py index be99cdceef0..fa17d8fe967 100644 --- a/custom_ops/xpu_ops/test/test_moe_expert_ffn.py +++ b/custom_ops/xpu_ops/test/test_moe_expert_ffn.py @@ -81,16 +81,17 @@ def weight_quant_wint4(w_fp32): return w_int4, w_max.reshape([-1]) -def weight_quant(w_fp32, algo="weight_only_int8"): - if algo == "weight_only_int8": +def weight_quant(w_fp32, algo="w_channelwise_int8_a_float32"): + if algo == "w_channelwise_int8_a_float32": return weight_quant_wint8(w_fp32) - elif algo == "weight_only_int4": + elif algo == "w_channelwise_int4_a_tokenwise_int15": return weight_quant_wint4(w_fp32) else: return None, None -quant_method = "weight_only_int4" +quant_method = "w_channelwise_int4_a_tokenwise_int15" +# quant_method = "w_channelwise_int8_a_float32" print(f"quant_method={quant_method}, used_in_ep_low_latency={used_in_ep_low_latency}") ffn1_quant_w, ffn1_w_scale = weight_quant(ffn1_w, quant_method) ffn2_quant_w, ffn2_w_scale = weight_quant(ffn2_w, quant_method) @@ -127,10 +128,10 @@ def weight_dequant_wint4(w_int, w_scale): return w_fp32 -def weight_dequant(w_int, w_scale, algo="weight_only_int8"): - if algo == "weight_only_int8": +def weight_dequant(w_int, w_scale, algo="w_channelwise_int8_a_float32"): + if algo == "w_channelwise_int8_a_float32": return weight_dequant_wint8(w_int, w_scale) - elif algo == "weight_only_int4": + elif algo == "w_channelwise_int4_a_tokenwise_int15": return weight_dequant_wint4(w_int, w_scale) else: return None, None diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 9170dfb4c8a..e8c7fef782b 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -1779,15 +1779,6 @@ def postprocess(self): if not current_platform.is_cuda() and not current_platform.is_maca(): self.graph_opt_config.use_cudagraph = False logger.info("CUDAGraph currently only support on GPU!") - if self.parallel_config.use_sequence_parallel_moe and self.graph_opt_config.use_cudagraph: - if self.scheduler_config.max_num_seqs < self.parallel_config.tensor_parallel_size: - self.parallel_config.use_sequence_parallel_moe = False - logger.info( - "Warning: sequence parallel moe do not support max_num_seqs < tensor_parallel_size when cudagraph enabled. We set use_sequence_parallel_moe to False." - ) - else: - # It will hang when real batch_size < tp_size - self.graph_opt_config.filter_capture_size(tp_size=self.parallel_config.tensor_parallel_size) # adjust speculative config if self.speculative_config is not None and self.speculative_config.method == "mtp": @@ -1806,6 +1797,16 @@ def postprocess(self): else: raise NotImplementedError + if self.parallel_config.use_sequence_parallel_moe and self.graph_opt_config.use_cudagraph: + if self.scheduler_config.max_num_seqs < self.parallel_config.tensor_parallel_size: + self.parallel_config.use_sequence_parallel_moe = False + logger.info( + "Warning: sequence parallel moe do not support max_num_seqs < tensor_parallel_size when cudagraph enabled. We set use_sequence_parallel_moe to False." + ) + else: + # It will hang when real batch_size < tp_size + self.graph_opt_config.filter_capture_size(tp_size=self.parallel_config.tensor_parallel_size) + self.postprocess_devices_and_ports() def postprocess_devices_and_ports(self): diff --git a/fastdeploy/model_executor/model_loader/default_loader_v1.py b/fastdeploy/model_executor/model_loader/default_loader_v1.py index 8fb0ebf3881..352ab042042 100644 --- a/fastdeploy/model_executor/model_loader/default_loader_v1.py +++ b/fastdeploy/model_executor/model_loader/default_loader_v1.py @@ -56,8 +56,8 @@ def load_weights(self, model, fd_config: FDConfig, enable_cache: bool = False) - load_weights_from_cache(model, weights_iterator) else: model.load_weights(weights_iterator) - if fd_config.speculative_config.model_type != "mtp": - process_final_after_loading(model, fd_config) + + process_final_after_loading(model, fd_config) self.clean_memory_fragments() diff --git a/fastdeploy/model_executor/models/ernie4_5_moe.py b/fastdeploy/model_executor/models/ernie4_5_moe.py index 99af90ca583..a418fd11241 100644 --- a/fastdeploy/model_executor/models/ernie4_5_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_moe.py @@ -609,8 +609,7 @@ def load_weights(self, weights_iterator) -> None: r"\.(up_gate_proj_weight|down_proj_weight|weight|cache_k_scale|cache_v_scale)$", "", model_param_name ) process_weights_after_loading_fn(model_sublayer_name, param) - - if self.tie_word_embeddings: + if getattr(self, "tie_word_embeddings", False): self.lm_head.linear.weight.set_value( self.ernie.embed_tokens.embeddings.weight.transpose([1, 0]).astype(self.lm_head.linear.weight.dtype) ) diff --git a/fastdeploy/model_executor/models/ernie4_5_mtp.py b/fastdeploy/model_executor/models/ernie4_5_mtp.py index db8499444b2..13203684d53 100644 --- a/fastdeploy/model_executor/models/ernie4_5_mtp.py +++ b/fastdeploy/model_executor/models/ernie4_5_mtp.py @@ -16,7 +16,6 @@ from __future__ import annotations -import re from functools import partial from typing import Dict, Union @@ -354,7 +353,6 @@ def __init__(self, fd_config: FDConfig): self.ori_vocab_size = fd_config.model_config.ori_vocab_size self.lm_head = fd_config.speculative_config.sharing_model.lm_head - self.tie_word_embeddings = fd_config.model_config.tie_word_embeddings @classmethod def name(self): @@ -372,11 +370,6 @@ def set_state_dict(self, state_dict: Dict[str, Union[np.ndarray, paddle.Tensor]] and values are NumPy arrays or PaddlePaddle tensors. """ self.ernie.load_state_dict(state_dict) - # if self.tie_word_embeddings: - # self.lm_head.linear.weight.set_value( - # self.ernie.embed_tokens.embeddings.weight.transpose([1, 0])) - # else: - # self.lm_head.load_state_dict(state_dict) @paddle.no_grad() def load_weights(self, weights_iterator) -> None: @@ -386,45 +379,22 @@ def load_weights(self, weights_iterator) -> None: Args: weights_iterator (Iterator): An iterator yielding (name, weight) pairs. """ - - from fastdeploy.model_executor.utils import ( - default_weight_loader, - process_weights_after_loading, + from fastdeploy.model_executor.models.ernie4_5_moe import ( + Ernie4_5_MoeForCausalLM, + ) + from fastdeploy.model_executor.utils import remap_weight_keys + + Ernie4_5_MoeForCausalLM.load_weights( + self, + remap_weight_keys( + weights_iterator, + { + "mtp_emb_norm.0": "enorm", + "mtp_hidden_norm.0": "hnorm", + "mtp_linear_proj.0": "eh_proj.linear", + }, + ), ) - - all_param_mapping = [ - # (param_name, weight_name, expert_id, shard_id) - ("embed_tokens.embeddings", "embed_tokens", None, None), - ("lm_head.linear", "lm_head", None, None), - ("enorm", "mtp_emb_norm.0", None, None), - ("hnorm", "mtp_hidden_norm.0", None, None), - ("eh_proj.linear", "mtp_linear_proj.0", None, None), - ] - - params_dict = dict(self.named_parameters()) - shard_id = None - process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers()), self.fd_config) - for loaded_weight_name, loaded_weight in weights_iterator: - for param_name, weight_name, exp_id, shard_id in all_param_mapping: - if weight_name not in loaded_weight_name: - continue - model_param_name = loaded_weight_name.replace(weight_name, param_name) - param = params_dict[model_param_name] - shard_id = shard_id - break - else: - if loaded_weight_name not in params_dict.keys(): - continue - model_param_name = loaded_weight_name - param = params_dict[loaded_weight_name] - - # Get weight loader from parameter and set weight - weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config)) - weight_loader(param, loaded_weight) - model_sublayer_name = re.sub( - r"\.(up_gate_proj_weight|down_proj_weight|weight|cache_k_scale|cache_v_scale)$", "", model_param_name - ) - process_weights_after_loading_fn(model_sublayer_name, param) def compute_logits(self, hidden_states: paddle.Tensor): """ diff --git a/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py b/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py index 43e0505b077..ac1d3f8b26d 100644 --- a/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py @@ -735,7 +735,7 @@ def load_weights(self, weights_iterator) -> None: r"\.(up_gate_proj_weight|down_proj_weight|weight|cache_k_scale|cache_v_scale)$", "", model_param_name ) process_weights_after_loading_fn(model_sublayer_name, param) - if self.tie_word_embeddings: + if getattr(self, "tie_word_embeddings", False): self.lm_head.linear.weight.set_value( self.ernie.embed_tokens.embeddings.weight.transpose([1, 0]).astype(self.lm_head.linear.weight.dtype) ) diff --git a/fastdeploy/model_executor/models/qwen2.py b/fastdeploy/model_executor/models/qwen2.py index 3064791bcc5..b1e02aaccaa 100644 --- a/fastdeploy/model_executor/models/qwen2.py +++ b/fastdeploy/model_executor/models/qwen2.py @@ -375,7 +375,7 @@ def load_weights(self, weights_iterator) -> None: weight_loader(param, loaded_weight) model_sublayer_name = re.sub(r"\.(weight)$", "", model_param_name) process_weights_after_loading_fn(model_sublayer_name, param) - if self.tie_word_embeddings: + if getattr(self, "tie_word_embeddings", False): self.lm_head.linear.weight.set_value( self.qwen2.embed_tokens.embeddings.weight.transpose([1, 0]).astype(self.lm_head.linear.weight.dtype) ) diff --git a/fastdeploy/model_executor/utils.py b/fastdeploy/model_executor/utils.py index 8b7224eb20a..fe0fa421daa 100644 --- a/fastdeploy/model_executor/utils.py +++ b/fastdeploy/model_executor/utils.py @@ -209,6 +209,13 @@ def apply(self, weight_name): return self._map_name(weight_name) +def remap_weight_keys(weights_iterator, mapper: dict): + return ( + (next((key.replace(k, v) for k, v in mapper.items() if k in key), key), value) + for key, value in weights_iterator + ) + + def process_weights_before_loading( *, skip_prefixes: Optional[List[str]] = None, mapper: Optional[WeightsMapper] = None ): diff --git a/scripts/run_ci_metax.sh b/scripts/run_ci_metax.sh index f4331829b6f..8f9aeff2939 100644 --- a/scripts/run_ci_metax.sh +++ b/scripts/run_ci_metax.sh @@ -232,6 +232,8 @@ while IFS=, read -r file exit_code cost_time; do FAIL_COUNT=$((FAIL_COUNT + 1)) FAIL_FILES+=$(basename "$file") echo "$file" >> ${FAIL_FILE_LIST} + echo -e "\n\n+++++++++++++++++++++++++ [ $(basename "$file") ] Fail Info +++++++++++++++++++++++++\n\n" + cat ${LOG_SUBDIR}/$(basename "$file").log fi done < "$LOG_RESULT_TMP"