diff --git a/src/sparseml/exporters/transforms/kv_cache/transforms_llama.py b/src/sparseml/exporters/transforms/kv_cache/transforms_llama.py index d58c0b3d447..a6a3aaf8c2d 100644 --- a/src/sparseml/exporters/transforms/kv_cache/transforms_llama.py +++ b/src/sparseml/exporters/transforms/kv_cache/transforms_llama.py @@ -32,8 +32,10 @@ class AdditionalTransformsLLAMA(AdditionalTransformsBase): - POSITION_IDS_MATCHING_PATTERN = dict(op_type="Range", children_ops=[["Unsqueeze"]]) - CAUSAL_MASK_MATCHING_PATTERN = dict(op_type="Expand", children_ops=[["Add"]]) + POSITION_IDS_MATCHING_PATTERN = dict( + op_type="Range", children_ops=[["Reshape"], ["Unsqueeze"]] + ) + CAUSAL_MASK_MATCHING_PATTERN = dict(op_type="ScatterND") SLICE_MAX_INT_NAME = "slice_max_int" def transform(self, model: ModelProto) -> ModelProto: @@ -69,12 +71,12 @@ def transform(self, model: ModelProto) -> ModelProto: f"found {len(position_ids_nodes)}" ) - model = self.inject_positions(model, position_ids_nodes, "Unsqueeze") + model = self.inject_positions(model, position_ids_nodes) causal_mask_nodes = self.find_nodes_by_pattern( model, pattern=self.CAUSAL_MASK_MATCHING_PATTERN ) - model = self.inject_causal_mask(model, causal_mask_nodes, "Add") + model = self.inject_causal_mask(model, causal_mask_nodes) model = self.adjust_causal_mask(model) return model diff --git a/src/sparseml/transformers/sparsification/sparse_tokenizer.py b/src/sparseml/transformers/sparsification/sparse_tokenizer.py index 09a8dd47b79..02903c9faf2 100644 --- a/src/sparseml/transformers/sparsification/sparse_tokenizer.py +++ b/src/sparseml/transformers/sparsification/sparse_tokenizer.py @@ -30,7 +30,13 @@ class SparseAutoTokenizer(AutoTokenizer): """ @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): + def from_pretrained( + cls, + pretrained_model_name_or_path, + pad_with_eos_token: bool = False, + *inputs, + **kwargs, + ): """ A wrapper around the AutoTokenizer.from_pretrained method that enables the loading of tokenizer from SparseZoo stubs @@ -40,6 +46,8 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): files is passed to the AutoTokenizer.from_pretrained method :param pretrained_model_name_or_path: the name of or path to the model to load + :param pad_with_eos_token: if True, set the pad token to be the eos token ( + required for many causal language models) :return tokenizer: the loaded tokenizer from pretrained """ if str(pretrained_model_name_or_path).startswith("zoo:"): @@ -53,4 +61,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): tokenizer_file = file tokenizer_file.download() pretrained_model_name_or_path = os.path.dirname(tokenizer_file.path) - return super().from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) + tokenizer = super().from_pretrained( + pretrained_model_name_or_path, *inputs, **kwargs + ) + if pad_with_eos_token: + tokenizer.pad_token = tokenizer.eos_token + return tokenizer diff --git a/tests/sparseml/export/transformers/test_generative_transformers.py b/tests/sparseml/export/transformers/test_generative_transformers.py index 1a9af341cb9..a9e5e39d984 100644 --- a/tests/sparseml/export/transformers/test_generative_transformers.py +++ b/tests/sparseml/export/transformers/test_generative_transformers.py @@ -21,11 +21,34 @@ import pytest import torch +from deepsparse import TextGeneration from huggingface_hub import snapshot_download from sparseml import export +from sparseml.exporters.kv_cache_injector import KeyValueCacheInjector from sparseml.transformers import SparseAutoModelForCausalLM, SparseAutoTokenizer +@pytest.mark.parametrize("model", ["Xenova/llama2.c-stories15M"]) +def test_kv_cache_injection(tmp_path, model): + export( + model=SparseAutoModelForCausalLM.from_pretrained(model), + tokenizer=SparseAutoTokenizer.from_pretrained(model, pad_with_eos_token=True), + target_path=tmp_path, + graph_optimizations="none", + ) + model_path = os.path.join(tmp_path, "deployment") + onnx_file_path = os.path.join(model_path, "model.onnx") + + onnx_model = onnx.load(onnx_file_path, load_external_data=False) + injector = KeyValueCacheInjector(model_path=model_path) + injector.export(onnx_model, onnx_file_path) + + llama_pipeline = TextGeneration(model_path=model_path, engine_type="onnxruntime") + + inference = llama_pipeline("Who is the president of the United States?") + shutil.rmtree(tmp_path) + + @pytest.mark.parametrize( "stub, task", [("roneneldan/TinyStories-1M", "text-generation")],