Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add bf16 inference for llm model (#387) #388

Merged
merged 1 commit into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ def __init__(self, kvargs):
self.max_seq_length = kvargs.get("max_seq_length", 1024 * 5)
self.return_all_prompt_logprobs = kvargs.get("return_all_prompt_logprobs", False)
self.use_dynamic_prompt_cache = kvargs.get("use_dynamic_prompt_cache", False)

self.data_type = kvargs.get("data_type", "float16")

self._init_datatype()
self._init_config()
self._verify_must()
self._verify_params()
Expand Down Expand Up @@ -80,16 +82,16 @@ def _verify_params(self):

def _init_weights(self):
self.pre_post_weight = self.pre_and_post_weight_class(
self.tp_rank_, self.world_size_, torch.float16, network_config=self.config, mode=self.mode
self.tp_rank_, self.world_size_, self.data_type, network_config=self.config, mode=self.mode
)
self.trans_layers_weight = [
self.transformer_weight_class(
i, self.tp_rank_, self.world_size_, torch.float16, network_config=self.config, mode=self.mode
i, self.tp_rank_, self.world_size_, self.data_type, network_config=self.config, mode=self.mode
)
for i in range(self.config["n_layer"])
]
load_hf_weights(
"fp16",
self.data_type,
weight_dir=self.weight_dir_,
pre_post_layer=self.pre_post_weight,
transformer_layer_list=self.trans_layers_weight,
Expand All @@ -103,7 +105,7 @@ def _init_mem_manager(self):
assert self.config["num_attention_heads"] % self.world_size_ == 0
self.mem_manager = MemoryManager(
self.max_total_token_num,
dtype=torch.float16,
dtype=self.data_type,
head_num=self.config["num_attention_heads"] // self.world_size_,
head_dim=self.config["n_embed"] // self.config["num_attention_heads"],
layer_num=self.config["n_layer"],
Expand Down Expand Up @@ -137,6 +139,16 @@ def _init_some_value(self):
self.vocab_size = self.config["vocab_size"]
return

def _init_datatype(self):
if self.data_type in ["fp16", "float16"]:
self.data_type = torch.float16
elif self.data_type in ["bf16", "bfloat16"]:
self.data_type = torch.bfloat16
elif self.data_type in ["fp32", "float32"]:
self.data_type =torch.float32
else:
raise ValueError(f"Unsupport datatype {self.data_type}!")

def _init_custom(self):
pass

Expand Down Expand Up @@ -223,7 +235,7 @@ def _prefill(
infer_state.mem_index = alloc_mem
infer_state.kv_buffer = torch.empty(
(input_ids.shape[0], self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_),
dtype=torch.float16,
dtype=self.data_type,
device="cuda",
)

Expand Down Expand Up @@ -279,7 +291,7 @@ def _decode(
infer_state.mem_index = alloc_mem
infer_state.kv_buffer = torch.empty(
(batch_size, self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_),
dtype=torch.float16,
dtype=self.data_type,
device="cuda",
)
copy_kv_index_to_req(self.req_manager.req_to_token_indexs, b_req_idx, b_seq_len, infer_state.mem_index)
Expand Down Expand Up @@ -341,7 +353,7 @@ def splitfuse_forward(
infer_state.mem_index = alloc_mem
infer_state.kv_buffer = torch.empty(
(alloc_size, self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_),
dtype=torch.float16,
dtype=self.data_type,
device="cuda",
)

Expand Down
3 changes: 2 additions & 1 deletion lightllm/common/basemodel/layer_weights/hf_load_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def load_func(file_, use_safetensors=False, pre_post_layer=None, transformer_lay


def load_hf_weights(data_type, weight_dir, pre_post_layer=None, transformer_layer_list=None, weight_dict=None):
data_type = torch.float16 if data_type == 'fp16' else torch.float32
if isinstance(data_type, str):
data_type = torch.float16 if data_type == 'fp16' else torch.float32
if pre_post_layer is not None:
assert pre_post_layer.data_type_ == data_type, "type is not right"
if transformer_layer_list is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ def load_hf_weights(self, weights):
(self.tp_rank_ + 1), :])
if 'lm_head.weight' in weights:
# print(weights['lm_head.weight'].shape)
self.lm_head_weight_ = nn.functional.normalize(weights['lm_head.weight'].to(
torch.float16).cuda())[split_vob_size * self.tp_rank_:split_vob_size * (self.tp_rank_ + 1), :]
self.lm_head_weight_ = self._cuda(
nn.functional.normalize(weights['lm_head.weight'])[split_vob_size * self.tp_rank_:split_vob_size * (self.tp_rank_ + 1), :])
if 'model.norm.weight' in weights:
self.final_norm_weight_ = self._cuda(weights['model.norm.weight'])

Expand Down
4 changes: 2 additions & 2 deletions lightllm/models/bloom/layer_infer/post_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def soft_max(self, data):

def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight: BloomPreAndPostLayerWeight, return_logics=False):
batch_size = infer_state.batch_size
last_input = torch.empty((batch_size, self.embed_dim_), device=input_embdings.device, dtype=torch.float16)
last_input = torch.empty((batch_size, self.embed_dim_), device=input_embdings.device, dtype=input_embdings.dtype)
if infer_state.is_prefill:
last_index = torch.cumsum(infer_state.b_seq_len - infer_state.b_ready_cache_len, dim=0, dtype=torch.long) - 1
last_input[:, :] = input_embdings[last_index, :]
Expand All @@ -44,7 +44,7 @@ def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weigh
if self.world_size_ == 1:
gather_data = logic_batch
else:
gather_data = torch.empty((self.vocab_size_, batch_size), device=logic_batch.device, dtype=torch.float16)
gather_data = torch.empty((self.vocab_size_, batch_size), device=logic_batch.device, dtype=input_embdings.dtype)
split_size = self.vocab_size_ // self.world_size_
dist.all_gather([gather_data[i * split_size: (i + 1) * split_size, :]
for i in range(self.world_size_)], logic_batch, group=None, async_op=False)
Expand Down
3 changes: 2 additions & 1 deletion lightllm/models/bloom/layer_weights/hf_load_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@


def load_hf_weights(data_type, weight_dir, pre_post_layer=None, transformer_layer_list=None, weight_dict=None):
data_type = torch.float16 if data_type == 'fp16' else torch.float32
if isinstance(data_type, str):
data_type = torch.float16 if data_type == 'fp16' else torch.float32
if pre_post_layer is not None:
assert pre_post_layer.data_type_ == data_type, "type is not right"
if transformer_layer_list is not None:
Expand Down
6 changes: 3 additions & 3 deletions lightllm/models/bloom/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@ def _reset_num_key_value_heads(self):
return

def _init_weights(self):
self.pre_post_weight = self.pre_and_post_weight_class(self.tp_rank_, self.world_size_, torch.float16, network_config=self.config, mode=self.mode)
self.pre_post_weight = self.pre_and_post_weight_class(self.tp_rank_, self.world_size_, self.data_type, network_config=self.config, mode=self.mode)
self.trans_layers_weight = [
self.transformer_weight_class(i, self.tp_rank_, self.world_size_, torch.float16, network_config=self.config, mode=self.mode)
self.transformer_weight_class(i, self.tp_rank_, self.world_size_, self.data_type, network_config=self.config, mode=self.mode)
for i in range(self.config["n_layer"])
]
load_hf_weights(
"fp16",
self.data_type,
weight_dir=self.weight_dir_,
pre_post_layer=self.pre_post_weight,
transformer_layer_list=self.trans_layers_weight,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def _fwd_kernel_token_att2(
v_value = tl.load(V + v_offs + v_loc[:, None] * stride_vbs, mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0)
acc += tl.sum(p_value[:, None] * v_value, 0)

acc = acc.to(tl.float16)
acc = acc.to(Out.dtype.element_ty)
off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od
out_ptrs = Out + off_o
tl.store(out_ptrs, acc)
Expand Down
4 changes: 2 additions & 2 deletions lightllm/models/chatglm2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,6 @@ def _init_to_get_rotary(self, base=10000):
t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor
freqs = torch.outer(t, inv_freq)

self._cos_cached = torch.cos(freqs).to(torch.float16).cuda()
self._sin_cached = torch.sin(freqs).to(torch.float16).cuda()
self._cos_cached = torch.cos(freqs).to(self.data_type).cuda()
self._sin_cached = torch.sin(freqs).to(self.data_type).cuda()
return
6 changes: 3 additions & 3 deletions lightllm/models/gemma_2b/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def _init_custom(self):

def _init_mem_manager(self):
self.mem_manager = MemoryManager(self.max_total_token_num,
dtype=torch.float16,
dtype=self.data_type,
head_num=self.config["num_key_value_heads"], # [SYM] always == 1
head_dim=self.config["hidden_size"] // self.config["num_attention_heads"],
layer_num=self.config["num_hidden_layers"])
Expand Down Expand Up @@ -73,7 +73,7 @@ def _init_to_get_rotary(self, default_base=10000):
t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor
freqs = torch.outer(t, inv_freq)

self._cos_cached = torch.cos(freqs).to(torch.float16).cuda()
self._sin_cached = torch.sin(freqs).to(torch.float16).cuda()
self._cos_cached = torch.cos(freqs).to(self.data_type).cuda()
self._sin_cached = torch.sin(freqs).to(self.data_type).cuda()
return

2 changes: 1 addition & 1 deletion lightllm/models/gemma_2b/triton_kernel/gelu_and_mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _gelu_and_mul_kernel(
).to(tl.float32)

gate = gelu(gate)
gate = gate.to(tl.float16)
gate = gate.to(input_ptr.dtype.element_ty)

tl.store(
input_ptr + res_offsets,
Expand Down
6 changes: 3 additions & 3 deletions lightllm/models/llama/layer_infer/post_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def _slice_get_last_input(self, input_embdings, infer_state: LlamaInferStateInfo
if infer_state.is_splitfuse:
# for SplitFuse
batch_size = infer_state.batch_size
last_input = torch.empty((batch_size, self.embed_dim_), device=input_embdings.device, dtype=torch.float16)
last_input = torch.empty((batch_size, self.embed_dim_), device=input_embdings.device, dtype=input_embdings.dtype)
tmp_ = torch.cat(
[
torch.ones(infer_state.decode_req_num, dtype=torch.int32, device="cuda"),
Expand All @@ -44,7 +44,7 @@ def _slice_get_last_input(self, input_embdings, infer_state: LlamaInferStateInfo

if not infer_state.is_splitfuse and infer_state.is_prefill and not infer_state.return_all_prompt_logprobs:
batch_size = infer_state.batch_size
last_input = torch.empty((batch_size, self.embed_dim_), device=input_embdings.device, dtype=torch.float16)
last_input = torch.empty((batch_size, self.embed_dim_), device=input_embdings.device, dtype=input_embdings.dtype)
last_index = (
torch.cumsum(infer_state.b_seq_len - infer_state.b_ready_cache_len, dim=0, dtype=torch.long) - 1
)
Expand Down Expand Up @@ -81,7 +81,7 @@ def token_forward(
if self.world_size_ == 1:
gather_data = logic_batch
else:
gather_data = torch.empty((self.vocab_size_, token_num), device=logic_batch.device, dtype=torch.float16)
gather_data = torch.empty((self.vocab_size_, token_num), device=logic_batch.device, dtype=input_embdings.dtype)
split_indexes = np.linspace(0, self.vocab_size_, self.world_size_ + 1, dtype=np.int64)
dist.all_gather(
[gather_data[split_indexes[i] : split_indexes[i + 1], :] for i in range(self.world_size_)],
Expand Down
30 changes: 15 additions & 15 deletions lightllm/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def _verify_params(self):

def _init_mem_manager(self):
self.mem_manager = select_mem_manager_class(self.mode)(self.max_total_token_num,
dtype=torch.float16,
dtype=self.data_type,
head_num=self.config["num_key_value_heads"] // self.world_size_,
head_dim=self.config["hidden_size"] // self.config["num_attention_heads"],
layer_num=self.config["num_hidden_layers"])
Expand All @@ -74,21 +74,21 @@ def _init_custom(self):
return

def _init_weights(self):
self.pre_post_weight = self.pre_and_post_weight_class(self.tp_rank_, self.world_size_, torch.float16, network_config=self.config, mode=self.mode)
self.pre_post_weight = self.pre_and_post_weight_class(self.tp_rank_, self.world_size_, self.data_type, network_config=self.config, mode=self.mode)
self.trans_layers_weight = [
self.transformer_weight_class(i, self.tp_rank_, self.world_size_, torch.float16, network_config=self.config, mode=self.mode)
self.transformer_weight_class(i, self.tp_rank_, self.world_size_, self.data_type, network_config=self.config, mode=self.mode)
for i in range(self.config["n_layer"])
]
if self.load_way == 'HF':
load_hf_weights(
"fp16",
self.data_type,
weight_dir=self.weight_dir_,
pre_post_layer=self.pre_post_weight,
transformer_layer_list=self.trans_layers_weight,
weight_dict=self.weight_dict)
else:
load_ds_weights(
"fp16",
self.data_type,
weight_dir=self.weight_dir_,
pre_post_layer=self.pre_post_weight,
transformer_layer_list=self.trans_layers_weight,
Expand Down Expand Up @@ -132,8 +132,8 @@ def _init_to_get_rotary(self, default_base=10000):
t = torch.arange(max_seq_len + 1024 * 128, device="cpu", dtype=torch.float32) / rope_scaling_factor
freqs = torch.outer(t, inv_freq)

self._cos_cached = torch.cos(freqs).to(torch.float16).cuda()
self._sin_cached = torch.sin(freqs).to(torch.float16).cuda()
self._cos_cached = torch.cos(freqs).to(self.data_type).cuda()
self._sin_cached = torch.sin(freqs).to(self.data_type).cuda()
return

def _init_to_get_dynamic_ntk_rotary(self):
Expand All @@ -145,22 +145,22 @@ def _init_to_get_dynamic_ntk_rotary(self):
else:
scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0)
max_seq_len = max(self.max_seq_length, max_position_embeddings)
self._cos_cached = torch.zeros((max_seq_len, partial_head_dim // 2), dtype=torch.float16, device="cuda")
self._sin_cached = torch.zeros((max_seq_len, partial_head_dim // 2), dtype=torch.float16, device="cuda")
self._cos_cached = torch.zeros((max_seq_len, partial_head_dim // 2), dtype=self.data_type, device="cuda")
self._sin_cached = torch.zeros((max_seq_len, partial_head_dim // 2), dtype=self.data_type, device="cuda")

inv_freq = 1.0 / (base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim))
t = torch.arange(max_position_embeddings, device="cpu", dtype=torch.float32)
freqs = torch.outer(t, inv_freq)
self._cos_cached[0:max_position_embeddings, :] = torch.cos(freqs).to(torch.float16).cuda()
self._sin_cached[0:max_position_embeddings, :] = torch.sin(freqs).to(torch.float16).cuda()
self._cos_cached[0:max_position_embeddings, :] = torch.cos(freqs).to(self.data_type).cuda()
self._sin_cached[0:max_position_embeddings, :] = torch.sin(freqs).to(self.data_type).cuda()

for seq_loc_index in range(max_position_embeddings, max_seq_len, 1):
new_base = base * ((scaling_factor * (seq_loc_index + 1) / max_position_embeddings) -(scaling_factor - 1)) ** (partial_head_dim / (partial_head_dim - 2))
inv_freq = 1.0 / (new_base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim))
t = torch.tensor([seq_loc_index,], device="cpu", dtype=torch.float32)
freqs = torch.outer(t, inv_freq)
self._cos_cached[seq_loc_index:seq_loc_index + 1, :] = torch.cos(freqs).to(torch.float16).cuda()
self._sin_cached[seq_loc_index:seq_loc_index + 1, :] = torch.sin(freqs).to(torch.float16).cuda()
self._cos_cached[seq_loc_index:seq_loc_index + 1, :] = torch.cos(freqs).to(self.data_type).cuda()
self._sin_cached[seq_loc_index:seq_loc_index + 1, :] = torch.sin(freqs).to(self.data_type).cuda()
return

def _init_to_get_yarn_rotary(self):
Expand Down Expand Up @@ -194,8 +194,8 @@ def _init_to_get_yarn_rotary(self):
freqs = torch.einsum("i,j->ij", t, inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self._cos_cached = emb.cos().to(torch.float16).cuda() * mscale
self._sin_cached = emb.sin().to(torch.float16).cuda() * mscale
self._cos_cached = emb.cos().to(self.data_type).cuda() * mscale
self._sin_cached = emb.sin().to(self.data_type).cuda() * mscale

return

Expand Down
2 changes: 1 addition & 1 deletion lightllm/models/llama/triton_kernel/rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def _rms_norm_fwd_fused(
x_hat = x * rstd
y = x_hat * w
# Write output
tl.store(Y + cols, y.to(tl.float16), mask=mask)
tl.store(Y + cols, y.to(Y.dtype.element_ty), mask=mask)


def rmsnorm_forward(x, weight, eps):
Expand Down
2 changes: 1 addition & 1 deletion lightllm/models/llama/triton_kernel/silu_and_mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _silu_and_mul_kernel(
).to(tl.float32)

gate = gate / (1 + tl.exp(-gate))
gate = gate.to(tl.float16)
gate = gate.to(input_ptr.dtype.element_ty)

tl.store(
input_ptr + res_offsets,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,8 @@ def _fwd_kernel_int8(
vs_ptrs + kv_loc[:, None] * stride_vsbs, mask=(start_n + offs_n)[:, None] < cur_batch_seq_len, other=0.0
)

p = p.to(tl.float16)
acc += tl.dot(p, v.to(tl.float16) * v_scale)
p = p.to(V.dtype.element_ty)
acc += tl.dot(p, v.to(V.dtype.element_ty) * v_scale)

# update m_i and l_i
l_i = l_i_new
Expand Down
6 changes: 3 additions & 3 deletions lightllm/models/mistral/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _init_custom(self):

def _init_mem_manager(self):
self.mem_manager = MemoryManager(self.max_total_token_num, # [SYM] should be sliding window?
dtype=torch.float16,
dtype=self.data_type,
head_num=self.config["num_key_value_heads"] // self.world_size_,
head_dim=self.config["hidden_size"] // self.config["num_attention_heads"],
layer_num=self.config["num_hidden_layers"],
Expand Down Expand Up @@ -79,7 +79,7 @@ def _init_to_get_rotary(self, default_base=10000):
t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor
freqs = torch.outer(t, inv_freq)

self._cos_cached = torch.cos(freqs).to(torch.float16).cuda()
self._sin_cached = torch.sin(freqs).to(torch.float16).cuda()
self._cos_cached = torch.cos(freqs).to(self.data_type).cuda()
self._sin_cached = torch.sin(freqs).to(self.data_type).cuda()
return

Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _fwd_kernel_token_att2(
v_value = tl.load(V + v_offs + v_loc[:, None] * stride_vbs, mask=(start_n + offs_n[:, None] + cur_batch_start_index) < cur_batch_seq_len, other=0.0) # [1, D] + [64, 1] = [64, D]
acc += tl.sum(p_value[:, None] * v_value, 0) # [64, 1] * [64, D] = [64, D] -> [D]

acc = acc.to(tl.float16)
acc = acc.to(Out.dtype.element_ty)
off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od
out_ptrs = Out + off_o
tl.store(out_ptrs, acc)
Expand Down
Loading
Loading