Skip to content

Commit

Permalink
Add bert support
Browse files Browse the repository at this point in the history
  • Loading branch information
jinchen62 committed Feb 2, 2024
1 parent 29f4259 commit 8f7474d
Showing 1 changed file with 23 additions and 7 deletions.
30 changes: 23 additions & 7 deletions python/turbine_models/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,30 +67,46 @@ def get_compiled_module(self, save_to: str = None) -> aot.CompiledModule:
return compiled_binary


bert_model_list = [
"bert-large-uncased",
"BAAI/bge-base-en-v1.5"
]

if __name__ == "__main__":
import sys
hf_id = sys.argv[-1]
safe_name = hf_id.replace("/", "_").replace("-", "_")
inp = torch.zeros(1, 1, dtype=torch.int64)
mask = torch.zeros(1, 1, dtype=torch.int64)
model = HFTransformerBuilder(inp, hf_id)
mapper=dict()
mod_params = dict(model.model.named_parameters())
for name in mod_params:
mapper["params." + name] = name
# safetensors.torch.save_file(mod_params, safe_name+".safetensors")
# safetensors.torch.save_file(mod_params, safe_name+".safetensors")
class GlobalModule(CompiledModule):
params = export_parameters(model.model, external=True, external_scope="",)
# params = export_parameters(model.model, external=True, external_scope="", name_mapper=mapper.get)
params = export_parameters(model.model)
compute = jittable(model.model.forward)

def run(self, x=abstractify(inp)):
return self.compute(x)

print("module defined")
inst = GlobalModule(context=Context())
print("module inst")
class BertModule(CompiledModule):
params = export_parameters(model.model)
compute = jittable(model.model.forward)

def run(self, x=abstractify(inp), mask=abstractify(mask)):
return self.compute(x, attention_mask=mask)

print("defining module")
if hf_id in bert_model_list:
inst = BertModule(context=Context())
else:
inst = GlobalModule(context=Context())
print("getting mlir module")
module = CompiledModule.get_mlir_module(inst)
# compiled = module.compile()
print("got mlir module")
print("writing mlir")
with open(safe_name+".mlir", "w+") as f:
f.write(str(module))

Expand Down

0 comments on commit 8f7474d

Please sign in to comment.