From 5ed212c190550594d7d5f3c0da90bf9f66ddb4f0 Mon Sep 17 00:00:00 2001 From: Stanley Winata Date: Wed, 31 Jan 2024 03:09:27 -0600 Subject: [PATCH] [Llama] Use rocm ukernel when available + use num_layer for pkv. --- .../custom_models/stateless_llama.py | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/python/turbine_models/custom_models/stateless_llama.py b/python/turbine_models/custom_models/stateless_llama.py index 169f01012..25e810cdf 100644 --- a/python/turbine_models/custom_models/stateless_llama.py +++ b/python/turbine_models/custom_models/stateless_llama.py @@ -71,9 +71,9 @@ """ -def slice_up_to_step(global_pkv, seq_step, heads, hidden_dim): +def slice_up_to_step(global_pkv, seq_step, heads, hidden_dim, num_layers): all_pkv_tensors = [] - for i in range(heads * 2): + for i in range(num_layers * 2): # Numpy semantic: sliced = global_pkv[i, 0, 0:seq_step, 0:heads, 0:hidden_dim] # Generates tensor<1 x 1 x seq_step x heads x hidden_dim> sliced = IREE.tensor_slice( @@ -121,12 +121,13 @@ def export_transformer_model( token=hf_auth_token, ) # TODO: generate these values instead of magic numbers + NUM_LAYERS = mod.config.num_hidden_layers HEADS = mod.config.num_attention_heads HIDDEN_DIM = int(mod.config.hidden_size / HEADS) BATCH_SIZE = 1 MAX_STEP_SEQ = mod.config.max_position_embeddings - 1 global_pkv = torch.zeros( - size=(HEADS * 2, BATCH_SIZE, MAX_STEP_SEQ, HEADS, HIDDEN_DIM), + size=(NUM_LAYERS * 2, BATCH_SIZE, MAX_STEP_SEQ, HEADS, HIDDEN_DIM), dtype=dtype, ) @@ -161,7 +162,7 @@ def run_initialize(self, x=AbstractTensor(BATCH_SIZE, None, dtype=torch.int64)): self.global_seq_step = IREE.tensor_dim( state[0], 1 ) # ? dimension of arbitrarily 0th kv tensor - for i in range(HEADS * 2): + for i in range(NUM_LAYERS * 2): slice_of_state = IREE.tensor_reshape( state[i], 1, 1, self.global_seq_step, HEADS, HIDDEN_DIM ) @@ -172,7 +173,7 @@ def run_initialize(self, x=AbstractTensor(BATCH_SIZE, None, dtype=torch.int64)): def run_forward(self, x=AbstractTensor(1, 1, dtype=torch.int64)): state_arg = slice_up_to_step( - self.global_state, self.global_seq_step, HEADS, HIDDEN_DIM + self.global_state, self.global_seq_step, HEADS, HIDDEN_DIM, NUM_LAYERS ) forw_const = ( [state_arg[0].dynamic_dim(1) < MAX_STEP_SEQ] @@ -183,7 +184,7 @@ def run_forward(self, x=AbstractTensor(1, 1, dtype=torch.int64)): + [x.dynamic_dim(1) < MAX_STEP_SEQ for x in state_arg[1:]] ) token, *state_update = self.forward(x, *state_arg, constraints=forw_const) - for i in range(HEADS * 2): + for i in range(NUM_LAYERS * 2): update = IREE.tensor_reshape( state_update[i], 1, 1, 1, HEADS, HIDDEN_DIM ) @@ -226,7 +227,7 @@ def run_cached_initialize( self, x=AbstractTensor(BATCH_SIZE, None, dtype=torch.int64) ): state_arg = slice_up_to_step( - self.global_state, self.global_seq_step, HEADS, HIDDEN_DIM + self.global_state, self.global_seq_step, HEADS, HIDDEN_DIM, NUM_LAYERS ) forw_const = ( [x.dynamic_dim(1) < MAX_STEP_SEQ] @@ -243,7 +244,7 @@ def run_cached_initialize( len_of_new_tokens = IREE.tensor_dim( state[0], 1 ) # ? dimension of arbitrarily 0th kv tensor - for i in range(HEADS * 2): + for i in range(NUM_LAYERS * 2): slice_of_state = IREE.tensor_reshape( state[i], 1, 1, len_of_new_tokens, HEADS, HIDDEN_DIM ) @@ -278,7 +279,7 @@ def evict_kvcache_space(self): sink_size = 4 window_size = 252 most_recent_window = self.global_seq_step + (-window_size) - for i in range(HEADS * 2): + for i in range(NUM_LAYERS * 2): update_window_state = IREE.tensor_slice( self.global_state, i, @@ -339,12 +340,14 @@ def evict_kvcache_space(self): [ "--iree-rocm-target-chip=" + target_triple, "--iree-rocm-link-bc=true", - "--iree-rocm-bc-dir=/opt/rocm/amdgcn/bitcode", "--iree-vm-bytecode-module-strip-source-map=true", "--iree-opt-strip-assertions=true", "--iree-vm-target-truncate-unsupported-floats", ] ) + ukernel_supported_arch = {"gfx90a", "gfx940", "gfx1030", "gfx1100"} + if target_triple in ukernel_supported_arch: + flags.extend(["--iree-rocm-enable-ukernels=argmax"]) elif device == "cuda": flags.extend( [