Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature] refactoring unit tests #9

Merged
merged 2 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 0 additions & 76 deletions tests/test_llama.py

This file was deleted.

97 changes: 97 additions & 0 deletions tests/test_moe_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import unittest
from typing import List

import torch
from transformers.models.llama.modeling_llama import LlamaConfig, LlamaMLP
from transformers.models.phi3.modeling_phi3 import Phi3Config, Phi3MLP

from mixlora.model import LoraLinear, MixLoraConfig, MixLoraSparseMoe


def dummy_moe_layer(
model_type: str,
mlp_layer: torch.nn.Module,
hidden_size: int,
mlp_projections: List[str],
):
config = MixLoraConfig.from_config(
{
"bias": "none",
"peft_type": "MIXLORA",
"r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"target_modules": [],
"routing_strategy": "mixtral",
"num_experts": 8,
"act_fn": "silu",
"top_k": 2,
"base_model_name_or_path": "DUMMY",
"task_type": "CAUSAL_LM",
}
)
config.model_type_ = model_type
moe_layer = MixLoraSparseMoe(mlp_layer, config)
gate_layer = torch.nn.Linear(hidden_size, config.num_experts_, bias=False)
torch.nn.init.normal_(gate_layer.weight)
moe_layer.gate_ = gate_layer.weight
for proj_name in mlp_projections:
base_layer: torch.nn.Linear = getattr(mlp_layer, proj_name)
torch.nn.init.normal_(base_layer.weight)
for expert_idx in range(config.num_experts_):
moe_layer.experts_[f"experts.{expert_idx}.{proj_name}"] = LoraLinear(
base_layer, config
)

return moe_layer


def dummy_test_shapes(hidden_size: int):
return [(2, 8, hidden_size), (1, 16, hidden_size), (4, 4, hidden_size)]


hidden_size = 16


class MoeLayerTestCase(unittest.TestCase):
def test_llama_forward(self):
mlp_layer = LlamaMLP(
LlamaConfig(
vocab_size=128,
hidden_size=hidden_size,
intermediate_size=hidden_size * 2,
num_hidden_layers=8,
num_attention_heads=2,
)
)
moe_layer = dummy_moe_layer(
"llama", mlp_layer, hidden_size, ["gate_proj", "down_proj", "up_proj"]
)
for shape in dummy_test_shapes(hidden_size):
with self.subTest(f"test for shape = {shape}"):
input = torch.zeros(shape)
output: torch.Tensor = moe_layer(input)
self.assertEqual(output.shape, shape)

def test_phi3_forward(self):
mlp_layer = Phi3MLP(
Phi3Config(
vocab_size=128,
hidden_size=hidden_size,
intermediate_size=hidden_size * 2,
num_hidden_layers=8,
num_attention_heads=2,
)
)
moe_layer = dummy_moe_layer(
"phi3", mlp_layer, hidden_size, ["gate_up_proj", "down_proj"]
)
for shape in dummy_test_shapes(hidden_size):
with self.subTest(f"test for shape = {shape}"):
input = torch.zeros(shape)
output: torch.Tensor = moe_layer(input)
self.assertEqual(output.shape, shape)


if __name__ == "__main__":
unittest.main()
61 changes: 0 additions & 61 deletions tests/test_phi3.py

This file was deleted.

Loading