diff --git a/python/shark_turbine/transforms/quantization/mm_group_quant.py b/python/shark_turbine/transforms/quantization/mm_group_quant.py index 2dad4d831..527a986a6 100644 --- a/python/shark_turbine/transforms/quantization/mm_group_quant.py +++ b/python/shark_turbine/transforms/quantization/mm_group_quant.py @@ -63,19 +63,22 @@ def match(self, op: Operation): m=m, n=n, k=k, - element_type=self.builder.get_tensor_element_type(op.operands[0].type), + element_type=self.builder.get_tensor_element_type( + op.operands[0].type + ), ) +# TODO (ian): Make more generalizable using RenameParametersPass. Currently hardcoded for brevitas quantization GROUP_MATMUL_TEMPLATE = r""" module {{ - util.global private @{param_name}.quant {{noinline}} : tensor<{k}x{n_div}xi8> - util.global private @{param_name}.quant.scale {{noinline}} : tensor<{k}x{group0}x{element_type}> - util.global private @{param_name}.quant.zero_point {{noinline}} : tensor<{k}x{group0}x{element_type}> + util.global private @{param_name} {{noinline}} = #stream.parameter.named<"model"::"{param_name}"> : tensor<{k}x{n_div}xi8> + util.global private @{param_name}.quant.scale {{noinline}} = #stream.parameter.named<"model"::"{param_name}_scale"> : tensor<{k}x{group0}x{element_type}> + util.global private @{param_name}.quant.zero_point {{noinline}} = #stream.parameter.named<"model"::"{param_name}_zp"> : tensor<{k}x{group0}x{element_type}> func.func private @compute_mm_group_quant(%a : tensor<{m}x{n}x{element_type}>) -> tensor<{m}x{k}x{element_type}> {{ %c0 = arith.constant 0 : index - %weight_raw = util.global.load @{param_name}.quant : tensor<{k}x{n_div}xi8> + %weight_raw = util.global.load @{param_name} : tensor<{k}x{n_div}xi8> %m = tensor.dim %a, %c0 : tensor<{m}x{n}x{element_type}> %k = tensor.dim %weight_raw, %c0 : tensor<{k}x{n_div}xi8> %scale = util.global.load @{param_name}.quant.scale : tensor<{k}x{group0}x{element_type}> @@ -131,7 +134,9 @@ def __init__(self, root_op: Operation, *, group_size: int = 128): def run(self): globals = self.globals - mms = match_children(self.funcs, TransposedMMMatcher(globals, self.builder)) + mms = match_children( + self.funcs, TransposedMMMatcher(globals, self.builder) + ) for mr in mms: if mr.k is None or mr.n is None: @@ -145,27 +150,31 @@ def run(self): def rewrite(self, mr: TransposedMMResult): none_to_q = lambda x: "?" if x is None else x - inline_module_asm = GROUP_MATMUL_TEMPLATE.format( - param_name=mr.param_name, - lowp_type="i4", - m=none_to_q(mr.m), - n=none_to_q(mr.n), - k=none_to_q(mr.k), - n_div=mr.n // 2, - group0=mr.n // self.group_size, - group1=self.group_size, - element_type=mr.element_type, - ) + # TODO (ian): make generalizable and not specific for brevitas + if "lm_head.weight" not in mr.param_name: + inline_module_asm = GROUP_MATMUL_TEMPLATE.format( + param_name=mr.param_name[8:], + lowp_type="i4", + m=none_to_q(mr.m), + n=none_to_q(mr.n), + k=none_to_q(mr.k), + n_div=mr.n // 2, + group0=mr.n // self.group_size, + group1=self.group_size, + element_type=mr.element_type, + ) - inline_module = Operation.parse(inline_module_asm, context=self.context) - actual_callee_name = self.merge_module(inline_module).translate_symbol( - "compute_mm_group_quant" - ) - with InsertionPoint(mr.op), mr.op.location: - results = self.builder.call_native( - actual_callee_name, [mr.op.result.type], mr.op.operands[0] + inline_module = Operation.parse( + inline_module_asm, context=self.context ) - self.replace_op(mr.op, *results) + actual_callee_name = self.merge_module( + inline_module + ).translate_symbol("compute_mm_group_quant") + with InsertionPoint(mr.op), mr.op.location: + results = self.builder.call_native( + actual_callee_name, [mr.op.result.type], mr.op.operands[0] + ) + self.replace_op(mr.op, *results) if __name__ == "__main__": diff --git a/python/turbine_models/custom_models/stateless_llama.py b/python/turbine_models/custom_models/stateless_llama.py index 332e8a857..5b42db324 100644 --- a/python/turbine_models/custom_models/stateless_llama.py +++ b/python/turbine_models/custom_models/stateless_llama.py @@ -69,7 +69,12 @@ def slice_up_to_step(global_pkv, seq_step, heads, hidden_dim): def export_transformer_model( - hf_model_name, hf_auth_token, compile_to, external_weights=None, external_weight_file=None, quantization=None, + hf_model_name, + hf_auth_token, + compile_to, + external_weights=None, + external_weight_file=None, + quantization=None, ): state_schema = pytree.treespec_loads(json_schema) @@ -83,6 +88,7 @@ def export_transformer_model( use_fast=False, use_auth_token=hf_auth_token, ) + # TODO: generate these values instead of magic numbers HEADS = 32 HIDDEN_DIM = 128 @@ -97,12 +103,14 @@ def export_transformer_model( if external_weights == "safetensors": mod_params = dict(mod.named_parameters()) for name in mod_params: - mapper["params."+name]=name + mapper["params." + name] = name if external_weight_file: safetensors.torch.save_file(mod_params, external_weight_file) - elif external_weights=="gguf": - tensor_mapper = remap_gguf.TensorNameMap(remap_gguf.MODEL_ARCH.LLAMA, HEADS) + elif external_weights == "gguf": + tensor_mapper = remap_gguf.TensorNameMap( + remap_gguf.MODEL_ARCH.LLAMA, HEADS + ) mapper = tensor_mapper.mapping class StateUpdateModule(CompiledModule): @@ -115,7 +123,9 @@ class StateUpdateModule(CompiledModule): global_state = export_global(abstractify(global_pkv), mutable=True) global_seq_step = export_global(AbstractIndex, mutable=True) - def run_initialize(self, x=AbstractTensor(BATCH_SIZE, None, dtype=torch.int64)): + def run_initialize( + self, x=AbstractTensor(BATCH_SIZE, None, dtype=torch.int64) + ): init_const = [x.dynamic_dim(1) < MAX_STEP_SEQ] token, *state = self.initialize(x, constraints=init_const) self.global_seq_step = IREE.tensor_dim( @@ -135,9 +145,12 @@ def run_forward(self, x=AbstractTensor(1, None, dtype=torch.int64)): self.global_state, self.global_seq_step, HEADS, HIDDEN_DIM ) forw_const = [state_arg[0].dynamic_dim(1) < MAX_STEP_SEQ] + [ - x.dynamic_dim(1) == (state_arg[0].dynamic_dim(1)) for x in state_arg[1:] + x.dynamic_dim(1) == (state_arg[0].dynamic_dim(1)) + for x in state_arg[1:] ] - token, *state_update = self.forward(x, *state_arg, constraints=forw_const) + token, *state_update = self.forward( + x, *state_arg, constraints=forw_const + ) for i in range(HEADS * 2): update = IREE.tensor_reshape( state_update[i], 1, 1, 1, HEADS, HIDDEN_DIM @@ -171,32 +184,32 @@ def forward(token0: torch.Tensor, *state0_flat): state0 = pytree.tree_unflatten(state0_flat, state_schema) result = mod.forward(token0, past_key_values=state0) state1_flat, _ = pytree.tree_flatten(result.past_key_values) - state1_flat = [torch.transpose(x[:, :, -1:, :], 1, 2) for x in state1_flat] + state1_flat = [ + torch.transpose(x[:, :, -1:, :], 1, 2) for x in state1_flat + ] token1 = torch.argmax(result.logits[:, -1, :], dim=1) token1 = token1[None, :] return token1, *state1_flat - import_to = "IMPORT" if compile_to == "torch" else "INPUT" + import_to = "INPUT" if compile_to == "linalg" else "IMPORT" inst = StateUpdateModule(context=Context(), import_to=import_to) # TODO: Integrate with external parameters to actually be able to run # TODO: Make more generalizable to be able to quantize with all compile_to options - if quantization == "int4" and compile_to == "torch": + if args.quantization == "int4" and not compile_to == "linalg": from shark_turbine.transforms.quantization import mm_group_quant mm_group_quant.MMGroupQuantRewriterPass( CompiledModule.get_mlir_module(inst).operation ).run() - module_str = str(CompiledModule.get_mlir_module(inst)) - safe_name = hf_model_name.split("/")[-1].strip() safe_name = re.sub("-", "_", safe_name) if compile_to != "vmfb": return module_str, tokenizer else: flags = [ - "--iree-input-type=tm_tensor", + "--iree-input-type=torch", "--iree-vm-bytecode-module-output-format=flatbuffer-binary", "--mlir-print-debuginfo", "--mlir-print-op-on-diagnostic=false", @@ -227,13 +240,9 @@ def forward(token0: torch.Tensor, *state0_flat): def run_vmfb_comparison(args): config = ireert.Config("local-task") - print(args.external_weight_file) if args.external_weight_file: - from pathlib import Path - index = ireert.ParameterIndex() - index.load(args.external_weight_file) safe_name = args.hf_model_name.split("/")[-1].strip() @@ -244,12 +253,17 @@ def run_vmfb_comparison(args): mod = ireert.VmModule.mmap(config.vm_instance, f"{safe_name}.vmfb") else: sys.exit("no vmfb_path provided, required for run_vmfb") - vm_modules = [mod, ireert.create_hal_module(config.vm_instance, config.device)] + + vm_modules = [ + mod, + ireert.create_hal_module(config.vm_instance, config.device), + ] if args.external_weight_file: param_module = ireert.create_io_parameters_module( config.vm_instance, index.create_provider(scope="model") ) vm_modules.insert(0, param_module) + ctx = ireert.SystemContext( vm_modules=vm_modules, config=config,