Skip to content

Commit

Permalink
Merge pull request #4 from TUDB-Labs/phi_support
Browse files Browse the repository at this point in the history
[feature] support phi family models
  • Loading branch information
mikecovlee authored Jul 29, 2024
2 parents 172c842 + 46a49cf commit e40cbdc
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 62 deletions.
99 changes: 61 additions & 38 deletions mixlora/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import copy
from dataclasses import dataclass
from typing import Dict, List
from typing import Dict, List, Optional

import torch
from transformers.activations import ACT2FN
Expand All @@ -14,6 +14,30 @@ class AdapterConfig:
dtype_: torch.dtype = None


lora_target_modules = {
# LLaMA names
"q_proj": False,
"k_proj": False,
"v_proj": False,
"o_proj": False,
"gate_proj": False,
"down_proj": False,
"up_proj": False,
# Phi names
"q_proj": False,
"k_proj": False,
"v_proj": False,
"dense": False,
"fc1": False,
"fc2": False,
# Phi3 names
"qkv_proj": False,
"o_proj": False,
"gate_up_proj": False,
"down_proj": False,
}


@dataclass
class LoraConfig(AdapterConfig):
# Weight-Decomposed Low-Rank Adaptation
Expand Down Expand Up @@ -45,35 +69,28 @@ def check(self) -> "LoraConfig":

return self

def from_config(self, config: Dict[str, any]) -> "LoraConfig":
self.use_dora_ = config.get("use_dora", False)
self.use_rslora_ = config.get("use_rslora", False)
self.lora_init_ = config.get("lora_init", "original")
self.lora_r_ = config["r"]
self.lora_alpha_ = config["lora_alpha"]
self.lora_dropout_ = config["lora_dropout"]
self.target_modules_ = {
# LLaMA names
"q_proj": False,
"k_proj": False,
"v_proj": False,
"o_proj": False,
"gate_proj": False,
"down_proj": False,
"up_proj": False,
}
@staticmethod
def from_config(config: Dict[str, any]) -> "LoraConfig":
lora_config = LoraConfig()
lora_config.use_dora_ = config.get("use_dora", False)
lora_config.use_rslora_ = config.get("use_rslora", False)
lora_config.lora_init_ = config.get("lora_init", "original")
lora_config.lora_r_ = config["r"]
lora_config.lora_alpha_ = config["lora_alpha"]
lora_config.lora_dropout_ = config["lora_dropout"]
lora_config.target_modules_ = copy.deepcopy(lora_target_modules)
if isinstance(config["target_modules"], List):
for target in config["target_modules"]:
if target in self.target_modules_:
self.target_modules_[target] = True
if target in lora_target_modules:
lora_config.target_modules_[target] = True
elif isinstance(config["target_modules"], Dict):
for target, value in config["target_modules"].items():
if target in self.target_modules_:
self.target_modules_[target] = value
if target in lora_target_modules:
lora_config.target_modules_[target] = value
else:
raise ValueError("broken config item: target_modules")

return self
return lora_config

def export(self) -> Dict[str, any]:
config = {}
Expand Down Expand Up @@ -109,7 +126,7 @@ class MixLoraConfig(LoraConfig):
jitter_noise_: float = None
router_loss_: bool = True
num_experts_: int = None
act_fn_: str = None
act_fn_: Optional[str] = None
# mixtral config
top_k_: int = None

Expand Down Expand Up @@ -141,30 +158,36 @@ def check(self) -> "MixLoraConfig":

return self

