From 42914be4879fd278e41b0a3c286048f4e95f5660 Mon Sep 17 00:00:00 2001 From: Xida Ren Date: Thu, 30 Nov 2023 01:02:21 +0000 Subject: [PATCH] move mmgroupquant --- python/shark_turbine/aot/compiled_module.py | 10 ++++++++++ python/turbine_models/custom_models/stateless_llama.py | 10 +++------- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/python/shark_turbine/aot/compiled_module.py b/python/shark_turbine/aot/compiled_module.py index dc8efc89f..6ba26820b 100644 --- a/python/shark_turbine/aot/compiled_module.py +++ b/python/shark_turbine/aot/compiled_module.py @@ -45,6 +45,8 @@ "CompiledModule", ] +from shark_turbine.transforms.rewriter import Pass # for type annotations + ################################################################################ # Data structures ################################################################################ @@ -472,6 +474,7 @@ def __new__( context: Optional[Context] = None, module_op: Optional[Operation] = None, import_to: Union[ImportPhase, None, str] = "full", + pre_import_passes=List[Pass] ): import_to = ImportPhase.parse(import_to) self = super().__new__(cls) @@ -538,6 +541,13 @@ def invoke_with_self(*args, **kwargs): do_export(proc_def) module_builder.finalize_construct() + + # `run_import` transforms module from torch to linalg, and passes like MMGroupQuantRewriterPass need to be run before + module_op = CompiledModule.get_mlir_module(self) + from shark_turbine.transforms.quantization.mm_group_quant import MMGroupQuantRewriterPass + for p in pre_import_passes: + p(module_op).run() + CompiledModule.run_import(self, import_to) return self diff --git a/python/turbine_models/custom_models/stateless_llama.py b/python/turbine_models/custom_models/stateless_llama.py index edffd085a..a2c1e4f22 100644 --- a/python/turbine_models/custom_models/stateless_llama.py +++ b/python/turbine_models/custom_models/stateless_llama.py @@ -197,15 +197,11 @@ def forward(token0: torch.Tensor, *state0_flat): return token1, *state1_flat import_to = "INPUT" if compile_to == "linalg" else "IMPORT" - inst = StateUpdateModule(context=Context(), import_to=import_to) - # TODO: Integrate with external parameters to actually be able to run - # TODO: Make more generalizable to be able to quantize with all compile_to options + pre_import_passes = [] if quantization == "int4" and not compile_to == "linalg": from shark_turbine.transforms.quantization import mm_group_quant - - mm_group_quant.MMGroupQuantRewriterPass( - CompiledModule.get_mlir_module(inst).operation - ).run() + pre_import_passes.append(mm_group_quant.MMGroupQuantRewriterPass) + inst = StateUpdateModule(context=Context(), import_to=import_to, pre_import_passes=pre_import_passes) module_str = str(CompiledModule.get_mlir_module(inst)) safe_name = hf_model_name.split("/")[-1].strip() safe_name = re.sub("-", "_", safe_name)