Skip to content

Commit

Permalink
[Llama] Use rocm ukernel when available + use num_layer for pkv. (#381)
Browse files Browse the repository at this point in the history
Use ukernel to improve perf + fix
#380.

Additionally, added fix to stateless llama to handle non 32 size layer.
Seems like currently our PKV value is based on number of attention head.
This currently work because number of attn head happens to be number of
layer for many models we are looking at. But once that assumption
breaks, we will run into some issues with stateless llama. This PR also
introduces fix for this minor bug.
  • Loading branch information
raikonenfnu authored Feb 1, 2024
1 parent c1dc94c commit da57fe3
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 21 deletions.
51 changes: 30 additions & 21 deletions python/turbine_models/custom_models/stateless_llama.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import sys
import re
import json

os.environ["TORCH_LOGS"] = "dynamic"
from transformers import AutoTokenizer, AutoModelForCausalLM
Expand Down Expand Up @@ -61,19 +62,26 @@
help="Compile LLM with StreamingLLM optimizations",
)

# TODO (Dan): replace this with a file once I figure out paths on windows exe
json_schema_64 = """
[1, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}]}]
"""

json_schema_16 = """
[1, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}]}]
"""
def generate_schema(num_layers):
null = None
schema = [1, {"type": "builtins.tuple", "context": "null", "children_spec": []}]
kv_schema_per_layer = {
"type": "builtins.tuple",
"context": "null",
"children_spec": [
{"type": null, "context": null, "children_spec": []},
{"type": null, "context": null, "children_spec": []},
],
}
for i in range(num_layers):
schema[1]["children_spec"].append(kv_schema_per_layer)
return json.dumps(schema)


def slice_up_to_step(global_pkv, seq_step, heads, hidden_dim):
def slice_up_to_step(global_pkv, seq_step, heads, hidden_dim, num_layers):
all_pkv_tensors = []
for i in range(heads * 2):
for i in range(num_layers * 2):
# Numpy semantic: sliced = global_pkv[i, 0, 0:seq_step, 0:heads, 0:hidden_dim]
# Generates tensor<1 x 1 x seq_step x heads x hidden_dim>
sliced = IREE.tensor_slice(
Expand Down Expand Up @@ -105,10 +113,8 @@ def export_transformer_model(
torch_dtype=torch.float,
token=hf_auth_token,
)
if mod.config.num_attention_heads == 8:
state_schema = pytree.treespec_loads(json_schema_16)
else:
state_schema = pytree.treespec_loads(json_schema_64)
schema_json = generate_schema(mod.config.num_hidden_layers)
state_schema = pytree.treespec_loads(schema_json)
if streaming_llm:
enable_llama_pos_shift_attention(mod)
dtype = torch.float32
Expand All @@ -121,12 +127,13 @@ def export_transformer_model(
token=hf_auth_token,
)
# TODO: generate these values instead of magic numbers
NUM_LAYERS = mod.config.num_hidden_layers
HEADS = mod.config.num_attention_heads
HIDDEN_DIM = int(mod.config.hidden_size / HEADS)
BATCH_SIZE = 1
MAX_STEP_SEQ = mod.config.max_position_embeddings - 1
global_pkv = torch.zeros(
size=(HEADS * 2, BATCH_SIZE, MAX_STEP_SEQ, HEADS, HIDDEN_DIM),
size=(NUM_LAYERS * 2, BATCH_SIZE, MAX_STEP_SEQ, HEADS, HIDDEN_DIM),
dtype=dtype,
)

Expand Down Expand Up @@ -161,7 +168,7 @@ def run_initialize(self, x=AbstractTensor(BATCH_SIZE, None, dtype=torch.int64)):
self.global_seq_step = IREE.tensor_dim(
state[0], 1
) # ? dimension of arbitrarily 0th kv tensor
for i in range(HEADS * 2):
for i in range(NUM_LAYERS * 2):
slice_of_state = IREE.tensor_reshape(
state[i], 1, 1, self.global_seq_step, HEADS, HIDDEN_DIM
)
Expand All @@ -172,7 +179,7 @@ def run_initialize(self, x=AbstractTensor(BATCH_SIZE, None, dtype=torch.int64)):

def run_forward(self, x=AbstractTensor(1, 1, dtype=torch.int64)):
state_arg = slice_up_to_step(
self.global_state, self.global_seq_step, HEADS, HIDDEN_DIM
self.global_state, self.global_seq_step, HEADS, HIDDEN_DIM, NUM_LAYERS
)
forw_const = (
[state_arg[0].dynamic_dim(1) < MAX_STEP_SEQ]
Expand All @@ -183,7 +190,7 @@ def run_forward(self, x=AbstractTensor(1, 1, dtype=torch.int64)):
+ [x.dynamic_dim(1) < MAX_STEP_SEQ for x in state_arg[1:]]
)
token, *state_update = self.forward(x, *state_arg, constraints=forw_const)
for i in range(HEADS * 2):
for i in range(NUM_LAYERS * 2):
update = IREE.tensor_reshape(
state_update[i], 1, 1, 1, HEADS, HIDDEN_DIM
)
Expand Down Expand Up @@ -226,7 +233,7 @@ def run_cached_initialize(
self, x=AbstractTensor(BATCH_SIZE, None, dtype=torch.int64)
):
state_arg = slice_up_to_step(
self.global_state, self.global_seq_step, HEADS, HIDDEN_DIM
self.global_state, self.global_seq_step, HEADS, HIDDEN_DIM, NUM_LAYERS
)
forw_const = (
[x.dynamic_dim(1) < MAX_STEP_SEQ]
Expand All @@ -243,7 +250,7 @@ def run_cached_initialize(
len_of_new_tokens = IREE.tensor_dim(
state[0], 1
) # ? dimension of arbitrarily 0th kv tensor
for i in range(HEADS * 2):
for i in range(NUM_LAYERS * 2):
slice_of_state = IREE.tensor_reshape(
state[i], 1, 1, len_of_new_tokens, HEADS, HIDDEN_DIM
)
Expand Down Expand Up @@ -278,7 +285,7 @@ def evict_kvcache_space(self):
sink_size = 4
window_size = 252
most_recent_window = self.global_seq_step + (-window_size)
for i in range(HEADS * 2):
for i in range(NUM_LAYERS * 2):
update_window_state = IREE.tensor_slice(
self.global_state,
i,
Expand Down Expand Up @@ -339,12 +346,14 @@ def evict_kvcache_space(self):
[
"--iree-rocm-target-chip=" + target_triple,
"--iree-rocm-link-bc=true",
"--iree-rocm-bc-dir=/opt/rocm/amdgcn/bitcode",
"--iree-vm-bytecode-module-strip-source-map=true",
"--iree-opt-strip-assertions=true",
"--iree-vm-target-truncate-unsupported-floats",
]
)
ukernel_supported_arch = {"gfx90a", "gfx940", "gfx1030", "gfx1100"}
if target_triple in ukernel_supported_arch:
flags.extend(["--iree-rocm-enable-ukernels=argmax"])
elif device == "cuda":
flags.extend(
[
Expand Down
Loading

0 comments on commit da57fe3

Please sign in to comment.