|
95 | 95 | LlamaModelPatcher, |
96 | 96 | LlavaImageEmbeddingModelPatcher, |
97 | 97 | LlavaQwen2ImageEmbeddingsModelPatcher, |
| 98 | + MambaPatcher, |
98 | 99 | MiniCPM3Patcher, |
99 | 100 | MiniCPMModelPatcher, |
100 | 101 | MiniCPMVImageEmbeddingsModelPatcher, |
@@ -2880,3 +2881,132 @@ def patch_model_for_export( |
2880 | 2881 | self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None |
2881 | 2882 | ) -> "ModelPatcher": |
2882 | 2883 | return DeepseekPatcher(self, model, model_kwargs=model_kwargs) |
| 2884 | + |
| 2885 | + |
| 2886 | +class MambaCacheDummyInputGenerator(DummyInputGenerator): |
| 2887 | + """ |
| 2888 | + Generates dummy past_key_values inputs for seq2seq architectures. |
| 2889 | + """ |
| 2890 | + |
| 2891 | + SUPPORTED_INPUT_NAMES = ("past_ssm_states", "past_conv_states", "cache_position") |
| 2892 | + |
| 2893 | + def __init__( |
| 2894 | + self, |
| 2895 | + task: str, |
| 2896 | + normalized_config, |
| 2897 | + batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], |
| 2898 | + sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"], |
| 2899 | + **kwargs, |
| 2900 | + ): |
| 2901 | + self.normalized_config = normalized_config |
| 2902 | + self.batch_size = batch_size |
| 2903 | + self.sequence_length = sequence_length |
| 2904 | + self.intermediate_size = self.normalized_config.config.intermediate_size |
| 2905 | + self.ssm_state_size = self.normalized_config.config.state_size |
| 2906 | + self.conv_kernel_size = self.normalized_config.config.conv_kernel |
| 2907 | + |
| 2908 | + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): |
| 2909 | + if input_name == "past_ssm_states": |
| 2910 | + ssm_shape = [self.batch_size, self.intermediate_size, self.ssm_state_size] |
| 2911 | + return [ |
| 2912 | + self.random_float_tensor(ssm_shape, framework=framework, dtype=float_dtype) |
| 2913 | + for _ in range(self.normalized_config.num_layers) |
| 2914 | + ] |
| 2915 | + |
| 2916 | + elif input_name == "past_conv_states": |
| 2917 | + conv_shape = [self.batch_size, self.intermediate_size, self.conv_kernel_size] |
| 2918 | + return [ |
| 2919 | + self.random_float_tensor(conv_shape, framework=framework, dtype=float_dtype) |
| 2920 | + for _ in range(self.normalized_config.num_layers) |
| 2921 | + ] |
| 2922 | + |
| 2923 | + elif input_name == "cache_position": |
| 2924 | + return self.random_int_tensor( |
| 2925 | + shape=[self.conv_kernel_size], |
| 2926 | + max_value=self.sequence_length, |
| 2927 | + framework=framework, |
| 2928 | + dtype=int_dtype, |
| 2929 | + ) |
| 2930 | + |
| 2931 | + raise ValueError(f"Unsupported input name {input_name}") |
| 2932 | + |
| 2933 | + |
| 2934 | +@register_in_tasks_manager("mamba", *["text-generation", "text-generation-with-past"], library_name="transformers") |
| 2935 | +class MambaOpenVINOConfig(TextDecoderOnnxConfig): |
| 2936 | + DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MambaCacheDummyInputGenerator) |
| 2937 | + DUMMY_PKV_GENERATOR_CLASS = MambaCacheDummyInputGenerator |
| 2938 | + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig |
| 2939 | + |
| 2940 | + @property |
| 2941 | + def inputs(self) -> Dict[str, Dict[int, str]]: |
| 2942 | + if self.use_past_in_inputs: |
| 2943 | + common_inputs = {"input_ids": {0: "batch_size", 1: "sequence_length"}} |
| 2944 | + self.add_past_key_values(common_inputs, direction="inputs") |
| 2945 | + # common_inputs["attention_mask"] = {0: "batch_size", 1: "past_sequence_length + 1"} |
| 2946 | + common_inputs["cache_position"] = {0: "cache_sequence_length"} |
| 2947 | + else: |
| 2948 | + common_inputs = { |
| 2949 | + "input_ids": {0: "batch_size", 1: "sequence_length"}, |
| 2950 | + # "attention_mask": {0: "batch_size", 1: "sequence_length"}, |
| 2951 | + "cache_position": {0: "cache_sequence_length"}, |
| 2952 | + } |
| 2953 | + return common_inputs |
| 2954 | + |
| 2955 | + def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str): |
| 2956 | + """ |
| 2957 | + Fills `input_or_outputs` mapping with past_key_values dynamic axes considering the direction. |
| 2958 | +
|
| 2959 | + Args: |
| 2960 | + inputs_or_outputs (`Dict[str, Dict[int, str]]`): |
| 2961 | + The mapping to fill. |
| 2962 | + direction (`str`): |
| 2963 | + either "inputs" or "outputs", it specifies whether `input_or_outputs` is the input mapping or the |
| 2964 | + output mapping, this is important for axes naming. |
| 2965 | + """ |
| 2966 | + if direction not in ["inputs", "outputs"]: |
| 2967 | + raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given') |
| 2968 | + |
| 2969 | + if direction == "inputs": |
| 2970 | + ssm_name = "past_ssm_states" |
| 2971 | + conv_name = "past_conv_states" |
| 2972 | + else: |
| 2973 | + ssm_name = "present_ssm_states" |
| 2974 | + conv_name = "present_conv_states" |
| 2975 | + |
| 2976 | + for i in range(self._normalized_config.num_layers): |
| 2977 | + inputs_or_outputs[f"{ssm_name}.{i}"] = {0: "batch_size"} |
| 2978 | + |
| 2979 | + for i in range(self._normalized_config.num_layers): |
| 2980 | + inputs_or_outputs[f"{conv_name}.{i}"] = {0: "batch_size"} |
| 2981 | + |
| 2982 | + def patch_model_for_export( |
| 2983 | + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None |
| 2984 | + ): |
| 2985 | + return MambaPatcher(self, model, model_kwargs) |
| 2986 | + |
| 2987 | + def generate_dummy_inputs(self, framework: str = "pt", **kwargs): |
| 2988 | + dummy_inputs_generators = self._create_dummy_input_generator_classes(**kwargs) |
| 2989 | + |
| 2990 | + dummy_inputs = {} |
| 2991 | + input_names = [key for key in self.inputs.keys() if not key.startswith("past_")] |
| 2992 | + if self.use_past_in_inputs and self.use_cache_branch is not False: |
| 2993 | + input_names.extend(["past_ssm_states", "past_conv_states"]) |
| 2994 | + |
| 2995 | + for input_name in input_names: |
| 2996 | + input_was_inserted = False |
| 2997 | + for dummy_input_gen in dummy_inputs_generators: |
| 2998 | + if dummy_input_gen.supports_input(input_name): |
| 2999 | + dummy_inputs[input_name] = self.overwrite_shape_and_generate_input( |
| 3000 | + dummy_input_gen, |
| 3001 | + input_name, |
| 3002 | + framework, |
| 3003 | + input_shapes=kwargs, |
| 3004 | + ) |
| 3005 | + input_was_inserted = True |
| 3006 | + break |
| 3007 | + if not input_was_inserted: |
| 3008 | + raise RuntimeError( |
| 3009 | + f'Could not generate dummy input for "{input_name}". Try adding a proper dummy input generator to the model ONNX config.' |
| 3010 | + ) |
| 3011 | + |
| 3012 | + return dummy_inputs |
0 commit comments