diff --git a/tests/test_moe_layer.py b/tests/test_moe_layer.py index 4cd596a..4c53a7e 100644 --- a/tests/test_moe_layer.py +++ b/tests/test_moe_layer.py @@ -4,6 +4,7 @@ import torch from transformers.models.llama.modeling_llama import LlamaConfig, LlamaMLP from transformers.models.phi3.modeling_phi3 import Phi3Config, Phi3MLP +from transformers.models.phi.modeling_phi import PhiConfig, PhiMLP from mixlora.model import LoraLinear, MixLoraConfig, MixLoraSparseMoe @@ -72,6 +73,25 @@ def test_llama_forward(self): input = torch.zeros(shape) output: torch.Tensor = moe_layer(input) self.assertEqual(output.shape, shape) + + def test_phi_forward(self): + mlp_layer = PhiMLP( + PhiConfig( + vocab_size=128, + hidden_size=hidden_size, + intermediate_size=hidden_size * 2, + num_hidden_layers=8, + num_attention_heads=2, + ) + ) + moe_layer = dummy_moe_layer( + "phi", mlp_layer, hidden_size, ["fc1", "fc2"] + ) + for shape in dummy_test_shapes(hidden_size): + with self.subTest(f"test for shape = {shape}"): + input = torch.zeros(shape) + output: torch.Tensor = moe_layer(input) + self.assertEqual(output.shape, shape) def test_phi3_forward(self): mlp_layer = Phi3MLP(