@@ -30,7 +30,9 @@ def _slice_get_last_input(self, input_embdings, infer_state: LlamaInferStateInfo
30
30
if infer_state .is_splitfuse :
31
31
# for SplitFuse
32
32
batch_size = infer_state .batch_size
33
- last_input = torch .empty ((batch_size , self .embed_dim_ ), device = input_embdings .device , dtype = input_embdings .dtype )
33
+ last_input = torch .empty (
34
+ (batch_size , self .embed_dim_ ), device = input_embdings .device , dtype = input_embdings .dtype
35
+ )
34
36
tmp_ = torch .cat (
35
37
[
36
38
torch .ones (infer_state .decode_req_num , dtype = torch .int32 , device = "cuda" ),
@@ -42,16 +44,43 @@ def _slice_get_last_input(self, input_embdings, infer_state: LlamaInferStateInfo
42
44
last_input [:, :] = input_embdings [last_index , :]
43
45
return last_input , batch_size
44
46
45
- if not infer_state .is_splitfuse and infer_state . is_prefill and not infer_state .return_all_prompt_logprobs :
47
+ if infer_state .is_prefill and infer_state .is_token_healing :
46
48
batch_size = infer_state .batch_size
47
- last_input = torch .empty ((batch_size , self .embed_dim_ ), device = input_embdings .device , dtype = input_embdings .dtype )
49
+ b_seq_len_numpy = (infer_state .b_seq_len - infer_state .b_ready_cache_len ).detach ().cpu ().numpy ()
50
+ select_index = []
51
+ start_index = 0
52
+ select_token_num = 0
53
+ for cur_len in b_seq_len_numpy :
54
+ if cur_len == 1 :
55
+ select_index .append (start_index + cur_len - 1 )
56
+ start_index += cur_len
57
+ select_token_num += 1
58
+ else :
59
+ select_index .append (start_index + cur_len - 2 )
60
+ select_index .append (start_index + cur_len - 1 )
61
+ start_index += cur_len
62
+ select_token_num += 2
63
+
64
+ last_index = torch .tensor (select_index , dtype = torch .long , device = input_embdings .device )
65
+ last_input = torch .empty (
66
+ (select_token_num , self .embed_dim_ ), device = input_embdings .device , dtype = input_embdings .dtype
67
+ )
68
+
69
+ last_input [:, :] = input_embdings [last_index , :]
70
+ return last_input , select_token_num
71
+
72
+ if not infer_state .is_splitfuse and infer_state .is_prefill and not infer_state .return_all_prompt_logics :
73
+ batch_size = infer_state .batch_size
74
+ last_input = torch .empty (
75
+ (batch_size , self .embed_dim_ ), device = input_embdings .device , dtype = input_embdings .dtype
76
+ )
48
77
last_index = (
49
78
torch .cumsum (infer_state .b_seq_len - infer_state .b_ready_cache_len , dim = 0 , dtype = torch .long ) - 1
50
79
)
51
80
last_input [:, :] = input_embdings [last_index , :]
52
81
return last_input , batch_size
53
82
54
- if not infer_state .is_splitfuse and infer_state .is_prefill and infer_state .return_all_prompt_logprobs :
83
+ if not infer_state .is_splitfuse and infer_state .is_prefill and infer_state .return_all_prompt_logics :
55
84
total_tokens = infer_state .total_token_num
56
85
return input_embdings , total_tokens
57
86
@@ -82,7 +111,9 @@ def token_forward(
82
111
if self .world_size_ == 1 :
83
112
gather_data = logic_batch
84
113
else :
85
- gather_data = torch .empty ((self .vocab_size_ , token_num ), device = logic_batch .device , dtype = input_embdings_dtype )
114
+ gather_data = torch .empty (
115
+ (self .vocab_size_ , token_num ), device = logic_batch .device , dtype = input_embdings_dtype
116
+ )
86
117
split_indexes = np .linspace (0 , self .vocab_size_ , self .world_size_ + 1 , dtype = np .int64 )
87
118
dist .all_gather (
88
119
[gather_data [split_indexes [i ] : split_indexes [i + 1 ], :] for i in range (self .world_size_ )],
0 commit comments