Skip to content

Commit

Permalink
[Feature] Dynamic prompt cache (#356)
Browse files Browse the repository at this point in the history
Something like RadixAttention .
link https://lmsys.org/blog/2024-01-17-sglang/

---------

Co-authored-by: wangzaijun <wangzaijun@sensetime.com>
Co-authored-by: shihaobai <42648726+shihaobai@users.noreply.github.com>
Co-authored-by: baishihao <baishihao@sensetime.com>
Co-authored-by: Wu SiYu <wu.siyu@hotmail.com>
Co-authored-by: Siyu Wu <wusiyu1@sensetime.com>
  • Loading branch information
6 people authored Mar 18, 2024
1 parent e2b5168 commit f1f66fe
Show file tree
Hide file tree
Showing 40 changed files with 2,374 additions and 1,147 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ repos:
hooks:
- id: flake8
additional_dependencies: [flake8-typing-imports==1.9.0]
args: ['--config=.flake8', '--max-line-length=120', '--ignore=C901, E203, E266, E402, E302, E241, E902, E731, F403, E701, F405, F401, W292, W293, W503, W606']
args: ['--config=.flake8', '--max-line-length=120', '--ignore=TYP001, E722, C901, E203, E266, E402, E302, E241, E902, E731, F403, E701, F405, F401, W292, W293, W503, W606']
270 changes: 172 additions & 98 deletions lightllm/common/basemodel/basemodel.py

Large diffs are not rendered by default.

14 changes: 9 additions & 5 deletions lightllm/common/basemodel/infer_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from lightllm.common.mem_manager import MemoryManager
from lightllm.common.req_manager import ReqManager


class InferStateInfo:
"""
推理时用的信息结构体
Expand All @@ -12,23 +13,26 @@ def __init__(self):
self.total_token_num = None
self.b_req_idx = None
self.b_start_loc = None
self.b_ready_cache_len = None # only for prefill prompt cache used.
self.b_seq_len = None
# max_len_in_batch prefill 和 decode 阶段含义不同
# prefill 阶段指每个req 输入token的长度(不包括已经cache的部分)最大值
# decode 阶段指的是每个req的总长 最大值
self.max_len_in_batch = None
self.is_prefill = None

self.mem_manager: MemoryManager = None
self.req_manager: ReqManager = None

self.mem_is_contiguous = None
self.mem_index = None
self.mem_start = None
self.mem_start = None
self.mem_end = None
self.kv_buffer = None

self.is_splitfuse = False
self.return_all_prompt_logprobs = False
self.multimodal_params = None


def init_some_extra_state(self, model, input_ids : torch.Tensor):
def init_some_extra_state(self, model, input_ids: torch.Tensor):
pass
23 changes: 12 additions & 11 deletions lightllm/common/basemodel/splitfuse_infer_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from lightllm.common.mem_manager import MemoryManager
from lightllm.common.req_manager import ReqManager


class SplitFuseInferStateInfo:
"""
推理时用的信息结构体
Expand All @@ -12,26 +13,26 @@ class SplitFuseInferStateInfo:

def __init__(self):
self.batch_size = None

self.decode_req_num = None
self.decode_total_token_num = None
self.decode_b_req_idx : torch.Tensor = None
self.decode_b_start_loc : torch.Tensor = None
self.decode_b_seq_len : torch.Tensor = None
self.decode_b_req_idx: torch.Tensor = None
self.decode_b_start_loc: torch.Tensor = None
self.decode_b_seq_len: torch.Tensor = None
self.decode_max_len_in_batch = None

self.prefill_req_num = None
self.prefill_b_req_idx : torch.Tensor = None
self.prefill_b_split_start_loc : torch.Tensor = None
self.prefill_b_split_seq_len : torch.Tensor = None
self.prefill_b_req_idx: torch.Tensor = None
self.prefill_b_split_start_loc: torch.Tensor = None
self.prefill_b_split_ready_cache_len: torch.Tensor = None
self.prefill_max_split_seq_len_in_batch = None
self.prefill_b_seq_len : torch.Tensor = None
self.prefill_b_seq_len: torch.Tensor = None

self.mem_manager: MemoryManager = None
self.req_manager: ReqManager = None

self.mem_is_contiguous = None
self.mem_start = None
self.mem_start = None
self.mem_end = None
self.mem_index = None
self.kv_buffer = None
Expand Down Expand Up @@ -59,6 +60,6 @@ def create_inner_decode_infer_status(self):

self.inner_decode_infer_status = infer_status
return infer_status
def init_some_extra_state(self, model, input_ids : torch.Tensor):

def init_some_extra_state(self, model, input_ids: torch.Tensor):
pass
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,15 @@

@triton.jit
def _fwd_kernel_copy_kv_index_to_req(
req_to_token_indexs, b_req_idx, b_split_seq_len, cumsum_split_seq_len, b_seq_len, memindex,
stride_req_to_token_b, stride_req_to_token_s,
BLOCK_M: tl.constexpr
req_to_token_indexs,
b_req_idx,
b_split_seq_len,
cumsum_split_seq_len,
b_seq_len,
memindex,
stride_req_to_token_b,
stride_req_to_token_s,
BLOCK_M: tl.constexpr,
):
cur_index = tl.program_id(0)
cur_req_idx = tl.load(b_req_idx + cur_index)
Expand All @@ -18,24 +24,36 @@ def _fwd_kernel_copy_kv_index_to_req(

store_end = tl.load(b_seq_len + cur_index)
store_start = store_end - q_split_len

off_m = tl.arange(0, BLOCK_M)
for block_start in range(0, q_split_len, BLOCK_M):
read_index = tl.load(memindex + q_mem_start + block_start + off_m, mask = q_mem_start + block_start + off_m < q_mem_end, other=0)
tl.store(req_to_token_indexs + cur_req_idx * stride_req_to_token_b + (block_start + store_start + off_m), read_index,
mask = block_start + store_start + off_m < store_end)
read_index = tl.load(
memindex + q_mem_start + block_start + off_m, mask=q_mem_start + block_start + off_m < q_mem_end, other=0
)
tl.store(
req_to_token_indexs + cur_req_idx * stride_req_to_token_b + (block_start + store_start + off_m),
read_index,
mask=block_start + store_start + off_m < store_end,
)
return


@torch.no_grad()
def splitfuse_copy_kv_index_to_req(req_to_token_indexs, b_req_idx, b_split_seq_len, b_seq_len, memindex):
def splitfuse_copy_kv_index_to_req(req_to_token_indexs, b_req_idx, b_ready_cache_len, b_seq_len, memindex):
batch_size = b_seq_len.shape[0]
grid = (batch_size,)
num_warps = 1
b_split_seq_len = b_seq_len - b_ready_cache_len
cumsum_split_seq_len = torch.cumsum(b_split_seq_len, dim=0)
_fwd_kernel_copy_kv_index_to_req[grid](
req_to_token_indexs, b_req_idx, b_split_seq_len, cumsum_split_seq_len, b_seq_len, memindex,
req_to_token_indexs.stride(0), req_to_token_indexs.stride(1),
req_to_token_indexs,
b_req_idx,
b_split_seq_len,
cumsum_split_seq_len,
b_seq_len,
memindex,
req_to_token_indexs.stride(0),
req_to_token_indexs.stride(1),
BLOCK_M=32,
num_warps=num_warps,
num_stages=1,
Expand Down
14 changes: 10 additions & 4 deletions lightllm/common/infer_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
def init_req_to_token_indexes(req_to_token_indexs, b_req_idx, b_seq_len, max_len_in_batch, alloc_mem_index):
def init_req_to_token_indexes(
req_to_token_indexs, b_req_idx, b_seq_len, b_ready_cache_len, max_len_in_batch, alloc_mem_index
):
start_index = 0
b_seq_len_numpy = b_seq_len.cpu().numpy()
b_ready_cache_len_numpy = b_ready_cache_len.cpu().numpy()
b_req_idx_numpy = b_req_idx.cpu().numpy()
for i in range(len(b_seq_len)):
cur_seq_len = b_seq_len_numpy[i]
req_to_token_indexs[b_req_idx_numpy[i], 0:cur_seq_len] = alloc_mem_index[start_index:start_index + cur_seq_len]
start_index += cur_seq_len
return
cur_ready_cache_len = b_ready_cache_len_numpy[i]
req_to_token_indexs[b_req_idx_numpy[i], cur_ready_cache_len:cur_seq_len] = alloc_mem_index[
start_index : start_index + cur_seq_len - cur_ready_cache_len
]
start_index += cur_seq_len - cur_ready_cache_len
return
47 changes: 25 additions & 22 deletions lightllm/common/mem_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,59 +2,63 @@
from lightllm.utils.log_utils import init_logger

logger = init_logger(__name__)


class MemoryManager:
def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False):
self.size = size
self.size = size
self.dtype = dtype
self.head_num = head_num
self.head_dim = head_dim
self.layer_num = layer_num
self.always_copy = always_copy

# mem_state 修改为使用计数方式,方便后期实现token共享机制,实现beam search 等
self.mem_state = torch.zeros((size,), dtype=torch.int32, device="cuda")
self.indexes = torch.arange(0, size, dtype=torch.long, device="cuda")
self.can_use_mem_size = size
self._init_buffers(size, dtype, head_num, head_dim, layer_num)

def _init_buffers(self, size, dtype, head_num, head_dim, layer_num):
self.kv_buffer = [torch.empty((size, 2 * head_num, head_dim), dtype=dtype, device="cuda") for _ in range(layer_num)]

self.kv_buffer = [
torch.empty((size, 2 * head_num, head_dim), dtype=dtype, device="cuda") for _ in range(layer_num)
]

def _free_buffers(self):
self.kv_buffer = None

@torch.no_grad()
def alloc(self, need_size):
if need_size > self.can_use_mem_size:
logger.warn(f'warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}')
logger.warn(f"warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}")
return None
can_use_index = torch.nonzero(self.mem_state == 0).view(-1)
select_index = can_use_index[0 : need_size]
select_index = can_use_index[0:need_size]
self.add_refs(select_index)
return select_index

@torch.no_grad()
def alloc_contiguous(self, need_size):
if self.always_copy:
return None
if need_size > self.can_use_mem_size:
logger.warn(f'warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}')
logger.warn(f"warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}")
return None

can_use_index = torch.nonzero(self.mem_state == 0).view(-1)
can_use_index_size = len(can_use_index)
can_use_index = can_use_index[0 : can_use_index_size - need_size + 1][(can_use_index[need_size - 1: ] - can_use_index[0 : can_use_index_size - need_size + 1]) == need_size - 1]
can_use_index = can_use_index[0 : can_use_index_size - need_size + 1][
(can_use_index[need_size - 1 :] - can_use_index[0 : can_use_index_size - need_size + 1]) == need_size - 1
]
if can_use_index.shape[0] == 0:
# logger.warn(f'warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}')
return None
start = can_use_index[0].item()
end = start + need_size
select_index = self.indexes[start : end]
select_index = self.indexes[start:end]
self.add_refs(select_index)
return select_index, start, end

@torch.no_grad()
def free(self, free_index):
"""_summary_
Expand All @@ -67,7 +71,7 @@ def free(self, free_index):
if self.can_use_mem_size == len(self.mem_state):
logger.debug(f"freed all gpu mem size {self.can_use_mem_size}")
return

@torch.no_grad()
def add_refs(self, token_index: torch.Tensor):
state = self.mem_state[token_index]
Expand All @@ -76,33 +80,32 @@ def add_refs(self, token_index: torch.Tensor):
self.can_use_mem_size -= all_tokens - has_used_tokens
self.mem_state[token_index] += 1
return

@torch.no_grad()
def decrease_refs(self, token_index: torch.Tensor):
self.mem_state[token_index] -= 1
token_index, counts = token_index.unique(return_counts=True)
self.mem_state[token_index] -= counts
state = self.mem_state[token_index]
used_tokens = torch.count_nonzero(state).item()
all_tokens = len(state)
self.can_use_mem_size += all_tokens - used_tokens
return


@torch.no_grad()
def free_all(self):
self.can_use_mem_size = len(self.mem_state)
self.mem_state[:] = 0

@torch.no_grad()
def resize_mem(self, new_size):
"""
just for test code
"""
size = new_size
size = new_size
dtype = self.dtype
head_num = self.head_num
head_dim = self.head_dim
layer_num = self.layer_num
always_copy = self.always_copy

self.mem_state = torch.zeros((size,), dtype=torch.int32, device="cuda")
self.indexes = torch.arange(0, size, dtype=torch.long, device="cuda")
Expand Down
2 changes: 1 addition & 1 deletion lightllm/models/bloom/layer_infer/post_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weigh
batch_size = infer_state.batch_size
last_input = torch.empty((batch_size, self.embed_dim_), device=input_embdings.device, dtype=torch.float16)
if infer_state.is_prefill:
last_index = torch.cumsum(infer_state.b_seq_len, dim=0, dtype=torch.long) - 1
last_index = torch.cumsum(infer_state.b_seq_len - infer_state.b_ready_cache_len, dim=0, dtype=torch.long) - 1
last_input[:, :] = input_embdings[last_index, :]
else:
last_input[:, :] = input_embdings[-batch_size:, :]
Expand Down
41 changes: 31 additions & 10 deletions lightllm/models/bloom/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,37 @@ def _context_attention_kernel(
self, q, kv, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight, out=None
) -> torch.Tensor:
o_tensor = torch.empty_like(q) if out is None else out
context_attention_fwd(
q.view(-1, self.tp_q_head_num_, self.head_dim_),
kv[:, 0 : self.tp_k_head_num_, :],
kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :],
o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_),
layer_weight.tp_alibi,
infer_state.b_start_loc,
infer_state.b_seq_len,
infer_state.max_len_in_batch,
)
import triton
if triton.__version__ >= "2.1.0":
context_attention_fwd(
q.view(-1, self.tp_q_head_num_, self.head_dim_),
infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :],
infer_state.mem_manager.kv_buffer[self.layer_num_][
:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :
],
o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_),
infer_state.b_req_idx,
layer_weight.tp_alibi,
infer_state.b_start_loc,
infer_state.b_seq_len,
infer_state.b_ready_cache_len,
infer_state.max_len_in_batch,
infer_state.req_manager.req_to_token_indexs,
)
elif triton.__version__ == "2.0.0":
context_attention_fwd(
q.view(-1, self.tp_q_head_num_, self.head_dim_),
kv[:, 0 : self.tp_k_head_num_, :],
kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :],
o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_),
layer_weight.tp_alibi,
infer_state.b_start_loc,
infer_state.b_seq_len,
infer_state.max_len_in_batch,
)
else:
assert False

return o_tensor

def _token_attention_kernel(
Expand Down
Loading

0 comments on commit f1f66fe

Please sign in to comment.