From 61d0f3e667ce14ee6c6369bca8d69353b60f6971 Mon Sep 17 00:00:00 2001 From: "Kazantsev, Roman" Date: Sun, 16 Nov 2025 13:09:21 +0000 Subject: [PATCH 01/13] [OpenVINO] Support Qwen3-next --- optimum/exporters/openvino/model_configs.py | 106 ++++++++++++++++++++ 1 file changed, 106 insertions(+) diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index dea24d26a9..cc8aca4eab 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -4383,3 +4383,109 @@ def inputs(self) -> Dict[str, Dict[int, str]]: if self.use_past_in_inputs: self.add_past_key_values(common_inputs, direction="inputs") return common_inputs + + +class Qwen3NextDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator): + """ + Generates dummy cache_params inputs for Zamba2 architectures. + """ + + SUPPORTED_INPUT_NAMES = ("cache_params",) + + def __init__( + self, + task: str, + normalized_config, + batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], + sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"], + **kwargs, + ): + super().__init__( + task=task, + normalized_config=normalized_config, + batch_size=batch_size, + sequence_length=sequence_length, + **kwargs, + ) + + config = normalized_config.config + self.num_full_attn_layers = config.layer_types.count("full_attention") + self.num_linear_attn_layers = config.layer_types.count("linear_attention") + self.conv_kernel_size = config.linear_conv_kernel_dim + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.head_k_dim = config.linear_key_head_dim + self.head_v_dim = config.linear_value_head_dim + self.num_v_heads = config.linear_num_value_heads + self.num_k_heads = config.linear_num_key_heads + self.num_key_value_heads = config.num_key_value_heads + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + cache_params = [] + + for idx in range(self.num_linear_attn_layers): + # (batch_size, d_inner, d_conv) + d_inner = self.num_k_heads * (2 * self.head_k_dim + self.head_v_dim * self.num_v_heads // self.num_k_heads) + conv_state_shape = ( + self.batch_size, + d_inner, + self.conv_kernel_size, + ) + conv_state = self.random_float_tensor(conv_state_shape, framework=framework, dtype=float_dtype) + cache_params.append(conv_state) + recurrent_state_shape = (self.batch_size, self.num_key_value_heads, self.head_k_dim, self.head_v_dim) + recurrent_state = self.random_float_tensor(recurrent_state_shape, framework=framework, dtype=float_dtype) + cache_params.append(recurrent_state) + + for idx in range(self.num_full_attn_layers): + kv_shape = (self.batch_size, self.num_key_value_heads, self.sequence_length, self.head_dim) + k = self.random_float_tensor(kv_shape, framework=framework, dtype=float_dtype) + v = self.random_float_tensor(kv_shape, framework=framework, dtype=float_dtype) + cache_params.append(k) + cache_params.append(v) + + return recurrent_state + + +@register_in_tasks_manager( + "qwen3_next", + *["text-generation", "text-generation-with-past"], + library_name="transformers", +) +class Qwen3NextOpenVINOConfig(Qwen3OpenVINOConfig): + DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, Qwen3NextDummyPastKeyValuesGenerator) + DUMMY_PKV_GENERATOR_CLASS = Qwen3NextDummyPastKeyValuesGenerator + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + MIN_TRANSFORMERS_VERSION = "4.57.0" + _MODEL_PATCHER = Zamba2ModelPatcher + + def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str): + if direction not in ["inputs", "outputs"]: + raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given') + + if direction == "inputs": + decoder_sequence_name = "past_sequence_length" + cache_name_prefix = "cache_params.past" + else: + decoder_sequence_name = "past_sequence_length + sequence_length" + cache_name_prefix = "cache_params.present" + + for i in range(self._normalized_config.num_layers): + # [batch_size, conv_kernel_size - 1, d_model] + inputs_or_outputs[f"{cache_name_prefix}.conv.{i}"] = {0: "batch_size"} + # [batch_size, d_state, d_model] + inputs_or_outputs[f"{cache_name_prefix}.ssm.{i}"] = {0: "batch_size"} + + for i in range(len(self._normalized_config.hybrid_layer_ids)): + inputs_or_outputs[f"{cache_name_prefix}.key.{i}"] = {0: "batch_size", 2: decoder_sequence_name} + inputs_or_outputs[f"{cache_name_prefix}.value.{i}"] = {0: "batch_size", 2: decoder_sequence_name} + + @property + def inputs(self) -> Dict[str, Dict[int, str]]: + common_inputs = { + "input_ids": {0: "batch_size", 1: "sequence_length"}, + "attention_mask": {0: "batch_size", 1: "sequence_length"}, + } + if self.use_past_in_inputs: + self.add_past_key_values(common_inputs, direction="inputs") + return common_inputs From ea6b4b3f0ce0a892a4595e9e18b85af4bbee51ab Mon Sep 17 00:00:00 2001 From: "Kazantsev, Roman" Date: Sun, 16 Nov 2025 20:58:42 +0400 Subject: [PATCH 02/13] Fix config and add base patching --- optimum/exporters/openvino/model_configs.py | 41 ++++- optimum/exporters/openvino/model_patcher.py | 157 ++++++++++++++++++++ 2 files changed, 194 insertions(+), 4 deletions(-) diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index cc8aca4eab..3d1cf1813e 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -140,6 +140,7 @@ SanaTextEncoderModelPatcher, XverseModelPatcher, Zamba2ModelPatcher, + Qwen3NextModelPatcher, ) @@ -4444,7 +4445,7 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int cache_params.append(k) cache_params.append(v) - return recurrent_state + return cache_params @register_in_tasks_manager( @@ -4457,7 +4458,7 @@ class Qwen3NextOpenVINOConfig(Qwen3OpenVINOConfig): DUMMY_PKV_GENERATOR_CLASS = Qwen3NextDummyPastKeyValuesGenerator NORMALIZED_CONFIG_CLASS = NormalizedTextConfig MIN_TRANSFORMERS_VERSION = "4.57.0" - _MODEL_PATCHER = Zamba2ModelPatcher + _MODEL_PATCHER = Qwen3NextModelPatcher def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str): if direction not in ["inputs", "outputs"]: @@ -4470,13 +4471,16 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire decoder_sequence_name = "past_sequence_length + sequence_length" cache_name_prefix = "cache_params.present" - for i in range(self._normalized_config.num_layers): + self.num_full_attn_layers = self._normalized_config.layer_types.count("full_attention") + self.num_linear_attn_layers = self._normalized_config.layer_types.count("linear_attention") + + for i in range(self.num_linear_attn_layers): # [batch_size, conv_kernel_size - 1, d_model] inputs_or_outputs[f"{cache_name_prefix}.conv.{i}"] = {0: "batch_size"} # [batch_size, d_state, d_model] inputs_or_outputs[f"{cache_name_prefix}.ssm.{i}"] = {0: "batch_size"} - for i in range(len(self._normalized_config.hybrid_layer_ids)): + for i in range(self.num_full_attn_layers): inputs_or_outputs[f"{cache_name_prefix}.key.{i}"] = {0: "batch_size", 2: decoder_sequence_name} inputs_or_outputs[f"{cache_name_prefix}.value.{i}"] = {0: "batch_size", 2: decoder_sequence_name} @@ -4489,3 +4493,32 @@ def inputs(self) -> Dict[str, Dict[int, str]]: if self.use_past_in_inputs: self.add_past_key_values(common_inputs, direction="inputs") return common_inputs + + def generate_dummy_inputs(self, framework: str = "pt", **kwargs): + # need to override `generate_dummy_inputs` since mamba model has other states: ssm_states and conv_states + # which we separate and call them as past_ssm_states and past_conv_states + dummy_inputs_generators = self._create_dummy_input_generator_classes(**kwargs) + + dummy_inputs = {} + input_names = [key for key in self.inputs.keys() if not key.startswith("cache_params")] + if self.use_past_in_inputs: + input_names.extend(["cache_params"]) + + for input_name in input_names: + input_was_inserted = False + for dummy_input_gen in dummy_inputs_generators: + if dummy_input_gen.supports_input(input_name): + dummy_inputs[input_name] = self.overwrite_shape_and_generate_input( + dummy_input_gen, + input_name, + framework, + input_shapes=kwargs, + ) + input_was_inserted = True + break + if not input_was_inserted: + raise RuntimeError( + f'Could not generate dummy input for "{input_name}". Try adding a proper dummy input generator to the model ONNX config.' + ) + + return dummy_inputs diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 67efa441a0..7ba6989e4f 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -6954,3 +6954,160 @@ def __exit__(self, exc_type, exc_value, traceback): else: continue mamba_layer.forward = mamba_layer._orig_forward + + +class Qwen3NextModelPatcher(ModelPatcher): + def __init__( + self, + config: "OnnxConfig", + model: "PreTrainedModel", + model_kwargs: Optional[Dict[str, Any]] = None, + ): + from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextDynamicCache + + super().__init__(config, model, model_kwargs) + + class Qwen3NextDynamicCacheWrap(Qwen3NextDynamicCache): + def __init__(self, config, conv_states, recurrent_states, key_cache, value_cache): + # Call parent constructor with all required arguments + super().__init__(config=config) + + self.conv_states = conv_states + self.recurrent_states = recurrent_states + self.key_cache = key_cache + self.value_cache = value_cache + self.full_attn_mapping = {} + self.linear_attn_mapping = {} + full_attn_layer_idx = 0 + linear_attn_layer_idx = 0 + for i in range(len(config.layer_types)): + if self.layer_types[i] == "full_attention": + self.full_attn_mapping[i] = full_attn_layer_idx + full_attn_layer_idx += 1 + elif self.layer_types[i] == "linear_attention": + self.linear_attn_mapping[i] = linear_attn_layer_idx + linear_attn_layer_idx += 1 + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + # map layer_idx to key_cache (value_cache) idx + layer_idx = self.full_attn_mapping[layer_idx] + if self.key_cache[layer_idx] is None: + self.key_cache[layer_idx] = key_states + self.value_cache[layer_idx] = value_states + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # take any layer that contains cache and not empty tensor + layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx + layer_idx = self.full_attn_mapping[layer_idx] + if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx] is None: + return 0 + return self.key_cache[layer_idx].shape[-2] + + @property + def has_previous_state(self): + """We have a previous state if the last linear (conv) layer was already updated.""" + layer_idx = self.linear_attn_mapping[self.last_linear_layer] + return self.conv_states[layer_idx] is not None + + # the patch is needed to include KV-cache, Conv, and SSM states in the inputs and outputs. + def patched_forward( + input_ids, + attention_mask=None, + cache_params=None, + ): + num_full_attn_layers = self.real_config._config.layer_types.count("full_attention") + num_linear_attn_layers = self.real_config._config.layer_types.count("linear_attention") + + use_cache = False + wrapped_cache_params = None + if cache_params is not None: + use_cache = True + conv_states = [] + recurrent_states = [] + key_cache = [] + value_cache = [] + + # decouple ssm_states, conv_states, keys and values from cache_params + for idx in range(num_linear_attn_layers): + conv_states.append(cache_params[2 * idx]) + recurrent_states.append(cache_params[2 * idx + 1]) + + for idx in range(num_full_attn_layers): + key_cache.append(cache_params[2 * num_linear_attn_layers + 2 * idx]) + value_cache.append(cache_params[2 * num_linear_attn_layers + 2 * idx + 1]) + + wrapped_cache_params = Qwen3NextDynamicCacheWrap( + self.real_config._config, conv_states, recurrent_states, key_cache, value_cache + ) + + causal_lm_output = self.model_orig_forward( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=wrapped_cache_params, + use_cache=use_cache, + ) + outputs = { + "logits": causal_lm_output.logits, + } + + if use_cache: + past_key_values = causal_lm_output.past_key_values + # unwrap Zamba2HybridDynamicCache object + present_key_values = [] + for idx in range(num_linear_attn_layers): + present_key_values.append(past_key_values.conv_states[idx]) + present_key_values.append(past_key_values.recurrent_states[idx]) + + for idx in range(num_full_attn_layers): + present_key_values.append(past_key_values.key_cache[idx]) + present_key_values.append(past_key_values.value_cache[idx]) + + outputs["present_key_values"] = present_key_values + + return outputs + + self.patched_forward = patched_forward + self.model_orig_forward = self.orig_forward + self.orig_forward = patched_forward + + def __enter__(self): + from transformers.models.zamba2.modeling_zamba2 import Zamba2HybridLayer, Zamba2MambaDecoderLayer + + super().__enter__() + setattr(self._model, self.orig_forward_name, self.patched_forward) + + #for layer in self._model.model.layers: + # if isinstance(layer, Zamba2MambaDecoderLayer): + # mamba_layer = layer.mamba + # elif isinstance(layer, Zamba2HybridLayer): + # mamba_layer = layer.mamba_decoder.mamba + # else: + # continue + # mamba_layer._orig_forward = mamba_layer.forward + # mamba_layer.forward = types.MethodType(zamba2_mamba_mixer, mamba_layer) + + def __exit__(self, exc_type, exc_value, traceback): + from transformers.models.zamba2.modeling_zamba2 import Zamba2HybridLayer, Zamba2MambaDecoderLayer + + super().__exit__(exc_type, exc_value, traceback) + setattr(self._model, self.orig_forward_name, self.model_orig_forward) + #for layer in self._model.model.layers: + # if isinstance(layer, Zamba2MambaDecoderLayer): + # mamba_layer = layer.mamba + # elif isinstance(layer, Zamba2HybridLayer): + # mamba_layer = layer.mamba_decoder.mamba + # else: + # continue + # mamba_layer.forward = mamba_layer._orig_forward From 7e37aae5ad79c504ee66133945ce2dcf8587a36f Mon Sep 17 00:00:00 2001 From: "Kazantsev, Roman" Date: Mon, 17 Nov 2025 15:21:14 +0400 Subject: [PATCH 03/13] Extend patching --- optimum/exporters/openvino/model_patcher.py | 171 +++++++++++++++++--- optimum/exporters/openvino/utils.py | 2 +- 2 files changed, 151 insertions(+), 22 deletions(-) diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 7ba6989e4f..89592dc582 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -6956,6 +6956,141 @@ def __exit__(self, exc_type, exc_value, traceback): mamba_layer.forward = mamba_layer._orig_forward +def qwen3_next_gated_delta_net_forward( + self, + hidden_states: torch.Tensor, + cache_params = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, +): + def apply_mask_to_padding_states(hidden_states, attention_mask): + """ + Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66 + """ + # NOTE: attention mask is a 2D boolean tensor + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + dtype = hidden_states.dtype + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return hidden_states + + hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) + + # Set up dimensions for reshapes later + batch_size, seq_len, _ = hidden_states.shape + + # distinguish prefill and decoding stage + dtype = hidden_states.dtype + # use_precomputed_states - is_decoding flag + use_precomputed_states = torch.tensor(seq_len == 1).to(dtype) + + # getting projected states from cache if it exists + layer_idx = None + if cache_params is not None: + layer_idx = cache_params.linear_attn_mapping[self.layer_idx] + conv_state = cache_params.conv_states[layer_idx] + recurrent_state = cache_params.recurrent_states[layer_idx] + + projected_states_qkvz = self.in_proj_qkvz(hidden_states) + projected_states_ba = self.in_proj_ba(hidden_states) + query, key, value, z, b, a = self.fix_query_key_value_ordering(projected_states_qkvz, projected_states_ba) + query, key, value = (x.reshape(x.shape[0], x.shape[1], -1) for x in (query, key, value)) + + mixed_qkv = torch.cat((query, key, value), dim=-1) + mixed_qkv = mixed_qkv.transpose(1, 2) + + if cache_params is not None: + # 2. Convolution sequence transformation + # NOTE: the conv state is updated in `causal_conv1d_update` + conv_state_dec = conv_state.clone() + mixed_qkv_dec = mixed_qkv.clone() + mixed_qkv_dec = self.causal_conv1d_update( + mixed_qkv_dec, + conv_state_dec, + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + self.activation, + ) + + conv_state_prefill = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)) + conv_state = conv_state_dec * use_precomputed_states + conv_state_prefill * (1.0 - use_precomputed_states) + mixed_qkv = mixed_qkv_dec * use_precomputed_states + mixed_qkv * (1.0 - use_precomputed_states) + cache_params.conv_states[layer_idx] = conv_state + + if self.causal_conv1d_fn is not None: + mixed_qkv = self.causal_conv1d_fn( + x=mixed_qkv, + weight=self.conv1d.weight.squeeze(1), + bias=self.conv1d.bias, + activation=self.activation, + seq_idx=None, + ) + else: + mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len]) + + + mixed_qkv = mixed_qkv.transpose(1, 2) + query, key, value = torch.split( + mixed_qkv, + [ + self.key_dim, + self.key_dim, + self.value_dim, + ], + dim=-1, + ) + query = query.reshape(query.shape[0], query.shape[1], -1, self.head_k_dim) + key = key.reshape(key.shape[0], key.shape[1], -1, self.head_k_dim) + value = value.reshape(value.shape[0], value.shape[1], -1, self.head_v_dim) + + beta = b.sigmoid() + # If the model is loaded in fp16, without the .float() here, A might be -inf + g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) + if self.num_v_heads // self.num_k_heads > 1: + query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) + key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) + + core_attn_out_prefill, last_recurrent_state_prefill = self.chunk_gated_delta_rule( + query, + key, + value, + g=g, + beta=beta, + initial_state=None, + output_final_state=cache_params is not None, + use_qk_l2norm_in_kernel=True, + ) + + core_attn_out_dec, last_recurrent_state_dec = self.recurrent_gated_delta_rule( + query, + key, + value, + g=g, + beta=beta, + initial_state=recurrent_state, + output_final_state=cache_params is not None, + use_qk_l2norm_in_kernel=True, + ) + + core_attn_out = core_attn_out_dec * use_precomputed_states + core_attn_out_prefill * (1.0 - use_precomputed_states) + last_recurrent_state = last_recurrent_state_dec * use_precomputed_states + last_recurrent_state_prefill * (1.0 - use_precomputed_states) + + # Update cache + if cache_params is not None: + cache_params.recurrent_states[layer_idx] = last_recurrent_state + + z_shape_og = z.shape + # reshape input data into 2D tensor + core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) + z = z.reshape(-1, z.shape[-1]) + core_attn_out = self.norm(core_attn_out, z) + core_attn_out = core_attn_out.reshape(z_shape_og) + core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1) + + output = self.out_proj(core_attn_out) + return output + + class Qwen3NextModelPatcher(ModelPatcher): def __init__( self, @@ -7083,31 +7218,25 @@ def patched_forward( self.orig_forward = patched_forward def __enter__(self): - from transformers.models.zamba2.modeling_zamba2 import Zamba2HybridLayer, Zamba2MambaDecoderLayer - super().__enter__() setattr(self._model, self.orig_forward_name, self.patched_forward) - #for layer in self._model.model.layers: - # if isinstance(layer, Zamba2MambaDecoderLayer): - # mamba_layer = layer.mamba - # elif isinstance(layer, Zamba2HybridLayer): - # mamba_layer = layer.mamba_decoder.mamba - # else: - # continue - # mamba_layer._orig_forward = mamba_layer.forward - # mamba_layer.forward = types.MethodType(zamba2_mamba_mixer, mamba_layer) + for idx, decoder_layer in enumerate(self._model.model.layers): + layer_type = self._model.model.config.layer_types[idx] + if layer_type == "linear_attention": + linear_attn_layer = decoder_layer.linear_attn + else: + continue + linear_attn_layer._orig_forward = linear_attn_layer.forward + linear_attn_layer.forward = types.MethodType(qwen3_next_gated_delta_net_forward, linear_attn_layer) def __exit__(self, exc_type, exc_value, traceback): - from transformers.models.zamba2.modeling_zamba2 import Zamba2HybridLayer, Zamba2MambaDecoderLayer - super().__exit__(exc_type, exc_value, traceback) setattr(self._model, self.orig_forward_name, self.model_orig_forward) - #for layer in self._model.model.layers: - # if isinstance(layer, Zamba2MambaDecoderLayer): - # mamba_layer = layer.mamba - # elif isinstance(layer, Zamba2HybridLayer): - # mamba_layer = layer.mamba_decoder.mamba - # else: - # continue - # mamba_layer.forward = mamba_layer._orig_forward + for idx, decoder_layer in enumerate(self._model.model.layers): + layer_type = self._model.model.config.layer_types[idx] + if layer_type == "linear_attention": + linear_attn_layer = decoder_layer.linear_attn + else: + continue + linear_attn_layer.forward = linear_attn_layer._orig_forward diff --git a/optimum/exporters/openvino/utils.py b/optimum/exporters/openvino/utils.py index 80228bbca7..319bf9a2c6 100644 --- a/optimum/exporters/openvino/utils.py +++ b/optimum/exporters/openvino/utils.py @@ -237,7 +237,7 @@ def get_submodels(model): "minicpmo", ] -SSM_MODELS = ["mamba", "falcon_mamba", "zamba2"] +SSM_MODELS = ["mamba", "falcon_mamba", "zamba2", "qwen3_next"] def save_config(config, save_dir): From 8bc1c5a11501c34b69cffdd4503deaf8441a9a0a Mon Sep 17 00:00:00 2001 From: "Kazantsev, Roman" Date: Tue, 18 Nov 2025 09:20:32 +0400 Subject: [PATCH 04/13] Initial patching for linear attention --- optimum/exporters/openvino/model_patcher.py | 149 ++++++++++++++++++++ 1 file changed, 149 insertions(+) diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 89592dc582..56e8e2970f 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -6956,6 +6956,150 @@ def __exit__(self, exc_type, exc_value, traceback): mamba_layer.forward = mamba_layer._orig_forward +def patched_chunk_gated_delta_rule( + query, + key, + value, + g, + beta, + chunk_size=64, + initial_state=None, + output_final_state=False, + use_qk_l2norm_in_kernel=False, +): + def l2norm(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6): + """This function is intended to align with the l2norm implementation in the FLA library.""" + inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps) + return x * inv_norm + + initial_dtype = query.dtype + if use_qk_l2norm_in_kernel: + query = l2norm(query, dim=-1, eps=1e-6) + key = l2norm(key, dim=-1, eps=1e-6) + query, key, value, beta, g = [ + x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g) + ] + + batch_size, num_heads, sequence_length, k_head_dim = key.shape + v_head_dim = value.shape[-1] + pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size + query = F.pad(query, (0, 0, 0, pad_size)) + key = F.pad(key, (0, 0, 0, pad_size)) + value = F.pad(value, (0, 0, 0, pad_size)) + beta = F.pad(beta, (0, pad_size)) + g = F.pad(g, (0, pad_size)) + total_sequence_length = sequence_length + pad_size + scale = 1 / (query.shape[-1] ** 0.5) + query = query * scale + + v_beta = value * beta.unsqueeze(-1) + k_beta = key * beta.unsqueeze(-1) + # reshape to chunks + query, key, value, k_beta, v_beta = [ + x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) for x in (query, key, value, k_beta, v_beta) + ] + g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size) + mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=0) + + # chunk decay + g = g.cumsum(dim=-1) + decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril() + attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) + + N = attn.size(-1) + i = torch.arange(N, device=attn.device).view(-1, 1) + j = torch.arange(N, device=attn.device).view(1, -1) + lower_mask = j < i # True where column < row + row = attn * lower_mask # broadcast over batch dims + idx = torch.arange(N, device=attn.device) + i_idx = idx.view(N, 1, 1) + j_idx = idx.view(1, N, 1) + k_idx = idx.view(1, 1, N) + mask = (k_idx < i_idx) & (k_idx < j_idx) + prod = row.unsqueeze(-1) * attn.unsqueeze(-2) + prod = prod * mask + upd = prod.sum(-1) + attn = torch.where(lower_mask, row + upd, attn) + + attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) + value = attn @ v_beta + k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) + last_recurrent_state = ( + torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value) + if initial_state is None + else initial_state.to(value) + ) + core_attn_out = torch.zeros_like(value) + mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=1) + + # for each chunk + for i in range(0, total_sequence_length // chunk_size): + q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i] + attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0) + v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state + v_new = v_i - v_prime + attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state + core_attn_out[:, :, i] = attn_inter + attn @ v_new + last_recurrent_state = ( + last_recurrent_state * g[:, :, i, -1, None, None].exp() + + (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(-1, -2) @ v_new + ) + + if not output_final_state: + last_recurrent_state = None + core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1]) + core_attn_out = core_attn_out[:, :, :sequence_length] + core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) + return core_attn_out, last_recurrent_state + + +def patched_recurrent_gated_delta_rule( + query, key, value, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel=False +): + def l2norm(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6): + """This function is intended to align with the l2norm implementation in the FLA library.""" + inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps) + return x * inv_norm + + initial_dtype = query.dtype + if use_qk_l2norm_in_kernel: + query = l2norm(query, dim=-1, eps=1e-6) + key = l2norm(key, dim=-1, eps=1e-6) + query, key, value, beta, g = [ + x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g) + ] + + batch_size, num_heads, sequence_length, k_head_dim = key.shape + v_head_dim = value.shape[-1] + scale = 1 / (query.shape[-1] ** 0.5) + query = query * scale + + core_attn_out = torch.zeros(batch_size, num_heads, sequence_length, v_head_dim).to(value) + last_recurrent_state = ( + torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value) + if initial_state is None + else initial_state.to(value) + ) + + for i in range(sequence_length): + q_t = query[:, :, i] + k_t = key[:, :, i] + v_t = value[:, :, i] + g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1) + beta_t = beta[:, :, i].unsqueeze(-1) + + last_recurrent_state = last_recurrent_state * g_t + kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) + delta = (v_t - kv_mem) * beta_t + last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) + core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2) + + if not output_final_state: + last_recurrent_state = None + core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) + return core_attn_out, last_recurrent_state + + def qwen3_next_gated_delta_net_forward( self, hidden_states: torch.Tensor, @@ -7229,6 +7373,9 @@ def __enter__(self): continue linear_attn_layer._orig_forward = linear_attn_layer.forward linear_attn_layer.forward = types.MethodType(qwen3_next_gated_delta_net_forward, linear_attn_layer) + linear_attn_layer._orig_chunk_gated_delta_rule = linear_attn_layer.chunk_gated_delta_rule + linear_attn_layer.chunk_gated_delta_rule = patched_chunk_gated_delta_rule + linear_attn_layer._orig_recurrent_gated_delta_rule = linear_attn_layer.recurrent_gated_delta_rule def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) @@ -7240,3 +7387,5 @@ def __exit__(self, exc_type, exc_value, traceback): else: continue linear_attn_layer.forward = linear_attn_layer._orig_forward + linear_attn_layer.chunk_gated_delta_rule = linear_attn_layer._orig_chunk_gated_delta_rule + linear_attn_layer.recurrent_gated_delta_rule = linear_attn_layer._orig_recurrent_gated_delta_rule From 26a4b658e2e59ce4096e3847198435738f262f87 Mon Sep 17 00:00:00 2001 From: "Kazantsev, Roman" Date: Tue, 18 Nov 2025 21:14:24 +0400 Subject: [PATCH 05/13] Patch recurrent gated delta rule --- optimum/exporters/openvino/model_patcher.py | 24 ++++++++++----------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 56e8e2970f..ce9fb7beee 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -7081,18 +7081,17 @@ def l2norm(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6): else initial_state.to(value) ) - for i in range(sequence_length): - q_t = query[:, :, i] - k_t = key[:, :, i] - v_t = value[:, :, i] - g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1) - beta_t = beta[:, :, i].unsqueeze(-1) - - last_recurrent_state = last_recurrent_state * g_t - kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) - delta = (v_t - kv_mem) * beta_t - last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) - core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2) + q_t = query[:, :, 0] + k_t = key[:, :, 0] + v_t = value[:, :, 0] + g_t = g[:, :, 0].exp().unsqueeze(-1).unsqueeze(-1) + beta_t = beta[:, :, 0].unsqueeze(-1) + + last_recurrent_state = last_recurrent_state * g_t + kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) + delta = (v_t - kv_mem) * beta_t + last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) + core_attn_out[:, :, 0] = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2) if not output_final_state: last_recurrent_state = None @@ -7376,6 +7375,7 @@ def __enter__(self): linear_attn_layer._orig_chunk_gated_delta_rule = linear_attn_layer.chunk_gated_delta_rule linear_attn_layer.chunk_gated_delta_rule = patched_chunk_gated_delta_rule linear_attn_layer._orig_recurrent_gated_delta_rule = linear_attn_layer.recurrent_gated_delta_rule + linear_attn_layer.recurrent_gated_delta_rule = patched_recurrent_gated_delta_rule def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) From a0e8d3c0f31323ac9788a7b432a04d32eca83c6a Mon Sep 17 00:00:00 2001 From: "Kazantsev, Roman" Date: Thu, 20 Nov 2025 21:07:59 +0400 Subject: [PATCH 06/13] Use module extension for conversion of chunked_attention_cell --- optimum/exporters/openvino/convert.py | 5 + optimum/exporters/openvino/model_patcher.py | 232 +++++++++++++++++++- 2 files changed, 225 insertions(+), 12 deletions(-) diff --git a/optimum/exporters/openvino/convert.py b/optimum/exporters/openvino/convert.py index 2bb53334b7..89f3b8449a 100644 --- a/optimum/exporters/openvino/convert.py +++ b/optimum/exporters/openvino/convert.py @@ -415,11 +415,16 @@ def ts_patched_forward(*args, **kwargs): __make_16bit_traceable(model) check_dummy_inputs_are_allowed(model, dummy_inputs) input_info = _get_input_info(model, config, dummy_inputs) + conversion_extensions = getattr(patcher, "conversion_extensions", []) + module_extensions = getattr(patcher, "module_extensions", None) + if module_extensions is not None: + ts_decoder_kwargs["module_extensions"] = module_extensions ts_decoder = TorchScriptPythonDecoder(model, example_input=dummy_inputs, **ts_decoder_kwargs) ov_model = convert_model( ts_decoder, example_input=dummy_inputs, input=[(item.shape, item.type) for item in input_info], + extension=conversion_extensions, ) ov_model.validate_nodes_and_infer_types() # TODO: remove as unnecessary validation? diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index ce9fb7beee..81d2884337 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -6957,6 +6957,7 @@ def __exit__(self, exc_type, exc_value, traceback): def patched_chunk_gated_delta_rule( + self, query, key, value, @@ -7029,21 +7030,33 @@ def l2norm(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6): if initial_state is None else initial_state.to(value) ) - core_attn_out = torch.zeros_like(value) + #core_attn_out = torch.zeros_like(value) mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=1) # for each chunk - for i in range(0, total_sequence_length // chunk_size): - q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i] - attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0) - v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state - v_new = v_i - v_prime - attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state - core_attn_out[:, :, i] = attn_inter + attn @ v_new - last_recurrent_state = ( - last_recurrent_state * g[:, :, i, -1, None, None].exp() - + (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(-1, -2) @ v_new - ) + #for i in range(0, total_sequence_length // chunk_size): + # q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i] + # attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0) + # v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state + # v_new = v_i - v_prime + # attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state + # core_attn_out[:, :, i] = attn_inter + attn @ v_new + # last_recurrent_state = ( + # last_recurrent_state * g[:, :, i, -1, None, None].exp() + # + (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(-1, -2) @ v_new + # ) + num_chunks = total_sequence_length // chunk_size + core_attn_out, last_recurrent_state = self.chunked_attention_cell( + query, + key, + value, + decay_mask, + mask, + k_cumdecay, + g, + last_recurrent_state, + num_chunks + ) if not output_final_state: last_recurrent_state = None @@ -7194,6 +7207,7 @@ def apply_mask_to_padding_states(hidden_states, attention_mask): key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) core_attn_out_prefill, last_recurrent_state_prefill = self.chunk_gated_delta_rule( + self, query, key, value, @@ -7234,6 +7248,190 @@ def apply_mask_to_padding_states(hidden_states, attention_mask): return output +class ChunkedAttentionCell(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward( + self, + query, # (B, H, T, D) + key, # (B, H, T, D) + value, # (B, H, T, D) + decay_mask, # (B, H, T) + mask, # (B, H, D, D) or broadcastable + k_cumdecay, # (B, H, T, D, D) + g, # (B, H, T, G) + last_recurrent_state, # (B, H, D, D) + num_chunks # int + ): + core_attn_out = torch.zeros_like(value) + + # Loop over chunks using total_sequence_length instead of T + for i in range(0, num_chunks): + q_i = query[:, :, i] # (B, H, D) + k_i = key[:, :, i] + v_i = value[:, :, i] + dec_i = decay_mask[:, :, i] + + # attention + attn = (q_i @ k_i.transpose(-1, -2)) * dec_i + attn = attn.masked_fill(mask, 0) + + # recurrent decay contribution + v_prime = k_cumdecay[:, :, i] @ last_recurrent_state + v_new = v_i - v_prime + + # intermediate attention term + g_i = g[:, :, i] # (B, H, G) + attn_inter = (q_i * g_i[..., None].exp()) @ last_recurrent_state + + # final output + core_attn_out[:, :, i] = attn_inter + attn @ v_new + + # update recurrent state + g_last = g[:, :, i, -1] # (B, H) + decay_factor = g_last[..., None, None].exp() + + g_diff = (g_last[..., None] - g_i).exp() # (B, H, G) + update_term = (k_i * g_diff[..., None]).transpose(-1, -2) @ v_new + + last_recurrent_state = ( + last_recurrent_state * decay_factor + update_term + ) + + return core_attn_out, last_recurrent_state + + +def convert_chunked_attention_cell(context): + import openvino.opset14 as ops + + # context.get_input(0) + query = context.get_input(0) + key = context.get_input(1) + value = context.get_input(2) + decay_mask = context.get_input(3) + mask = context.get_input(4) + k_cumdecay = context.get_input(5) + g = context.get_input(6) + last_recurrent_state = context.get_input(7) + num_chunks = context.get_input(8) + + # ------------------------------------------------------------ + # Create Loop node + # ------------------------------------------------------------ + num_chunks = ops.convert(num_chunks, "i32") + loop = ops.loop( + num_chunks, + ops.constant(True, dtype="bool") + ) + + # ============================================================ + # 1) CREATE LOOP BODY + # ============================================================ + + body = loop.body + + # Loop body has two implicit parameters: + # body.get_iter_value(): iteration counter i + # body.get_cond_value(): loop condition + iter_i = body.iter_value + cond_in = body.cond_value + + # ------------------------------------------------------------ + # Create slicing for each input (slice index = iter_i) + # ------------------------------------------------------------ + + # shape info + # query: (B,H,T,D) + # slice at dim 2 => one vector per iteration + q_i = ops.gather_nd(query, ops.stack([ops.full_like(iter_i, 0), # batch dim idx = ':' + ops.full_like(iter_i, 0), # head dim idx = ':' + iter_i], axis=0), batch_dims=0) + + k_i = ops.gather_nd(key, ops.stack([ops.full_like(iter_i, 0), + ops.full_like(iter_i, 0), + iter_i], axis=0), batch_dims=0) + + v_i = ops.gather_nd(value, ops.stack([ops.full_like(iter_i, 0), + ops.full_like(iter_i, 0), + iter_i], axis=0), batch_dims=0) + + dec_i = ops.gather_nd(decay_mask, ops.stack([ops.full_like(iter_i, 0), + ops.full_like(iter_i, 0), + iter_i], axis=0), batch_dims=0) + + kcum_i = ops.gather_nd(k_cumdecay, ops.stack([ops.full_like(iter_i, 0), + ops.full_like(iter_i, 0), + iter_i], axis=0), batch_dims=0) + + g_i = ops.gather_nd(g, ops.stack([ops.full_like(iter_i, 0), + ops.full_like(iter_i, 0), + iter_i], axis=0), batch_dims=0) + + # Get last recurrent state (loop-carried variable) + last_state_in = loop.add_loop_carried_state(last_recurrent_state) + + # ============================================================ + # 2) BODY COMPUTATION (one iteration) + # ============================================================ + + # ---- attn = (q_i @ k_i.T) * dec_i ---- + k_i_T = ops.transpose(k_i, [0, 1, 3, 2]) + attn = ops.matmul(q_i, k_i_T) + attn = ops.multiply(attn, dec_i) + + # masking + attn = ops.select(mask, + ops.constant(0, attn.get_element_type()), + attn) + + # ---- v_prime = k_cumdecay @ last_state ---- + v_prime = ops.matmul(kcum_i, last_state_in) + + v_new = ops.subtract(v_i, v_prime) + + # ---- attn_inter = (q_i * exp(g_i[...,None])) @ last_state ---- + g_expand = ops.unsqueeze(g_i, -1) + g_exp = ops.exp(g_expand) + q_scaled = ops.multiply(q_i, g_exp) + attn_inter = ops.matmul(q_scaled, last_state_in) + + # ---- final output for this iteration ---- + final_attn = ops.add(attn_inter, ops.matmul(attn, v_new)) + + # output accumulation (via Loop's "concat outputs") + loop.add_concat_output(final_attn) + + # ---- update last recurrent state ---- + g_last = ops.gather(g_i, ops.constant([g_i.get_shape()[3] - 1], dtype="int64"), axis=3) + + # expand dims + g_last_e = ops.unsqueeze(ops.unsqueeze(g_last, -1), -1) + decay_factor = ops.exp(g_last_e) + + g_diff = ops.exp(ops.subtract(ops.unsqueeze(g_last, -1), g_i)) + update_term = ops.matmul( + ops.transpose(ops.multiply(k_i, ops.unsqueeze(g_diff, -1)), [0,1,3,2]), + v_new + ) + + next_state = ops.add(ops.multiply(last_state_in, decay_factor), update_term) + + # register updated state + loop.set_loop_carried_state(last_state_in, next_state) + + # condition for next iteration stays True + body.set_conditions(ops.constant(True, dtype="bool")) + + # ============================================================ + # 3) BUILD LOOP OUTPUTS + # ============================================================ + concat_attn_out = loop.get_iter_value(0) # concatenated final_attn + final_state_out = loop.get_iter_value(1) # final recurrent state + + return concat_attn_out, final_state_out + + class Qwen3NextModelPatcher(ModelPatcher): def __init__( self, @@ -7242,6 +7440,7 @@ def __init__( model_kwargs: Optional[Dict[str, Any]] = None, ): from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextDynamicCache + from openvino.frontend.pytorch import ModuleExtension, ConversionExtension super().__init__(config, model, model_kwargs) @@ -7360,10 +7559,17 @@ def patched_forward( self.model_orig_forward = self.orig_forward self.orig_forward = patched_forward + self.module_extensions = { + ChunkedAttentionCell: ModuleExtension(ChunkedAttentionCell, "ChunkedAttentionCellOp"), + } + self.conversion_extensions = [ConversionExtension("ChunkedAttentionCellOp", convert_chunked_attention_cell)] + def __enter__(self): super().__enter__() setattr(self._model, self.orig_forward_name, self.patched_forward) + chunked_attention_cell_list = [] + for idx, decoder_layer in enumerate(self._model.model.layers): layer_type = self._model.model.config.layer_types[idx] if layer_type == "linear_attention": @@ -7376,6 +7582,8 @@ def __enter__(self): linear_attn_layer.chunk_gated_delta_rule = patched_chunk_gated_delta_rule linear_attn_layer._orig_recurrent_gated_delta_rule = linear_attn_layer.recurrent_gated_delta_rule linear_attn_layer.recurrent_gated_delta_rule = patched_recurrent_gated_delta_rule + linear_attn_layer.chunked_attention_cell = ChunkedAttentionCell() + chunked_attention_cell_list.append(linear_attn_layer.chunked_attention_cell) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) From 486a4f8d3d17516492b0a583a2172bac3dc0a459 Mon Sep 17 00:00:00 2001 From: "Kazantsev, Roman" Date: Sun, 23 Nov 2025 17:33:06 +0400 Subject: [PATCH 07/13] Implement conversion extension for chunked gated delta rule cell --- optimum/exporters/openvino/model_patcher.py | 209 ++++++++++---------- 1 file changed, 102 insertions(+), 107 deletions(-) diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 81d2884337..9b67c45ca9 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -7046,6 +7046,7 @@ def l2norm(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6): # + (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(-1, -2) @ v_new # ) num_chunks = total_sequence_length // chunk_size + core_attn_out, last_recurrent_state = self.chunked_attention_cell( query, key, @@ -7304,6 +7305,8 @@ def forward( def convert_chunked_attention_cell(context): import openvino.opset14 as ops + import openvino as ov + import numpy as np # context.get_input(0) query = context.get_input(0) @@ -7314,122 +7317,114 @@ def convert_chunked_attention_cell(context): k_cumdecay = context.get_input(5) g = context.get_input(6) last_recurrent_state = context.get_input(7) - num_chunks = context.get_input(8) - - # ------------------------------------------------------------ - # Create Loop node - # ------------------------------------------------------------ - num_chunks = ops.convert(num_chunks, "i32") - loop = ops.loop( - num_chunks, - ops.constant(True, dtype="bool") - ) - - # ============================================================ - # 1) CREATE LOOP BODY - # ============================================================ - - body = loop.body - - # Loop body has two implicit parameters: - # body.get_iter_value(): iteration counter i - # body.get_cond_value(): loop condition - iter_i = body.iter_value - cond_in = body.cond_value - - # ------------------------------------------------------------ - # Create slicing for each input (slice index = iter_i) - # ------------------------------------------------------------ - - # shape info - # query: (B,H,T,D) - # slice at dim 2 => one vector per iteration - q_i = ops.gather_nd(query, ops.stack([ops.full_like(iter_i, 0), # batch dim idx = ':' - ops.full_like(iter_i, 0), # head dim idx = ':' - iter_i], axis=0), batch_dims=0) - - k_i = ops.gather_nd(key, ops.stack([ops.full_like(iter_i, 0), - ops.full_like(iter_i, 0), - iter_i], axis=0), batch_dims=0) - - v_i = ops.gather_nd(value, ops.stack([ops.full_like(iter_i, 0), - ops.full_like(iter_i, 0), - iter_i], axis=0), batch_dims=0) - - dec_i = ops.gather_nd(decay_mask, ops.stack([ops.full_like(iter_i, 0), - ops.full_like(iter_i, 0), - iter_i], axis=0), batch_dims=0) - - kcum_i = ops.gather_nd(k_cumdecay, ops.stack([ops.full_like(iter_i, 0), - ops.full_like(iter_i, 0), - iter_i], axis=0), batch_dims=0) - - g_i = ops.gather_nd(g, ops.stack([ops.full_like(iter_i, 0), - ops.full_like(iter_i, 0), - iter_i], axis=0), batch_dims=0) - - # Get last recurrent state (loop-carried variable) - last_state_in = loop.add_loop_carried_state(last_recurrent_state) - - # ============================================================ - # 2) BODY COMPUTATION (one iteration) - # ============================================================ - - # ---- attn = (q_i @ k_i.T) * dec_i ---- - k_i_T = ops.transpose(k_i, [0, 1, 3, 2]) - attn = ops.matmul(q_i, k_i_T) - attn = ops.multiply(attn, dec_i) - - # masking - attn = ops.select(mask, - ops.constant(0, attn.get_element_type()), - attn) - - # ---- v_prime = k_cumdecay @ last_state ---- - v_prime = ops.matmul(kcum_i, last_state_in) + num_chunks_param = context.get_input(8) + # context.get_input(0) + #query = ops.parameter([-1, -1, -1, 64, -1], np.float32, "query") + #key = ops.parameter([-1, -1, -1, 64, -1], np.float32, "key") + #value = ops.parameter([-1, -1, -1, 64, -1], np.float32, "value") + #decay_mask = ops.parameter([-1, -1, -1, 64, -1], np.float32, "decay_mask") + #mask = ops.parameter([64, 64], bool, "mask") + #k_cumdecay = ops.parameter([-1, -1, -1, 64, -1], np.float32, "k_cumdecay") + #g = ops.parameter([-1, -1, -1, 64], np.float32, "g") + #last_recurrent_state = ops.parameter([-1, -1, 8, 8], np.float32, "last_recurrent_state") + #num_chunks_param = ops.parameter([], np.int64, "num_chunks") + + value_shape = ops.shape_of(value) + const_zero = ops.constant(0, dtype=np.float32) + core_attn_out = ops.broadcast(const_zero, value_shape) + + const_one_float = ops.constant(1, dtype=np.float32) + const_zero_float = ops.constant(0, dtype=np.float32) + mask_float = ops.select(mask, const_one_float, const_zero_float) + + # create a body graph + # q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i] + # attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0) + # v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state + # v_new = v_i - v_prime + # attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state + # core_attn_out[:, :, i] = attn_inter + attn @ v_new + # last_recurrent_state = ( + # last_recurrent_state * g[:, :, i, -1, None, None].exp() + # + (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(-1, -2) @ v_new + # ) + timestep = ops.parameter([], np.int32, "timestep") + q_i_param = ops.parameter([-1, -1, 1, -1, -1], np.float32, "q_i") + k_i_param = ops.parameter([-1, -1, 1, -1, -1], np.float32, "k_i") + v_i_param = ops.parameter([-1, -1, 1, -1, -1], np.float32, "v_i") + decay_mask_i_param = ops.parameter([-1, -1, 1, -1, -1], np.float32, "decay_mask_i") + mask_i = ops.parameter([-1, -1], np.float32, "mask_i") + k_cumdecay_i_param = ops.parameter([-1, -1, 1, -1, -1], np.float32, "k_cumdecay_i") + last_recurrent_state_i = ops.parameter([-1, -1, -1, -1], np.float32, "last_recurrent_state_i") + g_i_param = ops.parameter([-1, -1, 1, -1], np.float32, "g_i") + core_attn_out_i = ops.parameter([-1, -1, -1, -1, -1], np.float32, "core_attn_out_i") + + const_two = ops.constant(2, dtype=np.int32) + q_i = ops.squeeze(q_i_param, const_two) + k_i = ops.squeeze(k_i_param, const_two) + v_i = ops.squeeze(v_i_param, const_two) + decay_mask_i = ops.squeeze(decay_mask_i_param, const_two) + k_cumdecay_i = ops.squeeze(k_cumdecay_i_param, const_two) + g_i = ops.squeeze(g_i_param, const_two) + + attn = ops.einsum([q_i, k_i], "bhwd,bhld->bhwl") + attn = ops.multiply(attn, decay_mask_i) + attn = ops.multiply(attn, mask_i) + + v_prime = ops.einsum([k_cumdecay_i, last_recurrent_state_i], "bhwd,bhdl->bhwl") v_new = ops.subtract(v_i, v_prime) - # ---- attn_inter = (q_i * exp(g_i[...,None])) @ last_state ---- - g_expand = ops.unsqueeze(g_i, -1) - g_exp = ops.exp(g_expand) - q_scaled = ops.multiply(q_i, g_exp) - attn_inter = ops.matmul(q_scaled, last_state_in) + const_three = ops.constant(3, dtype=np.int32) + attn_inter = ops.einsum([ops.multiply(q_i, ops.exp(ops.unsqueeze(g_i, const_three))), last_recurrent_state_i], + "bhwd,bhdl->bhwl") - # ---- final output for this iteration ---- - final_attn = ops.add(attn_inter, ops.matmul(attn, v_new)) + update_core_attn = ops.add(attn_inter, ops.einsum([attn, v_new], "bhwd,bhdl->bhwl")) + update_core_attn = ops.unsqueeze(update_core_attn, const_two) + const_zero_int = ops.constant(0, dtype=np.int32) + timestep_unsq = ops.unsqueeze(timestep, const_zero_int) + core_attn_out_res = ops.scatter_update(core_attn_out_i, timestep_unsq, update_core_attn, const_two) - # output accumulation (via Loop's "concat outputs") - loop.add_concat_output(final_attn) + const_minus1 = ops.constant(-1, dtype=np.int32) + g_i_minus1 = ops.gather(g_i, const_minus1, const_two) + subtract_g = ops.unsqueeze(ops.exp(ops.subtract(ops.unsqueeze(g_i_minus1, const_minus1), g_i)), const_minus1) + update_lrs = ops.einsum([ops.multiply(k_i, subtract_g), v_new], "bhdw,bhdl->bhwl") + last_recurrent_state_res = ops.unsqueeze(ops.unsqueeze(g_i_minus1, const_minus1), const_minus1) + last_recurrent_state_res = ops.multiply(last_recurrent_state_i, ops.exp(last_recurrent_state_res)) + last_recurrent_state_res = ops.add(last_recurrent_state_res, update_lrs) - # ---- update last recurrent state ---- - g_last = ops.gather(g_i, ops.constant([g_i.get_shape()[3] - 1], dtype="int64"), axis=3) + body_cond = ops.constant([True], dtype=bool) - # expand dims - g_last_e = ops.unsqueeze(ops.unsqueeze(g_last, -1), -1) - decay_factor = ops.exp(g_last_e) + body_model = ov.Model([body_cond, last_recurrent_state_res, core_attn_out_res], + [timestep, q_i_param, k_i_param, v_i_param, decay_mask_i_param, mask_i, k_cumdecay_i_param, + last_recurrent_state_i, + g_i_param, core_attn_out_i], "body_model") - g_diff = ops.exp(ops.subtract(ops.unsqueeze(g_last, -1), g_i)) - update_term = ops.matmul( - ops.transpose(ops.multiply(k_i, ops.unsqueeze(g_diff, -1)), [0,1,3,2]), - v_new + num_chunks = ops.convert(num_chunks_param, "i32") + loop = ops.loop( + num_chunks, + ops.constant(True, dtype="bool") ) - - next_state = ops.add(ops.multiply(last_state_in, decay_factor), update_term) - - # register updated state - loop.set_loop_carried_state(last_state_in, next_state) - - # condition for next iteration stays True - body.set_conditions(ops.constant(True, dtype="bool")) - - # ============================================================ - # 3) BUILD LOOP OUTPUTS - # ============================================================ - concat_attn_out = loop.get_iter_value(0) # concatenated final_attn - final_state_out = loop.get_iter_value(1) # final recurrent state - - return concat_attn_out, final_state_out + loop.set_function(body_model) + loop.set_sliced_input(q_i_param, query, 0, 1, 1, -1, 2) + loop.set_sliced_input(k_i_param, key, 0, 1, 1, -1, 2) + loop.set_sliced_input(v_i_param, value, 0, 1, 1, -1, 2) + loop.set_sliced_input(decay_mask_i_param, decay_mask, 0, 1, 1, -1, 2) + loop.set_invariant_input(mask_i, mask_float.output(0)) + loop.set_sliced_input(k_cumdecay_i_param, k_cumdecay, 0, 1, 1, -1, 2) + loop.set_merged_input(last_recurrent_state_i, last_recurrent_state, last_recurrent_state_res.output(0)) + loop.set_sliced_input(g_i_param, g, 0, 1, 1, -1, 2) + loop.set_merged_input(core_attn_out_i, core_attn_out.output(0), core_attn_out_res.output(0)) + + loop.set_special_body_ports([0, 0]) + + core_attn_out_new = loop.get_iter_value(core_attn_out_res.output(0), -1) + last_recurrent_state_new = loop.get_iter_value(last_recurrent_state_res.output(0), -1) + + #core_attn_out_new_res = ops.result(core_attn_out_new) + #last_recurrent_state_new_res = ops.result(last_recurrent_state_new) + + return [core_attn_out_new, last_recurrent_state_new] class Qwen3NextModelPatcher(ModelPatcher): From f623e573e0b720ea2d0ab5307baaf78841c9fc78 Mon Sep 17 00:00:00 2001 From: "Kazantsev, Roman" Date: Sun, 23 Nov 2025 20:05:21 +0400 Subject: [PATCH 08/13] Patch sparse moe block --- optimum/exporters/openvino/model_patcher.py | 106 +++++++++++++++----- 1 file changed, 83 insertions(+), 23 deletions(-) diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 9b67c45ca9..ec79c8b667 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -7047,7 +7047,9 @@ def l2norm(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6): # ) num_chunks = total_sequence_length // chunk_size - core_attn_out, last_recurrent_state = self.chunked_attention_cell( + core_attn_out = torch.zeros_like(value) + + last_recurrent_state = self.chunked_attention_cell( query, key, value, @@ -7220,11 +7222,11 @@ def apply_mask_to_padding_states(hidden_states, attention_mask): ) core_attn_out_dec, last_recurrent_state_dec = self.recurrent_gated_delta_rule( - query, - key, - value, - g=g, - beta=beta, + query[:, :1], + key[:, :1], + value[:, :1], + g=g[:, :1], + beta=beta[:, :1], initial_state=recurrent_state, output_final_state=cache_params is not None, use_qk_l2norm_in_kernel=True, @@ -7249,6 +7251,60 @@ def apply_mask_to_padding_states(hidden_states, attention_mask): return output +def patched_qwen3_next_sparse_moe_block(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + if self.norm_topk_prob: + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + # Loop over all available experts in the model and perform the computation on each expert + is_active_expert = torch.greater(expert_mask.sum(dim=(-1, -2)), 0) + + for expert_idx in range(self.num_experts): + is_activated = is_active_expert[expert_idx] + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) + + idx = is_activated.to(idx.dtype) * idx + torch.tensor([0]).to(idx.dtype) * (1 - is_activated.to(idx.dtype)) + top_x = is_activated.to(top_x.dtype) * top_x + torch.tensor([0]).to(top_x.dtype) * (1 - is_activated.to(top_x.dtype)) + + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] + + current_hidden_states = (is_activated.to(current_hidden_states.dtype) * current_hidden_states + + torch.tensor([0]).to(current_hidden_states.dtype) * (1 - is_activated.to(current_hidden_states.dtype))) + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + + shared_expert_output = self.shared_expert(hidden_states) + shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output + + final_hidden_states = final_hidden_states + shared_expert_output + + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states, router_logits + + class ChunkedAttentionCell(torch.nn.Module): def __init__(self): super().__init__() @@ -7300,7 +7356,8 @@ def forward( last_recurrent_state * decay_factor + update_term ) - return core_attn_out, last_recurrent_state + #return core_attn_out, last_recurrent_state + return last_recurrent_state def convert_chunked_attention_cell(context): @@ -7424,7 +7481,7 @@ def convert_chunked_attention_cell(context): #core_attn_out_new_res = ops.result(core_attn_out_new) #last_recurrent_state_new_res = ops.result(last_recurrent_state_new) - return [core_attn_out_new, last_recurrent_state_new] + return [last_recurrent_state_new] class Qwen3NextModelPatcher(ModelPatcher): @@ -7560,6 +7617,7 @@ def patched_forward( self.conversion_extensions = [ConversionExtension("ChunkedAttentionCellOp", convert_chunked_attention_cell)] def __enter__(self): + from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextSparseMoeBlock super().__enter__() setattr(self._model, self.orig_forward_name, self.patched_forward) @@ -7569,26 +7627,28 @@ def __enter__(self): layer_type = self._model.model.config.layer_types[idx] if layer_type == "linear_attention": linear_attn_layer = decoder_layer.linear_attn - else: - continue - linear_attn_layer._orig_forward = linear_attn_layer.forward - linear_attn_layer.forward = types.MethodType(qwen3_next_gated_delta_net_forward, linear_attn_layer) - linear_attn_layer._orig_chunk_gated_delta_rule = linear_attn_layer.chunk_gated_delta_rule - linear_attn_layer.chunk_gated_delta_rule = patched_chunk_gated_delta_rule - linear_attn_layer._orig_recurrent_gated_delta_rule = linear_attn_layer.recurrent_gated_delta_rule - linear_attn_layer.recurrent_gated_delta_rule = patched_recurrent_gated_delta_rule - linear_attn_layer.chunked_attention_cell = ChunkedAttentionCell() - chunked_attention_cell_list.append(linear_attn_layer.chunked_attention_cell) + linear_attn_layer._orig_forward = linear_attn_layer.forward + linear_attn_layer.forward = types.MethodType(qwen3_next_gated_delta_net_forward, linear_attn_layer) + linear_attn_layer._orig_chunk_gated_delta_rule = linear_attn_layer.chunk_gated_delta_rule + linear_attn_layer.chunk_gated_delta_rule = patched_chunk_gated_delta_rule + linear_attn_layer._orig_recurrent_gated_delta_rule = linear_attn_layer.recurrent_gated_delta_rule + linear_attn_layer.recurrent_gated_delta_rule = patched_recurrent_gated_delta_rule + linear_attn_layer.chunked_attention_cell = ChunkedAttentionCell() + chunked_attention_cell_list.append(linear_attn_layer.chunked_attention_cell) + if isinstance(decoder_layer.mlp, Qwen3NextSparseMoeBlock): + decoder_layer.mlp._orig_forward = decoder_layer.mlp.forward + decoder_layer.mlp.forward = types.MethodType(patched_qwen3_next_sparse_moe_block, decoder_layer.mlp) def __exit__(self, exc_type, exc_value, traceback): + from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextSparseMoeBlock super().__exit__(exc_type, exc_value, traceback) setattr(self._model, self.orig_forward_name, self.model_orig_forward) for idx, decoder_layer in enumerate(self._model.model.layers): layer_type = self._model.model.config.layer_types[idx] if layer_type == "linear_attention": linear_attn_layer = decoder_layer.linear_attn - else: - continue - linear_attn_layer.forward = linear_attn_layer._orig_forward - linear_attn_layer.chunk_gated_delta_rule = linear_attn_layer._orig_chunk_gated_delta_rule - linear_attn_layer.recurrent_gated_delta_rule = linear_attn_layer._orig_recurrent_gated_delta_rule + linear_attn_layer.forward = linear_attn_layer._orig_forward + linear_attn_layer.chunk_gated_delta_rule = linear_attn_layer._orig_chunk_gated_delta_rule + linear_attn_layer.recurrent_gated_delta_rule = linear_attn_layer._orig_recurrent_gated_delta_rule + if isinstance(decoder_layer.mlp, Qwen3NextSparseMoeBlock): + decoder_layer.mlp.forward = decoder_layer.mlp._orig_forward From e76f243f5af09fd46631cf41e44f7347ada4279c Mon Sep 17 00:00:00 2001 From: "Kazantsev, Roman" Date: Sun, 23 Nov 2025 20:32:29 +0400 Subject: [PATCH 09/13] Use core_attn_out --- optimum/exporters/openvino/model_patcher.py | 32 +++++++++++++++++---- 1 file changed, 27 insertions(+), 5 deletions(-) diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index ec79c8b667..6debeec266 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -7049,7 +7049,20 @@ def l2norm(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6): core_attn_out = torch.zeros_like(value) - last_recurrent_state = self.chunked_attention_cell( + #core_attn_out, last_recurrent_state = self.chunked_attention_cell( + # query, + # key, + # value, + # decay_mask, + # mask, + # k_cumdecay, + # g, + # last_recurrent_state, + # num_chunks + #) + + # final_output = ops.concat([core_attn_out_new, last_recurrent_state_new], 0) + output_cell = self.chunked_attention_cell( query, key, value, @@ -7061,6 +7074,10 @@ def l2norm(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6): num_chunks ) + num_elems = value.numel() + core_attn_out = output_cell[:num_elems].reshape(value.shape) + last_recurrent_state = output_cell[num_elems:].reshape(last_recurrent_state.shape) + if not output_final_state: last_recurrent_state = None core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1]) @@ -7356,8 +7373,10 @@ def forward( last_recurrent_state * decay_factor + update_term ) + output_cell = torch.cat([core_attn_out.flatten(), last_recurrent_state.flatten()], dim=0) + return output_cell #return core_attn_out, last_recurrent_state - return last_recurrent_state + #return last_recurrent_state def convert_chunked_attention_cell(context): @@ -7478,10 +7497,13 @@ def convert_chunked_attention_cell(context): core_attn_out_new = loop.get_iter_value(core_attn_out_res.output(0), -1) last_recurrent_state_new = loop.get_iter_value(last_recurrent_state_res.output(0), -1) - #core_attn_out_new_res = ops.result(core_attn_out_new) - #last_recurrent_state_new_res = ops.result(last_recurrent_state_new) + flatten_shape = ops.constant([-1], dtype=np.int32) + core_attn_out_new = ops.reshape(core_attn_out_new, flatten_shape, False) + last_recurrent_state_new = ops.reshape(last_recurrent_state_new, flatten_shape, False) + + final_output = ops.concat([core_attn_out_new, last_recurrent_state_new], 0) - return [last_recurrent_state_new] + return [final_output.output(0)] class Qwen3NextModelPatcher(ModelPatcher): From b191d5986dd2e06bbc63eae14527bc4d11199cb6 Mon Sep 17 00:00:00 2001 From: "Kazantsev, Roman" Date: Mon, 24 Nov 2025 12:57:07 +0400 Subject: [PATCH 10/13] Fix use of mask --- optimum/exporters/openvino/model_patcher.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 6debeec266..c6a6ecd682 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -7047,8 +7047,6 @@ def l2norm(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6): # ) num_chunks = total_sequence_length // chunk_size - core_attn_out = torch.zeros_like(value) - #core_attn_out, last_recurrent_state = self.chunked_attention_cell( # query, # key, @@ -7412,7 +7410,7 @@ def convert_chunked_attention_cell(context): const_one_float = ops.constant(1, dtype=np.float32) const_zero_float = ops.constant(0, dtype=np.float32) - mask_float = ops.select(mask, const_one_float, const_zero_float) + mask_float = ops.select(mask, const_zero_float, const_one_float) # create a body graph # q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i] From 0b1bb212f4ad008aa1975f63d104378c398846f0 Mon Sep 17 00:00:00 2001 From: "Kazantsev, Roman" Date: Mon, 24 Nov 2025 14:05:44 +0400 Subject: [PATCH 11/13] Correct shape for recurrent_state in config file --- optimum/exporters/openvino/model_configs.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index 3d1cf1813e..4811f479fa 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -4434,7 +4434,9 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int ) conv_state = self.random_float_tensor(conv_state_shape, framework=framework, dtype=float_dtype) cache_params.append(conv_state) - recurrent_state_shape = (self.batch_size, self.num_key_value_heads, self.head_k_dim, self.head_v_dim) + #num_heads = self.num_key_value_heads * (self.num_v_heads // self.num_k_heads) + num_heads = self.num_v_heads + recurrent_state_shape = (self.batch_size, num_heads, self.head_k_dim, self.head_v_dim) recurrent_state = self.random_float_tensor(recurrent_state_shape, framework=framework, dtype=float_dtype) cache_params.append(recurrent_state) From 6a3d22fa6d64ea446ad55e931b053bbbfb755639 Mon Sep 17 00:00:00 2001 From: "Kazantsev, Roman" Date: Fri, 28 Nov 2025 17:30:04 +0000 Subject: [PATCH 12/13] Re-write patch for MoE --- optimum/exporters/openvino/model_patcher.py | 267 ++++++++++++++++++-- 1 file changed, 244 insertions(+), 23 deletions(-) diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index c6a6ecd682..085be660b5 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -7286,38 +7286,256 @@ def patched_qwen3_next_sparse_moe_block(self, hidden_states: torch.Tensor) -> to # One hot encode the selected experts to create an expert mask # this will be used to easily index which expert is going to be sollicitated expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + + num_experts = self.num_experts + #down_projs = [] + #gate_projs = [] + #up_projs = [] + #for idx in range(num_experts): + # down_projs.append(self.experts[idx].down_proj.weight) + # gate_projs.append(self.experts[idx].gate_proj.weight) + # up_projs.append(self.experts[idx].up_proj.weight) + + #down_projs = torch.stack(down_projs, dim=0) + #up_projs = torch.stack(up_projs, dim=0) + #gate_projs = torch.stack(gate_projs, dim=0) + + down_projs = torch.concat( + tuple(self.experts[i].down_proj.weight.unsqueeze(0) + for i in range(num_experts)), + dim=0 + ) + + gate_projs = torch.concat( + tuple(self.experts[i].gate_proj.weight.unsqueeze(0) + for i in range(num_experts)), + dim=0 + ) + + up_projs = torch.concat( + tuple(self.experts[i].up_proj.weight.unsqueeze(0) + for i in range(num_experts)), + dim=0 + ) + + #down_projs = torch.zeros((4, 32, 16), dtype=torch.float32) + #up_projs = torch.zeros((4, 16, 32), dtype=torch.float32) + #gate_projs = torch.zeros((4, 16, 32), dtype=torch.float32) + + final_hidden_states_res = self.moe_cell( + expert_hit, + expert_mask, + hidden_states, + down_projs, + up_projs, + gate_projs, + routing_weights, + final_hidden_states, + ) # Loop over all available experts in the model and perform the computation on each expert - is_active_expert = torch.greater(expert_mask.sum(dim=(-1, -2)), 0) + #is_active_expert = torch.greater(expert_mask.sum(dim=(-1, -2)), 0) - for expert_idx in range(self.num_experts): - is_activated = is_active_expert[expert_idx] - expert_layer = self.experts[expert_idx] - idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) + #for expert_idx in range(self.num_experts): + # is_activated = is_active_expert[expert_idx] + # expert_layer = self.experts[expert_idx] + # idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) - idx = is_activated.to(idx.dtype) * idx + torch.tensor([0]).to(idx.dtype) * (1 - is_activated.to(idx.dtype)) - top_x = is_activated.to(top_x.dtype) * top_x + torch.tensor([0]).to(top_x.dtype) * (1 - is_activated.to(top_x.dtype)) + # idx = is_activated.to(idx.dtype) * idx + torch.tensor([0]).to(idx.dtype) * (1 - is_activated.to(idx.dtype)) + # top_x = is_activated.to(top_x.dtype) * top_x + torch.tensor([0]).to(top_x.dtype) * (1 - is_activated.to(top_x.dtype)) - # Index the correct hidden states and compute the expert hidden state for - # the current expert. We need to make sure to multiply the output hidden - # states by `routing_weights` on the corresponding tokens (top-1 and top-2) - current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) - current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] + # # Index the correct hidden states and compute the expert hidden state for + # # the current expert. We need to make sure to multiply the output hidden + # # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + # current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + # current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] - current_hidden_states = (is_activated.to(current_hidden_states.dtype) * current_hidden_states + - torch.tensor([0]).to(current_hidden_states.dtype) * (1 - is_activated.to(current_hidden_states.dtype))) + # current_hidden_states = (is_activated.to(current_hidden_states.dtype) * current_hidden_states + + # torch.tensor([0]).to(current_hidden_states.dtype) * (1 - is_activated.to(current_hidden_states.dtype))) - # However `index_add_` only support torch tensors for indexing so we'll use - # the `top_x` tensor here. - final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + # # However `index_add_` only support torch tensors for indexing so we'll use + # # the `top_x` tensor here. + # final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) shared_expert_output = self.shared_expert(hidden_states) shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output - final_hidden_states = final_hidden_states + shared_expert_output + final_hidden_states_res = final_hidden_states_res + shared_expert_output - final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) - return final_hidden_states, router_logits + final_hidden_states_res = final_hidden_states_res.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states_res, router_logits + + +class LoopBasedMoECell(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward( + self, + expert_hit, + expert_mask, + hidden_states, + down_projs, + up_projs, + gate_projs, + routing_weights, + final_hidden_states, + ): + + _, hidden_dim = hidden_states.shape + final_hidden_states_res = final_hidden_states.clone() + + for expert_idx in expert_hit: + #expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) + + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + #current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] + # down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + act_fn_res = torch.nn.functional.silu(torch.nn.functional.linear(current_state, gate_projs[expert_idx[0]])) + up_proj_res = torch.nn.functional.linear(current_state, up_projs[expert_idx[0]]) + current_hidden_states = act_fn_res * up_proj_res + current_hidden_states = torch.nn.functional.linear(current_hidden_states, down_projs[expert_idx[0]]) + current_hidden_states = current_hidden_states * routing_weights[top_x, idx, None] + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states_res.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + + return final_hidden_states_res + +def convert_moe_cell(context): + import openvino.opset15 as ops + import openvino as ov + import numpy as np + + idx_param = ops.parameter([], np.int32, "idx_param") + expert_hit_param = ops.parameter([-1], np.int32, "expert_hit_param") + expert_mask_param = ops.parameter([-1, -1, -1], np.int64, "expert_mask_param") + hidden_states_param = ops.parameter([-1, -1], np.float32, "hidden_states_param") + gate_projs_param = ops.parameter([-1, -1, -1], np.float32, "gate_projs_param") + up_projs_param = ops.parameter([-1, -1, -1], np.float32, "up_projs_param") + down_projs_param = ops.parameter([-1, -1, -1], np.float32, "down_projs_param") + routing_weights_param = ops.parameter([-1, -1], np.float32, "routing_weights_param") + final_hidden_states_param = ops.parameter([-1, -1], np.float32, "final_hidden_states_param") + + #batch_size, hidden_dim = hidden_states_param.shape + hidden_states_shape = ops.shape_of(hidden_states_param) + const_zero = ops.constant(0, dtype=np.int32) + const_one = ops.constant([1], dtype=np.int32) + hidden_dim = ops.gather(hidden_states_shape, const_one, const_zero) + + expert_idx = ops.gather(expert_hit_param, idx_param, const_zero) + shape1d = ops.constant([1], dtype=np.int32) + scalar_shape = ops.constant([], dtype=np.int32) + expert_idx = ops.reshape(expert_idx, shape1d, False) + expert_idx_scalar = ops.reshape(expert_idx, scalar_shape, False) + + expert_mask = ops.gather(expert_mask_param, expert_idx_scalar, const_zero) + non_zeros = ops.non_zero(expert_mask) + + split_axis = ops.constant(0, dtype=np.int32) + split_res = ops.split(non_zeros, split_axis, 2) + #idx = ops.squeeze(split_res.output(0), const_zero) + top_x = ops.squeeze(split_res.output(1), const_zero) + transpose_order = ops.constant([1, 0], dtype=np.int32) + top_x_idx = ops.transpose(non_zeros, transpose_order) + + # current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_state = ops.gather(hidden_states_param, top_x, const_zero) + current_state = ops.unsqueeze(current_state, const_zero) + const_minus1 = ops.constant([-1], dtype=np.int64) + new_shape = ops.concat([const_minus1, hidden_dim], 0) + current_state = ops.reshape(current_state, new_shape, False) + + # #act_fn_res = torch.nn.functional.silu(torch.nn.functional.linear(current_state, gate_projs[expert_idx[0]])) + gate_proj = ops.gather(gate_projs_param, expert_idx, const_zero) + gate_proj = ops.squeeze(gate_proj, const_zero) + act_fn_res = ops.matmul(current_state, gate_proj, False, True) + act_fn_res = ops.multiply(ops.sigmoid(act_fn_res), act_fn_res) + + # up_proj_res = torch.nn.functional.linear(current_state, up_projs[expert_idx[0]]) + up_proj = ops.gather(up_projs_param, expert_idx, const_zero) + up_proj = ops.squeeze(up_proj, const_zero) + up_proj_res = ops.matmul(current_state, up_proj, False, True) + + # current_hidden_states = act_fn_res * up_proj_res + current_hidden_states = ops.multiply(act_fn_res, up_proj_res) + current_hidden_states = act_fn_res + + # current_hidden_states = torch.nn.functional.linear(current_hidden_states, down_projs[expert_idx[0]]) + down_proj = ops.gather(down_projs_param, expert_idx, const_zero) + down_proj = ops.squeeze(down_proj, const_zero) + current_hidden_states = ops.matmul(current_hidden_states, down_proj, False, True) + + # current_hidden_states = current_hidden_states * routing_weights[top_x, idx, None] + # routing_weights[top_x, idx, None] + #routing_weights = ops.gather(routing_weights_param, top_x, const_zero) + #routing_weights = ops.gather(routing_weights, idx, const_zero) + #routing_weights = ops.unsqueeze(routing_weights, const_zero) + routing_weights = ops.gather_nd(routing_weights_param, top_x_idx) + routing_weights = ops.unsqueeze(routing_weights, const_one) + current_hidden_states = ops.multiply(current_hidden_states, routing_weights) + + #current_hidden_states = current_state + + # final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + top_x = ops.unsqueeze(top_x, const_one) + #current_hidden_states = ops.reshape(current_hidden_states, new_shape, False) + + final_hidden_states_res = ops.scatter_nd_update(final_hidden_states_param, + top_x, current_hidden_states, "sum") + + body_cond = ops.constant([True], dtype=bool) + + body_model = ov.Model([body_cond, final_hidden_states_res], + [idx_param, expert_hit_param, expert_mask_param, hidden_states_param, + gate_projs_param, up_projs_param, down_projs_param, + routing_weights_param, final_hidden_states_param], + "body_model") + + # context.get_input(0) + expert_hit = context.get_input(0) + expert_mask_main = context.get_input(1) + hidden_states = context.get_input(2) + down_projs = context.get_input(3) + up_projs = context.get_input(4) + gate_projs = context.get_input(5) + routing_weights = context.get_input(6) + final_hidden_states = context.get_input(7) + + # flatten expert_hit + flatten_shape = ops.constant([-1], dtype=np.int32) + expert_hit = ops.reshape(expert_hit, flatten_shape, False) + + num_active_experts = ops.shape_of(expert_hit, "i32") + const_zero = ops.constant(0, dtype=np.int32) + num_active_experts = ops.gather(num_active_experts, const_zero, const_zero) + + loop = ops.loop( + num_active_experts, + ops.constant(True, dtype="bool") + ) + + loop.set_function(body_model) + loop.set_invariant_input(expert_hit_param, expert_hit.output(0)) + loop.set_invariant_input(expert_mask_param, expert_mask_main) + loop.set_invariant_input(hidden_states_param, hidden_states) + loop.set_invariant_input(gate_projs_param, gate_projs) + loop.set_invariant_input(up_projs_param, up_projs) + loop.set_invariant_input(down_projs_param, down_projs) + loop.set_invariant_input(routing_weights_param, routing_weights) + loop.set_merged_input(final_hidden_states_param, final_hidden_states, final_hidden_states_res.output(0)) + + loop.set_special_body_ports([0, 0]) + + final_hidden_states = loop.get_iter_value(final_hidden_states_res.output(0), -1) + + return [final_hidden_states] class ChunkedAttentionCell(torch.nn.Module): @@ -7373,8 +7591,6 @@ def forward( output_cell = torch.cat([core_attn_out.flatten(), last_recurrent_state.flatten()], dim=0) return output_cell - #return core_attn_out, last_recurrent_state - #return last_recurrent_state def convert_chunked_attention_cell(context): @@ -7633,8 +7849,12 @@ def patched_forward( self.module_extensions = { ChunkedAttentionCell: ModuleExtension(ChunkedAttentionCell, "ChunkedAttentionCellOp"), + LoopBasedMoECell: ModuleExtension(LoopBasedMoECell, "LoopBasedMoECellOp"), } - self.conversion_extensions = [ConversionExtension("ChunkedAttentionCellOp", convert_chunked_attention_cell)] + self.conversion_extensions = [ + ConversionExtension("ChunkedAttentionCellOp", convert_chunked_attention_cell), + ConversionExtension("LoopBasedMoECellOp", convert_moe_cell), + ] def __enter__(self): from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextSparseMoeBlock @@ -7658,6 +7878,7 @@ def __enter__(self): if isinstance(decoder_layer.mlp, Qwen3NextSparseMoeBlock): decoder_layer.mlp._orig_forward = decoder_layer.mlp.forward decoder_layer.mlp.forward = types.MethodType(patched_qwen3_next_sparse_moe_block, decoder_layer.mlp) + decoder_layer.mlp.moe_cell = LoopBasedMoECell() def __exit__(self, exc_type, exc_value, traceback): from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextSparseMoeBlock From 9df28e3e7a06bb00c153e1bbe2c2444f0bc64ffe Mon Sep 17 00:00:00 2001 From: "Kazantsev, Roman" Date: Fri, 28 Nov 2025 18:59:53 +0000 Subject: [PATCH 13/13] --- optimum/exporters/openvino/model_patcher.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 085be660b5..b8af891618 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -7302,19 +7302,19 @@ def patched_qwen3_next_sparse_moe_block(self, hidden_states: torch.Tensor) -> to #gate_projs = torch.stack(gate_projs, dim=0) down_projs = torch.concat( - tuple(self.experts[i].down_proj.weight.unsqueeze(0) + tuple(self.experts[i].down_proj.weight.float().unsqueeze(0) for i in range(num_experts)), dim=0 ) gate_projs = torch.concat( - tuple(self.experts[i].gate_proj.weight.unsqueeze(0) + tuple(self.experts[i].gate_proj.weight.float().unsqueeze(0) for i in range(num_experts)), dim=0 ) up_projs = torch.concat( - tuple(self.experts[i].up_proj.weight.unsqueeze(0) + tuple(self.experts[i].up_proj.weight.float().unsqueeze(0) for i in range(num_experts)), dim=0 )