From c69fe32c638e52433016df8d1a6746db3e7e70da Mon Sep 17 00:00:00 2001 From: Ekaterina Aidova Date: Tue, 21 May 2024 20:52:25 +0400 Subject: [PATCH] Add support export for new architectures (#720) * update codegen config for support codegen2 * add support DBRX * add qwen2moe support * fix test models * buichuan sdpa * apply review comments * Apply suggestions from code review Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> --------- Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> --- optimum/exporters/openvino/model_configs.py | 47 +++ optimum/exporters/openvino/model_patcher.py | 343 +++++++++++++++++++- tests/openvino/test_modeling.py | 4 + tests/openvino/utils_tests.py | 3 + 4 files changed, 396 insertions(+), 1 deletion(-) diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index 8feeafd619..d69adc9da3 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -20,6 +20,7 @@ from optimum.exporters.onnx.config import TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig from optimum.exporters.onnx.model_configs import ( + CodeGenOnnxConfig, FalconOnnxConfig, GemmaOnnxConfig, LlamaOnnxConfig, @@ -44,6 +45,8 @@ AquilaModelPatcher, BaichuanModelPatcher, ChatGLMModelPatcher, + CodeGenModelPatcher, + DBRXModelPatcher, GemmaModelPatcher, InternLM2Patcher, InternLMModelPatcher, @@ -112,6 +115,15 @@ class Qwen2OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedTextConfig +@register_in_tasks_manager("qwen2-moe", *["text-generation", "text-generation-with-past"], library_name="transformers") +class Qwen2MoEOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): + DEFAULT_ONNX_OPSET = 14 + + DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator) + DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + + @register_in_tasks_manager("minicpm", *["text-generation", "text-generation-with-past"], library_name="transformers") class MiniCPMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): DEFAULT_ONNX_OPSET = 14 @@ -738,3 +750,38 @@ def patch_model_for_export( self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None ) -> "ModelPatcher": return InternLMModelPatcher(self, model, model_kwargs=model_kwargs) + + +@register_in_tasks_manager( + "codegen", + *["feature-extraction", "feature-extraction-with-past", "text-generation", "text-generation-with-past"], + library_name="transformers", +) +class CodeGenOpenVINOConfig(CodeGenOnnxConfig): + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return CodeGenModelPatcher(self, model, model_kwargs=model_kwargs) + + +@register_in_tasks_manager( + "dbrx", + *["text-generation", "text-generation-with-past"], + library_name="transformers", +) +class DBRXOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): + DEFAULT_ONNX_OPSET = 14 + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args( + num_attention_heads="n_heads", + hidden_size="d_model", + num_layers="n_layers", + num_key_value_heads="attn_config.kv_n_heads", + allow_new=True, + ) + DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator) + DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator + + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return DBRXModelPatcher(self, model, model_kwargs=model_kwargs) diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 33fd77cba3..93a8430522 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -20,6 +20,7 @@ import torch import torch.nn.functional as F +from transformers.cache_utils import Cache, StaticCache from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.utils import is_tf_available @@ -43,6 +44,9 @@ from transformers.modeling_tf_utils import TFPreTrainedModel +BETTERTRANSFORMER_IGNORE = ("codegen",) + + def patch_model_with_bettertransformer(model): COLOR_RED = "\033[1;31m" COLOR_RESET = "\033[0m" @@ -81,6 +85,10 @@ def patch_model_with_bettertransformer(model): # model already has required SDPA implementation if getattr(model, "_supports_sdpa", False) and getattr(model.config, "_attn_implementation", "eager") == "sdpa": return model + + if model.config.model_type in BETTERTRANSFORMER_IGNORE: + return model + try: model = model.to_bettertransformer() except Exception as e: @@ -665,6 +673,72 @@ def _baichuan13b_atten_forward( return attn_output, attn_weights, past_key_value +# Adapted from https://huggingface.co/baichuan-inc/Baichuan-7B/blob/262c8cb58b6d3615c208d9230baa869fddee2adb/modeling_baichuan.py#L181 +def _baichuan7b_attn_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. + cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] + sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + bsz, q_len, _ = hidden_states.size() + + proj = self.W_pack(hidden_states) + proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2) + query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + # [bsz, nh, t, hd] + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + if not output_attentions: + attn_weights = None + attn_output = F.scaled_dot_product_attention( + query_states, key_states, value_states, attn_mask=attention_mask, scale=1 / math.sqrt(self.head_dim) + ) + else: + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights, past_key_value + + class BaichuanModelPatcher(DecoderModelPatcher): def __init__( self, @@ -712,13 +786,18 @@ def forward( for layer in self._model.model.layers: layer.self_attn._orig_forward = layer.self_attn.forward layer.self_attn.forward = types.MethodType(_baichuan13b_atten_forward, layer.self_attn) + else: + for layer in self._model.model.layers: + layer.self_attn._orig_forward = layer.self_attn.forward + layer.self_attn.forward = types.MethodType(_baichuan7b_attn_forward, layer.self_attn) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) if hasattr(self._model, "_orig_forward"): self._model.forward = self._model._orig_forward - for layer in self._model.model.layers: + for layer in self._model.model.layers: + if hasattr(layer.self_attn, "_orig_forward"): layer.self_attn.forward = layer.self_attn._orig_forward @@ -1328,3 +1407,265 @@ def __exit__(self, exc_type, exc_value, traceback): for layer in self._model.model.layers: if hasattr(layer.self_attn, "_orig_forward"): layer.self_attn.forward = layer.self_attn._orig_forward + + +class CodeGenModelPatcher(DecoderModelPatcher): + def __enter__(self): + super().__enter__() + + # whole codegen bettertransformer patch include attn.forward and does not cover codegen2. + # For avoiding breaking model on tracing stage, we reduce area of bettertransformer patch only for _attn. + from optimum.bettertransformer.models.attention import codegen_wrapped_scaled_dot_product + + for layer in self._model.transformer.h: + if is_torch_version(">=", "2.1.0") and not self._model.config.output_attentions: + orig_self_attn_fwd = layer.attn._attn + layer.attn._attn = types.MethodType(codegen_wrapped_scaled_dot_product, layer.attn) + layer.attn._orig_attn = orig_self_attn_fwd + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + for layer in self._model.transformer.h: + if hasattr(layer.attn, "_orig_attn"): + layer.attn._attn = layer.attn._orig_attn + + +# adapted from https://github.com/huggingface/transformers/blob/v4.40.2/src/transformers/models/dbrx/modeling_dbrx.py#L763 +def _dbrx_experts_forward( + self, x: torch.Tensor, weights: torch.Tensor, top_weights: torch.Tensor, top_experts: torch.LongTensor +): + bsz, q_len, hidden_size = x.shape + x = x.view(-1, hidden_size) + out = torch.zeros_like(x) + + expert_mask = torch.nn.functional.one_hot(top_experts, num_classes=self.moe_num_experts).permute(2, 1, 0) + # Chunk experts at once to avoid storing full parameter multiple times in autograd + w1_chunked = self.mlp.w1.view(self.mlp.moe_num_experts, self.mlp.ffn_hidden_size, self.mlp.hidden_size).chunk( + self.moe_num_experts, dim=0 + ) + v1_chunked = self.mlp.v1.view(self.mlp.moe_num_experts, self.mlp.ffn_hidden_size, self.mlp.hidden_size).chunk( + self.moe_num_experts, dim=0 + ) + w2_chunked = self.mlp.w2.view(self.mlp.moe_num_experts, self.mlp.ffn_hidden_size, self.mlp.hidden_size).chunk( + self.moe_num_experts, dim=0 + ) + w1_chunked = [w1.squeeze(dim=0) for w1 in w1_chunked] + v1_chunked = [v1.squeeze(dim=0) for v1 in v1_chunked] + w2_chunked = [w2.squeeze(dim=0) for w2 in w2_chunked] + for expert_idx in range(0, self.moe_num_experts): + topk_idx, token_idx = torch.where(expert_mask[expert_idx]) + + # Difference with original: removal + # if token_idx.shape[0] == 0: + # continue + # loop interruption depends on input data and may affect torchscript tracing + + token_list = token_idx + topk_list = topk_idx + + expert_tokens = x[None, token_list].reshape(-1, hidden_size) + expert_out = ( + self.mlp(expert_tokens, w1_chunked[expert_idx], v1_chunked[expert_idx], w2_chunked[expert_idx]) + * top_weights[token_list, topk_list, None] + ) + + out.index_add_(0, token_idx, expert_out) + + out = out.reshape(bsz, q_len, hidden_size) + return out + + +# adapted from https://github.com/huggingface/transformers/blob/v4.40.2/src/transformers/models/dbrx/modeling_dbrx.py#L1228 +def _dbrx_update_causal_mask_legacy( + self, attention_mask: Optional[torch.Tensor], input_tensor: torch.Tensor, cache_position: torch.Tensor +) -> Optional[torch.Tensor]: + from transformers.modeling_attn_mask_utils import AttentionMaskConverter + + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + dtype, device = input_tensor.dtype, input_tensor.device + # difference with original modeling + # using minimum from dtype with larger bandwith (floa32) may lead to overflow + # during execution on platforms with default lower precision (bfloat16, float16) + min_dtype = torch.finfo(torch.float16).min + sequence_length = input_tensor.shape[1] + if hasattr(self.blocks[0].norm_attn_norm.attn, "past_key_value"): # static cache + target_length = self.config.max_position_embeddings + else: # dynamic cache + target_length = ( + attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1 + ) + # difference with original modeling + # removed target_length = int(target_length). + # Casting to int leads to constant folding during tracing that makes impossible to use model for sequence of different length + causal_mask = torch.full((sequence_length, target_length), fill_value=1, dtype=dtype, device=device) * min_dtype + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.dim() == 2: + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + elif attention_mask.dim() == 4: + # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with + # cache. In that case, the 4D attention mask attends to the newest tokens only. + if attention_mask.shape[-2] < cache_position[0] + sequence_length: + offset = cache_position[0] + else: + offset = 0 + mask_shape = attention_mask.shape + mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype + causal_mask[ + : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] + ] = mask_slice + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + ): + # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400). + is_tracing = ( + torch.jit.is_tracing() + or isinstance(input_tensor, torch.fx.Proxy) + or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) + ) + if not is_tracing and torch.any(attention_mask != 1): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + +# adopted from https://github.com/huggingface/transformers/blob/1b3dba9417eebe16b7c206d1dfca6a4c7f11dbec/src/transformers/models/dbrx/modeling_dbrx.py#L1204 +def _dbrx_update_causal_mask_latest( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, +): + from transformers.modeling_attn_mask_utils import AttentionMaskConverter + + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + # difference with original modeling + # using minimum from dtype with larger bandwith (floa32) may lead to overflow + # during execution on platforms with default lower precision (bfloat16, float16) + min_dtype = torch.finfo(torch.float16).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + if attention_mask is not None and attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + causal_mask = attention_mask + else: + # difference with original modeling + causal_mask = ( + torch.full((sequence_length, target_length), fill_value=1, dtype=dtype, device=device) * min_dtype + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + +if is_transformers_version(">", "4.40.2"): + _dbrx_update_causal_mask = _dbrx_update_causal_mask_latest +else: + _dbrx_update_causal_mask = _dbrx_update_causal_mask_legacy + + +class DBRXModelPatcher(DecoderModelPatcher): + def __enter__(self): + super().__enter__() + # dbrx has some accuracy issues with bf16 with transformers >= 4.40 + # fill causal mask in slightly different way for avoid overflow on some platforms + self._model.transformer._orig_update_causal_mask = self._model.transformer._update_causal_mask + self._model.transformer._update_causal_mask = types.MethodType( + _dbrx_update_causal_mask, self._model.transformer + ) + + for block in self._model.transformer.blocks: + rotary_emb = block.norm_attn_norm.attn.rotary_emb + # initialize inv_freq for torchscript tracing + if rotary_emb.inv_freq is None: + inv_freq = 1.0 / ( + rotary_emb.base ** (torch.arange(0, rotary_emb.dim, 2, dtype=torch.int64).float() / rotary_emb.dim) + ) + rotary_emb.inv_freq = inv_freq + # remove continue-operator from iteration loop over experts + block.ffn.experts._orig_forward = block.ffn.experts.forward + block.ffn.experts.forward = types.MethodType(_dbrx_experts_forward, block.ffn.experts) + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + self._model.transformer._update_causal_mask = self._model.transformer._orig_update_causal_mask + for block in self._model.transformer.blocks: + block.ffn.experts.forward = block.ffn.experts._orig_forward diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index 692720a972..cb5ac52ed7 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -527,6 +527,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase): "bloom", "chatglm", "codegen", + "codegen2", # "data2vec-text", # TODO : enable when enabled in exporters "gemma", "gpt2", @@ -561,6 +562,8 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase): "aquila2", "xverse", "internlm", + "dbrx", + "qwen2-moe", ) GENERATION_LENGTH = 100 REMOTE_CODE_MODELS = ( @@ -577,6 +580,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase): "aquila2", "xverse", "internlm", + "codegen2", ) @parameterized.expand(SUPPORTED_ARCHITECTURES) diff --git a/tests/openvino/utils_tests.py b/tests/openvino/utils_tests.py index d4364d192a..91500cfc63 100644 --- a/tests/openvino/utils_tests.py +++ b/tests/openvino/utils_tests.py @@ -37,9 +37,11 @@ "cohere": "hf-internal-testing/tiny-random-CohereForCausalLM", "chatglm": "katuni4ka/tiny-random-chatglm2", "codegen": "hf-internal-testing/tiny-random-CodeGenForCausalLM", + "codegen2": "katuni4ka/tiny-random-codegen2", "data2vec_text": "hf-internal-testing/tiny-random-Data2VecTextModel", "data2vec_vision": "hf-internal-testing/tiny-random-Data2VecVisionModel", "data2vec_audio": "hf-internal-testing/tiny-random-Data2VecAudioModel", + "dbrx": "katuni4ka/tiny-random-dbrx", "deberta": "hf-internal-testing/tiny-random-deberta", "deberta_v2": "hf-internal-testing/tiny-random-DebertaV2Model", "deit": "hf-internal-testing/tiny-random-deit", @@ -91,6 +93,7 @@ "poolformer": "hf-internal-testing/tiny-random-PoolFormerModel", "qwen": "katuni4ka/tiny-random-qwen", "qwen2": "Qwen/Qwen1.5-0.5B", + "qwen2-moe": "katuni4ka/tiny-random-qwen1.5-moe", "resnet": "hf-internal-testing/tiny-random-resnet", "roberta": "hf-internal-testing/tiny-random-roberta", "roformer": "hf-internal-testing/tiny-random-roformer",