Skip to content

Commit

Permalink
Added evict in decode + general cleanup and refactor.
Browse files Browse the repository at this point in the history
  • Loading branch information
raikonenfnu committed Jan 4, 2024
1 parent 73de744 commit b81e32f
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 85 deletions.
12 changes: 12 additions & 0 deletions python/shark_turbine/aot/compiled_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ def __new__(mcls, name: str, bases, dct, *, export_name: Optional[str] = None):
continue
del_attr_keys.add(key)
info.def_attribute(key, value)

for key in del_attr_keys:
del dct[key]

Expand All @@ -343,6 +344,17 @@ def __new__(mcls, name: str, bases, dct, *, export_name: Optional[str] = None):
if key not in dct:
dct[key] = _blackhole_instance_attribute

# Inheritting methods, globals, and export from parent class.
# Use case such as building a child-class to StatelessLlama.
for base in bases:
if base is CompiledModule:
continue
base_exports = _all_compiled_module_class_infos[base].all_exports
for export_name in base_exports:
if export_name in info.all_exports:
continue
info.all_exports[export_name] = base_exports[export_name]

# Finish construction.
new_class = type.__new__(mcls, name, bases, dct)
_all_compiled_module_class_infos[new_class] = info
Expand Down
44 changes: 22 additions & 22 deletions python/turbine_models/custom_models/llm_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
help="local-sync, local-task, cuda, vulkan, rocm",
)
parser.add_argument(
"--init_cache",
"--streaming_llm",
type=bool,
default=False,
help="Use KV-Cache in between user prompts/multi-dialogue.",
Expand All @@ -65,48 +65,50 @@ def append_user_prompt(history, input_prompt):


def append_bot_prompt(history, input_prompt):
user_prompt = f"{B_SYS} {input_prompt} {E_SYS}"
user_prompt = f"{B_SYS} {input_prompt}{E_SYS} {E_SYS}"
history += user_prompt
return history


class SharkLLM(object):
def __init__(self, device, vmfb_path, external_weight_path, init_cache=False):
def __init__(self, device, vmfb_path, external_weight_path, streaming_llm=False):
self.runner = vmfbRunner(
device=device,
vmfb_path=vmfb_path,
external_weight_path=external_weight_path,
)
if streaming_llm:
self.model = self.runner.ctx.modules.streaming_state_update
else:
self.model = self.runner.ctx.modules.state_update
self.first_input = True
self.num_tokens = 0
self.last_prompt = None
self.init_cache = init_cache
self.streaming_llm = streaming_llm
self.prev_token_len = 0

def format_out(self, results):
return torch.tensor(results.to_host()[0][0])

def evict_kvcache_space(self):
self.runner.ctx.modules.state_update["evict_kvcache_space"]()
self.model["evict_kvcache_space"]()

def generate(self, input_ids):
# TODO: Replace with args.
if self.init_cache and self.runner.ctx.modules.state_update["get_seq_step"]() > 600:
if self.streaming_llm and self.model["get_seq_step"]() > 600:
print("Evicting cache space!")
self.runner.ctx.modules.state_update["evict_kvcache_space"]()
self.model["evict_kvcache_space"]()
turbine_results = []
# Only need not seen token for init cache
# Because we have stored the res in KV-cache.
token_len = input_ids.shape[-1]
if self.init_cache:
if self.streaming_llm:
token_slice = max(self.prev_token_len - 1, 0)
input_ids = input_ids[:, token_slice:]
inputs = [ireert.asdevicearray(self.runner.config.device, input_ids)]
if self.first_input or not self.init_cache:
if self.first_input or not self.streaming_llm:
s = time.time()
results = self.runner.ctx.modules.state_update["run_initialize"](
*inputs
) # example_input_id
results = self.model["run_initialize"](*inputs) # example_input_id
e = time.time()
print(
f"num_tokens: {token_len}, time_taken={e-s}, tok/second:{token_len/(e-s)}"
Expand All @@ -115,20 +117,18 @@ def generate(self, input_ids):
self.first_input = False
else:
s = time.time()
results = self.runner.ctx.modules.state_update["run_cached_initialize"](
*inputs
) # example_input_id
results = self.model["run_cached_initialize"](*inputs) # example_input_id
e = time.time()
print(
f"Cached num_tokens: {token_len}, time_taken={e-s}, tok/second:{token_len/(e-s)}"
)
token_len += 1
s = time.time()
while self.format_out(results) != 2:
if self.init_cache and self.runner.ctx.modules.state_update["get_seq_step"]() > 600:
if self.streaming_llm and self.model["get_seq_step"]() > 600:
print("Evicting cache space!")
self.runner.ctx.modules.state_update["evict_kvcache_space"]()
results = self.runner.ctx.modules.state_update["run_forward"](results)
self.model["evict_kvcache_space"]()
results = self.model["run_forward"](results)
# uncomment to see tokens as they are emitted
# print(f"turbine: {tokenizer.decode(self.format_out(results))}")
turbine_results.append(self.format_out(results))
Expand All @@ -148,7 +148,7 @@ def run_llm(
hf_model_name,
hf_auth_token,
external_weight_path,
init_cache,
streaming_llm,
):
runner = vmfbRunner(
device=device, vmfb_path=vmfb_path, external_weight_path=external_weight_path
Expand All @@ -162,7 +162,7 @@ def run_llm(
device=device,
vmfb_path=vmfb_path,
external_weight_path=external_weight_path,
init_cache=init_cache,
streaming_llm=streaming_llm,
)
prompt = system_prompt
while True:
Expand All @@ -171,7 +171,7 @@ def run_llm(
initial_input = tokenizer(prompt, return_tensors="pt")
example_input_id = initial_input.input_ids
result = llm.generate(example_input_id)
bot_response = tokenizer.decode(result)
bot_response = tokenizer.decode(result, skip_special_tokens=True)
print(f"\nBOT: {bot_response}\n")
prompt = append_bot_prompt(prompt, bot_response)

Expand All @@ -186,5 +186,5 @@ def run_llm(
args.hf_model_name,
args.hf_auth_token,
args.external_weight_path,
args.init_cache,
args.streaming_llm,
)
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ def llama_pos_shift_attention_forward(

if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)

### Shift Pos: key pos is the pos in cache
key_position_ids = torch.arange(kv_seq_len, device=position_ids.device).unsqueeze(0)
Expand Down
145 changes: 83 additions & 62 deletions python/turbine_models/custom_models/stateless_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from torch.utils import _pytree as pytree
from shark_turbine.aot import *
from iree.compiler.ir import Context
from llm_optimizations.streaming_llm.modify_llama import enable_llama_pos_shift_attention
from llm_optimizations.streaming_llm.modify_llama import (
enable_llama_pos_shift_attention,
)

from turbine_models.custom_models import remap_gguf
import safetensors
Expand Down Expand Up @@ -41,7 +43,6 @@
parser.add_argument(
"--precision", type=str, default="fp16", help="dtype of model [f16, f32]"
)

parser.add_argument(
"--device", type=str, default="llvm-cpu", help="llvm-cpu, cuda, vulkan, rocm"
)
Expand All @@ -53,6 +54,12 @@
help="Specify vulkan target triple or rocm/cuda target device.",
)
parser.add_argument("--vulkan_max_allocation", type=str, default="4294967296")
parser.add_argument(
"--streaming_llm",
type=bool,
default=False,
help="Compile LLM with StreamingLLM optimizations",
)

# TODO (Dan): replace this with a file once I figure out paths on windows exe
json_schema = """
Expand Down Expand Up @@ -86,6 +93,7 @@ def export_transformer_model(
device=None,
target_triple=None,
vulkan_max_allocation=None,
streaming_llm=False,
):
state_schema = pytree.treespec_loads(json_schema)

Expand Down Expand Up @@ -134,7 +142,7 @@ class StateUpdateModule(CompiledModule):
else:
params = export_parameters(mod)
global_state = export_global(
abstractify(global_pkv), uninitialized=True, mutable=True
abstractify(global_pkv), uninitialized=False, mutable=True
)
global_seq_step = export_global(AbstractIndex, mutable=True)

Expand All @@ -153,57 +161,6 @@ def run_initialize(self, x=AbstractTensor(BATCH_SIZE, None, dtype=torch.int64)):
)
return token

def run_cached_initialize(self, x=AbstractTensor(BATCH_SIZE, None, dtype=torch.int64)):
state_arg = slice_up_to_step(
self.global_state, self.global_seq_step, HEADS, HIDDEN_DIM
)
forw_const = (
[x.dynamic_dim(1) < MAX_STEP_SEQ]
+ [state_arg[0].dynamic_dim(1) < MAX_STEP_SEQ]
+ [
x.dynamic_dim(1) == (state_arg[0].dynamic_dim(1))
for x in state_arg[1:]
]
+ [x.dynamic_dim(1) < MAX_STEP_SEQ for x in state_arg[1:]]
)
token, *state = self.cached_initialize(x, *state_arg, constraints=forw_const)
len_of_new_tokens = IREE.tensor_dim(
state[0], 1
) # ? dimension of arbitrarily 0th kv tensor
for i in range(HEADS * 2):
slice_of_state = IREE.tensor_reshape(
state[i], 1, 1, len_of_new_tokens, HEADS, HIDDEN_DIM
)
self.global_state = IREE.tensor_update(
self.global_state, slice_of_state, i, 0, self.global_seq_step, 0, 0
)
self.global_seq_step = self.global_seq_step + len_of_new_tokens
return token

# Streaming-LLM KVCache evict algorithm:
# slice1 = KVCache[0 : sink]
# slice2 = KVCache[seq_len - window_size : seq_len]
# KVCache = torch.cat([slice1, slice2])
# TODO: There is actual overlap of data.
# For e.g at token length 600, sink size 4, and window size 508
# Then KVCache[4:512] going to be replaced by KVCache[600-508: (600-508)+508]
# => KVCache[4:512] = KVCache[92:600] => Much overlap of data(i.e 92->512)
# => We'd need to do a copy and then replace. Or we can make the gap at least 2X.
def evict_kvcache_space(self):
# TODO: Replace hardcoded with global variable.
sink_size = 4
window_size = 252
most_recent_window = self.global_seq_step + (-window_size)
for i in range(HEADS * 2):
update_window_state = IREE.tensor_slice(
self.global_state, i, 0, (most_recent_window, window_size), (0, HEADS), (0, HIDDEN_DIM)
) # sequence context dim
self.global_state = IREE.tensor_update(
self.global_state, update_window_state, i, 0, sink_size, 0, 0
)
self.global_seq_step = self.global_seq_step.set(window_size + sink_size)
return self.global_seq_step

def run_forward(self, x=AbstractTensor(1, 1, dtype=torch.int64)):
state_arg = slice_up_to_step(
self.global_state, self.global_seq_step, HEADS, HIDDEN_DIM
Expand Down Expand Up @@ -244,32 +201,95 @@ def initialize(input_ids):
return token1, *state1_flat

@jittable
def cached_initialize(input_ids, *state0_flat):
def forward(token0: torch.Tensor, *state0_flat):
# Unpad the states.
cur_token_len = state0_flat[0].size(1)
state0_flat = [torch.transpose(x, 1, 2) for x in state0_flat]
state0 = pytree.tree_unflatten(state0_flat, state_schema)
result = mod.forward(input_ids, past_key_values=state0)
result = mod.forward(token0, past_key_values=state0)
state1_flat, _ = pytree.tree_flatten(result.past_key_values)
state1_flat = [torch.transpose(x[:, :, cur_token_len:, :], 1, 2) for x in state1_flat]
state1_flat = [torch.transpose(x[:, :, -1:, :], 1, 2) for x in state1_flat]
token1 = torch.argmax(result.logits[:, -1, :], dim=1)
token1 = token1[None, :]
return token1, *state1_flat

class StreamingStateUpdateModule(StateUpdateModule):
def run_cached_initialize(
self, x=AbstractTensor(BATCH_SIZE, None, dtype=torch.int64)
):
state_arg = slice_up_to_step(
self.global_state, self.global_seq_step, HEADS, HIDDEN_DIM
)
forw_const = (
[x.dynamic_dim(1) < MAX_STEP_SEQ]
+ [state_arg[0].dynamic_dim(1) < MAX_STEP_SEQ]
+ [
x.dynamic_dim(1) == (state_arg[0].dynamic_dim(1))
for x in state_arg[1:]
]
+ [x.dynamic_dim(1) < MAX_STEP_SEQ for x in state_arg[1:]]
)
token, *state = self.cached_initialize(
x, *state_arg, constraints=forw_const
)
len_of_new_tokens = IREE.tensor_dim(
state[0], 1
) # ? dimension of arbitrarily 0th kv tensor
for i in range(HEADS * 2):
slice_of_state = IREE.tensor_reshape(
state[i], 1, 1, len_of_new_tokens, HEADS, HIDDEN_DIM
)
self.global_state = IREE.tensor_update(
self.global_state, slice_of_state, i, 0, self.global_seq_step, 0, 0
)
self.global_seq_step = self.global_seq_step + len_of_new_tokens
return token

@jittable
def forward(token0: torch.Tensor, *state0_flat):
def cached_initialize(input_ids, *state0_flat):
# Unpad the states.
cur_token_len = state0_flat[0].size(1)
state0_flat = [torch.transpose(x, 1, 2) for x in state0_flat]
state0 = pytree.tree_unflatten(state0_flat, state_schema)
result = mod.forward(token0, past_key_values=state0)
result = mod.forward(input_ids, past_key_values=state0)
state1_flat, _ = pytree.tree_flatten(result.past_key_values)
state1_flat = [torch.transpose(x[:, :, -1:, :], 1, 2) for x in state1_flat]
state1_flat = [
torch.transpose(x[:, :, cur_token_len:, :], 1, 2) for x in state1_flat
]
token1 = torch.argmax(result.logits[:, -1, :], dim=1)
token1 = token1[None, :]
return token1, *state1_flat

# Streaming-LLM KVCache evict algorithm:
# slice1 = KVCache[0 : sink]
# slice2 = KVCache[seq_len - window_size : seq_len]
# KVCache = torch.cat([slice1, slice2])
# TODO: Add move to handle overlap of data.
def evict_kvcache_space(self):
# TODO: Replace hardcoded with global variable.
sink_size = 4
window_size = 252
most_recent_window = self.global_seq_step + (-window_size)
for i in range(HEADS * 2):
update_window_state = IREE.tensor_slice(
self.global_state,
i,
0,
(most_recent_window, window_size),
(0, HEADS),
(0, HIDDEN_DIM),
) # sequence context dim
self.global_state = IREE.tensor_update(
self.global_state, update_window_state, i, 0, sink_size, 0, 0
)
self.global_seq_step = self.global_seq_step.set(window_size + sink_size)
return self.global_seq_step

import_to = "INPUT" if compile_to == "linalg" else "IMPORT"
inst = StateUpdateModule(context=Context(), import_to=import_to)
if streaming_llm:
print("Compiling with Streaming LLM")
inst = StreamingStateUpdateModule(context=Context(), import_to=import_to)
else:
inst = StateUpdateModule(context=Context(), import_to=import_to)
# TODO: Integrate with external parameters to actually be able to run
# TODO: Make more generalizable to be able to quantize with all compile_to options
if quantization == "int4" and not compile_to == "linalg":
Expand Down Expand Up @@ -353,6 +373,7 @@ def forward(token0: torch.Tensor, *state0_flat):
args.device,
args.iree_target_triple,
args.vulkan_max_allocation,
args.streaming_llm,
)
safe_name = args.hf_model_name.split("/")[-1].strip()
safe_name = re.sub("-", "_", safe_name)
Expand Down

0 comments on commit b81e32f

Please sign in to comment.