diff --git a/python/turbine_models/custom_models/stateless_llama.py b/python/turbine_models/custom_models/stateless_llama.py index 337390bb5..332e8a857 100644 --- a/python/turbine_models/custom_models/stateless_llama.py +++ b/python/turbine_models/custom_models/stateless_llama.py @@ -10,6 +10,9 @@ from iree.compiler.ir import Context from iree import runtime as ireert +from turbine_models.custom_models import remap_gguf +import safetensors + BATCH_SIZE = 1 MAX_STEP_SEQ = 4095 @@ -37,8 +40,9 @@ parser.add_argument("--vmfb_path", type=str, default="") parser.add_argument( "--external_weights", - action="store_true", - help="saves ir/vmfb without global weights for size and readability", + type=str, + default=None, + help="saves ir/vmfb without global weights for size and readability, options [gguf, safetensors]", ) prompt = """[INST] <> @@ -65,7 +69,7 @@ def slice_up_to_step(global_pkv, seq_step, heads, hidden_dim): def export_transformer_model( - hf_model_name, hf_auth_token, compile_to, external_weights=False, quantization=None + hf_model_name, hf_auth_token, compile_to, external_weights=None, external_weight_file=None, quantization=None, ): state_schema = pytree.treespec_loads(json_schema) @@ -88,11 +92,18 @@ def export_transformer_model( dtype=torch.float32, ) - if external_weights: - from turbine_models.custom_models import remap_gguf + mapper = {} + if external_weights is not None: + if external_weights == "safetensors": + mod_params = dict(mod.named_parameters()) + for name in mod_params: + mapper["params."+name]=name + if external_weight_file: + safetensors.torch.save_file(mod_params, external_weight_file) - tensor_mapper = remap_gguf.TensorNameMap(remap_gguf.MODEL_ARCH.LLAMA, HEADS) - mapper = tensor_mapper.mapping + elif external_weights=="gguf": + tensor_mapper = remap_gguf.TensorNameMap(remap_gguf.MODEL_ARCH.LLAMA, HEADS) + mapper = tensor_mapper.mapping class StateUpdateModule(CompiledModule): if external_weights: @@ -216,6 +227,7 @@ def forward(token0: torch.Tensor, *state0_flat): def run_vmfb_comparison(args): config = ireert.Config("local-task") + print(args.external_weight_file) if args.external_weight_file: from pathlib import Path @@ -237,7 +249,7 @@ def run_vmfb_comparison(args): param_module = ireert.create_io_parameters_module( config.vm_instance, index.create_provider(scope="model") ) - vm_modules.append(param_module) + vm_modules.insert(0, param_module) ctx = ireert.SystemContext( vm_modules=vm_modules, config=config, @@ -303,5 +315,6 @@ def get_token_from_logits(logits): args.hf_auth_token, args.compile_to, args.external_weights, + args.external_weight_file, args.quantization, )