Skip to content

Commit e6452fd

Browse files
hiworldwzjwangzaijunshihaobaibaishihao
authored
【Feature】 Refactor the code and add support for multi-output features like beam search. (#409)
Co-authored-by: wangzaijun <wangzaijun@sensetime.com> Co-authored-by: shihaobai <42648726+shihaobai@users.noreply.github.com> Co-authored-by: baishihao <baishihao@sensetime.com>
1 parent 4b037f6 commit e6452fd

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

58 files changed

+3076
-1127
lines changed

lightllm/common/basemodel/basemodel.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,19 @@ def __init__(self, kvargs):
3838
self.weight_dir_ = kvargs["weight_dir"]
3939
self.max_total_token_num = kvargs["max_total_token_num"]
4040
self.load_way = kvargs.get("load_way", "HF")
41-
self.mode = [m.replace('int4weight', 'w4a16').replace('int8weight', 'w8a16') for m in kvargs.get("mode", [])]
41+
self.mode = [m.replace("int4weight", "w4a16").replace("int8weight", "w8a16") for m in kvargs.get("mode", [])]
4242
self.weight_dict = kvargs.get("weight_dict", None)
4343
self.finetune_config = kvargs.get("finetune_config", None)
4444
self.max_req_num = kvargs.get("max_req_num", 1000)
4545
self.max_seq_length = kvargs.get("max_seq_length", 1024 * 5)
46-
self.return_all_prompt_logprobs = kvargs.get("return_all_prompt_logprobs", False)
46+
# is_token_healing 和 return_all_prompt_logics 是有排斥关系的两个模式,只能单独有一个生效
47+
# 主要是在prefill阶段返回多少个token的用于后续处理相关。
48+
self.is_token_healing = kvargs.get("is_token_healing", False)
49+
self.return_all_prompt_logics = kvargs.get("return_all_prompt_logics", False)
50+
assert not (self.is_token_healing and self.return_all_prompt_logics), "can not be true in same time"
4751
self.use_dynamic_prompt_cache = kvargs.get("use_dynamic_prompt_cache", False)
4852
self.data_type = kvargs.get("data_type", "float16")
49-
53+
5054
self._init_datatype()
5155
self._init_config()
5256
self._verify_must()
@@ -145,7 +149,7 @@ def _init_datatype(self):
145149
elif self.data_type in ["bf16", "bfloat16"]:
146150
self.data_type = torch.bfloat16
147151
elif self.data_type in ["fp32", "float32"]:
148-
self.data_type =torch.float32
152+
self.data_type = torch.float32
149153
else:
150154
raise ValueError(f"Unsupport datatype {self.data_type}!")
151155

@@ -204,7 +208,8 @@ def _prefill(
204208
):
205209
infer_state = self.infer_state_class()
206210
infer_state.is_prefill = True
207-
infer_state.return_all_prompt_logprobs = self.return_all_prompt_logprobs
211+
infer_state.is_token_healing = self.is_token_healing
212+
infer_state.return_all_prompt_logics = self.return_all_prompt_logics
208213
infer_state.use_dynamic_prompt_cache = self.use_dynamic_prompt_cache
209214
infer_state.batch_size = batch_size
210215
infer_state.total_token_num = total_token_num

lightllm/common/basemodel/infer_struct.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ def __init__(self):
3131
self.kv_buffer = None
3232

3333
self.is_splitfuse = False
34-
self.return_all_prompt_logprobs = False
34+
self.is_token_healing = False
35+
self.return_all_prompt_logics = False
3536
self.use_dynamic_prompt_cache = False
3637
self.multimodal_params = None
3738

lightllm/common/mem_manager.py

+13
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import os
12
import torch
23
from lightllm.utils.log_utils import init_logger
4+
from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt
35

46
logger = init_logger(__name__)
57

@@ -17,6 +19,13 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False
1719
self.mem_state = torch.zeros((size,), dtype=torch.int32, device="cuda")
1820
self.indexes = torch.arange(0, size, dtype=torch.long, device="cuda")
1921
self.can_use_mem_size = size
22+
# 用共享内存进行共享,router 模块读取进行精确的调度估计, nccl port 作为一个单机中单实列的标记。防止冲突。
23+
nccl_port = os.environ.get("_NCCL_PORT_", None)
24+
assert nccl_port is not None
25+
logger.info(f"mem manger get nccl port: {str(nccl_port)}")
26+
self.shared_can_use_token_num = SharedInt(f"{str(nccl_port)}_mem_manger_can_use_token_num")
27+
28+
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
2029
self._init_buffers(size, dtype, head_num, head_dim, layer_num)
2130

2231
def _init_buffers(self, size, dtype, head_num, head_dim, layer_num):
@@ -78,6 +87,7 @@ def add_refs(self, token_index: torch.Tensor):
7887
has_used_tokens = torch.count_nonzero(state).item()
7988
all_tokens = len(state)
8089
self.can_use_mem_size -= all_tokens - has_used_tokens
90+
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
8191
self.mem_state[token_index] += 1
8292
return
8393

@@ -89,11 +99,13 @@ def decrease_refs(self, token_index: torch.Tensor):
8999
used_tokens = torch.count_nonzero(state).item()
90100
all_tokens = len(state)
91101
self.can_use_mem_size += all_tokens - used_tokens
102+
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
92103
return
93104

94105
@torch.no_grad()
95106
def free_all(self):
96107
self.can_use_mem_size = len(self.mem_state)
108+
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
97109
self.mem_state[:] = 0
98110

99111
@torch.no_grad()
@@ -110,6 +122,7 @@ def resize_mem(self, new_size):
110122
self.mem_state = torch.zeros((size,), dtype=torch.int32, device="cuda")
111123
self.indexes = torch.arange(0, size, dtype=torch.long, device="cuda")
112124
self.can_use_mem_size = size
125+
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
113126
self._free_buffers()
114127
self._init_buffers(size, dtype, head_num, head_dim, layer_num)
115128
return

lightllm/common/req_manager.py

-2
Original file line numberDiff line numberDiff line change
@@ -37,5 +37,3 @@ def free_token(self, free_token_index):
3737
def free_all(self):
3838
self.can_use_req_size = len(self.req_state)
3939
self.req_state[:] = 0
40-
41-

lightllm/models/llama/layer_infer/post_layer_infer.py

+36-5
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@ def _slice_get_last_input(self, input_embdings, infer_state: LlamaInferStateInfo
3030
if infer_state.is_splitfuse:
3131
# for SplitFuse
3232
batch_size = infer_state.batch_size
33-
last_input = torch.empty((batch_size, self.embed_dim_), device=input_embdings.device, dtype=input_embdings.dtype)
33+
last_input = torch.empty(
34+
(batch_size, self.embed_dim_), device=input_embdings.device, dtype=input_embdings.dtype
35+
)
3436
tmp_ = torch.cat(
3537
[
3638
torch.ones(infer_state.decode_req_num, dtype=torch.int32, device="cuda"),
@@ -42,16 +44,43 @@ def _slice_get_last_input(self, input_embdings, infer_state: LlamaInferStateInfo
4244
last_input[:, :] = input_embdings[last_index, :]
4345
return last_input, batch_size
4446

45-
if not infer_state.is_splitfuse and infer_state.is_prefill and not infer_state.return_all_prompt_logprobs:
47+
if infer_state.is_prefill and infer_state.is_token_healing:
4648
batch_size = infer_state.batch_size
47-
last_input = torch.empty((batch_size, self.embed_dim_), device=input_embdings.device, dtype=input_embdings.dtype)
49+
b_seq_len_numpy = (infer_state.b_seq_len - infer_state.b_ready_cache_len).detach().cpu().numpy()
50+
select_index = []
51+
start_index = 0
52+
select_token_num = 0
53+
for cur_len in b_seq_len_numpy:
54+
if cur_len == 1:
55+
select_index.append(start_index + cur_len - 1)
56+
start_index += cur_len
57+
select_token_num += 1
58+
else:
59+
select_index.append(start_index + cur_len - 2)
60+
select_index.append(start_index + cur_len - 1)
61+
start_index += cur_len
62+
select_token_num += 2
63+
64+
last_index = torch.tensor(select_index, dtype=torch.long, device=input_embdings.device)
65+
last_input = torch.empty(
66+
(select_token_num, self.embed_dim_), device=input_embdings.device, dtype=input_embdings.dtype
67+
)
68+
69+
last_input[:, :] = input_embdings[last_index, :]
70+
return last_input, select_token_num
71+
72+
if not infer_state.is_splitfuse and infer_state.is_prefill and not infer_state.return_all_prompt_logics:
73+
batch_size = infer_state.batch_size
74+
last_input = torch.empty(
75+
(batch_size, self.embed_dim_), device=input_embdings.device, dtype=input_embdings.dtype
76+
)
4877
last_index = (
4978
torch.cumsum(infer_state.b_seq_len - infer_state.b_ready_cache_len, dim=0, dtype=torch.long) - 1
5079
)
5180
last_input[:, :] = input_embdings[last_index, :]
5281
return last_input, batch_size
5382

54-
if not infer_state.is_splitfuse and infer_state.is_prefill and infer_state.return_all_prompt_logprobs:
83+
if not infer_state.is_splitfuse and infer_state.is_prefill and infer_state.return_all_prompt_logics:
5584
total_tokens = infer_state.total_token_num
5685
return input_embdings, total_tokens
5786

@@ -82,7 +111,9 @@ def token_forward(
82111
if self.world_size_ == 1:
83112
gather_data = logic_batch
84113
else:
85-
gather_data = torch.empty((self.vocab_size_, token_num), device=logic_batch.device, dtype=input_embdings_dtype)
114+
gather_data = torch.empty(
115+
(self.vocab_size_, token_num), device=logic_batch.device, dtype=input_embdings_dtype
116+
)
86117
split_indexes = np.linspace(0, self.vocab_size_, self.world_size_ + 1, dtype=np.int64)
87118
dist.all_gather(
88119
[gather_data[split_indexes[i] : split_indexes[i + 1], :] for i in range(self.world_size_)],

lightllm/models/llava/llava_visual.py

+15-13
Original file line numberDiff line numberDiff line change
@@ -7,49 +7,49 @@
77

88

99
class LlavaVisionModel:
10-
1110
def __init__(self):
1211
pass
1312

1413
def load_model(self, weight_dir):
1514
config_file = os.path.join(weight_dir, "config.json")
1615
config = json.load(open(config_file))
17-
self.select_layer = config.get('mm_vision_select_layer', -2)
18-
self.select_feature = config.get('mm_vision_select_feature', 'patch')
16+
self.select_layer = config.get("mm_vision_select_layer", -2)
17+
self.select_feature = config.get("mm_vision_select_feature", "patch")
1918

2019
# load clip vision model by cfg['mm_vision_tower']:
2120
# huggingface_name or path_of_clip_relative_to_llava_model_dir
22-
vision_path = config.get('mm_vision_tower', 'openai/clip-vit-large-patch14-336')
21+
vision_path = config.get("mm_vision_tower", "openai/clip-vit-large-patch14-336")
2322
if isinstance(vision_path, list):
2423
vision_path = vision_path[0]
2524
if vision_path.startswith("./"):
2625
vision_path = os.path.join(weight_dir, vision_path)
2726

2827
from transformers import CLIPVisionModel, CLIPImageProcessor
28+
2929
self.image_processor = CLIPImageProcessor.from_pretrained(vision_path)
3030
self.vision_tower = CLIPVisionModel.from_pretrained(vision_path).half()
3131
self.vision_tower.requires_grad_(False)
32-
self.device = torch.device('cpu')
32+
self.device = torch.device("cpu")
3333

3434
# load projector weights
3535
self.projector_weights = {}
3636
for f in os.listdir(weight_dir):
3737
if f.endswith(".bin"):
3838
d = torch.load(os.path.join(weight_dir, f), "cpu")
3939
for k, v in d.items():
40-
if 'model.mm_projector' in k:
40+
if "model.mm_projector" in k:
4141
self.projector_weights[k] = v.half()
4242

43-
assert 'model.mm_projector.0.weight' in self.projector_weights
44-
assert 'model.mm_projector.0.bias' in self.projector_weights
45-
assert 'model.mm_projector.2.weight' in self.projector_weights
46-
assert 'model.mm_projector.2.bias' in self.projector_weights
43+
assert "model.mm_projector.0.weight" in self.projector_weights
44+
assert "model.mm_projector.0.bias" in self.projector_weights
45+
assert "model.mm_projector.2.weight" in self.projector_weights
46+
assert "model.mm_projector.2.bias" in self.projector_weights
4747

4848
def cuda(self):
4949
self.vision_tower = self.vision_tower.cuda()
5050
for k, v in self.projector_weights.items():
5151
self.projector_weights[k] = v.cuda()
52-
self.device = torch.device('cuda')
52+
self.device = torch.device("cuda")
5353
return self
5454

5555
# batch images infer
@@ -58,7 +58,7 @@ def forward(self, x):
5858

5959
x = self.vision_tower(x, output_hidden_states=True)
6060
x = x.hidden_states[self.select_layer]
61-
if self.select_feature == 'patch':
61+
if self.select_feature == "patch":
6262
x = x[:, 1:].contiguous()
6363
B, L, N = x.shape
6464
x = x.view(-1, N)
@@ -84,10 +84,12 @@ def encode(self, image_items: List[Union[str, Image.Image]]):
8484
if isinstance(item, Image.Image):
8585
image = item
8686
elif item.startswith("http://") or item.startswith("https://"):
87+
import requests
88+
8789
image = Image.open(requests.get(item, stream=True).raw)
8890
else:
8991
image = Image.open(item)
9092
images.append(image.convert("RGB"))
9193

92-
images = self.image_processor.preprocess(images, return_tensors='pt')['pixel_values']
94+
images = self.image_processor.preprocess(images, return_tensors="pt")["pixel_values"]
9395
return self.forward(images)

lightllm/models/starcoder_wquant/layer_infer/transformer_layer_infer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def _get_o(
6969
self, input, infer_state: StarcoderInferStateInfo, layer_weight: StarcoderTransformerLayerWeightQuantized
7070
) -> torch.Tensor:
7171
o_output = self._wquant_matmul_for_o(
72-
input.view(-1, self.embed_dim_), layer_weight.o_weight_, infer_state=infer_state, bias=layer_weight.o_bias_
72+
input, layer_weight.o_weight_, infer_state=infer_state, bias=layer_weight.o_bias_
7373
)
7474
return o_output
7575

lightllm/server/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .router.token_load import TokenLoad

0 commit comments

Comments
 (0)