Skip to content

Commit

Permalink
Adds gpu support for stateless llama
Browse files Browse the repository at this point in the history
 - Adds support to compile vmfb to different device backends including cpu/vulkan/rocm/cuda
 - Adds flags for device, iree_target_triple (backend device specific info) and vulkan max allocation
 - Minor fix to gen_external_params when not doing quantization
  • Loading branch information
IanNod committed Dec 1, 2023
1 parent cd063df commit 21430d6
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 13 deletions.
72 changes: 61 additions & 11 deletions python/turbine_models/custom_models/stateless_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,16 @@
"--precision", type=str, default="fp16", help="dtype of model [f16, f32]"
)

parser.add_argument("--device", type=str, default="cpu", help="cpu, cuda, vulkan")
# TODO: Bring in detection for target triple
parser.add_argument(
"--iree_target_triple",
type=str,
default="",
help="Specify vulkan target triple or rocm/cuda target device.",
)
parser.add_argument("--vulkan_max_allocation", type=str, default="4294967296")

prompt = """<s>[INST] <<SYS>>
Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <</SYS>> hi what are you? [/INST]
"""
Expand Down Expand Up @@ -79,6 +89,9 @@ def export_transformer_model(
external_weight_file=None,
quantization=None,
precision=None,
device=None,
target_triple=None,
max_alloc=None,
):
state_schema = pytree.treespec_loads(json_schema)

Expand Down Expand Up @@ -214,28 +227,62 @@ def forward(token0: torch.Tensor, *state0_flat):
else:
flags = [
"--iree-input-type=torch",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
"--mlir-print-debuginfo",
"--mlir-print-op-on-diagnostic=false",
"--iree-llvmcpu-target-cpu-features=host",
"--iree-llvmcpu-target-triple=x86_64-linux-gnu",
"--iree-llvmcpu-enable-microkernels",
"--iree-llvmcpu-stack-allocation-limit=256000",
"--iree-stream-resource-index-bits=64",
"--iree-vm-target-index-bits=64",
"--iree-vm-bytecode-module-strip-source-map=true",
"--iree-util-zero-fill-elided-attrs",
"--iree-vm-target-truncate-unsupported-floats",
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
"--iree-opt-const-expr-hoisting=False",
]

if device == "cpu":
flags.append("--iree-llvmcpu-enable-microkernels")
device = "llvm-cpu"
elif device == "vulkan":
flags.extend(
[
"--iree-hal-target-backends=vulkan-spirv",
"--iree-vulkan-target-triple=" + target_triple,
"--iree-spirv-index-bits=64",
"--iree-opt-const-eval=false",
"--iree-stream-resource-max-allocation-size=" + max_alloc,
]
)
elif device == "rocm":
flags.extend(
[
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
"--iree-hal-target-backends=rocm",
"--iree-rocm-target-chip=" + target_triple,
"--iree-rocm-link-bc=true",
"--iree-rocm-bc-dir=/opt/rocm/amdgcn/bitcode",
"--iree-vm-bytecode-module-strip-source-map=true",
"--iree-util-zero-fill-elided-attrs",
"--iree-opt-strip-assertions=true",
"--verify=false",
"--iree-vm-target-truncate-unsupported-floats",
]
)
elif device == "cuda":
flags.extend(
[
"--iree-hal-target-backends=cuda",
"--iree-hal-cuda-llvm-target-arch=" + target_triple,
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
"--iree-vm-bytecode-module-strip-source-map=true",
"--iree-util-zero-fill-elided-attrs",
"--iree-vm-target-truncate-unsupported-floats",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
]
)
else:
print("incorrect device: ", device)
import iree.compiler as ireec

flatbuffer_blob = ireec.compile_str(
module_str,
target_backends=["llvm-cpu"],
target_backends=[device],
extra_args=flags,
)
with open(f"{safe_name}.vmfb", "wb+") as f:
Expand All @@ -245,7 +292,7 @@ def forward(token0: torch.Tensor, *state0_flat):


def run_vmfb_comparison(args):
config = ireert.Config("local-task")
config = ireert.Config(args.device)

if args.external_weight_file:
index = ireert.ParameterIndex()
Expand Down Expand Up @@ -277,7 +324,7 @@ def run_vmfb_comparison(args):
tokenizer = AutoTokenizer.from_pretrained(
args.hf_model_name,
use_fast=False,
use_auth_token=args.hf_auth_token,
token=args.hf_auth_token,
)
initial_input = tokenizer(prompt, return_tensors="pt")
example_input_id = initial_input.input_ids
Expand Down Expand Up @@ -338,6 +385,9 @@ def get_token_from_logits(logits):
args.external_weight_file,
args.quantization,
args.precision,
args.device,
args.iree_target_triple,
args.vulkan_max_allocation,
)
safe_name = args.hf_model_name.split("/")[-1].strip()
safe_name = re.sub("-", "_", safe_name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

def quantize(model, quantization, dtype):
accumulates = dtype
int_weights = {}
if quantization in ["int4", "int8"]:
from brevitas_examples.common.generative.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
Expand All @@ -48,7 +49,6 @@ class DummyLinearWeightBlockQuantHandler(LinearWeightBlockQuantHandler):
def forward(self, x):
raise NotImplementedError

int_weights = {}
for prefix, layer in model.named_modules():
if isinstance(layer, QuantLinear):
print(f"Exporting layer {prefix}")
Expand All @@ -68,7 +68,8 @@ def forward(self, x):
if "wrapped_scaling_impl" in k or "wrapped_zero_point_impl" in k:
del all_weights[k]

all_weights.update(int_weights)
if len(int_weights) != 0:
all_weights.update(int_weights)
return all_weights


Expand Down

0 comments on commit 21430d6

Please sign in to comment.