Skip to content

Commit

Permalink
Add support for safetensor e2e in llama (#178)
Browse files Browse the repository at this point in the history
  • Loading branch information
dan-garvey authored Nov 16, 2023
1 parent 4117974 commit 247dadc
Showing 1 changed file with 21 additions and 8 deletions.
29 changes: 21 additions & 8 deletions python/turbine_models/custom_models/stateless_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from iree.compiler.ir import Context
from iree import runtime as ireert

from turbine_models.custom_models import remap_gguf
import safetensors

BATCH_SIZE = 1
MAX_STEP_SEQ = 4095

Expand Down Expand Up @@ -37,8 +40,9 @@
parser.add_argument("--vmfb_path", type=str, default="")
parser.add_argument(
"--external_weights",
action="store_true",
help="saves ir/vmfb without global weights for size and readability",
type=str,
default=None,
help="saves ir/vmfb without global weights for size and readability, options [gguf, safetensors]",
)

prompt = """<s>[INST] <<SYS>>
Expand All @@ -65,7 +69,7 @@ def slice_up_to_step(global_pkv, seq_step, heads, hidden_dim):


def export_transformer_model(
hf_model_name, hf_auth_token, compile_to, external_weights=False, quantization=None
hf_model_name, hf_auth_token, compile_to, external_weights=None, external_weight_file=None, quantization=None,
):
state_schema = pytree.treespec_loads(json_schema)

Expand All @@ -88,11 +92,18 @@ def export_transformer_model(
dtype=torch.float32,
)

if external_weights:
from turbine_models.custom_models import remap_gguf
mapper = {}
if external_weights is not None:
if external_weights == "safetensors":
mod_params = dict(mod.named_parameters())
for name in mod_params:
mapper["params."+name]=name
if external_weight_file:
safetensors.torch.save_file(mod_params, external_weight_file)

tensor_mapper = remap_gguf.TensorNameMap(remap_gguf.MODEL_ARCH.LLAMA, HEADS)
mapper = tensor_mapper.mapping
elif external_weights=="gguf":
tensor_mapper = remap_gguf.TensorNameMap(remap_gguf.MODEL_ARCH.LLAMA, HEADS)
mapper = tensor_mapper.mapping

class StateUpdateModule(CompiledModule):
if external_weights:
Expand Down Expand Up @@ -216,6 +227,7 @@ def forward(token0: torch.Tensor, *state0_flat):

def run_vmfb_comparison(args):
config = ireert.Config("local-task")
print(args.external_weight_file)

if args.external_weight_file:
from pathlib import Path
Expand All @@ -237,7 +249,7 @@ def run_vmfb_comparison(args):
param_module = ireert.create_io_parameters_module(
config.vm_instance, index.create_provider(scope="model")
)
vm_modules.append(param_module)
vm_modules.insert(0, param_module)
ctx = ireert.SystemContext(
vm_modules=vm_modules,
config=config,
Expand Down Expand Up @@ -303,5 +315,6 @@ def get_token_from_logits(logits):
args.hf_auth_token,
args.compile_to,
args.external_weights,
args.external_weight_file,
args.quantization,
)

0 comments on commit 247dadc

Please sign in to comment.