Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for phi3-mini (#433) #435

Merged
merged 1 commit into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ LightLLM is a Python-based LLM (Large Language Model) inference and serving fram
- [Mixtral](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1)
- [Stablelm](https://huggingface.co/stabilityai/stablelm-2-1_6b)
- [MiniCPM](https://huggingface.co/openbmb/MiniCPM-2B-sft-bf16)
- [Phi-3](https://huggingface.co/collections/microsoft/phi-3-6626e15e9585a200d2d761e3)
- [CohereForAI](https://huggingface.co/CohereForAI/c4ai-command-r-plus)

> When you start Qwen-7b, you need to set the parameter '--eos_id 151643 --trust_remote_code'.
Expand All @@ -61,6 +62,8 @@ LightLLM is a Python-based LLM (Large Language Model) inference and serving fram

> Stablelm needs to set the parameter '--trust_remote_code'.

> Phi-3 only supports Mini and Small.

## Get started

### Requirements
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ def load_hf_weights(self, weights):
if "lm_head.weight" in weights:
# print(weights['lm_head.weight'].shape)
self.lm_head_weight_ = self._cuda(weights["lm_head.weight"][split_start:split_end, :])
tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False)
if tie_word_embeddings:
self.lm_head_weight_ = self.wte_weight_
if "model.norm.weight" in weights:
self.final_norm_weight_ = self._cuda(weights["model.norm.weight"])

Expand Down
142 changes: 111 additions & 31 deletions lightllm/models/llama/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import json
import torch
import math
from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer
from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer
from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
Expand All @@ -17,6 +18,7 @@

logger = init_logger(__name__)


class LlamaTpPartModel(TpPartBaseModel):
# weight class
pre_and_post_weight_class = LlamaPreAndPostLayerWeight
Expand All @@ -34,14 +36,14 @@ class LlamaTpPartModel(TpPartBaseModel):
def __init__(self, kvargs):
super().__init__(kvargs)
return

def _init_config(self):
super()._init_config()
# rename key
# repair_config()
self._reset_num_key_value_heads()
return
return

def _reset_num_key_value_heads(self):
if "num_key_value_heads" not in self.config:
self.config["num_key_value_heads"] = self.config["num_attention_heads"]
Expand All @@ -52,13 +54,15 @@ def _verify_params(self):
assert self.config["num_key_value_heads"] % self.world_size_ == 0
assert self.config["num_attention_heads"] % self.world_size_ == 0
return

def _init_mem_manager(self):
self.mem_manager = select_mem_manager_class(self.mode)(self.max_total_token_num,
dtype=self.data_type,
head_num=self.config["num_key_value_heads"] // self.world_size_,
head_dim=self.config["hidden_size"] // self.config["num_attention_heads"],
layer_num=self.config["num_hidden_layers"])
self.mem_manager = select_mem_manager_class(self.mode)(
self.max_total_token_num,
dtype=self.data_type,
head_num=self.config["num_key_value_heads"] // self.world_size_,
head_dim=self.config["hidden_size"] // self.config["num_attention_heads"],
layer_num=self.config["num_hidden_layers"],
)
return

def _init_custom(self):
Expand All @@ -67,37 +71,51 @@ def _init_custom(self):
"""
if self.config.get("use_rope_yarn", False):
self._init_to_get_yarn_rotary()
elif self.config.get("use_dynamic_ntk", False) or (self.config.get("rope_scaling", None) is not None and self.config.get("rope_scaling", {}).get("type", "base") == "dynamic"):
elif self.config.get("use_dynamic_ntk", False) or (
self.config.get("rope_scaling", None) is not None
and self.config.get("rope_scaling", {}).get("type", "base") == "dynamic"
):
self._init_to_get_dynamic_ntk_rotary()
elif (
self.config.get("rope_scaling", None) is not None
and self.config.get("rope_scaling", {}).get("type", "base") == "su"
):
self._init_to_su_rotary()
else:
self._init_to_get_rotary()
return

def _init_weights(self):
self.pre_post_weight = self.pre_and_post_weight_class(self.tp_rank_, self.world_size_, self.data_type, network_config=self.config, mode=self.mode)
self.pre_post_weight = self.pre_and_post_weight_class(
self.tp_rank_, self.world_size_, self.data_type, network_config=self.config, mode=self.mode
)
self.trans_layers_weight = [
self.transformer_weight_class(i, self.tp_rank_, self.world_size_, self.data_type, network_config=self.config, mode=self.mode)
self.transformer_weight_class(
i, self.tp_rank_, self.world_size_, self.data_type, network_config=self.config, mode=self.mode
)
for i in range(self.config["n_layer"])
]
if self.load_way == 'HF':
if self.load_way == "HF":
load_hf_weights(
self.data_type,
weight_dir=self.weight_dir_,
pre_post_layer=self.pre_post_weight,
transformer_layer_list=self.trans_layers_weight,
weight_dict=self.weight_dict)
weight_dict=self.weight_dict,
)
else:
load_ds_weights(
self.data_type,
weight_dir=self.weight_dir_,
pre_post_layer=self.pre_post_weight,
transformer_layer_list=self.trans_layers_weight,
weight_dict=self.weight_dict,
prefix='model.layers.',
num_layer=self.config["n_layer"])
prefix="model.layers.",
num_layer=self.config["n_layer"],
)
self.pre_post_weight.verify_load()
[weight.verify_load() for weight in self.trans_layers_weight]
return
[weight.verify_load() for weight in self.trans_layers_weight]
return

def _init_to_get_rotary(self, default_base=10000):
partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_)
Expand All @@ -112,8 +130,7 @@ def _init_to_get_rotary(self, default_base=10000):
max_seq_len = self.config["max_sequence_length"]
else:
max_position_embeddings = self.config.get(
"max_position_embeddings",
2048 if base <= 10000.0 + 1e-5 else 16384
"max_position_embeddings", 2048 if base <= 10000.0 + 1e-5 else 16384
)
max_seq_len = max_position_embeddings * rope_scaling_factor

Expand All @@ -124,11 +141,13 @@ def _init_to_get_rotary(self, default_base=10000):
if ntk_alpha > 1:
logger.info(f"Note: NTK enabled, alpha set to {ntk_alpha}")
max_seq_len *= ntk_alpha
base = base * (ntk_alpha ** (partial_head_dim / (partial_head_dim-2))) #Base change formula
base = base * (ntk_alpha ** (partial_head_dim / (partial_head_dim - 2))) # Base change formula
except:
pass

inv_freq = 1.0 / (base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim))
inv_freq = 1.0 / (
base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim)
)
t = torch.arange(max_seq_len + 1024 * 128, device="cpu", dtype=torch.float32) / rope_scaling_factor
freqs = torch.outer(t, inv_freq)

Expand All @@ -147,24 +166,37 @@ def _init_to_get_dynamic_ntk_rotary(self):
max_seq_len = max(self.max_seq_length, max_position_embeddings)
self._cos_cached = torch.zeros((max_seq_len, partial_head_dim // 2), dtype=self.data_type, device="cuda")
self._sin_cached = torch.zeros((max_seq_len, partial_head_dim // 2), dtype=self.data_type, device="cuda")

inv_freq = 1.0 / (base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim))

inv_freq = 1.0 / (
base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim)
)
t = torch.arange(max_position_embeddings, device="cpu", dtype=torch.float32)
freqs = torch.outer(t, inv_freq)
self._cos_cached[0:max_position_embeddings, :] = torch.cos(freqs).to(self.data_type).cuda()
self._sin_cached[0:max_position_embeddings, :] = torch.sin(freqs).to(self.data_type).cuda()

for seq_loc_index in range(max_position_embeddings, max_seq_len, 1):
new_base = base * ((scaling_factor * (seq_loc_index + 1) / max_position_embeddings) -(scaling_factor - 1)) ** (partial_head_dim / (partial_head_dim - 2))
inv_freq = 1.0 / (new_base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim))
t = torch.tensor([seq_loc_index,], device="cpu", dtype=torch.float32)
new_base = base * (
(scaling_factor * (seq_loc_index + 1) / max_position_embeddings) - (scaling_factor - 1)
) ** (partial_head_dim / (partial_head_dim - 2))
inv_freq = 1.0 / (
new_base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim)
)
t = torch.tensor(
[
seq_loc_index,
],
device="cpu",
dtype=torch.float32,
)
freqs = torch.outer(t, inv_freq)
self._cos_cached[seq_loc_index:seq_loc_index + 1, :] = torch.cos(freqs).to(self.data_type).cuda()
self._sin_cached[seq_loc_index:seq_loc_index + 1, :] = torch.sin(freqs).to(self.data_type).cuda()
self._cos_cached[seq_loc_index : seq_loc_index + 1, :] = torch.cos(freqs).to(self.data_type).cuda()
self._sin_cached[seq_loc_index : seq_loc_index + 1, :] = torch.sin(freqs).to(self.data_type).cuda()
return

def _init_to_get_yarn_rotary(self):
from .yarn_rotary_utils import find_correction_range, linear_ramp_mask, get_mscale

dim = self.head_dim_
max_position_embeddings = self.config.get("max_position_embeddings", 2048)
base = self.config.get("rope_theta", 10000.0)
Expand All @@ -183,10 +215,12 @@ def _init_to_get_yarn_rotary(self):
inv_freq_interpolation = 1.0 / (scale * pos_freqs)

low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_max_position_embeddings)
inv_freq_mask = (1 - linear_ramp_mask(low, high, dim // 2).float().cuda()) * extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
inv_freq_mask = (
1 - linear_ramp_mask(low, high, dim // 2).float().cuda()
) * extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask

mscale = float(get_mscale(scale) * attn_factor) # Get n-d magnitude scaling corrected for interpolation
mscale = float(get_mscale(scale) * attn_factor) # Get n-d magnitude scaling corrected for interpolation

# Build here to make `torch.jit.trace` work.
max_seq_len_cached = max_position_embeddings
Expand All @@ -199,4 +233,50 @@ def _init_to_get_yarn_rotary(self):

return

def _init_to_su_rotary(self):
rope_scaling = self.config["rope_scaling"]
short_factor = rope_scaling["short_factor"]
long_factor = rope_scaling["long_factor"]
original_max_position_embeddings = self.config["original_max_position_embeddings"]
max_position_embeddings = self.config.get("max_position_embeddings", original_max_position_embeddings)
base = self.config.get("rope_theta", 10000.0)
short_factor = torch.tensor(short_factor, dtype=torch.float32, device="cpu")
long_factor = torch.tensor(long_factor, dtype=torch.float32, device="cpu")

scale = max_position_embeddings / original_max_position_embeddings
if scale <= 1.0:
rope_scaling_factor = 1.0
else:
rope_scaling_factor = math.sqrt(1 + math.log(scale) / math.log(original_max_position_embeddings))

max_seq_len = max(self.max_seq_length, max_position_embeddings)
self._cos_cached = torch.zeros((max_seq_len, self.head_dim_ // 2), dtype=self.data_type, device="cuda")
self._sin_cached = torch.zeros((max_seq_len, self.head_dim_ // 2), dtype=self.data_type, device="cuda")

inv_freq = 1.0 / (
short_factor
* base ** (torch.arange(0, self.head_dim_, 2, device="cpu", dtype=torch.float32) / self.head_dim_)
)
t = torch.arange(original_max_position_embeddings, device="cpu", dtype=torch.float32)
freqs = torch.outer(t, inv_freq)
self._cos_cached[0:original_max_position_embeddings, :] = (
(torch.cos(freqs) * rope_scaling_factor).to(self.data_type).cuda()
)
self._sin_cached[0:original_max_position_embeddings, :] = (
(torch.sin(freqs) * rope_scaling_factor).to(self.data_type).cuda()
)

inv_freq = 1.0 / (
long_factor
* base ** (torch.arange(0, self.head_dim_, 2, device="cpu", dtype=torch.float32) / self.head_dim_)
)
t = torch.arange(original_max_position_embeddings, max_seq_len, device="cpu", dtype=torch.float32)
freqs = torch.outer(t, inv_freq)
self._cos_cached[original_max_position_embeddings:, :] = (
(torch.cos(freqs) * rope_scaling_factor).to(self.data_type).cuda()
)
self._sin_cached[original_max_position_embeddings:, :] = (
(torch.sin(freqs) * rope_scaling_factor).to(self.data_type).cuda()
)

return
Empty file.
Empty file.
90 changes: 90 additions & 0 deletions lightllm/models/phi3/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import torch
import torch.functional as F
import torch.distributed as dist
import numpy as np
from functools import partial

from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
from lightllm.models.phi3.triton_kernel.rotary_emb import rotary_emb_fwd
from lightllm.models.phi3.triton_kernel.context_flashattention_nopad import (
context_attention_fwd,
context_attention_fwd_no_prompt_cache,
)
from lightllm.models.phi3.triton_kernel.destindex_copy_kv import destindex_copy_kv
from lightllm.models.phi3.layer_weights.transformer_layer_weight import Phi3TransformerLayerWeight
from lightllm.models.llama.infer_struct import LlamaInferStateInfo


class Phi3TransformerLayerInfer(LlamaTransformerLayerInfer):
""" """

def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]):
super().__init__(layer_num, tp_rank, world_size, network_config, mode)
return

def _bind_attention(self):
self._context_attention_kernel = partial(Phi3TransformerLayerInfer._context_attention_kernel, self)
self._copy_kv_to_mem_cache = partial(Phi3TransformerLayerInfer._copy_kv_to_mem_cache_normal, self)
self._token_attention_kernel = partial(Phi3TransformerLayerInfer._token_decode_attention_flashdecoding, self)
return

def _get_qkv(self, input_emb, cache_kv, infer_state: LlamaInferStateInfo, layer_weight: Phi3TransformerLayerWeight):
q = torch.mm(input_emb.view(-1, self.embed_dim_), layer_weight.q_weight_)
torch.mm(
input_emb.view(-1, self.embed_dim_),
layer_weight.kv_weight_,
out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_),
)
rotary_emb_fwd(
q.view(-1, self.tp_q_head_num_, self.head_dim_),
cache_kv[:, 0 : self.tp_k_head_num_, :],
infer_state.position_cos,
infer_state.position_sin,
)
return q, cache_kv

def _copy_kv_to_mem_cache(self, buffer, mem_index, mem_manager):
destindex_copy_kv(buffer, mem_index, mem_manager.kv_buffer[self.layer_num_])
return

def _context_attention_kernel(
self, q, kv, infer_state: LlamaInferStateInfo, layer_weight, out=None
) -> torch.Tensor:
o_tensor = torch.empty_like(q) if out is None else out
if infer_state.use_dynamic_prompt_cache:
kv = infer_state.mem_manager.kv_buffer[self.layer_num_]
context_attention_fwd(
q.view(-1, self.tp_q_head_num_, self.head_dim_),
kv[:, 0 : self.tp_k_head_num_, :],
kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :],
o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_),
infer_state.b_req_idx,
infer_state.b_start_loc,
infer_state.b_seq_len,
infer_state.b_ready_cache_len,
infer_state.max_len_in_batch,
infer_state.req_manager.req_to_token_indexs,
)
else:
context_attention_fwd_no_prompt_cache(
q.view(-1, self.tp_q_head_num_, self.head_dim_),
kv[:, 0 : self.tp_k_head_num_, :],
kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :],
o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_),
infer_state.b_start_loc,
infer_state.b_seq_len,
infer_state.max_len_in_batch,
)

return o_tensor

def _token_decode_attention_flashdecoding(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None):
from lightllm.models.phi3.triton_kernel.flash_decoding import token_decode_attention_flash_decoding

cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :]
cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][
:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :
]
return token_decode_attention_flash_decoding(
q, infer_state, self.tp_q_head_num_, self.head_dim_, cache_k, cache_v, out=out
)
Empty file.
Loading
Loading