Skip to content

Commit

Permalink
Add support for safetensor e2e in llama
Browse files Browse the repository at this point in the history
  • Loading branch information
dan-garvey committed Nov 16, 2023
1 parent d32c8b6 commit 18411e6
Showing 1 changed file with 39 additions and 14 deletions.
53 changes: 39 additions & 14 deletions python/turbine_models/custom_models/stateless_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from iree.compiler.ir import Context
from iree import runtime as ireert

from turbine_models.custom_models import remap_gguf
import safetensors

BATCH_SIZE = 1
MAX_STEP_SEQ = 4095

Expand Down Expand Up @@ -37,8 +40,9 @@
parser.add_argument("--vmfb_path", type=str, default="")
parser.add_argument(
"--external_weights",
action="store_true",
help="saves ir/vmfb without global weights for size and readability",
type=str,
default=None,
help="saves ir/vmfb without global weights for size and readability, options [gguf, safetensors]",
)

prompt = """<s>[INST] <<SYS>>
Expand All @@ -65,7 +69,7 @@ def slice_up_to_step(global_pkv, seq_step, heads, hidden_dim):


def export_transformer_model(
hf_model_name, hf_auth_token, compile_to, external_weights=False, quantization=None
hf_model_name, hf_auth_token, compile_to, external_weights=None, external_weight_file=None, quantization=None,
):
state_schema = pytree.treespec_loads(json_schema)

Expand All @@ -88,11 +92,18 @@ def export_transformer_model(
dtype=torch.float32,
)

if external_weights:
from turbine_models.custom_models import remap_gguf
mapper = {}
if external_weights is not None:
if external_weights == "safetensors":
mod_params = dict(mod.named_parameters())
for name in mod_params:
mapper["params."+name]=name
if external_weight_file:
safetensors.torch.save_file(mod_params, external_weight_file)

tensor_mapper = remap_gguf.TensorNameMap(remap_gguf.MODEL_ARCH.LLAMA, HEADS)
mapper = tensor_mapper.mapping
elif external_weights=="gguf":
tensor_mapper = remap_gguf.TensorNameMap(remap_gguf.MODEL_ARCH.LLAMA, HEADS)
mapper = tensor_mapper.mapping

class StateUpdateModule(CompiledModule):
if external_weights:
Expand Down Expand Up @@ -216,6 +227,7 @@ def forward(token0: torch.Tensor, *state0_flat):

def run_vmfb_comparison(args):
config = ireert.Config("local-task")
print(args.external_weight_file)

if args.external_weight_file:
from pathlib import Path
Expand All @@ -232,14 +244,26 @@ def run_vmfb_comparison(args):
mod = ireert.VmModule.mmap(config.vm_instance, f"{safe_name}.vmfb")
else:
sys.exit("no vmfb_path provided, required for run_vmfb")
vm_modules = [mod, ireert.create_hal_module(config.vm_instance, config.device)]
if args.external_weight_file:
param_module = ireert.create_io_parameters_module(
config.vm_instance, index.create_provider(scope="model")
)
vm_modules.append(param_module)
# vm_modules = [mod, ireert.create_hal_module(config.vm_instance, config.device)]
# print(vm_modules)
# if args.external_weight_file:
# param_module = ireert.create_io_parameters_module(
# config.vm_instance, index.create_provider(scope="model")
# )
# vm_modules.append(param_module)
# print(vm_modules)
# ctx = ireert.SystemContext(
# vm_modules=vm_modules,
# config=config,
# )
ctx = ireert.SystemContext(
vm_modules=vm_modules,
vm_modules=[
ireert.create_io_parameters_module(
config.vm_instance, index.create_provider(scope="model")
),
ireert.create_hal_module(config.vm_instance, config.device),
mod,
],
config=config,
)
tokenizer = AutoTokenizer.from_pretrained(
Expand Down Expand Up @@ -303,5 +327,6 @@ def get_token_from_logits(logits):
args.hf_auth_token,
args.compile_to,
args.external_weights,
args.external_weight_file,
args.quantization,
)

0 comments on commit 18411e6

Please sign in to comment.