diff --git a/python/turbine_models/custom_models/stateless_llama.py b/python/turbine_models/custom_models/stateless_llama.py index 61de6d4bb..9e29817ce 100644 --- a/python/turbine_models/custom_models/stateless_llama.py +++ b/python/turbine_models/custom_models/stateless_llama.py @@ -324,7 +324,7 @@ def get_token_from_logits(logits): if args.run_vmfb: run_vmfb_comparison(args) else: - export_transformer_model( + mod_str, _ = export_transformer_model( args.hf_model_name, args.hf_auth_token, args.compile_to, @@ -332,3 +332,8 @@ def get_token_from_logits(logits): args.external_weight_file, args.quantization, ) + safe_name = args.hf_model_name.split("/")[-1].strip() + safe_name = re.sub("-", "_", safe_name) + with open(f"{safe_name}.mlir", "w+") as f: + f.write(mod_str) +