From b81e32fe5dda3a1ae9a570f5e18dc6ba3d6d8b4c Mon Sep 17 00:00:00 2001 From: stanley-nod Date: Thu, 4 Jan 2024 01:08:14 -0800 Subject: [PATCH] Added evict in decode + general cleanup and refactor. --- python/shark_turbine/aot/compiled_module.py | 12 ++ .../turbine_models/custom_models/llm_app.py | 44 +++--- .../streaming_llm/modify_llama.py | 4 +- .../custom_models/stateless_llama.py | 145 ++++++++++-------- 4 files changed, 120 insertions(+), 85 deletions(-) diff --git a/python/shark_turbine/aot/compiled_module.py b/python/shark_turbine/aot/compiled_module.py index 9808ffeb4..a190159e4 100644 --- a/python/shark_turbine/aot/compiled_module.py +++ b/python/shark_turbine/aot/compiled_module.py @@ -334,6 +334,7 @@ def __new__(mcls, name: str, bases, dct, *, export_name: Optional[str] = None): continue del_attr_keys.add(key) info.def_attribute(key, value) + for key in del_attr_keys: del dct[key] @@ -343,6 +344,17 @@ def __new__(mcls, name: str, bases, dct, *, export_name: Optional[str] = None): if key not in dct: dct[key] = _blackhole_instance_attribute + # Inheritting methods, globals, and export from parent class. + # Use case such as building a child-class to StatelessLlama. + for base in bases: + if base is CompiledModule: + continue + base_exports = _all_compiled_module_class_infos[base].all_exports + for export_name in base_exports: + if export_name in info.all_exports: + continue + info.all_exports[export_name] = base_exports[export_name] + # Finish construction. new_class = type.__new__(mcls, name, bases, dct) _all_compiled_module_class_infos[new_class] = info diff --git a/python/turbine_models/custom_models/llm_app.py b/python/turbine_models/custom_models/llm_app.py index f51877fa1..447454f1e 100644 --- a/python/turbine_models/custom_models/llm_app.py +++ b/python/turbine_models/custom_models/llm_app.py @@ -40,7 +40,7 @@ help="local-sync, local-task, cuda, vulkan, rocm", ) parser.add_argument( - "--init_cache", + "--streaming_llm", type=bool, default=False, help="Use KV-Cache in between user prompts/multi-dialogue.", @@ -65,48 +65,50 @@ def append_user_prompt(history, input_prompt): def append_bot_prompt(history, input_prompt): - user_prompt = f"{B_SYS} {input_prompt} {E_SYS}" + user_prompt = f"{B_SYS} {input_prompt}{E_SYS} {E_SYS}" history += user_prompt return history class SharkLLM(object): - def __init__(self, device, vmfb_path, external_weight_path, init_cache=False): + def __init__(self, device, vmfb_path, external_weight_path, streaming_llm=False): self.runner = vmfbRunner( device=device, vmfb_path=vmfb_path, external_weight_path=external_weight_path, ) + if streaming_llm: + self.model = self.runner.ctx.modules.streaming_state_update + else: + self.model = self.runner.ctx.modules.state_update self.first_input = True self.num_tokens = 0 self.last_prompt = None - self.init_cache = init_cache + self.streaming_llm = streaming_llm self.prev_token_len = 0 def format_out(self, results): return torch.tensor(results.to_host()[0][0]) def evict_kvcache_space(self): - self.runner.ctx.modules.state_update["evict_kvcache_space"]() + self.model["evict_kvcache_space"]() def generate(self, input_ids): # TODO: Replace with args. - if self.init_cache and self.runner.ctx.modules.state_update["get_seq_step"]() > 600: + if self.streaming_llm and self.model["get_seq_step"]() > 600: print("Evicting cache space!") - self.runner.ctx.modules.state_update["evict_kvcache_space"]() + self.model["evict_kvcache_space"]() turbine_results = [] # Only need not seen token for init cache # Because we have stored the res in KV-cache. token_len = input_ids.shape[-1] - if self.init_cache: + if self.streaming_llm: token_slice = max(self.prev_token_len - 1, 0) input_ids = input_ids[:, token_slice:] inputs = [ireert.asdevicearray(self.runner.config.device, input_ids)] - if self.first_input or not self.init_cache: + if self.first_input or not self.streaming_llm: s = time.time() - results = self.runner.ctx.modules.state_update["run_initialize"]( - *inputs - ) # example_input_id + results = self.model["run_initialize"](*inputs) # example_input_id e = time.time() print( f"num_tokens: {token_len}, time_taken={e-s}, tok/second:{token_len/(e-s)}" @@ -115,9 +117,7 @@ def generate(self, input_ids): self.first_input = False else: s = time.time() - results = self.runner.ctx.modules.state_update["run_cached_initialize"]( - *inputs - ) # example_input_id + results = self.model["run_cached_initialize"](*inputs) # example_input_id e = time.time() print( f"Cached num_tokens: {token_len}, time_taken={e-s}, tok/second:{token_len/(e-s)}" @@ -125,10 +125,10 @@ def generate(self, input_ids): token_len += 1 s = time.time() while self.format_out(results) != 2: - if self.init_cache and self.runner.ctx.modules.state_update["get_seq_step"]() > 600: + if self.streaming_llm and self.model["get_seq_step"]() > 600: print("Evicting cache space!") - self.runner.ctx.modules.state_update["evict_kvcache_space"]() - results = self.runner.ctx.modules.state_update["run_forward"](results) + self.model["evict_kvcache_space"]() + results = self.model["run_forward"](results) # uncomment to see tokens as they are emitted # print(f"turbine: {tokenizer.decode(self.format_out(results))}") turbine_results.append(self.format_out(results)) @@ -148,7 +148,7 @@ def run_llm( hf_model_name, hf_auth_token, external_weight_path, - init_cache, + streaming_llm, ): runner = vmfbRunner( device=device, vmfb_path=vmfb_path, external_weight_path=external_weight_path @@ -162,7 +162,7 @@ def run_llm( device=device, vmfb_path=vmfb_path, external_weight_path=external_weight_path, - init_cache=init_cache, + streaming_llm=streaming_llm, ) prompt = system_prompt while True: @@ -171,7 +171,7 @@ def run_llm( initial_input = tokenizer(prompt, return_tensors="pt") example_input_id = initial_input.input_ids result = llm.generate(example_input_id) - bot_response = tokenizer.decode(result) + bot_response = tokenizer.decode(result, skip_special_tokens=True) print(f"\nBOT: {bot_response}\n") prompt = append_bot_prompt(prompt, bot_response) @@ -186,5 +186,5 @@ def run_llm( args.hf_model_name, args.hf_auth_token, args.external_weight_path, - args.init_cache, + args.streaming_llm, ) diff --git a/python/turbine_models/custom_models/llm_optimizations/streaming_llm/modify_llama.py b/python/turbine_models/custom_models/llm_optimizations/streaming_llm/modify_llama.py index 7e89a4cd4..a496b461f 100644 --- a/python/turbine_models/custom_models/llm_optimizations/streaming_llm/modify_llama.py +++ b/python/turbine_models/custom_models/llm_optimizations/streaming_llm/modify_llama.py @@ -91,7 +91,9 @@ def llama_pos_shift_attention_forward( if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) ### Shift Pos: key pos is the pos in cache key_position_ids = torch.arange(kv_seq_len, device=position_ids.device).unsqueeze(0) diff --git a/python/turbine_models/custom_models/stateless_llama.py b/python/turbine_models/custom_models/stateless_llama.py index 198f23ae7..0b5c2f50c 100644 --- a/python/turbine_models/custom_models/stateless_llama.py +++ b/python/turbine_models/custom_models/stateless_llama.py @@ -8,7 +8,9 @@ from torch.utils import _pytree as pytree from shark_turbine.aot import * from iree.compiler.ir import Context -from llm_optimizations.streaming_llm.modify_llama import enable_llama_pos_shift_attention +from llm_optimizations.streaming_llm.modify_llama import ( + enable_llama_pos_shift_attention, +) from turbine_models.custom_models import remap_gguf import safetensors @@ -41,7 +43,6 @@ parser.add_argument( "--precision", type=str, default="fp16", help="dtype of model [f16, f32]" ) - parser.add_argument( "--device", type=str, default="llvm-cpu", help="llvm-cpu, cuda, vulkan, rocm" ) @@ -53,6 +54,12 @@ help="Specify vulkan target triple or rocm/cuda target device.", ) parser.add_argument("--vulkan_max_allocation", type=str, default="4294967296") +parser.add_argument( + "--streaming_llm", + type=bool, + default=False, + help="Compile LLM with StreamingLLM optimizations", +) # TODO (Dan): replace this with a file once I figure out paths on windows exe json_schema = """ @@ -86,6 +93,7 @@ def export_transformer_model( device=None, target_triple=None, vulkan_max_allocation=None, + streaming_llm=False, ): state_schema = pytree.treespec_loads(json_schema) @@ -134,7 +142,7 @@ class StateUpdateModule(CompiledModule): else: params = export_parameters(mod) global_state = export_global( - abstractify(global_pkv), uninitialized=True, mutable=True + abstractify(global_pkv), uninitialized=False, mutable=True ) global_seq_step = export_global(AbstractIndex, mutable=True) @@ -153,57 +161,6 @@ def run_initialize(self, x=AbstractTensor(BATCH_SIZE, None, dtype=torch.int64)): ) return token - def run_cached_initialize(self, x=AbstractTensor(BATCH_SIZE, None, dtype=torch.int64)): - state_arg = slice_up_to_step( - self.global_state, self.global_seq_step, HEADS, HIDDEN_DIM - ) - forw_const = ( - [x.dynamic_dim(1) < MAX_STEP_SEQ] - + [state_arg[0].dynamic_dim(1) < MAX_STEP_SEQ] - + [ - x.dynamic_dim(1) == (state_arg[0].dynamic_dim(1)) - for x in state_arg[1:] - ] - + [x.dynamic_dim(1) < MAX_STEP_SEQ for x in state_arg[1:]] - ) - token, *state = self.cached_initialize(x, *state_arg, constraints=forw_const) - len_of_new_tokens = IREE.tensor_dim( - state[0], 1 - ) # ? dimension of arbitrarily 0th kv tensor - for i in range(HEADS * 2): - slice_of_state = IREE.tensor_reshape( - state[i], 1, 1, len_of_new_tokens, HEADS, HIDDEN_DIM - ) - self.global_state = IREE.tensor_update( - self.global_state, slice_of_state, i, 0, self.global_seq_step, 0, 0 - ) - self.global_seq_step = self.global_seq_step + len_of_new_tokens - return token - - # Streaming-LLM KVCache evict algorithm: - # slice1 = KVCache[0 : sink] - # slice2 = KVCache[seq_len - window_size : seq_len] - # KVCache = torch.cat([slice1, slice2]) - # TODO: There is actual overlap of data. - # For e.g at token length 600, sink size 4, and window size 508 - # Then KVCache[4:512] going to be replaced by KVCache[600-508: (600-508)+508] - # => KVCache[4:512] = KVCache[92:600] => Much overlap of data(i.e 92->512) - # => We'd need to do a copy and then replace. Or we can make the gap at least 2X. - def evict_kvcache_space(self): - # TODO: Replace hardcoded with global variable. - sink_size = 4 - window_size = 252 - most_recent_window = self.global_seq_step + (-window_size) - for i in range(HEADS * 2): - update_window_state = IREE.tensor_slice( - self.global_state, i, 0, (most_recent_window, window_size), (0, HEADS), (0, HIDDEN_DIM) - ) # sequence context dim - self.global_state = IREE.tensor_update( - self.global_state, update_window_state, i, 0, sink_size, 0, 0 - ) - self.global_seq_step = self.global_seq_step.set(window_size + sink_size) - return self.global_seq_step - def run_forward(self, x=AbstractTensor(1, 1, dtype=torch.int64)): state_arg = slice_up_to_step( self.global_state, self.global_seq_step, HEADS, HIDDEN_DIM @@ -244,32 +201,95 @@ def initialize(input_ids): return token1, *state1_flat @jittable - def cached_initialize(input_ids, *state0_flat): + def forward(token0: torch.Tensor, *state0_flat): # Unpad the states. - cur_token_len = state0_flat[0].size(1) state0_flat = [torch.transpose(x, 1, 2) for x in state0_flat] state0 = pytree.tree_unflatten(state0_flat, state_schema) - result = mod.forward(input_ids, past_key_values=state0) + result = mod.forward(token0, past_key_values=state0) state1_flat, _ = pytree.tree_flatten(result.past_key_values) - state1_flat = [torch.transpose(x[:, :, cur_token_len:, :], 1, 2) for x in state1_flat] + state1_flat = [torch.transpose(x[:, :, -1:, :], 1, 2) for x in state1_flat] token1 = torch.argmax(result.logits[:, -1, :], dim=1) token1 = token1[None, :] return token1, *state1_flat + class StreamingStateUpdateModule(StateUpdateModule): + def run_cached_initialize( + self, x=AbstractTensor(BATCH_SIZE, None, dtype=torch.int64) + ): + state_arg = slice_up_to_step( + self.global_state, self.global_seq_step, HEADS, HIDDEN_DIM + ) + forw_const = ( + [x.dynamic_dim(1) < MAX_STEP_SEQ] + + [state_arg[0].dynamic_dim(1) < MAX_STEP_SEQ] + + [ + x.dynamic_dim(1) == (state_arg[0].dynamic_dim(1)) + for x in state_arg[1:] + ] + + [x.dynamic_dim(1) < MAX_STEP_SEQ for x in state_arg[1:]] + ) + token, *state = self.cached_initialize( + x, *state_arg, constraints=forw_const + ) + len_of_new_tokens = IREE.tensor_dim( + state[0], 1 + ) # ? dimension of arbitrarily 0th kv tensor + for i in range(HEADS * 2): + slice_of_state = IREE.tensor_reshape( + state[i], 1, 1, len_of_new_tokens, HEADS, HIDDEN_DIM + ) + self.global_state = IREE.tensor_update( + self.global_state, slice_of_state, i, 0, self.global_seq_step, 0, 0 + ) + self.global_seq_step = self.global_seq_step + len_of_new_tokens + return token + @jittable - def forward(token0: torch.Tensor, *state0_flat): + def cached_initialize(input_ids, *state0_flat): # Unpad the states. + cur_token_len = state0_flat[0].size(1) state0_flat = [torch.transpose(x, 1, 2) for x in state0_flat] state0 = pytree.tree_unflatten(state0_flat, state_schema) - result = mod.forward(token0, past_key_values=state0) + result = mod.forward(input_ids, past_key_values=state0) state1_flat, _ = pytree.tree_flatten(result.past_key_values) - state1_flat = [torch.transpose(x[:, :, -1:, :], 1, 2) for x in state1_flat] + state1_flat = [ + torch.transpose(x[:, :, cur_token_len:, :], 1, 2) for x in state1_flat + ] token1 = torch.argmax(result.logits[:, -1, :], dim=1) token1 = token1[None, :] return token1, *state1_flat + # Streaming-LLM KVCache evict algorithm: + # slice1 = KVCache[0 : sink] + # slice2 = KVCache[seq_len - window_size : seq_len] + # KVCache = torch.cat([slice1, slice2]) + # TODO: Add move to handle overlap of data. + def evict_kvcache_space(self): + # TODO: Replace hardcoded with global variable. + sink_size = 4 + window_size = 252 + most_recent_window = self.global_seq_step + (-window_size) + for i in range(HEADS * 2): + update_window_state = IREE.tensor_slice( + self.global_state, + i, + 0, + (most_recent_window, window_size), + (0, HEADS), + (0, HIDDEN_DIM), + ) # sequence context dim + self.global_state = IREE.tensor_update( + self.global_state, update_window_state, i, 0, sink_size, 0, 0 + ) + self.global_seq_step = self.global_seq_step.set(window_size + sink_size) + return self.global_seq_step + import_to = "INPUT" if compile_to == "linalg" else "IMPORT" - inst = StateUpdateModule(context=Context(), import_to=import_to) + if streaming_llm: + print("Compiling with Streaming LLM") + inst = StreamingStateUpdateModule(context=Context(), import_to=import_to) + else: + inst = StateUpdateModule(context=Context(), import_to=import_to) # TODO: Integrate with external parameters to actually be able to run # TODO: Make more generalizable to be able to quantize with all compile_to options if quantization == "int4" and not compile_to == "linalg": @@ -353,6 +373,7 @@ def forward(token0: torch.Tensor, *state0_flat): args.device, args.iree_target_triple, args.vulkan_max_allocation, + args.streaming_llm, ) safe_name = args.hf_model_name.split("/")[-1].strip() safe_name = re.sub("-", "_", safe_name)