diff --git a/python/turbine_models/custom_models/stateless_llama.py b/python/turbine_models/custom_models/stateless_llama.py index 169f01012..ad6239848 100644 --- a/python/turbine_models/custom_models/stateless_llama.py +++ b/python/turbine_models/custom_models/stateless_llama.py @@ -1,6 +1,7 @@ import os import sys import re +import json os.environ["TORCH_LOGS"] = "dynamic" from transformers import AutoTokenizer, AutoModelForCausalLM @@ -71,9 +72,25 @@ """ -def slice_up_to_step(global_pkv, seq_step, heads, hidden_dim): +def generate_schema(num_layers): + null = None + schema = [1, {"type": "builtins.tuple", "context": "null", "children_spec": []}] + kv_schema_per_layer = { + "type": "builtins.tuple", + "context": "null", + "children_spec": [ + {"type": null, "context": null, "children_spec": []}, + {"type": null, "context": null, "children_spec": []}, + ], + } + for i in range(num_layers): + schema[1]["children_spec"].append(kv_schema_per_layer) + return json.dumps(schema) + + +def slice_up_to_step(global_pkv, seq_step, heads, hidden_dim, num_layers): all_pkv_tensors = [] - for i in range(heads * 2): + for i in range(num_layers * 2): # Numpy semantic: sliced = global_pkv[i, 0, 0:seq_step, 0:heads, 0:hidden_dim] # Generates tensor<1 x 1 x seq_step x heads x hidden_dim> sliced = IREE.tensor_slice( @@ -105,10 +122,8 @@ def export_transformer_model( torch_dtype=torch.float, token=hf_auth_token, ) - if mod.config.num_attention_heads == 8: - state_schema = pytree.treespec_loads(json_schema_16) - else: - state_schema = pytree.treespec_loads(json_schema_64) + schema_json = generate_schema(mod.config.num_hidden_layers) + state_schema = pytree.treespec_loads(schema_json) if streaming_llm: enable_llama_pos_shift_attention(mod) dtype = torch.float32 @@ -121,12 +136,13 @@ def export_transformer_model( token=hf_auth_token, ) # TODO: generate these values instead of magic numbers + NUM_LAYERS = mod.config.num_hidden_layers HEADS = mod.config.num_attention_heads HIDDEN_DIM = int(mod.config.hidden_size / HEADS) BATCH_SIZE = 1 MAX_STEP_SEQ = mod.config.max_position_embeddings - 1 global_pkv = torch.zeros( - size=(HEADS * 2, BATCH_SIZE, MAX_STEP_SEQ, HEADS, HIDDEN_DIM), + size=(NUM_LAYERS * 2, BATCH_SIZE, MAX_STEP_SEQ, HEADS, HIDDEN_DIM), dtype=dtype, ) @@ -161,7 +177,7 @@ def run_initialize(self, x=AbstractTensor(BATCH_SIZE, None, dtype=torch.int64)): self.global_seq_step = IREE.tensor_dim( state[0], 1 ) # ? dimension of arbitrarily 0th kv tensor - for i in range(HEADS * 2): + for i in range(NUM_LAYERS * 2): slice_of_state = IREE.tensor_reshape( state[i], 1, 1, self.global_seq_step, HEADS, HIDDEN_DIM ) @@ -172,7 +188,7 @@ def run_initialize(self, x=AbstractTensor(BATCH_SIZE, None, dtype=torch.int64)): def run_forward(self, x=AbstractTensor(1, 1, dtype=torch.int64)): state_arg = slice_up_to_step( - self.global_state, self.global_seq_step, HEADS, HIDDEN_DIM + self.global_state, self.global_seq_step, HEADS, HIDDEN_DIM, NUM_LAYERS ) forw_const = ( [state_arg[0].dynamic_dim(1) < MAX_STEP_SEQ] @@ -183,7 +199,7 @@ def run_forward(self, x=AbstractTensor(1, 1, dtype=torch.int64)): + [x.dynamic_dim(1) < MAX_STEP_SEQ for x in state_arg[1:]] ) token, *state_update = self.forward(x, *state_arg, constraints=forw_const) - for i in range(HEADS * 2): + for i in range(NUM_LAYERS * 2): update = IREE.tensor_reshape( state_update[i], 1, 1, 1, HEADS, HIDDEN_DIM ) @@ -226,7 +242,7 @@ def run_cached_initialize( self, x=AbstractTensor(BATCH_SIZE, None, dtype=torch.int64) ): state_arg = slice_up_to_step( - self.global_state, self.global_seq_step, HEADS, HIDDEN_DIM + self.global_state, self.global_seq_step, HEADS, HIDDEN_DIM, NUM_LAYERS ) forw_const = ( [x.dynamic_dim(1) < MAX_STEP_SEQ] @@ -243,7 +259,7 @@ def run_cached_initialize( len_of_new_tokens = IREE.tensor_dim( state[0], 1 ) # ? dimension of arbitrarily 0th kv tensor - for i in range(HEADS * 2): + for i in range(NUM_LAYERS * 2): slice_of_state = IREE.tensor_reshape( state[i], 1, 1, len_of_new_tokens, HEADS, HIDDEN_DIM ) @@ -278,7 +294,7 @@ def evict_kvcache_space(self): sink_size = 4 window_size = 252 most_recent_window = self.global_seq_step + (-window_size) - for i in range(HEADS * 2): + for i in range(NUM_LAYERS * 2): update_window_state = IREE.tensor_slice( self.global_state, i, @@ -339,12 +355,14 @@ def evict_kvcache_space(self): [ "--iree-rocm-target-chip=" + target_triple, "--iree-rocm-link-bc=true", - "--iree-rocm-bc-dir=/opt/rocm/amdgcn/bitcode", "--iree-vm-bytecode-module-strip-source-map=true", "--iree-opt-strip-assertions=true", "--iree-vm-target-truncate-unsupported-floats", ] ) + ukernel_supported_arch = {"gfx90a", "gfx940", "gfx1030", "gfx1100"} + if target_triple in ukernel_supported_arch: + flags.extend(["--iree-rocm-enable-ukernels=argmax"]) elif device == "cuda": flags.extend( [