def from_config(self, config: Dict[str, any]) -> "MixLoraConfig":
assert config["peft_type"] == "MIXLORA"
super().from_config(config)
@staticmethod
def from_config(config: Dict[str, any]) -> "MixLoraConfig":
lora_config = MixLoraConfig(**LoraConfig.from_config(config).__dict__)
assert (
"peft_type" in config
and config["peft_type"] == "MIXLORA"
and "routing_strategy" in config
and config["routing_strategy"] == "mixtral"
), "MixLoraConfig only supports MixLoRA models with 'mixtral' routing_strategy."
if "expert_lora" in config:
expert_config = copy.deepcopy(config)
expert_config.update(config["expert_lora"])
self.expert_config_ = LoraConfig().from_config(expert_config)
self.router_aux_loss_coef_ = config.get(
lora_config.expert_config_ = LoraConfig().from_config(expert_config)
lora_config.router_aux_loss_coef_ = config.get(
"router_aux_loss_coef", 0.001
) # for training
self.routing_strategy_ = config["routing_strategy"]
self.router_loss_ = config.get("router_loss", True)
self.num_experts_ = config["num_experts"]
lora_config.routing_strategy_ = config["routing_strategy"]
lora_config.router_loss_ = config.get("router_loss", True)
lora_config.num_experts_ = config["num_experts"]
# silu for mixtral or gelu_new for switch transformers
# left blank to automatically use the original act_fn of FFN
self.act_fn_ = config.get("act_fn", None)
if self.routing_strategy_ == "mixtral":
self.router_init_range_ = config.get("router_init_range", 0.02)
self.jitter_noise_ = config.get("jitter_noise", 0.0)
self.top_k_ = config.get("top_k", 2)
lora_config.act_fn_ = config.get("act_fn", None)
if lora_config.routing_strategy_ == "mixtral":
lora_config.router_init_range_ = config.get("router_init_range", 0.02)
lora_config.jitter_noise_ = config.get("jitter_noise", 0.0)
lora_config.top_k_ = config.get("top_k", 2)
else:
raise NotImplementedError()

return self
return lora_config

def export(self) -> Dict[str, any]:
config = super().export()
Expand Down
113 changes: 94 additions & 19 deletions mixlora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def _mixtral_slice_tensor(
"gemma2": "_llama_forward",
"qwen2": "_llama_forward",
"mistral": "_llama_forward",
"phi": "_phi_forward",
"phi3": "_phi3_forward",
}


Expand Down Expand Up @@ -66,47 +68,120 @@ def __init__(
def _llama_forward(
self, expert_mask: torch.Tensor, hidden_states: torch.Tensor, input_dtype
):
common_w1 = self.base_layer_.gate_proj(hidden_states.to(input_dtype)).to(
common_gate = self.base_layer_.gate_proj(hidden_states.to(input_dtype)).to(
hidden_states.dtype
)
common_w3 = self.base_layer_.up_proj(hidden_states.to(input_dtype)).to(
common_up = self.base_layer_.up_proj(hidden_states.to(input_dtype)).to(
hidden_states.dtype
)
final_expert_states = []
for expert_idx in range(self.num_experts_):
_, top_x = torch.where(expert_mask[expert_idx])
lora_w1: Optional[Lora] = self.experts_.get(
lora_gate: Optional[Lora] = self.experts_.get(
f"experts.{expert_idx}.gate_proj", None
)
lora_w2: Optional[Lora] = self.experts_.get(
lora_down: Optional[Lora] = self.experts_.get(
f"experts.{expert_idx}.down_proj", None
)
lora_w3: Optional[Lora] = self.experts_.get(
lora_up: Optional[Lora] = self.experts_.get(
f"experts.{expert_idx}.up_proj", None
)
if lora_w1 is not None:
if lora_gate is not None:
lora_data = _mixtral_slice_tensor(hidden_states, top_x, input_dtype)
w1 = lora_w1(
_mixtral_slice_tensor(common_w1, top_x, input_dtype), lora_data
gate_states = lora_gate(
_mixtral_slice_tensor(common_gate, top_x, input_dtype), lora_data
)
else:
lora_data = None
w1 = _mixtral_slice_tensor(common_w1, top_x, input_dtype)
gate_states = _mixtral_slice_tensor(common_gate, top_x, input_dtype)

if lora_w3 is not None:
if lora_up is not None:
lora_data = _mixtral_slice_tensor(hidden_states, top_x, input_dtype)
w3 = lora_w3(
_mixtral_slice_tensor(common_w3, top_x, input_dtype), lora_data
up_states = lora_up(
_mixtral_slice_tensor(common_up, top_x, input_dtype), lora_data
)
else:
lora_data = None
w3 = _mixtral_slice_tensor(common_w3, top_x, input_dtype)
up_states = _mixtral_slice_tensor(common_up, top_x, input_dtype)

act_result = self.act_(w1) * w3
act_result = self.act_(gate_states) * up_states

if lora_w2 is not None:
if lora_down is not None:
final_expert_states.append(
lora_w2(self.base_layer_.down_proj(act_result), act_result)
lora_down(self.base_layer_.down_proj(act_result), act_result)
)
else:
final_expert_states.append(self.base_layer_.down_proj(act_result))

return final_expert_states

def _phi_forward(
self, expert_mask: torch.Tensor, hidden_states: torch.Tensor, input_dtype
):
common_fc1 = self.base_layer_.fc1(hidden_states.to(input_dtype)).to(
hidden_states.dtype
)
final_expert_states = []
for expert_idx in range(self.num_experts_):
_, top_x = torch.where(expert_mask[expert_idx])
lora_fc1: Optional[Lora] = self.experts_.get(
f"experts.{expert_idx}.fc1", None
)
lora_fc2: Optional[Lora] = self.experts_.get(
f"experts.{expert_idx}.fc2", None
)
if lora_fc1 is not None:
lora_data = _mixtral_slice_tensor(hidden_states, top_x, input_dtype)
act_result = self.act_(
lora_fc1(
_mixtral_slice_tensor(common_fc1, top_x, input_dtype), lora_data
)
)
else:
act_result = self.act_(
_mixtral_slice_tensor(common_fc1, top_x, input_dtype)
)

if lora_fc2 is not None:
final_expert_states.append(
lora_fc2(self.base_layer_.fc2(act_result), act_result)
)
else:
final_expert_states.append(self.base_layer_.fc2(act_result))

return final_expert_states

def _phi3_forward(
self, expert_mask: torch.Tensor, hidden_states: torch.Tensor, input_dtype
):
common_gate_up = self.base_layer_.gate_up_proj(
hidden_states.to(input_dtype)
).to(hidden_states.dtype)
final_expert_states = []
for expert_idx in range(self.num_experts_):
_, top_x = torch.where(expert_mask[expert_idx])
lora_gate_up: Optional[Lora] = self.experts_.get(
f"experts.{expert_idx}.gate_up_proj", None
)
lora_down: Optional[Lora] = self.experts_.get(
f"experts.{expert_idx}.down_proj", None
)
if lora_gate_up is not None:
gate_up_states = lora_gate_up(
_mixtral_slice_tensor(common_gate_up, top_x, input_dtype),
_mixtral_slice_tensor(hidden_states, top_x, input_dtype),
)
else:
gate_up_states = _mixtral_slice_tensor(
common_gate_up, top_x, input_dtype
)

gate_states, up_states = gate_up_states.chunk(2, dim=-1)
act_result = up_states * self.act_(gate_states)

if lora_down is not None:
final_expert_states.append(
lora_down(self.base_layer_.down_proj(act_result), act_result)
)
else:
final_expert_states.append(self.base_layer_.down_proj(act_result))
Expand Down Expand Up @@ -260,9 +335,9 @@ def load_adapter_weights(
with open(
name_or_path + os.sep + "adapter_config.json", "r", encoding="utf8"
) as fp:
config = MixLoraConfig(adapter_name_=adapter_name, dtype_=dtype).from_config(
json.load(fp)
)
config = MixLoraConfig.from_config(json.load(fp))
config.adapter_name_ = adapter_name
config.dtype_ = dtype

weights = torch.load(
name_or_path + os.sep + "adapter_model.bin", map_location=device
Expand Down
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "mixlora"
version = "0.1.1"
version = "0.1.2"
description = "State-of-the-art Parameter-Efficient MoE Fine-tuning Method"
readme = "README.md"
requires-python = ">=3.8"
Expand All @@ -14,8 +14,8 @@ classifiers = [
"Operating System :: OS Independent",
]
dependencies = [
"torch==2.3.1",
"transformers==4.42.4",
"torch>=2.3.0,<2.4.0",
"transformers>=4.43.0,<4.44.0",
"huggingface_hub",
]

Expand Down
4 changes: 2 additions & 2 deletions tests/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ def main(
)
output = tokenizer.batch_decode(
outputs.detach().cpu().numpy(), skip_special_tokens=True
)[0][len(instruction) :]
)[0][input_ids.shape[-1] :]

print(output)
print(f"\nOutput: {prompter.get_response(output)}\n")


if __name__ == "__main__":
Expand Down

0 comments on commit e40cbdc

Please sign in to comment.