Skip to content

Commit

Permalink
Adds support to use brevitas quantized weights for stateless_llama
Browse files Browse the repository at this point in the history
- Modifies mm_group_quant to work with brevitas safetensors, needs work to generalize
- Changes compiler to use torch as input to enable quantization of torch ir
  • Loading branch information
IanNod committed Nov 17, 2023
1 parent a7b6ec6 commit 3460784
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 43 deletions.
59 changes: 34 additions & 25 deletions python/shark_turbine/transforms/quantization/mm_group_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,19 +63,22 @@ def match(self, op: Operation):
m=m,
n=n,
k=k,
element_type=self.builder.get_tensor_element_type(op.operands[0].type),
element_type=self.builder.get_tensor_element_type(
op.operands[0].type
),
)


# TODO (ian): Make more generalizable using RenameParametersPass. Currently hardcoded for brevitas quantization
GROUP_MATMUL_TEMPLATE = r"""
module {{
util.global private @{param_name}.quant {{noinline}} : tensor<{k}x{n_div}xi8>
util.global private @{param_name}.quant.scale {{noinline}} : tensor<{k}x{group0}x{element_type}>
util.global private @{param_name}.quant.zero_point {{noinline}} : tensor<{k}x{group0}x{element_type}>
util.global private @{param_name} {{noinline}} = #stream.parameter.named<"model"::"{param_name}"> : tensor<{k}x{n_div}xi8>
util.global private @{param_name}.quant.scale {{noinline}} = #stream.parameter.named<"model"::"{param_name}_scale"> : tensor<{k}x{group0}x{element_type}>
util.global private @{param_name}.quant.zero_point {{noinline}} = #stream.parameter.named<"model"::"{param_name}_zp"> : tensor<{k}x{group0}x{element_type}>
func.func private @compute_mm_group_quant(%a : tensor<{m}x{n}x{element_type}>) -> tensor<{m}x{k}x{element_type}> {{
%c0 = arith.constant 0 : index
%weight_raw = util.global.load @{param_name}.quant : tensor<{k}x{n_div}xi8>
%weight_raw = util.global.load @{param_name} : tensor<{k}x{n_div}xi8>
%m = tensor.dim %a, %c0 : tensor<{m}x{n}x{element_type}>
%k = tensor.dim %weight_raw, %c0 : tensor<{k}x{n_div}xi8>
%scale = util.global.load @{param_name}.quant.scale : tensor<{k}x{group0}x{element_type}>
Expand Down Expand Up @@ -131,7 +134,9 @@ def __init__(self, root_op: Operation, *, group_size: int = 128):

def run(self):
globals = self.globals
mms = match_children(self.funcs, TransposedMMMatcher(globals, self.builder))
mms = match_children(
self.funcs, TransposedMMMatcher(globals, self.builder)
)

for mr in mms:
if mr.k is None or mr.n is None:
Expand All @@ -145,27 +150,31 @@ def run(self):

def rewrite(self, mr: TransposedMMResult):
none_to_q = lambda x: "?" if x is None else x
inline_module_asm = GROUP_MATMUL_TEMPLATE.format(
param_name=mr.param_name,
lowp_type="i4",
m=none_to_q(mr.m),
n=none_to_q(mr.n),
k=none_to_q(mr.k),
n_div=mr.n // 2,
group0=mr.n // self.group_size,
group1=self.group_size,
element_type=mr.element_type,
)
# TODO (ian): make generalizable and not specific for brevitas
if "lm_head.weight" not in mr.param_name:
inline_module_asm = GROUP_MATMUL_TEMPLATE.format(
param_name=mr.param_name[8:],
lowp_type="i4",
m=none_to_q(mr.m),
n=none_to_q(mr.n),
k=none_to_q(mr.k),
n_div=mr.n // 2,
group0=mr.n // self.group_size,
group1=self.group_size,
element_type=mr.element_type,
)

inline_module = Operation.parse(inline_module_asm, context=self.context)
actual_callee_name = self.merge_module(inline_module).translate_symbol(
"compute_mm_group_quant"
)
with InsertionPoint(mr.op), mr.op.location:
results = self.builder.call_native(
actual_callee_name, [mr.op.result.type], mr.op.operands[0]
inline_module = Operation.parse(
inline_module_asm, context=self.context
)
self.replace_op(mr.op, *results)
actual_callee_name = self.merge_module(
inline_module
).translate_symbol("compute_mm_group_quant")
with InsertionPoint(mr.op), mr.op.location:
results = self.builder.call_native(
actual_callee_name, [mr.op.result.type], mr.op.operands[0]
)
self.replace_op(mr.op, *results)


if __name__ == "__main__":
Expand Down
50 changes: 32 additions & 18 deletions python/turbine_models/custom_models/stateless_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,12 @@ def slice_up_to_step(global_pkv, seq_step, heads, hidden_dim):


def export_transformer_model(
hf_model_name, hf_auth_token, compile_to, external_weights=None, external_weight_file=None, quantization=None,
hf_model_name,
hf_auth_token,
compile_to,
external_weights=None,
external_weight_file=None,
quantization=None,
):
state_schema = pytree.treespec_loads(json_schema)

Expand All @@ -83,6 +88,7 @@ def export_transformer_model(
use_fast=False,
use_auth_token=hf_auth_token,
)

# TODO: generate these values instead of magic numbers
HEADS = 32
HIDDEN_DIM = 128
Expand All @@ -97,12 +103,14 @@ def export_transformer_model(
if external_weights == "safetensors":
mod_params = dict(mod.named_parameters())
for name in mod_params:
mapper["params."+name]=name
mapper["params." + name] = name
if external_weight_file:
safetensors.torch.save_file(mod_params, external_weight_file)

elif external_weights=="gguf":
tensor_mapper = remap_gguf.TensorNameMap(remap_gguf.MODEL_ARCH.LLAMA, HEADS)
elif external_weights == "gguf":
tensor_mapper = remap_gguf.TensorNameMap(
remap_gguf.MODEL_ARCH.LLAMA, HEADS
)
mapper = tensor_mapper.mapping

class StateUpdateModule(CompiledModule):
Expand All @@ -115,7 +123,9 @@ class StateUpdateModule(CompiledModule):
global_state = export_global(abstractify(global_pkv), mutable=True)
global_seq_step = export_global(AbstractIndex, mutable=True)

def run_initialize(self, x=AbstractTensor(BATCH_SIZE, None, dtype=torch.int64)):
def run_initialize(
self, x=AbstractTensor(BATCH_SIZE, None, dtype=torch.int64)
):
init_const = [x.dynamic_dim(1) < MAX_STEP_SEQ]
token, *state = self.initialize(x, constraints=init_const)
self.global_seq_step = IREE.tensor_dim(
Expand All @@ -135,9 +145,12 @@ def run_forward(self, x=AbstractTensor(1, None, dtype=torch.int64)):
self.global_state, self.global_seq_step, HEADS, HIDDEN_DIM
)
forw_const = [state_arg[0].dynamic_dim(1) < MAX_STEP_SEQ] + [
x.dynamic_dim(1) == (state_arg[0].dynamic_dim(1)) for x in state_arg[1:]
x.dynamic_dim(1) == (state_arg[0].dynamic_dim(1))
for x in state_arg[1:]
]
token, *state_update = self.forward(x, *state_arg, constraints=forw_const)
token, *state_update = self.forward(
x, *state_arg, constraints=forw_const
)
for i in range(HEADS * 2):
update = IREE.tensor_reshape(
state_update[i], 1, 1, 1, HEADS, HIDDEN_DIM
Expand Down Expand Up @@ -171,32 +184,32 @@ def forward(token0: torch.Tensor, *state0_flat):
state0 = pytree.tree_unflatten(state0_flat, state_schema)
result = mod.forward(token0, past_key_values=state0)
state1_flat, _ = pytree.tree_flatten(result.past_key_values)
state1_flat = [torch.transpose(x[:, :, -1:, :], 1, 2) for x in state1_flat]
state1_flat = [
torch.transpose(x[:, :, -1:, :], 1, 2) for x in state1_flat
]
token1 = torch.argmax(result.logits[:, -1, :], dim=1)
token1 = token1[None, :]
return token1, *state1_flat

import_to = "IMPORT" if compile_to == "torch" else "INPUT"
import_to = "INPUT" if compile_to == "linalg" else "IMPORT"
inst = StateUpdateModule(context=Context(), import_to=import_to)

# TODO: Integrate with external parameters to actually be able to run
# TODO: Make more generalizable to be able to quantize with all compile_to options
if quantization == "int4" and compile_to == "torch":
if args.quantization == "int4" and not compile_to == "linalg":
from shark_turbine.transforms.quantization import mm_group_quant

mm_group_quant.MMGroupQuantRewriterPass(
CompiledModule.get_mlir_module(inst).operation
).run()

module_str = str(CompiledModule.get_mlir_module(inst))

safe_name = hf_model_name.split("/")[-1].strip()
safe_name = re.sub("-", "_", safe_name)
if compile_to != "vmfb":
return module_str, tokenizer
else:
flags = [
"--iree-input-type=tm_tensor",
"--iree-input-type=torch",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
"--mlir-print-debuginfo",
"--mlir-print-op-on-diagnostic=false",
Expand Down Expand Up @@ -227,13 +240,9 @@ def forward(token0: torch.Tensor, *state0_flat):

def run_vmfb_comparison(args):
config = ireert.Config("local-task")
print(args.external_weight_file)

if args.external_weight_file:
from pathlib import Path

index = ireert.ParameterIndex()

index.load(args.external_weight_file)

safe_name = args.hf_model_name.split("/")[-1].strip()
Expand All @@ -244,12 +253,17 @@ def run_vmfb_comparison(args):
mod = ireert.VmModule.mmap(config.vm_instance, f"{safe_name}.vmfb")
else:
sys.exit("no vmfb_path provided, required for run_vmfb")
vm_modules = [mod, ireert.create_hal_module(config.vm_instance, config.device)]

vm_modules = [
mod,
ireert.create_hal_module(config.vm_instance, config.device),
]
if args.external_weight_file:
param_module = ireert.create_io_parameters_module(
config.vm_instance, index.create_provider(scope="model")
)
vm_modules.insert(0, param_module)

ctx = ireert.SystemContext(
vm_modules=vm_modules,
config=config,
Expand Down

0 comments on commit 3460784

Please sign in to comment.