diff --git a/python/turbine_models/model_builder.py b/python/turbine_models/model_builder.py index 03b56dc1a..5fc396641 100644 --- a/python/turbine_models/model_builder.py +++ b/python/turbine_models/model_builder.py @@ -67,30 +67,46 @@ def get_compiled_module(self, save_to: str = None) -> aot.CompiledModule: return compiled_binary +bert_model_list = [ + "bert-large-uncased", + "BAAI/bge-base-en-v1.5" +] + if __name__ == "__main__": import sys hf_id = sys.argv[-1] safe_name = hf_id.replace("/", "_").replace("-", "_") inp = torch.zeros(1, 1, dtype=torch.int64) + mask = torch.zeros(1, 1, dtype=torch.int64) model = HFTransformerBuilder(inp, hf_id) mapper=dict() mod_params = dict(model.model.named_parameters()) for name in mod_params: mapper["params." + name] = name -# safetensors.torch.save_file(mod_params, safe_name+".safetensors") + # safetensors.torch.save_file(mod_params, safe_name+".safetensors") class GlobalModule(CompiledModule): - params = export_parameters(model.model, external=True, external_scope="",) + # params = export_parameters(model.model, external=True, external_scope="", name_mapper=mapper.get) + params = export_parameters(model.model) compute = jittable(model.model.forward) def run(self, x=abstractify(inp)): return self.compute(x) - print("module defined") - inst = GlobalModule(context=Context()) - print("module inst") + class BertModule(CompiledModule): + params = export_parameters(model.model) + compute = jittable(model.model.forward) + + def run(self, x=abstractify(inp), mask=abstractify(mask)): + return self.compute(x, attention_mask=mask) + + print("defining module") + if hf_id in bert_model_list: + inst = BertModule(context=Context()) + else: + inst = GlobalModule(context=Context()) + print("getting mlir module") module = CompiledModule.get_mlir_module(inst) -# compiled = module.compile() - print("got mlir module") + print("writing mlir") with open(safe_name+".mlir", "w+") as f: f.write(str(module))