Skip to content

Commit

Permalink
[Llama] Use rocm ukernel when available + use num_layer for pkv.
Browse files Browse the repository at this point in the history
  • Loading branch information
Stanley Winata committed Jan 31, 2024
1 parent c1dc94c commit 5ed212c
Showing 1 changed file with 13 additions and 10 deletions.
23 changes: 13 additions & 10 deletions python/turbine_models/custom_models/stateless_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@
"""


def slice_up_to_step(global_pkv, seq_step, heads, hidden_dim):
def slice_up_to_step(global_pkv, seq_step, heads, hidden_dim, num_layers):
all_pkv_tensors = []
for i in range(heads * 2):
for i in range(num_layers * 2):
# Numpy semantic: sliced = global_pkv[i, 0, 0:seq_step, 0:heads, 0:hidden_dim]
# Generates tensor<1 x 1 x seq_step x heads x hidden_dim>
sliced = IREE.tensor_slice(
Expand Down Expand Up @@ -121,12 +121,13 @@ def export_transformer_model(
token=hf_auth_token,
)
# TODO: generate these values instead of magic numbers
NUM_LAYERS = mod.config.num_hidden_layers
HEADS = mod.config.num_attention_heads
HIDDEN_DIM = int(mod.config.hidden_size / HEADS)
BATCH_SIZE = 1
MAX_STEP_SEQ = mod.config.max_position_embeddings - 1
global_pkv = torch.zeros(
size=(HEADS * 2, BATCH_SIZE, MAX_STEP_SEQ, HEADS, HIDDEN_DIM),
size=(NUM_LAYERS * 2, BATCH_SIZE, MAX_STEP_SEQ, HEADS, HIDDEN_DIM),
dtype=dtype,
)

Expand Down Expand Up @@ -161,7 +162,7 @@ def run_initialize(self, x=AbstractTensor(BATCH_SIZE, None, dtype=torch.int64)):
self.global_seq_step = IREE.tensor_dim(
state[0], 1
) # ? dimension of arbitrarily 0th kv tensor
for i in range(HEADS * 2):
for i in range(NUM_LAYERS * 2):
slice_of_state = IREE.tensor_reshape(
state[i], 1, 1, self.global_seq_step, HEADS, HIDDEN_DIM
)
Expand All @@ -172,7 +173,7 @@ def run_initialize(self, x=AbstractTensor(BATCH_SIZE, None, dtype=torch.int64)):

def run_forward(self, x=AbstractTensor(1, 1, dtype=torch.int64)):
state_arg = slice_up_to_step(
self.global_state, self.global_seq_step, HEADS, HIDDEN_DIM
self.global_state, self.global_seq_step, HEADS, HIDDEN_DIM, NUM_LAYERS
)
forw_const = (
[state_arg[0].dynamic_dim(1) < MAX_STEP_SEQ]
Expand All @@ -183,7 +184,7 @@ def run_forward(self, x=AbstractTensor(1, 1, dtype=torch.int64)):
+ [x.dynamic_dim(1) < MAX_STEP_SEQ for x in state_arg[1:]]
)
token, *state_update = self.forward(x, *state_arg, constraints=forw_const)
for i in range(HEADS * 2):
for i in range(NUM_LAYERS * 2):
update = IREE.tensor_reshape(
state_update[i], 1, 1, 1, HEADS, HIDDEN_DIM
)
Expand Down Expand Up @@ -226,7 +227,7 @@ def run_cached_initialize(
self, x=AbstractTensor(BATCH_SIZE, None, dtype=torch.int64)
):
state_arg = slice_up_to_step(
self.global_state, self.global_seq_step, HEADS, HIDDEN_DIM
self.global_state, self.global_seq_step, HEADS, HIDDEN_DIM, NUM_LAYERS
)
forw_const = (
[x.dynamic_dim(1) < MAX_STEP_SEQ]
Expand All @@ -243,7 +244,7 @@ def run_cached_initialize(
len_of_new_tokens = IREE.tensor_dim(
state[0], 1
) # ? dimension of arbitrarily 0th kv tensor
for i in range(HEADS * 2):
for i in range(NUM_LAYERS * 2):
slice_of_state = IREE.tensor_reshape(
state[i], 1, 1, len_of_new_tokens, HEADS, HIDDEN_DIM
)
Expand Down Expand Up @@ -278,7 +279,7 @@ def evict_kvcache_space(self):
sink_size = 4
window_size = 252
most_recent_window = self.global_seq_step + (-window_size)
for i in range(HEADS * 2):
for i in range(NUM_LAYERS * 2):
update_window_state = IREE.tensor_slice(
self.global_state,
i,
Expand Down Expand Up @@ -339,12 +340,14 @@ def evict_kvcache_space(self):
[
"--iree-rocm-target-chip=" + target_triple,
"--iree-rocm-link-bc=true",
"--iree-rocm-bc-dir=/opt/rocm/amdgcn/bitcode",
"--iree-vm-bytecode-module-strip-source-map=true",
"--iree-opt-strip-assertions=true",
"--iree-vm-target-truncate-unsupported-floats",
]
)
ukernel_supported_arch = {"gfx90a", "gfx940", "gfx1030", "gfx1100"}
if target_triple in ukernel_supported_arch:
flags.extend(["--iree-rocm-enable-ukernels=argmax"])
elif device == "cuda":
flags.extend(
[
Expand Down

0 comments on commit 5ed212c

Please sign in to comment.