Skip to content

Commit

Permalink
[Llama] Use rocm ukernel when available + use num_layer for pkv.
Browse files Browse the repository at this point in the history
  • Loading branch information
Stanley Winata committed Jan 31, 2024
1 parent c1dc94c commit b893a92
Showing 1 changed file with 32 additions and 14 deletions.
46 changes: 32 additions & 14 deletions python/turbine_models/custom_models/stateless_llama.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import sys
import re
import json

os.environ["TORCH_LOGS"] = "dynamic"
from transformers import AutoTokenizer, AutoModelForCausalLM
Expand Down Expand Up @@ -71,9 +72,25 @@
"""


def slice_up_to_step(global_pkv, seq_step, heads, hidden_dim):
def generate_schema(num_layers):
null = None
schema = [1, {"type": "builtins.tuple", "context": "null", "children_spec": []}]
kv_schema_per_layer = {
"type": "builtins.tuple",
"context": "null",
"children_spec": [
{"type": null, "context": null, "children_spec": []},
{"type": null, "context": null, "children_spec": []},
],
}
for i in range(num_layers):
schema[1]["children_spec"].append(kv_schema_per_layer)
return json.dumps(schema)


def slice_up_to_step(global_pkv, seq_step, heads, hidden_dim, num_layers):
all_pkv_tensors = []
for i in range(heads * 2):
for i in range(num_layers * 2):
# Numpy semantic: sliced = global_pkv[i, 0, 0:seq_step, 0:heads, 0:hidden_dim]
# Generates tensor<1 x 1 x seq_step x heads x hidden_dim>
sliced = IREE.tensor_slice(
Expand Down Expand Up @@ -105,10 +122,8 @@ def export_transformer_model(
torch_dtype=torch.float,
token=hf_auth_token,
)
if mod.config.num_attention_heads == 8:
state_schema = pytree.treespec_loads(json_schema_16)
else:
state_schema = pytree.treespec_loads(json_schema_64)
schema_json = generate_schema(mod.config.num_hidden_layers)
state_schema = pytree.treespec_loads(schema_json)
if streaming_llm:
enable_llama_pos_shift_attention(mod)
dtype = torch.float32
Expand All @@ -121,12 +136,13 @@ def export_transformer_model(
token=hf_auth_token,
)
# TODO: generate these values instead of magic numbers
NUM_LAYERS = mod.config.num_hidden_layers
HEADS = mod.config.num_attention_heads
HIDDEN_DIM = int(mod.config.hidden_size / HEADS)
BATCH_SIZE = 1
MAX_STEP_SEQ = mod.config.max_position_embeddings - 1
global_pkv = torch.zeros(
size=(HEADS * 2, BATCH_SIZE, MAX_STEP_SEQ, HEADS, HIDDEN_DIM),
size=(NUM_LAYERS * 2, BATCH_SIZE, MAX_STEP_SEQ, HEADS, HIDDEN_DIM),
dtype=dtype,
)

Expand Down Expand Up @@ -161,7 +177,7 @@ def run_initialize(self, x=AbstractTensor(BATCH_SIZE, None, dtype=torch.int64)):
self.global_seq_step = IREE.tensor_dim(
state[0], 1
) # ? dimension of arbitrarily 0th kv tensor
for i in range(HEADS * 2):
for i in range(NUM_LAYERS * 2):
slice_of_state = IREE.tensor_reshape(
state[i], 1, 1, self.global_seq_step, HEADS, HIDDEN_DIM
)
Expand All @@ -172,7 +188,7 @@ def run_initialize(self, x=AbstractTensor(BATCH_SIZE, None, dtype=torch.int64)):

def run_forward(self, x=AbstractTensor(1, 1, dtype=torch.int64)):
state_arg = slice_up_to_step(
self.global_state, self.global_seq_step, HEADS, HIDDEN_DIM
self.global_state, self.global_seq_step, HEADS, HIDDEN_DIM, NUM_LAYERS
)
forw_const = (
[state_arg[0].dynamic_dim(1) < MAX_STEP_SEQ]
Expand All @@ -183,7 +199,7 @@ def run_forward(self, x=AbstractTensor(1, 1, dtype=torch.int64)):
+ [x.dynamic_dim(1) < MAX_STEP_SEQ for x in state_arg[1:]]
)
token, *state_update = self.forward(x, *state_arg, constraints=forw_const)
for i in range(HEADS * 2):
for i in range(NUM_LAYERS * 2):
update = IREE.tensor_reshape(
state_update[i], 1, 1, 1, HEADS, HIDDEN_DIM
)
Expand Down Expand Up @@ -226,7 +242,7 @@ def run_cached_initialize(
self, x=AbstractTensor(BATCH_SIZE, None, dtype=torch.int64)
):
state_arg = slice_up_to_step(
self.global_state, self.global_seq_step, HEADS, HIDDEN_DIM
self.global_state, self.global_seq_step, HEADS, HIDDEN_DIM, NUM_LAYERS
)
forw_const = (
[x.dynamic_dim(1) < MAX_STEP_SEQ]
Expand All @@ -243,7 +259,7 @@ def run_cached_initialize(
len_of_new_tokens = IREE.tensor_dim(
state[0], 1
) # ? dimension of arbitrarily 0th kv tensor
for i in range(HEADS * 2):
for i in range(NUM_LAYERS * 2):
slice_of_state = IREE.tensor_reshape(
state[i], 1, 1, len_of_new_tokens, HEADS, HIDDEN_DIM
)
Expand Down Expand Up @@ -278,7 +294,7 @@ def evict_kvcache_space(self):
sink_size = 4
window_size = 252
most_recent_window = self.global_seq_step + (-window_size)
for i in range(HEADS * 2):
for i in range(NUM_LAYERS * 2):
update_window_state = IREE.tensor_slice(
self.global_state,
i,
Expand Down Expand Up @@ -339,12 +355,14 @@ def evict_kvcache_space(self):
[
"--iree-rocm-target-chip=" + target_triple,
"--iree-rocm-link-bc=true",
"--iree-rocm-bc-dir=/opt/rocm/amdgcn/bitcode",
"--iree-vm-bytecode-module-strip-source-map=true",
"--iree-opt-strip-assertions=true",
"--iree-vm-target-truncate-unsupported-floats",
]
)
ukernel_supported_arch = {"gfx90a", "gfx940", "gfx1030", "gfx1100"}
if target_triple in ukernel_supported_arch:
flags.extend(["--iree-rocm-enable-ukernels=argmax"])
elif device == "cuda":
flags.extend(
[
Expand Down

0 comments on commit b893a92

Please sign in to comment.