diff --git a/.github/workflows/test_inc.yml b/.github/workflows/test_inc.yml index 1a8cb28bab..e3a7518a6f 100644 --- a/.github/workflows/test_inc.yml +++ b/.github/workflows/test_inc.yml @@ -32,8 +32,9 @@ jobs: python -m pip install --upgrade pip pip install cmake pip install py-cpuinfo + pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cpu pip install .[neural-compressor,diffusers,tests] - pip install intel-extension-for-transformers + pip install intel-extension-for-transformers==1.4.1 pip install peft - name: Test with Pytest @@ -42,7 +43,7 @@ jobs: - name: Test IPEX run: | pip uninstall -y intel-extension-for-transformers - pip install torch==2.1.0 torchaudio==2.1.0 torchvision==0.16 --extra-index-url https://download.pytorch.org/whl/cpu - pip install intel-extension-for-pytorch==2.1.100 + pip install torch==2.3.0 torchaudio==2.3.0 torchvision==0.18 --extra-index-url https://download.pytorch.org/whl/cpu + pip install intel-extension-for-pytorch==2.3.0 pytest tests/neural_compressor/test_ipex.py diff --git a/.github/workflows/test_ipex.yml b/.github/workflows/test_ipex.yml index 82c9e8c7f7..8e02bd5510 100644 --- a/.github/workflows/test_ipex.yml +++ b/.github/workflows/test_ipex.yml @@ -30,6 +30,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip + pip install torch torchaudio torchvision --extra-index-url https://download.pytorch.org/whl/cpu pip install .[ipex,tests] - name: Test with Pytest run: | diff --git a/.github/workflows/test_offline.yaml b/.github/workflows/test_offline.yaml new file mode 100644 index 0000000000..a54ba20766 --- /dev/null +++ b/.github/workflows/test_offline.yaml @@ -0,0 +1,40 @@ +name: Offline usage / Python - Test + +on: + push: + branches: [main] + pull_request: + branches: [main] + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + build: + strategy: + fail-fast: false + matrix: + python-version: [3.9] + os: [ubuntu-latest] + + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v3 + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + pip install .[tests,openvino] + - name: Test + run: | + HF_HOME=/tmp/ huggingface-cli download hf-internal-testing/tiny-random-gpt2 + HF_HOME=/tmp/ HF_HUB_OFFLINE=1 optimum-cli export openvino --model hf-internal-testing/tiny-random-gpt2 gpt2_openvino --task text-generation + + huggingface-cli download hf-internal-testing/tiny-random-gpt2 + HF_HUB_OFFLINE=1 optimum-cli export openvino --model hf-internal-testing/tiny-random-gpt2 gpt2_openvino --task text-generation + + pytest tests/openvino/test_modeling.py -k "test_load_from_hub" -s -vvvvv + HF_HUB_OFFLINE=1 pytest tests/openvino/test_modeling.py -k "test_load_from_hub" -s -vvvvv diff --git a/.github/workflows/test_openvino_examples.yml b/.github/workflows/test_openvino_examples.yml index 61b411c967..747afa31b5 100644 --- a/.github/workflows/test_openvino_examples.yml +++ b/.github/workflows/test_openvino_examples.yml @@ -7,11 +7,11 @@ on: push: paths: - '.github/workflows/test_openvino_examples.yml' - - 'examples/openvino/*' + - 'examples/openvino/**' pull_request: paths: - '.github/workflows/test_openvino_examples.yml' - - 'examples/openvino/*' + - 'examples/openvino/**' concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} @@ -22,9 +22,9 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.8", "3.10"] + python-version: ["3.8", "3.11"] - runs-on: ubuntu-20.04 + runs-on: ubuntu-22.04 steps: - uses: actions/checkout@v2 @@ -35,12 +35,12 @@ jobs: - name: Install dependencies run: | - pip install optimum[openvino] jstyleson nncf pytest - pip install -r examples/openvino/audio-classification/requirements.txt - pip install -r examples/openvino/image-classification/requirements.txt - pip install -r examples/openvino/question-answering/requirements.txt - pip install -r examples/openvino/text-classification/requirements.txt + pip install .[openvino] jstyleson pytest + pip install -r examples/openvino/audio-classification/requirements.txt --extra-index-url https://download.pytorch.org/whl/cpu + pip install -r examples/openvino/image-classification/requirements.txt --extra-index-url https://download.pytorch.org/whl/cpu + pip install -r examples/openvino/question-answering/requirements.txt --extra-index-url https://download.pytorch.org/whl/cpu + pip install -r examples/openvino/text-classification/requirements.txt --extra-index-url https://download.pytorch.org/whl/cpu - name: Test examples run: | - python -m pytest examples/openvino/test_examples.py \ No newline at end of file + python -m pytest examples/openvino/test_examples.py diff --git a/.github/workflows/test_openvino_notebooks.yml b/.github/workflows/test_openvino_notebooks.yml index 7b037d0565..ed77077e87 100644 --- a/.github/workflows/test_openvino_notebooks.yml +++ b/.github/workflows/test_openvino_notebooks.yml @@ -23,9 +23,9 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.8", "3.10"] + python-version: ["3.8", "3.11"] - runs-on: ubuntu-20.04 + runs-on: ubuntu-22.04 steps: - uses: actions/checkout@v2 diff --git a/examples/neural_compressor/language-modeling/run_clm.py b/examples/neural_compressor/language-modeling/run_clm.py index 1799ad6782..e2496d5baf 100644 --- a/examples/neural_compressor/language-modeling/run_clm.py +++ b/examples/neural_compressor/language-modeling/run_clm.py @@ -57,13 +57,10 @@ from transformers.utils.versions import require_version from optimum.intel.neural_compressor import INCModelForCausalLM, INCQuantizer, INCTrainer -from optimum.intel.utils.import_utils import ( - INTEL_EXTENSION_FOR_TRANSFORMERS_IMPORT_ERROR, - is_intel_extension_for_transformers_available, -) +from optimum.intel.utils.import_utils import ITREX_IMPORT_ERROR, is_itrex_available -if is_intel_extension_for_transformers_available(): +if is_itrex_available(): from intel_extension_for_transformers.transformers.utils.config import GPTQConfig, RtnConfig os.environ["CUDA_VISIBLE_DEVICES"] = "" @@ -658,8 +655,8 @@ def compute_metrics(eval_preds): else: recipes = {} if optim_args.quantization_approach == "weight_only": - if not is_intel_extension_for_transformers_available(): - raise ImportError(INTEL_EXTENSION_FOR_TRANSFORMERS_IMPORT_ERROR.format("WeightOnly quantization")) + if not is_itrex_available(): + raise ImportError(ITREX_IMPORT_ERROR.format("WeightOnly quantization")) if optim_args.apply_pruning or optim_args.apply_distillation: raise ValueError("Weight only quantization and pruning or distillation cannot be combined.") diff --git a/examples/openvino/audio-classification/requirements.txt b/examples/openvino/audio-classification/requirements.txt index 60c66d8091..df77b9298b 100644 --- a/examples/openvino/audio-classification/requirements.txt +++ b/examples/openvino/audio-classification/requirements.txt @@ -1,4 +1,5 @@ datasets>=1.14.0 evaluate librosa -torchaudio \ No newline at end of file +torchaudio +accelerate diff --git a/examples/openvino/audio-classification/run_audio_classification.py b/examples/openvino/audio-classification/run_audio_classification.py index b8df86a575..30b95c1739 100644 --- a/examples/openvino/audio-classification/run_audio_classification.py +++ b/examples/openvino/audio-classification/run_audio_classification.py @@ -35,7 +35,7 @@ from transformers.utils import check_min_version, send_example_telemetry from transformers.utils.versions import require_version -from optimum.intel.openvino import OVConfig, OVTrainer, OVTrainingArguments +from optimum.intel import OVConfig, OVTrainer, OVTrainingArguments logger = logging.getLogger(__name__) diff --git a/examples/openvino/image-classification/requirements.txt b/examples/openvino/image-classification/requirements.txt index c52a5f399b..a55c46ca4c 100644 --- a/examples/openvino/image-classification/requirements.txt +++ b/examples/openvino/image-classification/requirements.txt @@ -2,3 +2,4 @@ datasets >= 1.8.0 torch >= 1.9.0 torchvision>=0.6.0 evaluate +accelerate diff --git a/examples/openvino/image-classification/run_image_classification.py b/examples/openvino/image-classification/run_image_classification.py index 5f98d95cb5..04c2984d8b 100644 --- a/examples/openvino/image-classification/run_image_classification.py +++ b/examples/openvino/image-classification/run_image_classification.py @@ -52,7 +52,7 @@ from transformers.utils import check_min_version, send_example_telemetry from transformers.utils.versions import require_version -from optimum.intel.openvino import OVConfig, OVTrainer, OVTrainingArguments +from optimum.intel import OVConfig, OVTrainer, OVTrainingArguments logger = logging.getLogger(__name__) diff --git a/examples/openvino/question-answering/requirements.txt b/examples/openvino/question-answering/requirements.txt index 3bd58b158b..0bb12723c2 100644 --- a/examples/openvino/question-answering/requirements.txt +++ b/examples/openvino/question-answering/requirements.txt @@ -1,3 +1,4 @@ datasets >= 1.8.0 torch >= 1.9.0 evaluate +accelerate diff --git a/examples/openvino/question-answering/run_qa.py b/examples/openvino/question-answering/run_qa.py index a86c7fb6d7..261fa839c9 100644 --- a/examples/openvino/question-answering/run_qa.py +++ b/examples/openvino/question-answering/run_qa.py @@ -49,7 +49,7 @@ from transformers.utils.versions import require_version from utils_qa import postprocess_qa_predictions -from optimum.intel.openvino import OVConfig, OVTrainingArguments +from optimum.intel import OVConfig, OVTrainingArguments # Will error if the minimal version of Transformers is not installed. Remove at your own risks. diff --git a/examples/openvino/question-answering/trainer_qa.py b/examples/openvino/question-answering/trainer_qa.py index bda91f99b5..c10466060b 100644 --- a/examples/openvino/question-answering/trainer_qa.py +++ b/examples/openvino/question-answering/trainer_qa.py @@ -20,7 +20,7 @@ import torch.nn.functional as F from transformers.trainer_utils import PredictionOutput -from optimum.intel.openvino.trainer import OVTrainer +from optimum.intel import OVTrainer class QuestionAnsweringOVTrainer(OVTrainer): diff --git a/examples/openvino/text-classification/requirements.txt b/examples/openvino/text-classification/requirements.txt index 95655f80ec..660e820c3c 100644 --- a/examples/openvino/text-classification/requirements.txt +++ b/examples/openvino/text-classification/requirements.txt @@ -4,4 +4,5 @@ scipy scikit-learn protobuf torch >= 1.3 -evaluate \ No newline at end of file +evaluate +accelerate diff --git a/examples/openvino/text-classification/run_glue.py b/examples/openvino/text-classification/run_glue.py index 002f67232c..66670de77e 100644 --- a/examples/openvino/text-classification/run_glue.py +++ b/examples/openvino/text-classification/run_glue.py @@ -46,7 +46,7 @@ from transformers.utils import check_min_version, send_example_telemetry from transformers.utils.versions import require_version -from optimum.intel.openvino import OVConfig, OVTrainer, OVTrainingArguments +from optimum.intel import OVConfig, OVTrainer, OVTrainingArguments # Will error if the minimal version of Transformers is not installed. Remove at your own risks. diff --git a/notebooks/openvino/optimum_openvino_inference.ipynb b/notebooks/openvino/optimum_openvino_inference.ipynb index b94238d358..dcd7dc866f 100644 --- a/notebooks/openvino/optimum_openvino_inference.ipynb +++ b/notebooks/openvino/optimum_openvino_inference.ipynb @@ -76,7 +76,7 @@ } ], "source": [ - "from optimum.intel.openvino import OVModelForQuestionAnswering\n", + "from optimum.intel import OVModelForQuestionAnswering\n", "\n", "# Load PyTorch model from the Hub and export to OpenVINO in the background\n", "model = OVModelForQuestionAnswering.from_pretrained(\"distilbert-base-uncased-distilled-squad\", export=True)\n", @@ -182,7 +182,7 @@ } ], "source": [ - "from optimum.intel.openvino import OVModelForQuestionAnswering\n", + "from optimum.intel import OVModelForQuestionAnswering\n", "from transformers import AutoTokenizer, pipeline\n", "\n", "model = OVModelForQuestionAnswering.from_pretrained(\"distilbert-base-uncased-distilled-squad-ov-fp32\")\n", @@ -240,7 +240,7 @@ ], "source": [ "import torch\n", - "from optimum.intel.openvino import OVModelForQuestionAnswering\n", + "from optimum.intel import OVModelForQuestionAnswering\n", "from transformers import AutoTokenizer, pipeline\n", "\n", "model = OVModelForQuestionAnswering.from_pretrained(\"distilbert-base-uncased-distilled-squad-ov-fp32\")\n", @@ -324,7 +324,7 @@ } ], "source": [ - "from optimum.intel.openvino import OVModelForQuestionAnswering\n", + "from optimum.intel import OVModelForQuestionAnswering\n", "from transformers import AutoTokenizer, pipeline\n", "\n", "model = OVModelForQuestionAnswering.from_pretrained(\n", @@ -529,7 +529,7 @@ ], "source": [ "from IPython.display import Audio\n", - "from optimum.intel.openvino import OVModelForAudioClassification\n", + "from optimum.intel import OVModelForAudioClassification\n", "from transformers import AutoFeatureExtractor, pipeline\n", "from datasets import load_dataset\n", "\n", @@ -638,7 +638,7 @@ } ], "source": [ - "from optimum.intel.openvino import OVModelForCausalLM\n", + "from optimum.intel import OVModelForCausalLM\n", "from transformers import AutoTokenizer, pipeline\n", "\n", "model_id = \"helenai/gpt2-ov\"\n", @@ -704,7 +704,7 @@ ], "source": [ "from IPython.display import Image\n", - "from optimum.intel.openvino import OVModelForImageClassification\n", + "from optimum.intel import OVModelForImageClassification\n", "from transformers import AutoImageProcessor, pipeline\n", "\n", "model_id = \"helenai/microsoft-swin-tiny-patch4-window7-224-ov\"\n", @@ -766,7 +766,7 @@ } ], "source": [ - "from optimum.intel.openvino import OVModelForMaskedLM\n", + "from optimum.intel import OVModelForMaskedLM\n", "from transformers import AutoTokenizer, pipeline\n", "\n", "model_id = \"helenai/bert-base-uncased-ov\"\n", @@ -835,7 +835,7 @@ } ], "source": [ - "from optimum.intel.openvino import OVModelForQuestionAnswering\n", + "from optimum.intel import OVModelForQuestionAnswering\n", "from transformers import AutoTokenizer, pipeline\n", "\n", "# Load the model and tokenizer saved in Part 1 of this notebook. Or use the line below to load them from the hub\n", @@ -890,7 +890,7 @@ } ], "source": [ - "from optimum.intel.openvino import OVModelForSeq2SeqLM\n", + "from optimum.intel import OVModelForSeq2SeqLM\n", "from transformers import AutoTokenizer, pipeline\n", "\n", "model_id = \"helenai/t5-small-ov\"\n", @@ -998,7 +998,7 @@ } ], "source": [ - "from optimum.intel.openvino import OVModelForSequenceClassification\n", + "from optimum.intel import OVModelForSequenceClassification\n", "from transformers import AutoTokenizer, pipeline\n", "\n", "model_id = \"helenai/papluca-xlm-roberta-base-language-detection-ov\"\n", @@ -1047,7 +1047,7 @@ } ], "source": [ - "from optimum.intel.openvino import OVModelForTokenClassification\n", + "from optimum.intel import OVModelForTokenClassification\n", "from transformers import AutoTokenizer, pipeline\n", "\n", "model_id = \"helenai/dslim-bert-base-NER-ov-fp32\"\n", diff --git a/notebooks/openvino/question_answering_quantization.ipynb b/notebooks/openvino/question_answering_quantization.ipynb index 196e3ba6a7..2481c9b904 100644 --- a/notebooks/openvino/question_answering_quantization.ipynb +++ b/notebooks/openvino/question_answering_quantization.ipynb @@ -51,7 +51,7 @@ "import transformers\n", "from evaluate import evaluator\n", "from openvino.runtime import Core\n", - "from optimum.intel.openvino import OVModelForQuestionAnswering, OVQuantizer, OVQuantizationConfig, OVConfig\n", + "from optimum.intel import OVModelForQuestionAnswering, OVQuantizer, OVQuantizationConfig, OVConfig\n", "from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline\n", "\n", "transformers.logging.set_verbosity_error()\n", @@ -286,7 +286,7 @@ "**NOTE:** if you notice very low accuracy after post-training quantization, it is likely caused by an overflow issue which affects processors that do not contain VNNI (Vector Neural Network Instruction). NNCF has an `overflow_fix` option to address this. It will effectively use 7-bits for quantizing instead of 8-bits to prevent the overflow. To use this option, modify the code in the next cell to add an explicit quantization configuration, and set `overflow_fix` to `\"enable\"`:\n", "\n", "```\n", - "from optimum.intel.openvino import OVConfig, OVQuantizationConfig\n", + "from optimum.intel import OVConfig, OVQuantizationConfig\n", "\n", "ov_config = OVConfig(quantization_config=OVQuantizationConfig(overflow_fix=\"enable\")\n", "quantizer = OVQuantizer.from_pretrained(model)\n", diff --git a/notebooks/openvino/requirements.txt b/notebooks/openvino/requirements.txt index 3432e949e9..bb7a517cff 100644 --- a/notebooks/openvino/requirements.txt +++ b/notebooks/openvino/requirements.txt @@ -1,4 +1,4 @@ -optimum-intel[openvino, nncf] +optimum-intel[openvino] datasets evaluate[evaluator] ipywidgets diff --git a/notebooks/openvino/stable_diffusion_hybrid_quantization.ipynb b/notebooks/openvino/stable_diffusion_hybrid_quantization.ipynb index 41969b162a..142cde4923 100644 --- a/notebooks/openvino/stable_diffusion_hybrid_quantization.ipynb +++ b/notebooks/openvino/stable_diffusion_hybrid_quantization.ipynb @@ -52,7 +52,8 @@ "import transformers\n", "from pathlib import Path\n", "from openvino.runtime import Core\n", - "from optimum.intel import OVStableDiffusionPipeline, OVWeightQuantizationConfig\n", + "from optimum.intel import OVConfig, OVQuantizer, OVStableDiffusionPipeline, OVWeightQuantizationConfig\n", + "from optimum.intel.openvino.configuration import OVQuantizationMethod\n", "\n", "transformers.logging.set_verbosity_error()\n", "datasets.logging.set_verbosity_error()" @@ -198,9 +199,14 @@ }, "outputs": [], "source": [ - "quantization_config = OVWeightQuantizationConfig(bits=8, dataset=calibration_dataset, num_samples=NUM_SAMPLES)\n", - "int8_pipe = OVStableDiffusionPipeline.from_pretrained(model_id=MODEL_ID, export=True, quantization_config=quantization_config)\n", - "int8_pipe.save_pretrained(int8_model_path)" + "int8_pipe = OVStableDiffusionPipeline.from_pretrained(model_id=MODEL_ID, export=True)\n", + "quantization_config = OVWeightQuantizationConfig(bits=8, num_samples=NUM_SAMPLES, quant_method=OVQuantizationMethod.HYBRID)\n", + "quantizer = OVQuantizer(int8_pipe)\n", + "quantizer.quantize(\n", + " ov_config=OVConfig(quantization_config=quantization_config),\n", + " calibration_dataset=calibration_dataset,\n", + " save_directory=int8_model_path\n", + ")" ] }, { diff --git a/notebooks/openvino/stable_diffusion_optimization.ipynb b/notebooks/openvino/stable_diffusion_optimization.ipynb index 6c79bc5df0..f2297b2151 100644 --- a/notebooks/openvino/stable_diffusion_optimization.ipynb +++ b/notebooks/openvino/stable_diffusion_optimization.ipynb @@ -14,7 +14,7 @@ "metadata": {}, "outputs": [], "source": [ - "from optimum.intel.openvino import OVStableDiffusionPipeline\n", + "from optimum.intel import OVStableDiffusionPipeline\n", "from diffusers.training_utils import set_seed\n", "from IPython.display import display" ] diff --git a/optimum/commands/export/openvino.py b/optimum/commands/export/openvino.py index cdae847468..ffd084d4e6 100644 --- a/optimum/commands/export/openvino.py +++ b/optimum/commands/export/openvino.py @@ -18,6 +18,8 @@ from pathlib import Path from typing import TYPE_CHECKING, Optional +from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE + from ...exporters import TasksManager from ...intel.utils.import_utils import DIFFUSERS_IMPORT_ERROR, is_diffusers_available from ..base import BaseOptimumCLICommand, CommandInfo @@ -47,7 +49,9 @@ def parse_args_openvino(parser: "ArgumentParser"): f" {str(TasksManager.get_all_tasks())}. For decoder models, use `xxx-with-past` to export the model using past key values in the decoder." ), ) - optional_group.add_argument("--cache_dir", type=str, default=None, help="Path indicating where to store cache.") + optional_group.add_argument( + "--cache_dir", type=str, default=HUGGINGFACE_HUB_CACHE, help="Path indicating where to store cache." + ) optional_group.add_argument( "--framework", type=str, @@ -115,6 +119,15 @@ def parse_args_openvino(parser: "ArgumentParser"): "or ['conceptual_captions','laion/220k-GPT4Vision-captions-from-LIVIS','laion/filtered-wit'] for diffusion models." ), ) + optional_group.add_argument( + "--all-layers", + action="store_true", + default=None, + help=( + "Whether embeddings and last MatMul layers should be compressed to INT4. If not provided an weight " + "compression is applied, they are compressed to INT8." + ), + ) optional_group.add_argument( "--disable-stateful", action="store_true", @@ -194,6 +207,7 @@ def run(self): and self.args.ratio is None and self.args.group_size is None and self.args.sym is None + and self.args.all_layers is None and self.args.model in _DEFAULT_4BIT_CONFIGS ): quantization_config = _DEFAULT_4BIT_CONFIGS[self.args.model] @@ -203,6 +217,7 @@ def run(self): "ratio": 1 if is_int8 else (self.args.ratio or 0.8), "sym": self.args.sym or False, "group_size": -1 if is_int8 else self.args.group_size, + "all_layers": None if is_int8 else self.args.all_layers, } if self.args.weight_format in {"int4_sym_g128", "int4_asym_g128", "int4_sym_g64", "int4_asym_g64"}: @@ -222,6 +237,9 @@ def run(self): ) library_name = "transformers" + if self.args.convert_tokenizer: + logger.warning("`--convert-tokenizer` option is deprecated. Tokenizer will be converted by default.") + if ( library_name == "diffusers" and ov_config @@ -257,10 +275,21 @@ def run(self): ) model.save_pretrained(self.args.output) - else: - if self.args.convert_tokenizer: - logger.warning("`--convert-tokenizer` option is deprecated. Tokenizer will be converted by default.") + if self.args.disable_convert_tokenizer: + return + # avoid import when using other exporters (IPEX, INC) + from ...exporters.openvino.convert import export_tokenizer + + output = Path(self.args.output) + tokenizer = getattr(model, "tokenizer", None) + if tokenizer is not None: + export_tokenizer(tokenizer, output / "tokenizer") + + tokenizer_2 = getattr(model, "tokenizer_2", None) + if tokenizer_2 is not None: + export_tokenizer(tokenizer_2, output / "tokenizer_2") + else: # TODO : add input shapes main_export( model_name_or_path=self.args.model, diff --git a/optimum/exporters/openvino/__main__.py b/optimum/exporters/openvino/__main__.py index d7b29584d6..9db6719069 100644 --- a/optimum/exporters/openvino/__main__.py +++ b/optimum/exporters/openvino/__main__.py @@ -13,16 +13,18 @@ # limitations under the License. import logging +import warnings from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union +from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from requests.exceptions import ConnectionError as RequestsConnectionError from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizerBase from optimum.exporters import TasksManager from optimum.exporters.onnx.base import OnnxConfig from optimum.exporters.onnx.constants import SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED -from optimum.exporters.openvino.convert import export_from_model, export_tokenizer +from optimum.exporters.openvino.convert import export_from_model from optimum.intel.utils.import_utils import is_openvino_tokenizers_available, is_transformers_version from optimum.utils.save_utils import maybe_load_preprocessors @@ -48,7 +50,7 @@ def main_export( task: str = "auto", device: str = "cpu", framework: Optional[str] = None, - cache_dir: Optional[str] = None, + cache_dir: str = HUGGINGFACE_HUB_CACHE, trust_remote_code: bool = False, pad_token_id: Optional[int] = None, subfolder: str = "", @@ -56,6 +58,7 @@ def main_export( force_download: bool = False, local_files_only: bool = False, use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, model_kwargs: Optional[Dict[str, Any]] = None, custom_export_configs: Optional[Dict[str, "OnnxConfig"]] = None, fn_get_submodels: Optional[Callable] = None, @@ -106,9 +109,11 @@ def main_export( cached versions if they exist. local_files_only (`Optional[bool]`, defaults to `False`): Whether or not to only look at local files (i.e., do not try to download the model). - use_auth_token (`Optional[str]`, defaults to `None`): + use_auth_token (Optional[Union[bool, str]], defaults to `None`): + Deprecated. Please use `token` instead. + token (Optional[Union[bool, str]], defaults to `None`): The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated - when running `transformers-cli login` (stored in `~/.huggingface`). + when running `huggingface-cli login` (stored in `~/.huggingface`). model_kwargs (`Optional[Dict[str, Any]]`, defaults to `None`): Experimental usage: keyword arguments to pass to the model during the export. This argument should be used along the `custom_export_configs` argument @@ -137,6 +142,15 @@ def main_export( ``` """ + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") + token = use_auth_token + if compression_option is not None: logger.warning( "The `compression_option` argument is deprecated and will be removed in optimum-intel v1.17.0. " @@ -195,7 +209,7 @@ def main_export( subfolder=subfolder, revision=revision, cache_dir=cache_dir, - use_auth_token=use_auth_token, + token=token, local_files_only=local_files_only, force_download=force_download, trust_remote_code=trust_remote_code, @@ -205,6 +219,10 @@ def main_export( model_type = config.model_type.replace("_", "-") if model_type not in TasksManager._SUPPORTED_MODEL_TYPE: custom_architecture = True + if custom_export_configs is None: + raise ValueError( + f"Trying to export a {model_type} model, that is a custom or unsupported architecture, but no custom export configuration was passed as `custom_export_configs`. Please refer to https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model#custom-export-of-transformers-models for an example on how to export custom models. Please open an issue at https://github.com/huggingface/optimum-intel/issues if you would like the model type {model_type} to be supported natively in the OpenVINO export." + ) elif task not in TasksManager.get_supported_tasks_for_model_type( model_type, exporter="openvino", library_name=library_name ): @@ -218,6 +236,7 @@ def main_export( raise ValueError( f"Asked to export a {model_type} model for the task {task}{autodetected_message}, but the Optimum OpenVINO exporter only supports the tasks {', '.join(model_tasks.keys())} for {model_type}. Please use a supported task. Please open an issue at https://github.com/huggingface/optimum/issues if you would like the task {task} to be supported in the ONNX export for {model_type}." ) + if is_transformers_version(">=", "4.36") and model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED: loading_kwargs["attn_implementation"] = "eager" # there are some difference between remote and in library representation of past key values for some models, @@ -267,7 +286,7 @@ class StoreAttr(object): subfolder=subfolder, revision=revision, cache_dir=cache_dir, - use_auth_token=use_auth_token, + token=token, local_files_only=local_files_only, force_download=force_download, trust_remote_code=trust_remote_code, @@ -341,6 +360,9 @@ class StoreAttr(object): **kwargs_shapes, ) + # hide openvino import when using other exporters + from optimum.exporters.openvino.convert import export_tokenizer + if convert_tokenizer and is_openvino_tokenizers_available(): if library_name != "diffusers": tokenizer = next( @@ -359,11 +381,11 @@ class StoreAttr(object): else: tokenizer = getattr(model, "tokenizer", None) if tokenizer is not None: - export_tokenizer(tokenizer, output) + export_tokenizer(tokenizer, output / "tokenizer") tokenizer_2 = getattr(model, "tokenizer_2", None) if tokenizer_2 is not None: - export_tokenizer(tokenizer_2, output, suffix="_2") + export_tokenizer(tokenizer_2, output / "tokenizer_2") elif convert_tokenizer and not is_openvino_tokenizers_available(): logger.warning("Tokenizer won't be converted.") diff --git a/optimum/exporters/openvino/convert.py b/optimum/exporters/openvino/convert.py index 55e3318017..3b214f77e4 100644 --- a/optimum/exporters/openvino/convert.py +++ b/optimum/exporters/openvino/convert.py @@ -20,9 +20,10 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union +import onnx from transformers.utils import is_tf_available, is_torch_available -from openvino.runtime import PartialShape, save_model +from openvino.runtime import Model, PartialShape, save_model from openvino.runtime.exceptions import OVTypeError from openvino.runtime.utils.types import get_element_type from openvino.tools.ovc import convert_model @@ -32,6 +33,14 @@ from optimum.exporters.onnx.convert import export_pytorch as export_pytorch_to_onnx from optimum.exporters.onnx.convert import export_tensorflow as export_tensorflow_onnx from optimum.exporters.utils import _get_submodels_and_export_configs +from optimum.intel.utils.import_utils import ( + _nncf_version, + _optimum_intel_version, + _optimum_version, + _timm_version, + _torch_version, + _transformers_version, +) from optimum.utils import DEFAULT_DUMMY_SHAPES, is_diffusers_available from optimum.utils.save_utils import maybe_save_preprocessors @@ -81,6 +90,8 @@ def _save_model(model, path: str, ov_config: Optional["OVConfig"] = None): compress_to_fp16 = ov_config.dtype == "fp16" + library_name = TasksManager.infer_library_from_model(Path(path).parent) + model = _add_version_info_to_model(model, library_name) save_model(model, path, compress_to_fp16) @@ -347,6 +358,7 @@ def ts_patched_forward(*args, **kwargs): with patcher: check_dummy_inputs_are_allowed(model, dummy_inputs) + sig = inspect.signature(model.forward) if hasattr(model, "forward") else inspect.signature(model.call) inputs = config.ordered_inputs(model) input_names = list(inputs.keys()) output_names = list(config.outputs.keys()) @@ -376,7 +388,6 @@ def ts_patched_forward(*args, **kwargs): ov_config=ov_config, ) - sig = inspect.signature(model.forward) if hasattr(model, "forward") else inspect.signature(model.call) ordered_dummy_inputs = {param: dummy_inputs[param] for param in sig.parameters if param in dummy_inputs} if not ordered_dummy_inputs: ordered_dummy_inputs = dummy_inputs @@ -392,7 +403,7 @@ def ts_patched_forward(*args, **kwargs): inp_tensor.get_tensor().set_names({input_name}) inp_data = flatten_inputs[idx] static_shape = PartialShape(inp_data.shape) - dims = inputs[input_name] + dims = inputs.get(input_name, []) for dim in dims: static_shape[dim] = -1 inp_tensor.get_node().set_partial_shape(static_shape) @@ -536,7 +547,7 @@ def export_from_model( # TODO: support onnx_config.py in the model repo if custom_architecture and custom_export_configs is None: raise ValueError( - f"Trying to export a {model_type} model, that is a custom or unsupported architecture, but no custom export configuration was passed as `custom_export_configs`. Please refer to https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model#custom-export-of-transformers-models for an example on how to export custom models. Please open an issue at https://github.com/huggingface/optimum/issues if you would like the model type {model_type} to be supported natively in the ONNX export." + f"Trying to export a {model_type} model, that is a custom or unsupported architecture, but no custom export configuration was passed as `custom_export_configs`. Please refer to https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model#custom-export-of-transformers-models for an example on how to export custom models. Please open an issue at https://github.com/huggingface/optimum-intel/issues if you would like the model type {model_type} to be supported natively in the OpenVINO export." ) if task.startswith("text-generation") and model.config.is_encoder_decoder: @@ -603,7 +614,12 @@ def export_from_model( model.config.save_pretrained(output) generation_config = getattr(model, "generation_config", None) if generation_config is not None: - generation_config.save_pretrained(output) + try: + generation_config.save_pretrained(output) + except Exception as exception: + logger.warning( + f"The generation config will not be saved, saving failed with following error:\n{exception}" + ) model_name_or_path = model.config._name_or_path maybe_save_preprocessors(model_name_or_path, output, trust_remote_code=trust_remote_code) @@ -656,20 +672,21 @@ def export_tokenizer( output: Union[str, Path], suffix: Optional[str] = "", ): - from optimum.intel.openvino import OV_DETOKENIZER_NAME, OV_TOKENIZER_NAME # avoid circular imports + # avoid circular imports + from optimum.intel.openvino import OV_DETOKENIZER_NAME, OV_TOKENIZER_NAME + from optimum.intel.openvino.utils import maybe_convert_tokenizer_to_fast try: from openvino_tokenizers import convert_tokenizer except ModuleNotFoundError: - # avoid this message before tokenizers are part of the openvino dependencies - # logger.info( - # "Run `pip install openvino-tokenizers[transformers]` to get OpenVINO tokenizer/detokenizer models." - # ) return if not isinstance(output, Path): output = Path(output) + if output.exists(): + tokenizer = maybe_convert_tokenizer_to_fast(tokenizer, output) + try: converted = convert_tokenizer(tokenizer, with_detokenizer=True) except NotImplementedError: @@ -689,3 +706,34 @@ def export_tokenizer( for model, file_name in zip(converted, (OV_TOKENIZER_NAME, OV_DETOKENIZER_NAME)): save_model(model, output / file_name.format(suffix)) + + +def _add_version_info_to_model(model: Model, library_name: Optional[str] = None): + """ + Add dependency versions to OpenVINO model + """ + try: + model.set_rt_info(_transformers_version, ["optimum", "transformers_version"]) + model.set_rt_info(_torch_version, ["optimum", "pytorch_version"]) + model.set_rt_info(_optimum_intel_version, ["optimum", "optimum_intel_version"]) + model.set_rt_info(_optimum_version, ["optimum", "optimum_version"]) + + if any("token_embeddings" in output.get_names() for output in model.outputs): + import sentence_transformers + + model.set_rt_info(sentence_transformers.__version__, ["optimum", "sentence_transformers_version"]) + if library_name == "diffusers": + model.set_rt_info(_optimum_version, ["optimum", "diffusers_version"]) + elif library_name == "timm": + model.set_rt_info(_timm_version, ["optimum", "timm_version"]) + rt_info = model.get_rt_info() + if "nncf" in rt_info: + model.set_rt_info(_nncf_version, ["optimum", "nncf_version"]) + input_model = rt_info["conversion_parameters"].get("input_model", None) + if input_model is not None and "onnx" in input_model.value: + model.set_rt_info(onnx.__version__, ["optimum", "onnx_version"]) + + except Exception: + pass + + return model diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index 90297c8fb3..d69adc9da3 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -19,24 +19,43 @@ from transformers.utils import is_tf_available from optimum.exporters.onnx.config import TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig -from optimum.exporters.onnx.model_configs import GemmaOnnxConfig, LlamaOnnxConfig +from optimum.exporters.onnx.model_configs import ( + CodeGenOnnxConfig, + FalconOnnxConfig, + GemmaOnnxConfig, + LlamaOnnxConfig, + MPTOnnxConfig, + PhiOnnxConfig, + UNetOnnxConfig, + VaeDecoderOnnxConfig, + VaeEncoderOnnxConfig, +) from optimum.exporters.tasks import TasksManager from optimum.utils import DEFAULT_DUMMY_SHAPES from optimum.utils.input_generators import ( DummyInputGenerator, DummyPastKeyValuesGenerator, DummyTextInputGenerator, + FalconDummyPastKeyValuesGenerator, MistralDummyPastKeyValuesGenerator, ) from optimum.utils.normalized_config import NormalizedTextConfig from .model_patcher import ( + AquilaModelPatcher, BaichuanModelPatcher, ChatGLMModelPatcher, + CodeGenModelPatcher, + DBRXModelPatcher, GemmaModelPatcher, + InternLM2Patcher, + InternLMModelPatcher, LlamaModelPatcher, MixtralModelPatcher, + MPTModelPatcher, + Phi3ModelPatcher, QwenModelPatcher, + XverseModelPatcher, ) @@ -96,6 +115,15 @@ class Qwen2OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedTextConfig +@register_in_tasks_manager("qwen2-moe", *["text-generation", "text-generation-with-past"], library_name="transformers") +class Qwen2MoEOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): + DEFAULT_ONNX_OPSET = 14 + + DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator) + DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + + @register_in_tasks_manager("minicpm", *["text-generation", "text-generation-with-past"], library_name="transformers") class MiniCPMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): DEFAULT_ONNX_OPSET = 14 @@ -429,6 +457,11 @@ class InternLM2OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return InternLM2Patcher(self, model, model_kwargs=model_kwargs) + @register_in_tasks_manager("orion", *["text-generation", "text-generation-with-past"], library_name="transformers") class OrionOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): @@ -437,3 +470,318 @@ class OrionOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator) DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + + +@register_in_tasks_manager("olmo", *["text-generation", "text-generation-with-past"], library_name="transformers") +class OlmoOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): + DEFAULT_ONNX_OPSET = 14 + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + + +@register_in_tasks_manager( + "mpt", *["text-generation", "text-generation-with-past", "text-classification"], library_name="transformers" +) +class MPTOpenVINOConfig(MPTOnnxConfig): + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return MPTModelPatcher(self, model, model_kwargs=model_kwargs) + + +@register_in_tasks_manager( + "phi3", + *[ + "feature-extraction", + "feature-extraction-with-past", + "text-generation", + "text-generation-with-past", + "text-classification", + ], + library_name="transformers", +) +class Phi3OpenVINOConfig(PhiOnnxConfig): + DUMMY_INPUT_GENERATOR_CLASSES = ( + MistralDummyPastKeyValuesGenerator, + ) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES + DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_key_value_heads="num_key_value_heads", allow_new=True) + + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return Phi3ModelPatcher(self, model, model_kwargs=model_kwargs) + + +class OVFalconDummyPastKeyValuesGenerator(FalconDummyPastKeyValuesGenerator): + def __init__( + self, + task: str, + normalized_config: NormalizedTextConfig, + batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], + sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"], + random_batch_size_range: Optional[Tuple[int, int]] = None, + random_sequence_length_range: Optional[Tuple[int, int]] = None, + **kwargs, + ): + super().__init__( + task=task, + normalized_config=normalized_config, + batch_size=batch_size, + sequence_length=sequence_length, + random_batch_size_range=random_batch_size_range, + random_sequence_length_range=random_sequence_length_range, + **kwargs, + ) + if normalized_config.new_decoder_architecture: + self.num_kv_heads = normalized_config.num_attention_heads + else: + self.num_kv_heads = normalized_config.num_kv_heads if not normalized_config.multi_query else 1 + + self.head_dim = self.hidden_size // self.num_attention_heads + + +@register_in_tasks_manager( + "falcon", + *[ + "feature-extraction", + "feature-extraction-with-past", + "question-answering", + "text-generation", + "text-generation-with-past", + "token-classification", + ], + library_name="transformers", +) +class FalconOpenVINOConfig(FalconOnnxConfig): + DUMMY_INPUT_GENERATOR_CLASSES = ( + OVFalconDummyPastKeyValuesGenerator, + ) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES + DUMMY_PKV_GENERATOR_CLASS = OVFalconDummyPastKeyValuesGenerator + + +@register_in_tasks_manager("unet", *["semantic-segmentation"], library_name="diffusers") +class UNetOpenVINOConfig(UNetOnnxConfig): + @property + def inputs(self) -> Dict[str, Dict[int, str]]: + common_inputs = { + "sample": {0: "batch_size", 2: "height", 3: "width"}, + "timestep": {0: "steps"}, + "encoder_hidden_states": {0: "batch_size", 1: "sequence_length"}, + } + + # TODO : add text_image, image and image_embeds + if getattr(self._normalized_config, "addition_embed_type", None) == "text_time": + common_inputs["text_embeds"] = {0: "batch_size"} + common_inputs["time_ids"] = {0: "batch_size"} + + if getattr(self._normalized_config, "time_cond_proj_dim", None) is not None: + common_inputs["timestep_cond"] = {0: "batch_size"} + return common_inputs + + @property + def outputs(self) -> Dict[str, Dict[int, str]]: + return { + "out_sample": {0: "batch_size", 2: "height", 3: "width"}, + } + + +@register_in_tasks_manager("vae-encoder", *["semantic-segmentation"], library_name="diffusers") +class VaeEncoderOpenVINOConfig(VaeEncoderOnnxConfig): + @property + def inputs(self) -> Dict[str, Dict[int, str]]: + return { + "sample": {0: "batch_size", 2: "height", 3: "width"}, + } + + @property + def outputs(self) -> Dict[str, Dict[int, str]]: + return { + "latent_sample": {0: "batch_size", 2: "height_latent", 3: "width_latent"}, + } + + +@register_in_tasks_manager("vae-decoder", *["semantic-segmentation"], library_name="diffusers") +class VaeDecoderOpenVINOConfig(VaeDecoderOnnxConfig): + @property + def inputs(self) -> Dict[str, Dict[int, str]]: + return { + "latent_sample": {0: "batch_size", 2: "height_latent", 3: "width_latent"}, + } + + @property + def outputs(self) -> Dict[str, Dict[int, str]]: + return { + "sample": {0: "batch_size", 2: "height", 3: "width"}, + } + + +@register_in_tasks_manager( + "persimmon", + *[ + "feature-extraction", + "feature-extraction-with-past", + "text-generation", + "text-generation-with-past", + "text-classification", + ], + library_name="transformers", +) +class PersimmonOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): + DEFAULT_ONNX_OPSET = 14 + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + + +@register_in_tasks_manager("biogpt", *["text-generation", "text-generation-with-past"], library_name="transformers") +class BioGPTOpenVINOConfig(TextDecoderOnnxConfig): + # BioGPT does not require position_ids input. + DEFAULT_ONNX_OPSET = 13 + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + + +@register_in_tasks_manager( + "gpt-neox-japanese", *["text-generation", "text-generation-with-past"], library_name="transformers" +) +class GPTNeoxJapaneseOpenVINOConfig(TextDecoderOnnxConfig): + # GPTNeoxJapanese does not require position_ids input. + DEFAULT_ONNX_OPSET = 13 + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + + +@register_in_tasks_manager( + "cohere", + *[ + "feature-extraction", + "feature-extraction-with-past", + "text-generation", + "text-generation-with-past", + "text-classification", + ], + library_name="transformers", +) +class CohereOpenVINOConfig(LlamaOpenVINOConfig): + pass + + +@register_in_tasks_manager("xglm", *["text-generation", "text-generation-with-past"], library_name="transformers") +class XGLMConfig(TextDecoderWithPositionIdsOnnxConfig): + DEFAULT_ONNX_OPSET = 13 + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args( + num_attention_heads="attention_heads", hidden_size="d_model" + ) + + +class AquilaDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator): + def __init__( + self, + task: str, + normalized_config: NormalizedTextConfig, + batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], + sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"], + random_batch_size_range: Optional[Tuple[int, int]] = None, + random_sequence_length_range: Optional[Tuple[int, int]] = None, + **kwargs, + ): + super().__init__( + task, + normalized_config, + batch_size, + sequence_length, + random_batch_size_range, + random_sequence_length_range, + **kwargs, + ) + self.num_key_value_heads = getattr( + normalized_config, "num_key_value_heads", normalized_config.num_attention_heads + ) + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + shape = ( + self.batch_size, + self.num_key_value_heads, + self.sequence_length, + self.hidden_size // self.num_attention_heads, + ) + return [ + ( + self.random_float_tensor(shape, framework=framework, dtype=float_dtype), + self.random_float_tensor(shape, framework=framework, dtype=float_dtype), + ) + for _ in range(self.num_layers) + ] + + +@register_in_tasks_manager("aquila", *["text-generation", "text-generation-with-past"], library_name="transformers") +class AquilaMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): + DEFAULT_ONNX_OPSET = 14 + + DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, AquilaDummyPastKeyValuesGenerator) + DUMMY_PKV_GENERATOR_CLASS = AquilaDummyPastKeyValuesGenerator + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_key_value_heads="num_key_value_heads", allow_new=True) + + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return AquilaModelPatcher(self, model, model_kwargs=model_kwargs) + + +@register_in_tasks_manager("xverse", *["text-generation", "text-generation-with-past"], library_name="transformers") +class XverseMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): + DEFAULT_ONNX_OPSET = 14 + + DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, DummyPastKeyValuesGenerator) + DUMMY_PKV_GENERATOR_CLASS = DummyPastKeyValuesGenerator + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return XverseModelPatcher(self, model, model_kwargs=model_kwargs) + + +@register_in_tasks_manager("internlm", *["text-generation", "text-generation-with-past"], library_name="transformers") +class InternLMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): + DEFAULT_ONNX_OPSET = 14 + + DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, DummyPastKeyValuesGenerator) + DUMMY_PKV_GENERATOR_CLASS = DummyPastKeyValuesGenerator + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return InternLMModelPatcher(self, model, model_kwargs=model_kwargs) + + +@register_in_tasks_manager( + "codegen", + *["feature-extraction", "feature-extraction-with-past", "text-generation", "text-generation-with-past"], + library_name="transformers", +) +class CodeGenOpenVINOConfig(CodeGenOnnxConfig): + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return CodeGenModelPatcher(self, model, model_kwargs=model_kwargs) + + +@register_in_tasks_manager( + "dbrx", + *["text-generation", "text-generation-with-past"], + library_name="transformers", +) +class DBRXOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): + DEFAULT_ONNX_OPSET = 14 + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args( + num_attention_heads="n_heads", + hidden_size="d_model", + num_layers="n_layers", + num_key_value_heads="attn_config.kv_n_heads", + allow_new=True, + ) + DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator) + DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator + + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return DBRXModelPatcher(self, model, model_kwargs=model_kwargs) diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 3649c163c6..0265b3a5fc 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -12,12 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect import logging as log +import math import types from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import torch import torch.nn.functional as F +from transformers.cache_utils import Cache, StaticCache from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.utils import is_tf_available @@ -41,6 +44,9 @@ from transformers.modeling_tf_utils import TFPreTrainedModel +BETTERTRANSFORMER_IGNORE = ("codegen",) + + def patch_model_with_bettertransformer(model): COLOR_RED = "\033[1;31m" COLOR_RESET = "\033[0m" @@ -79,6 +85,10 @@ def patch_model_with_bettertransformer(model): # model already has required SDPA implementation if getattr(model, "_supports_sdpa", False) and getattr(model.config, "_attn_implementation", "eager") == "sdpa": return model + + if model.config.model_type in BETTERTRANSFORMER_IGNORE: + return model + try: model = model.to_bettertransformer() except Exception as e: @@ -291,23 +301,39 @@ def __exit__(self, exc_type, exc_value, traceback): # adopted from # https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/gemma/modeling_gemma.py#L965 # https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/llama/modeling_llama.py#L1058 -def _llama_gemma_update_causal_mask(self, attention_mask, input_tensor, cache_position, **kwargs): +def _llama_gemma_update_causal_mask_legacy(self, attention_mask, input_tensor, cache_position, past_seen_tokens=None): from transformers.modeling_attn_mask_utils import AttentionMaskConverter - # for compatibility with https://github.com/huggingface/transformers/pull/30047 - current_length = kwargs.get("current_length", cache_position[-1]) + if self.config._attn_implementation == "sdpa" and past_seen_tokens is not None: + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, + # in order to dispatch on Flash Attention 2. + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens + ): + return None + dtype, device = input_tensor.dtype, input_tensor.device + # difference with original modeling # using minimum from dtype with larger bandwith (floa32) may lead to overflow # during execution on platforms with default lower precision (bfloat16, float16) min_dtype = torch.finfo(torch.float16).min sequence_length = input_tensor.shape[1] + # difference with original modeling if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache target_length = self.config.max_position_embeddings else: # dynamic cache - target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else current_length + 1 + if past_seen_tokens is not None: + current_length = past_seen_tokens + sequence_length + 1 + # TODO : remove after support of transformers >= v4.40.0 + else: + current_length = cache_position[-1] + 1 + + target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else current_length + # difference with original modeling causal_mask = torch.full((sequence_length, target_length), fill_value=1, dtype=dtype, device=device) * min_dtype + if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) @@ -344,6 +370,104 @@ def _llama_gemma_update_causal_mask(self, attention_mask, input_tensor, cache_po return causal_mask +# adopted from https://github.com/huggingface/transformers/blob/f4014e75db0190792b3feeccfc5dc5b5f9f0ce7b/src/transformers/models/llama/modeling_llama.py#L1036 +def _llama_gemma_update_causal_mask_latest( + self, + attention_mask, + input_tensor, + cache_position, + past_key_values, + output_attentions, +): + from transformers.cache_utils import StaticCache + from transformers.modeling_attn_mask_utils import AttentionMaskConverter + + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + # difference with original modeling + # using minimum from dtype with larger bandwith (floa32) may lead to overflow + # during execution on platforms with default lower precision (bfloat16, float16) + min_dtype = torch.finfo(torch.float16).min + + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + if attention_mask is not None and attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + causal_mask = attention_mask + else: + # difference with original modeling + causal_mask = ( + torch.full((sequence_length, target_length), fill_value=1, dtype=dtype, device=device) * min_dtype + ) + + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + +# TODO : deprecate _llama_gemma_update_causal_mask_legacy when transformers>=4.41.0 +if is_transformers_version(">", "4.40.2"): + _llama_gemma_update_causal_mask = _llama_gemma_update_causal_mask_latest +else: + _llama_gemma_update_causal_mask = _llama_gemma_update_causal_mask_legacy + + class GemmaModelPatcher(DecoderModelPatcher): def __enter__(self): super().__enter__() @@ -600,6 +724,123 @@ def __exit__(self, exc_type, exc_value, traceback): self._model.config.fp16 = self.original_fp16 +def _baichuan13b_atten_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = True, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + proj = self.W_pack(hidden_states) + proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2) + query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + if past_key_value is not None: + # reuse k, v, self_attention + if attention_mask is not None: + attention_mask = attention_mask[:, :, -key_states.shape[-2] :, :] + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + if not output_attentions: + past_key_value = (key_states, value_states) if use_cache else None + attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask) + attn_weights = None + else: + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: + if q_len == 1: # inference with cache + if len(attention_mask.size()) == 4: + attention_mask = attention_mask[:, :, -1:, :] + else: + attention_mask = attention_mask[:, -1:, :] + attn_weights = attn_weights + attention_mask + attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights, past_key_value + + +# Adapted from https://huggingface.co/baichuan-inc/Baichuan-7B/blob/262c8cb58b6d3615c208d9230baa869fddee2adb/modeling_baichuan.py#L181 +def _baichuan7b_attn_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. + cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] + sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + bsz, q_len, _ = hidden_states.size() + + proj = self.W_pack(hidden_states) + proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2) + query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + # [bsz, nh, t, hd] + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + if not output_attentions: + attn_weights = None + attn_output = F.scaled_dot_product_attention( + query_states, key_states, value_states, attn_mask=attention_mask, scale=1 / math.sqrt(self.head_dim) + ) + else: + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights, past_key_value + + class BaichuanModelPatcher(DecoderModelPatcher): def __init__( self, @@ -611,3 +852,922 @@ def __init__( # model has first inference buffers initialization if hasattr(self._model.lm_head, "first_flag"): self._model(torch.ones((1, 10), dtype=torch.int64), torch.ones((1, 10), dtype=torch.int64)) + + def __enter__(self): + super().__enter__() + # override signature to have position_ids + if "position_ids" not in inspect.signature(self._model.forward).parameters: + self._model._orig_forward = self._model.forward + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + position_ids: Optional[torch.LongTensor] = None, + ): + return self._orig_forward( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=past_key_values is not None, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=self.config.return_dict, + ) + + self._model.forward = types.MethodType(forward, self._model) + for layer in self._model.model.layers: + layer.self_attn._orig_forward = layer.self_attn.forward + layer.self_attn.forward = types.MethodType(_baichuan13b_atten_forward, layer.self_attn) + else: + for layer in self._model.model.layers: + layer.self_attn._orig_forward = layer.self_attn.forward + layer.self_attn.forward = types.MethodType(_baichuan7b_attn_forward, layer.self_attn) + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + if hasattr(self._model, "_orig_forward"): + self._model.forward = self._model._orig_forward + + for layer in self._model.model.layers: + if hasattr(layer.self_attn, "_orig_forward"): + layer.self_attn.forward = layer.self_attn._orig_forward + + +def _mpt_sdpa_attention_forward( + self, + hidden_states: torch.Tensor, + position_bias: torch.Tensor, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, +): + batch_size, seq_length = hidden_states.shape[:2] + + mixed_qkv = self.Wqkv(hidden_states) + query_states, key_states, value_states = mixed_qkv.chunk(3, dim=2) + query_states = query_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2) + key_states = key_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2) + value_states = value_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + if len(past_key_value) != 0: + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + past_key_value = (key_states, value_states) + else: + past_key_value = (key_states, value_states) + + key_length = key_states.shape[-2] + query_length = seq_length if past_key_value is None else seq_length + past_key_value[0].shape[2] + attention_mask_sdpa = torch.ones( + (query_states.shape[0], query_states.shape[1], query_states.shape[2], key_states.shape[2]), + dtype=query_states.dtype, + ) + if position_bias is not None: + position_bias_query_index = max(0, position_bias.size(1) - query_length) + position_bias_key_index = max(0, position_bias.size(2) - key_length) + + position_bias = position_bias[:, position_bias_query_index:, position_bias_key_index:] + attention_mask_sdpa += position_bias + attention_mask_sdpa.masked_fill_(attention_mask, torch.finfo(query_states.dtype).min) + context_states = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask_sdpa, + dropout_p=self.attn_dropout_p, + scale=self.softmax_scale, + ) + + context_states = context_states.permute(0, 2, 1, 3).contiguous().view(batch_size, seq_length, -1) + attn_output = self.out_proj(context_states) + + return attn_output, None, past_key_value + + +def _mpt_block_forward( + self, + hidden_states: torch.Tensor, + position_bias: torch.Tensor, + attention_mask: torch.Tensor, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + use_cache: bool = False, + output_attentions: bool = False, +): + # hidden_states: [batch_size, seq_length, hidden_size] + # Layer norm at the beginning of the transformer layer. + layernorm_output = self.norm_1(hidden_states) + + residual = hidden_states + + if not output_attentions: + # Self attention. + attn_outputs, attn_weights, past_key_value = self.attn( + layernorm_output, + position_bias=position_bias, + attention_mask=attention_mask, + past_key_value=layer_past, + ) + else: + attn_outputs, attn_weights, past_key_value = self.attn._orig_forward( + layernorm_output, + position_bias=position_bias, + attention_mask=attention_mask, + past_key_value=layer_past, + ) + + hidden_states = self.resid_attn_dropout(attn_outputs) + residual + + layernorm_output = self.norm_2(hidden_states) + + # Get residual + residual = hidden_states + + # MLP. + output = self.ffn(layernorm_output, residual) + outputs = (output,) + + if use_cache: + outputs += (past_key_value,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class MPTModelPatcher(DecoderModelPatcher): + def __enter__(self): + super().__enter__() + + if is_torch_version(">=", "2.1.0"): + for block in self._model.transformer.blocks: + block._orig_forward = block.forward + block.forward = types.MethodType(_mpt_block_forward, block) + block.attn._orig_forward = block.attn.forward + block.attn.forward = types.MethodType(_mpt_sdpa_attention_forward, block.attn) + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + for block in self._model.transformer.blocks: + if hasattr(block, "_orig_forward"): + block.forward = block._orig_forward + if hasattr(block.attn, "_orig_forward"): + block.attn.forward = block.attn._orig_forward + + +def _internlm2_attention_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv + from einops import rearrange + + def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors.""" + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + bsz, q_len, _ = hidden_states.size() + + qkv_states = self.wqkv(hidden_states) + + qkv_states = rearrange( + qkv_states, + "b q (h gs d) -> b q h gs d", + gs=2 + self.num_key_value_groups, + d=self.head_dim, + ) + + query_states = qkv_states[..., : self.num_key_value_groups, :] + query_states = rearrange(query_states, "b q h gs d -> b q (h gs) d") + key_states = qkv_states[..., -2, :] + value_states = qkv_states[..., -1, :] + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + if not output_attentions: + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, key_states, value_states, attention_mask, scale=(1 / math.sqrt(self.head_dim)) + ) + attn_weights = None + else: + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.wo(attn_output) + + return attn_output, attn_weights, past_key_value + + +class InternLM2Patcher(DecoderModelPatcher): + def __enter__(self): + super().__enter__() + + if is_torch_version(">=", "2.1.0"): + for block in self._model.model.layers: + block.attention._orig_forward = block.attention.forward + block.attention.forward = types.MethodType(_internlm2_attention_forward, block.attention) + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + for block in self._model.model.layers: + if hasattr(block.attention, "_orig_forward"): + block.attention.forward = block.attention._orig_forward + + +# Adapted from https://github.com/huggingface/transformers/blob/ccdabc5642bf84849af93f591e207dc625c8e1e1/src/transformers/models/phi3/modeling_phi3.py#L426 +def _phi3_self_attn_sdpa_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + return self._orig_forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + # TO DO: remove llama imports when transformers with phi3 support will be released + try: + from transformers.models.phi3.modelling_phi3 import apply_rotary_pos_emb, repeat_kv + except ImportError: + from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv + + bsz, q_len, _ = hidden_states.size() + + qkv = self.qkv_proj(hidden_states) + query_pos = self.num_heads * self.head_dim + query_states = qkv[..., :query_pos] + key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim] + value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :] + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal=self.is_causal and attention_mask is None and q_len > 1, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +class Phi3ModelPatcher(DecoderModelPatcher): + def __enter__(self): + super().__enter__() + # https://github.com/huggingface/transformers/blob/30ee508c6c92a1c0aa0281d193c7c0fb815b8d2f/src/transformers/models/phi3/modeling_phi3.py#L113 + # init inv_freq for torchscript tracing + for layer in self._model.model.layers: + if is_torch_version(">=", "2.1.0"): + orig_self_attn_fwd = layer.self_attn.forward + layer.self_attn.forward = types.MethodType(_phi3_self_attn_sdpa_forward, layer.self_attn) + layer.self_attn._orig_forward = orig_self_attn_fwd + + if layer.self_attn.rotary_emb.inv_freq is None: + rotary_emb = layer.self_attn.rotary_emb + layer.self_attn.rotary_emb.inv_freq = 1.0 / ( + rotary_emb.base ** (torch.arange(0, rotary_emb.dim, 2, dtype=torch.int64).float() / rotary_emb.dim) + ) + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + for layer in self._model.model.layers: + if hasattr(layer.self_attn, "_orig_forward"): + layer.self_attn.forward = layer.self_attn._orig_forward + + +def _aquila_self_attn_sdpa_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. + cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] + sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + if output_attentions: + return self._orig_forward( + hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache + ) + bsz, q_len, _ = hidden_states.size() + + if hasattr(self.config, "pretraining_tp") and self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, getattr(self, "num_key_value_heads", self.num_heads), self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, getattr(self, "num_key_value_heads", self.num_heads), self.head_dim + ).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + if hasattr(self, "num_key_value_groups"): + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, key_states, value_states, attention_mask, scale=(1 / math.sqrt(self.head_dim)) + ) + attn_weights = None + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + if hasattr(self.config, "pretraining_tp") and self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights, past_key_value + + +class AquilaModelPatcher(DecoderModelPatcher): + def __enter__(self): + super().__enter__() + for layer in self._model.model.layers: + if is_torch_version(">=", "2.1.0"): + orig_self_attn_fwd = layer.self_attn.forward + layer.self_attn.forward = types.MethodType(_aquila_self_attn_sdpa_forward, layer.self_attn) + layer.self_attn._orig_forward = orig_self_attn_fwd + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + for layer in self._model.model.layers: + if hasattr(layer.self_attn, "_orig_forward"): + layer.self_attn.forward = layer.self_attn._orig_forward + + +def _xverse_self_attn_sdpa_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. + cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] + sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + if output_attentions: + return self._orig_forward( + hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache + ) + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + # [bsz, nh, t, hd] + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, key_states, value_states, attention_mask, scale=(1 / math.sqrt(self.head_dim)) + ) + attn_weights = None + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights, past_key_value + + +def _internlm_self_attn_sdpa_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + cos = cos[position_ids].unsqueeze(1) + sin = sin[position_ids].unsqueeze(1) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + if output_attentions: + return self._orig_forward( + hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache + ) + + bsz, q_len, _ = hidden_states.size() + query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, key_states, value_states, attention_mask, scale=(1 / math.sqrt(self.head_dim)) + ) + attn_weights = None + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights, past_key_value + + +class XverseModelPatcher(DecoderModelPatcher): + def __enter__(self): + super().__enter__() + for layer in self._model.model.layers: + if is_torch_version(">=", "2.1.0"): + orig_self_attn_fwd = layer.self_attn.forward + layer.self_attn.forward = types.MethodType(_xverse_self_attn_sdpa_forward, layer.self_attn) + layer.self_attn._orig_forward = orig_self_attn_fwd + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + for layer in self._model.model.layers: + if hasattr(layer.self_attn, "_orig_forward"): + layer.self_attn.forward = layer.self_attn._orig_forward + + +class InternLMModelPatcher(DecoderModelPatcher): + def __enter__(self): + super().__enter__() + for layer in self._model.model.layers: + if is_torch_version(">=", "2.1.0"): + orig_self_attn_fwd = layer.self_attn.forward + layer.self_attn.forward = types.MethodType(_internlm_self_attn_sdpa_forward, layer.self_attn) + layer.self_attn._orig_forward = orig_self_attn_fwd + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + for layer in self._model.model.layers: + if hasattr(layer.self_attn, "_orig_forward"): + layer.self_attn.forward = layer.self_attn._orig_forward + + +class CodeGenModelPatcher(DecoderModelPatcher): + def __enter__(self): + super().__enter__() + + # whole codegen bettertransformer patch include attn.forward and does not cover codegen2. + # For avoiding breaking model on tracing stage, we reduce area of bettertransformer patch only for _attn. + from optimum.bettertransformer.models.attention import codegen_wrapped_scaled_dot_product + + for layer in self._model.transformer.h: + if is_torch_version(">=", "2.1.0") and not self._model.config.output_attentions: + orig_self_attn_fwd = layer.attn._attn + layer.attn._attn = types.MethodType(codegen_wrapped_scaled_dot_product, layer.attn) + layer.attn._orig_attn = orig_self_attn_fwd + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + for layer in self._model.transformer.h: + if hasattr(layer.attn, "_orig_attn"): + layer.attn._attn = layer.attn._orig_attn + + +# adapted from https://github.com/huggingface/transformers/blob/v4.40.2/src/transformers/models/dbrx/modeling_dbrx.py#L763 +def _dbrx_experts_forward( + self, x: torch.Tensor, weights: torch.Tensor, top_weights: torch.Tensor, top_experts: torch.LongTensor +): + bsz, q_len, hidden_size = x.shape + x = x.view(-1, hidden_size) + out = torch.zeros_like(x) + + expert_mask = torch.nn.functional.one_hot(top_experts, num_classes=self.moe_num_experts).permute(2, 1, 0) + # Chunk experts at once to avoid storing full parameter multiple times in autograd + w1_chunked = self.mlp.w1.view(self.mlp.moe_num_experts, self.mlp.ffn_hidden_size, self.mlp.hidden_size).chunk( + self.moe_num_experts, dim=0 + ) + v1_chunked = self.mlp.v1.view(self.mlp.moe_num_experts, self.mlp.ffn_hidden_size, self.mlp.hidden_size).chunk( + self.moe_num_experts, dim=0 + ) + w2_chunked = self.mlp.w2.view(self.mlp.moe_num_experts, self.mlp.ffn_hidden_size, self.mlp.hidden_size).chunk( + self.moe_num_experts, dim=0 + ) + w1_chunked = [w1.squeeze(dim=0) for w1 in w1_chunked] + v1_chunked = [v1.squeeze(dim=0) for v1 in v1_chunked] + w2_chunked = [w2.squeeze(dim=0) for w2 in w2_chunked] + for expert_idx in range(0, self.moe_num_experts): + topk_idx, token_idx = torch.where(expert_mask[expert_idx]) + + # Difference with original: removal + # if token_idx.shape[0] == 0: + # continue + # loop interruption depends on input data and may affect torchscript tracing + + token_list = token_idx + topk_list = topk_idx + + expert_tokens = x[None, token_list].reshape(-1, hidden_size) + expert_out = ( + self.mlp(expert_tokens, w1_chunked[expert_idx], v1_chunked[expert_idx], w2_chunked[expert_idx]) + * top_weights[token_list, topk_list, None] + ) + + out.index_add_(0, token_idx, expert_out) + + out = out.reshape(bsz, q_len, hidden_size) + return out + + +# adapted from https://github.com/huggingface/transformers/blob/v4.40.2/src/transformers/models/dbrx/modeling_dbrx.py#L1228 +def _dbrx_update_causal_mask_legacy( + self, attention_mask: Optional[torch.Tensor], input_tensor: torch.Tensor, cache_position: torch.Tensor +) -> Optional[torch.Tensor]: + from transformers.modeling_attn_mask_utils import AttentionMaskConverter + + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + dtype, device = input_tensor.dtype, input_tensor.device + # difference with original modeling + # using minimum from dtype with larger bandwith (floa32) may lead to overflow + # during execution on platforms with default lower precision (bfloat16, float16) + min_dtype = torch.finfo(torch.float16).min + sequence_length = input_tensor.shape[1] + if hasattr(self.blocks[0].norm_attn_norm.attn, "past_key_value"): # static cache + target_length = self.config.max_position_embeddings + else: # dynamic cache + target_length = ( + attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1 + ) + # difference with original modeling + # removed target_length = int(target_length). + # Casting to int leads to constant folding during tracing that makes impossible to use model for sequence of different length + causal_mask = torch.full((sequence_length, target_length), fill_value=1, dtype=dtype, device=device) * min_dtype + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.dim() == 2: + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + elif attention_mask.dim() == 4: + # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with + # cache. In that case, the 4D attention mask attends to the newest tokens only. + if attention_mask.shape[-2] < cache_position[0] + sequence_length: + offset = cache_position[0] + else: + offset = 0 + mask_shape = attention_mask.shape + mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype + causal_mask[ + : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] + ] = mask_slice + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + ): + # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400). + is_tracing = ( + torch.jit.is_tracing() + or isinstance(input_tensor, torch.fx.Proxy) + or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) + ) + if not is_tracing and torch.any(attention_mask != 1): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + +# adopted from https://github.com/huggingface/transformers/blob/1b3dba9417eebe16b7c206d1dfca6a4c7f11dbec/src/transformers/models/dbrx/modeling_dbrx.py#L1204 +def _dbrx_update_causal_mask_latest( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, +): + from transformers.modeling_attn_mask_utils import AttentionMaskConverter + + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + # difference with original modeling + # using minimum from dtype with larger bandwith (floa32) may lead to overflow + # during execution on platforms with default lower precision (bfloat16, float16) + min_dtype = torch.finfo(torch.float16).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + if attention_mask is not None and attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + causal_mask = attention_mask + else: + # difference with original modeling + causal_mask = ( + torch.full((sequence_length, target_length), fill_value=1, dtype=dtype, device=device) * min_dtype + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + +if is_transformers_version(">", "4.40.2"): + _dbrx_update_causal_mask = _dbrx_update_causal_mask_latest +else: + _dbrx_update_causal_mask = _dbrx_update_causal_mask_legacy + + +class DBRXModelPatcher(DecoderModelPatcher): + def __enter__(self): + super().__enter__() + # dbrx has some accuracy issues with bf16 with transformers >= 4.40 + # fill causal mask in slightly different way for avoid overflow on some platforms + self._model.transformer._orig_update_causal_mask = self._model.transformer._update_causal_mask + self._model.transformer._update_causal_mask = types.MethodType( + _dbrx_update_causal_mask, self._model.transformer + ) + + for block in self._model.transformer.blocks: + rotary_emb = block.norm_attn_norm.attn.rotary_emb + # initialize inv_freq for torchscript tracing + if rotary_emb.inv_freq is None: + inv_freq = 1.0 / ( + rotary_emb.base ** (torch.arange(0, rotary_emb.dim, 2, dtype=torch.int64).float() / rotary_emb.dim) + ) + rotary_emb.inv_freq = inv_freq + # remove continue-operator from iteration loop over experts + block.ffn.experts._orig_forward = block.ffn.experts.forward + block.ffn.experts.forward = types.MethodType(_dbrx_experts_forward, block.ffn.experts) + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + self._model.transformer._update_causal_mask = self._model.transformer._orig_update_causal_mask + for block in self._model.transformer.blocks: + block.ffn.experts.forward = block.ffn.experts._orig_forward diff --git a/optimum/intel/generation/modeling.py b/optimum/intel/generation/modeling.py index 3d9c657626..054ef44bfe 100644 --- a/optimum/intel/generation/modeling.py +++ b/optimum/intel/generation/modeling.py @@ -15,12 +15,14 @@ import inspect import logging import os +import warnings from pathlib import Path from tempfile import TemporaryDirectory from typing import Optional, Tuple, Union import torch from huggingface_hub import hf_hub_download +from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from transformers import AutoConfig, AutoModelForCausalLM, GenerationConfig, PretrainedConfig, PreTrainedModel from transformers.generation import GenerationMixin from transformers.modeling_outputs import CausalLMOutputWithPast @@ -179,6 +181,8 @@ def _reorder_cache( """ if self.config.model_type == "bloom": return self._reorder_cache_bloom(past_key_values, beam_idx) + elif self.config.model_type == "gpt_bigcode": + return self._reorder_cache_gpt_bigcode(past_key_values, beam_idx) # from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache return tuple( @@ -186,6 +190,13 @@ def _reorder_cache( for layer_past in past_key_values ) + # Copied from transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeForCausalLM._reorder_cache + @staticmethod + def _reorder_cache_gpt_bigcode( + past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor]]: + return tuple(layer_past.index_select(0, beam_idx.to(layer_past.device)) for layer_past in past_key_values) + # Copied from transformers.models.bloom.modeling_bloom.BloomForCausalLM._reorder_cache def _reorder_cache_bloom( self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor @@ -353,15 +364,25 @@ def _from_pretrained( cls, model_id: Union[str, Path], config: PretrainedConfig, - use_auth_token: Optional[Union[bool, str, None]] = None, - revision: Optional[Union[str, None]] = None, + use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, force_download: bool = False, - cache_dir: Optional[str] = None, + cache_dir: str = HUGGINGFACE_HUB_CACHE, file_name: Optional[str] = WEIGHTS_NAME, local_files_only: bool = False, use_cache: bool = True, **kwargs, ): + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") + token = use_auth_token + if not getattr(config, "torchscript", False): raise ValueError("`torchscript` should be set to True to load TorchScript model") @@ -375,7 +396,7 @@ def _from_pretrained( model_cache_path = hf_hub_download( repo_id=model_id, filename=file_name, - use_auth_token=use_auth_token, + token=token, revision=revision, cache_dir=cache_dir, force_download=force_download, @@ -398,22 +419,32 @@ def _from_transformers( model_id: str, config: PretrainedConfig, use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, force_download: bool = False, - cache_dir: Optional[str] = None, + cache_dir: str = HUGGINGFACE_HUB_CACHE, subfolder: str = "", local_files_only: bool = False, use_cache: bool = True, torch_dtype: Optional[Union[str, "torch.dtype"]] = None, **kwargs, ): + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") + token = use_auth_token + if is_torch_version("<", "2.1.0"): raise ImportError("`torch>=2.0.0` is needed to trace your model") task = cls.export_feature model_kwargs = { "revision": revision, - "use_auth_token": use_auth_token, + "token": token, "cache_dir": cache_dir, "subfolder": subfolder, "local_files_only": local_files_only, @@ -435,7 +466,7 @@ def _from_transformers( model_id=save_dir_path, config=config, use_cache=use_cache, - use_auth_token=use_auth_token, + token=token, revision=revision, force_download=force_download, cache_dir=cache_dir, diff --git a/optimum/intel/ipex/inference.py b/optimum/intel/ipex/inference.py index ccf2da9d80..a628ebe12e 100644 --- a/optimum/intel/ipex/inference.py +++ b/optimum/intel/ipex/inference.py @@ -97,6 +97,10 @@ def __init__( jit (`boolean = False`, *optional*): Enable jit to accelerate inference speed """ + logger.warning( + "`inference_mode` is deprecated and will be removed in v1.18.0. Use `pipeline` to load and export your model to TorchScript instead." + ) + if not is_ipex_available(): raise ImportError(IPEX_NOT_AVAILABLE_ERROR_MSG) diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index 8a7a4f2028..e929a4ddb8 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -15,6 +15,7 @@ import logging import os +import warnings from pathlib import Path from tempfile import TemporaryDirectory from typing import Optional, Tuple, Union @@ -22,6 +23,7 @@ import intel_extension_for_pytorch as ipex import torch from huggingface_hub import hf_hub_download +from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from intel_extension_for_pytorch.cpu._auto_kernel_selection import _enable_tpp from intel_extension_for_pytorch.transformers.optimize import get_dummy_input from transformers import ( @@ -37,6 +39,7 @@ GenerationConfig, GenerationMixin, PretrainedConfig, + is_torch_xpu_available, ) from transformers.dynamic_module_utils import get_class_from_dynamic_module from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput @@ -50,7 +53,7 @@ from ...exporters.ipex.model_patcher import _IPEX_EXPORTED_TASK, _patch_model from ..generation.modeling import prepare_jit_inputs from ..utils.import_utils import is_ipex_version, is_torch_version, is_transformers_version -from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, patch_decoder_attention_mask +from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, patch_decoder_attention_mask, recursive_to_device logger = logging.getLogger(__name__) @@ -126,10 +129,14 @@ def __init__( **kwargs, ): OptimizedModel.__init__(self, model=model, config=config) - # To do: add XPU support - self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - self._dtype = self.config.torch_dtype if self.config.torch_dtype is not None else torch.float32 + if is_torch_xpu_available(check_device=True): + self._device = torch.device("xpu:0") + elif torch.cuda.is_available(): + self._device = torch.device("cuda:0") + else: + self._device = torch.device("cpu") self.model.to(self._device) + self._dtype = self.config.torch_dtype if self.config.torch_dtype is not None else torch.float32 self.model_save_dir = model_save_dir self._is_ipex_exported = _is_patched_with_ipex(model, self.export_feature) @@ -151,27 +158,41 @@ def _from_transformers( config: PretrainedConfig, use_cache: bool = True, use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, force_download: bool = False, - cache_dir: Optional[str] = None, + cache_dir: str = HUGGINGFACE_HUB_CACHE, subfolder: str = "", local_files_only: bool = False, torch_dtype: Optional[Union[str, "torch.dtype"]] = None, trust_remote_code: bool = False, + _commit_hash: str = None, ): + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if token is not None: + raise ValueError( + "Both the arguments `use_auth_token` and `token` were specified, which is not supported. Please specify only `token`." + ) + token = use_auth_token + if is_torch_version("<", "2.1.0"): raise ImportError("`torch>=2.0.0` is needed to trace your model") task = cls.export_feature model_kwargs = { "revision": revision, - "use_auth_token": use_auth_token, + "token": token, "cache_dir": cache_dir, "subfolder": subfolder, "local_files_only": local_files_only, "force_download": force_download, "torch_dtype": torch_dtype, "trust_remote_code": trust_remote_code, + "_commit_hash": _commit_hash, } model = TasksManager.get_model_from_task(task, model_id, **model_kwargs) @@ -187,15 +208,27 @@ def _from_pretrained( cls, model_id: Union[str, Path], config: PretrainedConfig, - use_auth_token: Optional[Union[bool, str, None]] = None, - revision: Optional[Union[str, None]] = None, + use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, force_download: bool = False, - cache_dir: Optional[str] = None, + cache_dir: str = HUGGINGFACE_HUB_CACHE, file_name: Optional[str] = WEIGHTS_NAME, local_files_only: bool = False, subfolder: str = "", **kwargs, ): + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if token is not None: + raise ValueError( + "Both the arguments `use_auth_token` and `token` were specified, which is not supported. Please specify only `token`." + ) + token = use_auth_token + if not getattr(config, "torchscript", False): raise ValueError( "`config.torchscript` should be set to `True`, if your model is not a TorchScript model and needs to be traced please set `export=True` when loading it with `.from_pretrained()`" @@ -210,7 +243,7 @@ def _from_pretrained( model_cache_path = hf_hub_download( repo_id=model_id, filename=file_name, - use_auth_token=use_auth_token, + token=token, revision=revision, cache_dir=cache_dir, force_download=force_download, @@ -293,6 +326,8 @@ def _init_warmup(self): if not self._is_ipex_exported: use_cache = "past_key_values" in self.input_names dummy_inputs = prepare_jit_inputs(self, self.export_feature, use_cache) + if self._device.type != "cpu": + dummy_inputs = recursive_to_device(value=dummy_inputs, device=self._device) for _ in range(2): self(**dummy_inputs) diff --git a/optimum/intel/neural_compressor/__init__.py b/optimum/intel/neural_compressor/__init__.py index 2daecfbc93..a7170120b7 100644 --- a/optimum/intel/neural_compressor/__init__.py +++ b/optimum/intel/neural_compressor/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ..utils.import_utils import is_diffusers_available, is_intel_extension_for_transformers_available +from ..utils.import_utils import is_diffusers_available from .configuration import INCConfig from .modeling_base import ( INCModel, diff --git a/optimum/intel/neural_compressor/modeling_base.py b/optimum/intel/neural_compressor/modeling_base.py index c46e3f41c5..bb3d2fe8c8 100644 --- a/optimum/intel/neural_compressor/modeling_base.py +++ b/optimum/intel/neural_compressor/modeling_base.py @@ -14,12 +14,15 @@ import logging import os +import warnings from pathlib import Path from tempfile import TemporaryDirectory from typing import Dict, Optional, Union import torch from huggingface_hub import hf_hub_download +from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE +from huggingface_hub.utils import EntryNotFoundError from neural_compressor.utils.pytorch import load from transformers import ( AutoConfig, @@ -38,18 +41,15 @@ ) from transformers.modeling_utils import no_init_weights from transformers.models.auto.auto_factory import _get_model_class +from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME from transformers.utils.generic import ContextManagers from optimum.intel.generation import BaseModelForCausalLM from ...modeling_base import OptimizedModel -from ..utils.import_utils import ( - _torch_version, - is_intel_extension_for_transformers_available, - is_torch_version, -) +from ..utils.import_utils import _torch_version, is_itrex_available, is_torch_version from .configuration import INCConfig -from .utils import WEIGHTS_NAME +from .utils import QUANTIZATION_CONFIG_NAME logger = logging.getLogger(__name__) @@ -101,66 +101,117 @@ def _from_pretrained( cls, model_id: Union[str, Path], config: PretrainedConfig, - use_auth_token: Optional[Union[bool, str, None]] = None, - revision: Optional[Union[str, None]] = None, + use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, force_download: bool = False, - cache_dir: Optional[str] = None, + cache_dir: str = HUGGINGFACE_HUB_CACHE, file_name: str = WEIGHTS_NAME, local_files_only: bool = False, subfolder: str = "", trust_remote_code: bool = False, **kwargs, ): - model_name_or_path = kwargs.pop("model_name_or_path", None) - if model_name_or_path is not None: - logger.warning("`model_name_or_path` is deprecated please use `model_id`") - model_id = model_id or model_name_or_path - - model_path = Path(model_id) - - if model_path.is_dir(): - model_cache_path = model_path / file_name - else: - model_cache_path = hf_hub_download( - repo_id=model_id, - filename=file_name, - subfolder=subfolder, - use_auth_token=use_auth_token, - revision=revision, - cache_dir=cache_dir, - force_download=force_download, - local_files_only=local_files_only, + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, ) + if token is not None: + raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") + token = use_auth_token - model_save_dir = Path(model_cache_path).parent + model_path = Path(model_id) + is_local = model_path.is_dir() + model_cache_path = None inc_config = None msg = None - if is_intel_extension_for_transformers_available(): + if is_local: + if (model_path / subfolder / SAFE_WEIGHTS_NAME).is_file(): + file_name = SAFE_WEIGHTS_NAME + elif not (model_path / subfolder / file_name).is_file(): + raise EnvironmentError( + f"Error no file named {SAFE_WEIGHTS_NAME} or {file_name} found in directory {model_path / subfolder}" + ) + model_cache_path = model_path / subfolder / file_name + else: + # Try download safetensors if exist try: - quantization_config = PretrainedConfig.from_pretrained(model_save_dir / "quantize_config.json") + model_cache_path = hf_hub_download( + repo_id=model_id, + filename=SAFE_WEIGHTS_NAME, + subfolder=subfolder, + token=token, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + local_files_only=local_files_only, + ) + except EntryNotFoundError: + pass + + if model_cache_path is None: + model_cache_path = hf_hub_download( + repo_id=model_id, + filename=file_name, + subfolder=subfolder, + token=token, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + local_files_only=local_files_only, + ) + + model_save_dir = Path(model_cache_path).parent + + if is_itrex_available(): + quantization_config_path = None + if is_local: + quantization_config_path = model_path / subfolder / QUANTIZATION_CONFIG_NAME + else: + try: + quantization_config_path = hf_hub_download( + repo_id=model_id, + filename=QUANTIZATION_CONFIG_NAME, + subfolder=subfolder, + token=token, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + local_files_only=local_files_only, + ) + except EntryNotFoundError: + pass + + if quantization_config_path and Path(quantization_config_path).is_file(): + quantization_config = PretrainedConfig.from_pretrained(quantization_config_path) algorithm = getattr(quantization_config, "quant_method", None) - if algorithm in {"rtn", "gptq", "awq", "autoaround"}: + if algorithm in {"rtn", "gptq", "awq", "autoround"}: from intel_extension_for_transformers.transformers.modeling.modeling_auto import ( _BaseQBitsAutoModelClass, ) _BaseQBitsAutoModelClass.ORIG_MODEL = cls.auto_model_class - return _BaseQBitsAutoModelClass.from_pretrained( + model = _BaseQBitsAutoModelClass.from_pretrained( pretrained_model_name_or_path=model_id, - use_auth_token=use_auth_token, + token=token, revision=revision, force_download=force_download, cache_dir=cache_dir, local_files_only=local_files_only, subfolder=subfolder, trust_remote_code=trust_remote_code, + use_neural_speed=False, **kwargs, ) - except EnvironmentError: - msg = "The model is not quantized with weight-only quantization." + + return cls( + model, config=config, model_save_dir=model_save_dir, q_config=quantization_config, **kwargs + ) + try: - inc_config = INCConfig.from_pretrained(model_id) + inc_config = INCConfig.from_pretrained(model_id, subfolder=subfolder, revision=revision) if not is_torch_version("==", inc_config.torch_version): msg = f"Quantized model was obtained with torch version {inc_config.torch_version} but {_torch_version} was found." logger.warning(f"{msg}") @@ -201,15 +252,19 @@ def _from_pretrained( ) def _save_pretrained(self, save_directory: Union[str, Path]): - output_path = os.path.join(save_directory, WEIGHTS_NAME) - if isinstance(self.model, torch.nn.Module): - state_dict = self.model.state_dict() - if self._q_config: - state_dict["best_configure"] = self._q_config - torch.save(state_dict, output_path) + # For ITREX model + if isinstance(self._q_config, PretrainedConfig): + self._q_config.to_json_file(os.path.join(save_directory, QUANTIZATION_CONFIG_NAME)) + self.model.save_pretrained(save_directory) + # For INC model the state dictionary needs to be modified to include the quantization parameters + else: + state_dict = self.model.state_dict() + if isinstance(self._q_config, dict): + state_dict["best_configure"] = self._q_config + torch.save(state_dict, os.path.join(save_directory, WEIGHTS_NAME)) else: - torch.jit.save(self.model, output_path) + torch.jit.save(self.model, os.path.join(save_directory, WEIGHTS_NAME)) if self.inc_config: self.inc_config.save_pretrained(save_directory) diff --git a/optimum/intel/neural_compressor/quantization.py b/optimum/intel/neural_compressor/quantization.py index 09f651df05..57bc3ae7a1 100644 --- a/optimum/intel/neural_compressor/quantization.py +++ b/optimum/intel/neural_compressor/quantization.py @@ -16,14 +16,14 @@ import inspect import logging import types +import warnings from enum import Enum from itertools import chain from pathlib import Path -from typing import Callable, Dict, Optional, Union +from typing import Callable, Optional, Union import torch from datasets import Dataset, load_dataset -from neural_compressor.adaptor.pytorch import PyTorch_FXAdaptor, _cfg_to_qconfig, _propagate_qconfig from neural_compressor.config import PostTrainingQuantConfig from neural_compressor.experimental.export import torch_to_int8_onnx from neural_compressor.model.onnx_model import ONNXModel @@ -47,14 +47,14 @@ from ..utils.constant import _TASK_ALIASES, MIN_QDQ_ONNX_OPSET, ONNX_WEIGHTS_NAME, WEIGHTS_NAME from ..utils.import_utils import ( - INTEL_EXTENSION_FOR_TRANSFORMERS_IMPORT_ERROR, - _intel_extension_for_transformers_version, + ITREX_IMPORT_ERROR, _ipex_version, + _itrex_version, _neural_compressor_version, _torch_version, - is_intel_extension_for_transformers_available, - is_intel_extension_for_transformers_version, is_ipex_version, + is_itrex_available, + is_itrex_version, is_neural_compressor_version, is_torch_version, ) @@ -69,17 +69,25 @@ INCModelForTokenClassification, INCModelForVision2Seq, ) -from .utils import INCDataLoader, _cfgs_to_fx_cfgs +from .utils import ( + IPEX_MINIMUM_VERSION, + ITREX_MINIMUM_TORCH_VERSION, + ITREX_MINIMUM_VERSION, + NEURAL_COMPRESSOR_MINIMUM_VERSION, + NEURAL_COMPRESSOR_WEIGHT_ONLY_MINIMUM_VERSION, + INCDataLoader, +) -INTEL_EXTENSION_FOR_TRANSFORMERS_MINIMUM_VERSION = "1.4.0" +_ITREX_EXCLUDED_VERSION = "1.4.2" -if is_intel_extension_for_transformers_available(): - if is_intel_extension_for_transformers_version("!=", INTEL_EXTENSION_FOR_TRANSFORMERS_MINIMUM_VERSION): +if is_itrex_available(): + if is_itrex_version("<", ITREX_MINIMUM_VERSION): raise ImportError( - f"Found an incompatible version of `intel-extension-for-transformers`. Found version {_intel_extension_for_transformers_version}, " - f"but only version {INTEL_EXTENSION_FOR_TRANSFORMERS_MINIMUM_VERSION} is supported." + f"Found an incompatible version of `intel-extension-for-transformers`. Found version {_itrex_version}, " + f"but only version {ITREX_MINIMUM_VERSION} or higher is supported." ) + from intel_extension_for_transformers.transformers.llm.quantization.utils import convert_to_quantized_model from intel_extension_for_transformers.transformers.modeling.modeling_auto import save_low_bit from intel_extension_for_transformers.transformers.utils.config import ( @@ -92,10 +100,6 @@ logger = logging.getLogger(__name__) -NEURAL_COMPRESSOR_MINIMUM_VERSION = "2.1.0" -NEURAL_COMPRESSOR_WEIGHT_ONLY_MINIMUM_VERSION = "2.3.0" -IPEX_MINIMUM_VERSION = "2.1.0" -ITREX_MINIMUM_TORCH_VERSION = "2.2.0" if is_neural_compressor_version("<", NEURAL_COMPRESSOR_MINIMUM_VERSION): raise ImportError( @@ -225,14 +229,20 @@ def quantize( # ITREX Weight Only Quantization if not isinstance(quantization_config, PostTrainingQuantConfig): + if is_itrex_version("==", _ITREX_EXCLUDED_VERSION): + raise ImportError( + f"Found an incompatible version of `intel-extension-for-transformers`. Found version {_itrex_version}, " + f"but {_ITREX_EXCLUDED_VERSION} is not compatible." + ) + # check neural-compressor version if is_neural_compressor_version("<", NEURAL_COMPRESSOR_WEIGHT_ONLY_MINIMUM_VERSION): raise ImportError( f"Found an incompatible version of neural-compressor. Found version {_neural_compressor_version}, " f"but only version {NEURAL_COMPRESSOR_WEIGHT_ONLY_MINIMUM_VERSION} or higher supports weight-only quantization." ) - if not is_intel_extension_for_transformers_available(): - raise ImportError(INTEL_EXTENSION_FOR_TRANSFORMERS_IMPORT_ERROR.format("Weight only quantization")) + if not is_itrex_available(): + raise ImportError(ITREX_IMPORT_ERROR.format("Weight only quantization")) if is_torch_version("<", ITREX_MINIMUM_TORCH_VERSION): raise ImportError( @@ -446,7 +456,8 @@ def get_calibration_dataset( dataset_split: str = "train", preprocess_function: Optional[Callable] = None, preprocess_batch: bool = True, - use_auth_token: bool = False, + use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, ) -> Dataset: """ Create the calibration `datasets.Dataset` to use for the post-training static quantization calibration step. @@ -465,16 +476,28 @@ def get_calibration_dataset( Processing function to apply to each example after loading dataset. preprocess_batch (`bool`, defaults to `True`): Whether the `preprocess_function` should be batched. - use_auth_token (`bool`, defaults to `False`): - Whether to use the token generated when running `transformers-cli login`. + use_auth_token (Optional[Union[bool, str]], defaults to `None`): + Deprecated. Please use `token` instead. + token (Optional[Union[bool, str]], defaults to `None`): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `huggingface-cli login` (stored in `~/.huggingface`). Returns: The calibration `datasets.Dataset` to use for the post-training static quantization calibration step. """ + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") + token = use_auth_token + calibration_dataset = load_dataset( dataset_name, name=dataset_config_name, split=dataset_split, - use_auth_token=use_auth_token, + token=token, ) if num_samples is not None: @@ -514,70 +537,3 @@ def _get_calibration_dataloader( def _remove_unused_columns(self, dataset: Dataset): ignored_columns = list(set(dataset.column_names) - set(self._signature_columns)) return dataset.remove_columns(ignored_columns) - - -# Adapted from https://github.com/intel/neural-compressor/blob/master/neural_compressor/utils/pytorch.py#L96 -def _apply_quantization_from_config(q_config: Dict, model: torch.nn.Module) -> torch.nn.Module: - """ - Apply Intel Neural Compressor quantization steps on the given model. - - Arguments: - q_config (`Dict`): - Dictionary containing all quantization information such as approach, dtype, scheme and granularity. - model (`torch.nn.Module`): - Model to quantize. - Returns: - q_model (`torch.nn.Module`): - Quantized model. - """ - from torch.quantization import add_observer_, convert - from torch.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx - - approach = q_config.get("approach") - framework = q_config.get("framework") - - if approach not in SUPPORTED_QUANT_MODE: - raise ValueError( - "Unknown quantization approach. Supported approach are " + ", ".join(SUPPORTED_QUANT_MODE.keys()) - ) - - quant_mode = INCQuantizationMode(approach) - q_model = copy.deepcopy(model) - q_model.eval() - - if framework == "pytorch_fx": - op_cfgs = _cfg_to_qconfig(q_config, approach) - fx_op_cfgs = _cfgs_to_fx_cfgs(op_cfgs, approach) - - if not q_config["fx_sub_module_list"]: - if quant_mode == INCQuantizationMode.AWARE_TRAINING: - q_model.train() - q_model = prepare_qat_fx(q_model, fx_op_cfgs) - else: - q_model = prepare_fx(q_model, fx_op_cfgs) - q_model = convert_fx(q_model) - - else: - sub_module_list = q_config["fx_sub_module_list"] - if q_config["approach"] == "quant_aware_training": - q_model.train() - PyTorch_FXAdaptor.prepare_sub_graph(sub_module_list, fx_op_cfgs, q_model, prefix="", is_qat=True) - else: - PyTorch_FXAdaptor.prepare_sub_graph(sub_module_list, fx_op_cfgs, q_model, prefix="") - PyTorch_FXAdaptor.convert_sub_graph(sub_module_list, q_model, prefix="") - - else: - if quant_mode == INCQuantizationMode.DYNAMIC: - q_mapping = torch.quantization.quantization_mappings.get_default_dynamic_quant_module_mappings() - op_cfgs = _cfg_to_qconfig(q_config, approach) - else: - q_mapping = torch.quantization.quantization_mappings.get_default_static_quant_module_mappings() - op_cfgs = _cfg_to_qconfig(q_config) - - _propagate_qconfig(q_model, op_cfgs, approach=approach) - - if quant_mode != INCQuantizationMode.DYNAMIC: - add_observer_(q_model) - q_model = convert(q_model, mapping=q_mapping, inplace=True) - - return q_model diff --git a/optimum/intel/neural_compressor/utils.py b/optimum/intel/neural_compressor/utils.py index 3e36065195..84c1d6dc29 100644 --- a/optimum/intel/neural_compressor/utils.py +++ b/optimum/intel/neural_compressor/utils.py @@ -16,11 +16,9 @@ import os import warnings from collections import UserDict -from typing import Dict import torch from neural_compressor.utils.pytorch import load -from packaging import version from torch.utils.data import DataLoader from ..utils.constant import WEIGHTS_NAME @@ -30,6 +28,13 @@ CONFIG_NAME = "best_configure.yaml" +QUANTIZATION_CONFIG_NAME = "quantize_config.json" + +NEURAL_COMPRESSOR_MINIMUM_VERSION = "2.1.0" +NEURAL_COMPRESSOR_WEIGHT_ONLY_MINIMUM_VERSION = "2.3.0" +IPEX_MINIMUM_VERSION = "2.1.0" +ITREX_MINIMUM_VERSION = "1.4.0" +ITREX_MINIMUM_TORCH_VERSION = "2.2.0" _HEAD_TO_AUTOMODELS = { @@ -45,10 +50,6 @@ } -parsed_torch_version_base = version.parse(version.parse(torch.__version__).base_version) -is_torch_less_than_1_13 = parsed_torch_version_base < version.parse("1.13.0") - - class INCDataLoader(DataLoader): use_label = True @@ -73,44 +74,6 @@ def __iter__(self): yield input -def _cfgs_to_fx_cfgs(op_cfgs: Dict, observer_type: str = "post_training_static_quant") -> Dict: - """Inc function which convert a quantization config to a format that meets the requirements of torch.fx. - - Arguments: - op_cfgs (`dict`): - Dictionary of quantization configure for each op. - observer_type (`str`): - Specify observer type. - Returns: - fx_op_cfgs (`dict`): - Dictionary of quantization configure that meets the requirements of torch.fx. - """ - if not is_torch_less_than_1_13: - from torch.ao.quantization import QConfigMapping - - fx_op_cfgs = QConfigMapping() - else: - fx_op_cfgs = {} - op_tuple_cfg_list = [] - for key, value in op_cfgs.items(): - if key == "default_qconfig": - if not is_torch_less_than_1_13: - fx_op_cfgs.set_global(value) - else: - fx_op_cfgs[""] = value - continue - if not is_torch_less_than_1_13: - fx_op_cfgs.set_module_name(key, value) - else: - op_tuple = (key, value) - op_tuple_cfg_list.append(op_tuple) - - if is_torch_less_than_1_13: - fx_op_cfgs["module_name"] = op_tuple_cfg_list - - return fx_op_cfgs - - def load_quantized_model(checkpoint_dir_or_file: str, model: torch.nn.Module, **kwargs) -> torch.nn.Module: """ Returns the quantized model, which was quantized through neural_compressor. diff --git a/optimum/intel/openvino/configuration.py b/optimum/intel/openvino/configuration.py index 739b5640a6..30dfe5ae6f 100644 --- a/optimum/intel/openvino/configuration.py +++ b/optimum/intel/openvino/configuration.py @@ -57,6 +57,7 @@ class OVQuantizationMethod(str, Enum): DEFAULT = "default" + HYBRID = "hybrid" @dataclass @@ -133,7 +134,7 @@ class OVWeightQuantizationConfig(OVQuantizationConfigBase): using the [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`. dataset (`str or List[str]`, *optional*): The dataset used for data-aware compression or quantization with NNCF. You can provide your own dataset - in a list of strings or just use the one from the list ['wikitext','c4','c4-new','ptb','ptb-new'] for LLLMs + in a list of strings or just use the one from the list ['wikitext2','c4','c4-new','ptb','ptb-new'] for LLLMs or ['conceptual_captions','laion/220k-GPT4Vision-captions-from-LIVIS','laion/filtered-wit'] for diffusion models. Alternatively, you can provide data objects via `calibration_dataset` argument of `OVQuantizer.quantize()` method. @@ -194,7 +195,7 @@ def post_init(self): f"If you wish to provide a custom dataset, please use the `OVQuantizer` instead." ) if self.dataset is not None and isinstance(self.dataset, str): - llm_datasets = ["wikitext", "c4", "c4-new", "ptb", "ptb-new"] + llm_datasets = ["wikitext2", "c4", "c4-new", "ptb", "ptb-new"] stable_diffusion_datasets = [ "conceptual_captions", "laion/220k-GPT4Vision-captions-from-LIVIS", @@ -310,7 +311,9 @@ def __init__( if isinstance(quantization_config, dict): quantization_config = self._quantization_config_from_dict(quantization_config) self.quantization_config = quantization_config - self.compression = None # A field for backward-compatability of training-time compression parameters + self.compression = kwargs.get( + "compression", None + ) # A field for backward-compatability of training-time compression parameters bits = self.quantization_config.bits if self.quantization_config else None self.dtype = "int" + str(bits) if isinstance(bits, int) else dtype diff --git a/optimum/intel/openvino/loaders.py b/optimum/intel/openvino/loaders.py index 61d5755cfa..fc5ae97495 100644 --- a/optimum/intel/openvino/loaders.py +++ b/optimum/intel/openvino/loaders.py @@ -13,6 +13,7 @@ # limitations under the License. import logging +import warnings from typing import Dict, List, Optional, Union import torch @@ -25,7 +26,7 @@ import safetensors import openvino -from huggingface_hub.constants import HF_HUB_OFFLINE +from huggingface_hub.constants import HF_HUB_OFFLINE, HUGGINGFACE_HUB_CACHE from openvino.runtime import Type from openvino.runtime import opset11 as ops from openvino.runtime.passes import Manager, Matcher, MatcherPass, WrapType @@ -37,7 +38,7 @@ try: from diffusers.utils import DIFFUSERS_CACHE except ImportError: - DIFFUSERS_CACHE = None + DIFFUSERS_CACHE = HUGGINGFACE_HUB_CACHE logger = logging.getLogger(__name__) @@ -188,9 +189,11 @@ def load_textual_inversion( local_files_only (`bool`, *optional*, defaults to `False`): Whether to only load local model weights and configuration files or not. If set to `True`, the model won't be downloaded from the Hub. - use_auth_token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from - `diffusers-cli login` (stored in `~/.huggingface`) is used. + use_auth_token (Optional[Union[bool, str]], defaults to `None`): + Deprecated. Please use `token` instead. + token (Optional[Union[bool, str]], defaults to `None`): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `huggingface-cli login` (stored in `~/.huggingface`). revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier allowed by Git. @@ -258,11 +261,21 @@ def load_textual_inversion( proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) use_auth_token = kwargs.pop("use_auth_token", None) + token = kwargs.pop("token", None) revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") + token = use_auth_token + if use_safetensors and not is_safetensors_available(): raise ValueError( "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors" @@ -319,7 +332,7 @@ def load_textual_inversion( resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, - use_auth_token=use_auth_token, + use_auth_token=token, # still uses use_auth_token revision=revision, subfolder=subfolder, user_agent=user_agent, @@ -340,7 +353,7 @@ def load_textual_inversion( resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, - use_auth_token=use_auth_token, + use_auth_token=token, # still uses use_auth_token revision=revision, subfolder=subfolder, user_agent=user_agent, diff --git a/optimum/intel/openvino/modeling.py b/optimum/intel/openvino/modeling.py index 9c7c2b5258..1c907f2135 100644 --- a/optimum/intel/openvino/modeling.py +++ b/optimum/intel/openvino/modeling.py @@ -14,6 +14,7 @@ import logging import os +import warnings from pathlib import Path from tempfile import TemporaryDirectory from typing import Dict, Optional, Union @@ -23,6 +24,7 @@ import torch import transformers from huggingface_hub import model_info +from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from transformers import ( AutoConfig, AutoModel, @@ -421,9 +423,10 @@ def _from_transformers( model_id: str, config: PretrainedConfig, use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, force_download: bool = False, - cache_dir: Optional[str] = None, + cache_dir: str = HUGGINGFACE_HUB_CACHE, subfolder: str = "", local_files_only: bool = False, task: Optional[str] = None, @@ -432,6 +435,15 @@ def _from_transformers( quantization_config: Union[OVWeightQuantizationConfig, Dict] = None, **kwargs, ): + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") + token = use_auth_token + save_dir = TemporaryDirectory() save_dir_path = Path(save_dir.name) @@ -449,7 +461,7 @@ def _from_transformers( subfolder=subfolder, revision=revision, cache_dir=cache_dir, - use_auth_token=use_auth_token, + token=token, local_files_only=local_files_only, force_download=force_download, trust_remote_code=trust_remote_code, @@ -583,15 +595,25 @@ def from_pretrained( export: bool = False, config: Optional["PretrainedConfig"] = None, use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, force_download: bool = False, - cache_dir: Optional[str] = None, + cache_dir: str = HUGGINGFACE_HUB_CACHE, subfolder: str = "", local_files_only: bool = False, task: Optional[str] = None, trust_remote_code: bool = False, **kwargs, ): + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") + token = use_auth_token + # Fix the mismatch between timm_config and huggingface_config local_timm_model = _is_timm_ov_dir(model_id) if local_timm_model or (not os.path.isdir(model_id) and model_info(model_id).library_name == "timm"): @@ -620,7 +642,7 @@ def from_pretrained( model_id=model_id, config=config, export=export, - use_auth_token=use_auth_token, + token=token, revision=revision, force_download=force_download, cache_dir=cache_dir, diff --git a/optimum/intel/openvino/modeling_base.py b/optimum/intel/openvino/modeling_base.py index a48cdf5c92..7937deea52 100644 --- a/optimum/intel/openvino/modeling_base.py +++ b/optimum/intel/openvino/modeling_base.py @@ -14,12 +14,14 @@ import logging import os +import warnings from pathlib import Path from tempfile import TemporaryDirectory, gettempdir from typing import Dict, Optional, Union import openvino from huggingface_hub import hf_hub_download +from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from openvino import Core, convert_model from openvino._offline_transformations import apply_moc_transformations, compress_model_transformation from transformers import GenerationConfig, PretrainedConfig @@ -168,10 +170,11 @@ def _from_pretrained( cls, model_id: Union[str, Path], config: PretrainedConfig, - use_auth_token: Optional[Union[bool, str, None]] = None, - revision: Optional[Union[str, None]] = None, + use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, force_download: bool = False, - cache_dir: Optional[str] = None, + cache_dir: str = HUGGINGFACE_HUB_CACHE, file_name: Optional[str] = None, subfolder: str = "", from_onnx: bool = False, @@ -189,9 +192,11 @@ def _from_pretrained( Can be either: - The model id of a pretrained model hosted inside a model repo on huggingface.co. - The path to a directory containing the model weights. - use_auth_token (`str` or `bool`): - The token to use as HTTP bearer authorization for remote files. Needed to load models from a private - repository. + use_auth_token (Optional[Union[bool, str]], defaults to `None`): + Deprecated. Please use `token` instead. + token (Optional[Union[bool, str]], defaults to `None`): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `huggingface-cli login` (stored in `~/.huggingface`). revision (`str`, *optional*): The specific model version to use. It can be a branch name, a tag name, or a commit id. cache_dir (`Union[str, Path]`, *optional*): @@ -208,13 +213,22 @@ def _from_pretrained( load_in_8bit (`bool`, *optional*, defaults to `False`): Whether or not to apply 8-bit weight quantization. """ + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") + token = use_auth_token + model_path = Path(model_id) default_file_name = ONNX_WEIGHTS_NAME if from_onnx else OV_XML_FILE_NAME file_name = file_name or default_file_name model_cache_path = cls._cached_file( model_path=model_path, - use_auth_token=use_auth_token, + token=token, revision=revision, force_download=force_download, cache_dir=cache_dir, @@ -260,6 +274,7 @@ def _set_ov_config_parameters(self): def _cached_file( model_path: Union[Path, str], use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, force_download: bool = False, cache_dir: Optional[str] = None, @@ -267,6 +282,15 @@ def _cached_file( subfolder: str = "", local_files_only: bool = False, ): + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") + token = use_auth_token + # locates a file in a local folder and repo, downloads and cache it if necessary. model_path = Path(model_path) if model_path.is_dir(): @@ -282,7 +306,7 @@ def _cached_file( repo_id=model_path.as_posix(), filename=file_name.as_posix(), subfolder=subfolder, - use_auth_token=use_auth_token, + token=token, revision=revision, cache_dir=cache_dir, force_download=force_download, @@ -298,9 +322,10 @@ def _from_transformers( model_id: str, config: PretrainedConfig, use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, force_download: bool = False, - cache_dir: Optional[str] = None, + cache_dir: str = HUGGINGFACE_HUB_CACHE, subfolder: str = "", local_files_only: bool = False, task: Optional[str] = None, @@ -320,13 +345,25 @@ def _from_transformers( - The path to a directory containing the model weights. save_dir (`str` or `Path`): The directory where the exported ONNX model should be saved, default to `transformers.file_utils.default_cache_path`, which is the cache directory for transformers. - use_auth_token (`str` or `bool`): - Is needed to load models from a private repository + use_auth_token (`Optional[str]`, defaults to `None`): + Deprecated. Please use `token` instead. + token (Optional[Union[bool, str]], defaults to `None`): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `huggingface-cli login` (stored in `~/.huggingface`). revision (`str`): Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id kwargs (`Dict`, *optional*): kwargs will be passed to the model during initialization """ + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") + token = use_auth_token + save_dir = TemporaryDirectory() save_dir_path = Path(save_dir.name) @@ -343,7 +380,7 @@ def _from_transformers( subfolder=subfolder, revision=revision, cache_dir=cache_dir, - use_auth_token=use_auth_token, + token=token, local_files_only=local_files_only, force_download=force_download, trust_remote_code=trust_remote_code, @@ -366,13 +403,23 @@ def _to_load( config: PretrainedConfig, onnx_config: OnnxConfig, use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, force_download: bool = False, - cache_dir: Optional[str] = None, + cache_dir: str = HUGGINGFACE_HUB_CACHE, local_files_only: bool = False, stateful: bool = False, **kwargs, ): + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") + token = use_auth_token + save_dir = TemporaryDirectory() save_dir_path = Path(save_dir.name) @@ -389,7 +436,7 @@ def _to_load( model_id=save_dir_path, config=config, from_onnx=False, - use_auth_token=use_auth_token, + token=token, revision=revision, force_download=force_download, cache_dir=cache_dir, diff --git a/optimum/intel/openvino/modeling_base_seq2seq.py b/optimum/intel/openvino/modeling_base_seq2seq.py index 78648e93d2..fb53f9b2e2 100644 --- a/optimum/intel/openvino/modeling_base_seq2seq.py +++ b/optimum/intel/openvino/modeling_base_seq2seq.py @@ -14,12 +14,14 @@ import logging import os +import warnings from pathlib import Path from tempfile import TemporaryDirectory from typing import Dict, Optional, Union import openvino from huggingface_hub import hf_hub_download +from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from openvino._offline_transformations import apply_moc_transformations, compress_model_transformation from transformers import GenerationConfig, PretrainedConfig from transformers.file_utils import add_start_docstrings @@ -109,9 +111,10 @@ def _from_pretrained( model_id: Union[str, Path], config: PretrainedConfig, use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, force_download: bool = False, - cache_dir: Optional[str] = None, + cache_dir: str = HUGGINGFACE_HUB_CACHE, encoder_file_name: Optional[str] = None, decoder_file_name: Optional[str] = None, decoder_with_past_file_name: Optional[str] = None, @@ -131,9 +134,11 @@ def _from_pretrained( Can be either: - The model id of a pretrained model hosted inside a model repo on huggingface.co. - The path to a directory containing the model weights. - use_auth_token (`str` or `bool`): - The token to use as HTTP bearer authorization for remote files. Needed to load models from a private - repository. + use_auth_token (Optional[Union[bool, str]], defaults to `None`): + Deprecated. Please use `token` instead. + token (Optional[Union[bool, str]], defaults to `None`): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `huggingface-cli login` (stored in `~/.huggingface`). revision (`str`): The specific model version to use. It can be a branch name, a tag name, or a commit id. force_download (`bool`, *optional*, defaults to `False`): @@ -154,6 +159,15 @@ def _from_pretrained( local_files_only(`bool`, *optional*, defaults to `False`): Whether or not to only look at local files (i.e., do not try to download the model). """ + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") + token = use_auth_token + default_encoder_file_name = ONNX_ENCODER_NAME if from_onnx else OV_ENCODER_NAME default_decoder_file_name = ONNX_DECODER_NAME if from_onnx else OV_DECODER_NAME default_decoder_with_past_file_name = ONNX_DECODER_WITH_PAST_NAME if from_onnx else OV_DECODER_WITH_PAST_NAME @@ -190,7 +204,7 @@ def _from_pretrained( model_cache_path = hf_hub_download( repo_id=model_id, filename=file_name, - use_auth_token=use_auth_token, + token=token, revision=revision, cache_dir=cache_dir, force_download=force_download, @@ -220,9 +234,10 @@ def _from_transformers( model_id: str, config: PretrainedConfig, use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, force_download: bool = False, - cache_dir: Optional[str] = None, + cache_dir: str = HUGGINGFACE_HUB_CACHE, subfolder: str = "", local_files_only: bool = False, task: Optional[str] = None, @@ -244,13 +259,25 @@ def _from_transformers( save_dir (`str` or `Path`): The directory where the exported ONNX model should be saved, defaults to `transformers.file_utils.default_cache_path`, which is the cache directory for transformers. - use_auth_token (`str` or `bool`): - Is needed to load models from a private repository + use_auth_token (`Optional[str]`, defaults to `None`): + Deprecated. Please use `token` instead. + token (Optional[Union[bool, str]], defaults to `None`): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `huggingface-cli login` (stored in `~/.huggingface`). revision (`str`): Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id kwargs (`Dict`, *optional*): kwargs will be passed to the model during initialization """ + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") + token = use_auth_token + save_dir = TemporaryDirectory() save_dir_path = Path(save_dir.name) @@ -272,7 +299,7 @@ def _from_transformers( subfolder=subfolder, revision=revision, cache_dir=cache_dir, - use_auth_token=use_auth_token, + token=token, local_files_only=local_files_only, force_download=force_download, trust_remote_code=trust_remote_code, diff --git a/optimum/intel/openvino/modeling_decoder.py b/optimum/intel/openvino/modeling_decoder.py index 39a7bee9a2..72cd1b6487 100644 --- a/optimum/intel/openvino/modeling_decoder.py +++ b/optimum/intel/openvino/modeling_decoder.py @@ -11,21 +11,27 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import copy import logging import os +import warnings from pathlib import Path from tempfile import TemporaryDirectory -from typing import Dict, Optional, Tuple, Union +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union import numpy as np import openvino import torch +from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from openvino.preprocess import PrePostProcessor from openvino.runtime import Core, Tensor, Type -from transformers import AutoModelForCausalLM, AutoTokenizer, PretrainedConfig +from transformers import AutoModelForCausalLM, PretrainedConfig from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward from transformers.generation import GenerationMixin +from transformers.generation.configuration_utils import GenerationConfig +from transformers.generation.logits_process import LogitsProcessorList +from transformers.generation.stopping_criteria import StoppingCriteriaList +from transformers.generation.utils import GenerateOutput, GenerationMode from transformers.modeling_outputs import CausalLMOutputWithPast from optimum.utils.normalized_config import NormalizedConfigManager @@ -36,7 +42,12 @@ from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS from .configuration import _DEFAULT_4BIT_CONFIGS, OVConfig, OVWeightQuantizationConfig, _check_default_4bit_configs from .modeling import _TOKENIZER_FOR_DOC, INPUTS_DOCSTRING, MODEL_START_DOCSTRING, OVModel -from .utils import ONNX_WEIGHTS_NAME, OV_XML_FILE_NAME, STR_TO_OV_TYPE +from .utils import ONNX_WEIGHTS_NAME, OV_TO_NP_TYPE, OV_XML_FILE_NAME, STR_TO_OV_TYPE + + +if TYPE_CHECKING: + from transformers.modeling_utils import PreTrainedModel + from transformers.streamers import BaseStreamer logger = logging.getLogger(__name__) @@ -120,6 +131,8 @@ def __init__( self._pkv_precision = Type.f32 self.next_beam_idx = None self._past_length = 0 + self._first_iter_beam_search = False + self._second_iter_beam_search = False self.update_pkv_precision() if self.is_dynamic: self.model = self._reshape(self.model, -1, -1) @@ -219,9 +232,10 @@ def _from_transformers( model_id: str, config: PretrainedConfig, use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, force_download: bool = False, - cache_dir: Optional[str] = None, + cache_dir: str = HUGGINGFACE_HUB_CACHE, subfolder: str = "", local_files_only: bool = False, task: Optional[str] = None, @@ -231,6 +245,15 @@ def _from_transformers( quantization_config: Optional[Union[OVWeightQuantizationConfig, Dict]] = None, **kwargs, ): + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") + token = use_auth_token + save_dir = TemporaryDirectory() save_dir_path = Path(save_dir.name) @@ -254,7 +277,7 @@ def _from_transformers( subfolder=subfolder, revision=revision, cache_dir=cache_dir, - use_auth_token=use_auth_token, + token=token, local_files_only=local_files_only, force_download=force_download, trust_remote_code=trust_remote_code, @@ -363,7 +386,9 @@ def prepare_inputs( inputs = {} if not self.stateful: if past_key_values is not None: - if self.config.model_type not in MULTI_QUERY_ATTN_MODELS: + if self.config.model_type not in MULTI_QUERY_ATTN_MODELS or ( + self.config.model_type == "falcon" and self.config.new_decoder_architecture + ): if self._pkv_precision == Type.bf16: # numpy does not support bf16, pretending f16, should change to bf16 past_key_values = tuple( @@ -384,6 +409,7 @@ def prepare_inputs( elif self.use_cache: for input_name in self.key_value_input_names: model_inputs = self.model.input(input_name) + dtype = OV_TO_NP_TYPE[model_inputs.get_element_type().get_type_name()] shape = model_inputs.get_partial_shape() if self.config.model_type == "chatglm": shape[0] = 0 @@ -394,7 +420,7 @@ def prepare_inputs( shape[2] = 0 else: shape[1] = 0 - inputs[input_name] = Tensor(model_inputs.get_element_type(), shape.get_shape()) + inputs[input_name] = np.empty([dim.get_length() for dim in shape], dtype=dtype) else: # past_key_values are not used explicitly, instead they are handled inside the model if past_key_values is None: @@ -406,7 +432,6 @@ def prepare_inputs( self.next_beam_idx = np.arange(batch_size, dtype=int) self._past_length = 0 past_len = self._get_past_length(past_key_values) - inputs["input_ids"] = np.array(input_ids) # Add the attention_mask inputs when needed if "attention_mask" in self.input_names or "position_ids" in self.input_names: @@ -456,6 +481,8 @@ def forward( **kwargs, ) + if self._first_iter_beam_search: + inputs, duplication_indices = self._deduplicate_inputs(inputs) # Run inference self.request.start_async(inputs, share_inputs=True) self.request.wait() @@ -471,7 +498,9 @@ def forward( if self.use_cache: # Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the self-attention layer) past_key_values = tuple(self.request.get_tensor(key).data for key in self.key_value_output_names) - if self.config.model_type not in MULTI_QUERY_ATTN_MODELS: + if self.config.model_type not in MULTI_QUERY_ATTN_MODELS or ( + self.config.model_type == "falcon" and self.config.new_decoder_architecture + ): # Tuple of tuple of length `n_layers`, with each tuple of length equal to 2 (k/v of self-attention) past_key_values = tuple( past_key_values[i : i + self.num_pkv] for i in range(0, len(past_key_values), self.num_pkv) @@ -479,6 +508,10 @@ def forward( else: past_key_values = None + if self._first_iter_beam_search: + logits, past_key_values = self._expand_outputs_for_generation(duplication_indices, logits, past_key_values) + self._first_iter_beam_search = False + return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values) # Adapted from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation @@ -508,7 +541,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] - return { + model_inputs = { "input_ids": input_ids, "past_key_values": past_key_values, "use_cache": use_cache, @@ -516,12 +549,114 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg "attention_mask": attention_mask, } + return model_inputs + + def _expand_outputs_for_generation(self, indicies, logits: torch.Tensor, past_key_values: Tuple): + batch_size = logits.shape[0] + if indicies.shape[0] != 1: + logits = logits[indicies] + if past_key_values and not self.stateful: + if self.config.model_type not in MULTI_QUERY_ATTN_MODELS or ( + self.config.model_type == "falcon" and self.config.new_decoder_architecture + ): + past_key_values = tuple( + tuple( + past_state[indicies] + if not self.config.model_type == "chatglm" + else past_state[:, indicies, ...] + for past_state in layer_past + ) + for layer_past in past_key_values + ) + else: + past_key_values = tuple([past_state[indicies] for past_state in past_key_values]) + if self.stateful: + self.next_beam_idx = ( + self.next_beam_idx[indicies] + if self.next_beam_idx is not None + else np.arange(batch_size, dtype=int)[indicies] + ) + self._second_iter_beam_search = True + return logits, past_key_values + + def _deduplicate_inputs(self, model_inputs: Dict): + input_ids = model_inputs["input_ids"] + upd_model_inputs = {} + unique_input_ids, indicies, reverse_indicies = np.unique( + input_ids, axis=0, return_index=True, return_inverse=True + ) + for input_name, input_tensor in model_inputs.items(): + if input_name not in ["input_ids", "beam_idx"]: + if not isinstance(input_tensor, Tensor): + upd_model_inputs[input_name] = input_tensor[indicies] + else: + shape = input_tensor.shape + dtype = input_tensor.element_type + upd_batch_size = indicies.shape[0] + if self.config.model_type == "bloom": + upd_batch_size *= self.config.num_attention_heads + shape[0 if not self.config.model_type == "chatglm" else 1] = upd_batch_size + upd_model_inputs[input_name] = Tensor(dtype, shape) + upd_model_inputs["input_ids"] = unique_input_ids + if "beam_idx" in model_inputs: + beam_range = ( + unique_input_ids.shape[0] + if self.config.model_type != "bloom" + else unique_input_ids.shape[0] * self.config.num_attention_heads + ) + beam_idx = np.arange(beam_range, dtype=int) + upd_model_inputs["beam_idx"] = beam_idx + return upd_model_inputs, reverse_indicies + + @torch.no_grad() + def generate( + self, + inputs: Optional[torch.Tensor] = None, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, + synced_gpus: Optional[bool] = None, + assistant_model: Optional["PreTrainedModel"] = None, + streamer: Optional["BaseStreamer"] = None, + negative_prompt_ids: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[GenerateOutput, torch.LongTensor]: + _generation_config, _ = self._prepare_generation_config(generation_config, **kwargs) + generation_mode = _generation_config.get_generation_mode(assistant_model) + + is_beam_search = generation_mode in [ + GenerationMode.BEAM_SEARCH, + GenerationMode.BEAM_SAMPLE, + GenerationMode.GROUP_BEAM_SEARCH, + GenerationMode.CONSTRAINED_BEAM_SEARCH, + ] + if is_beam_search: + self._first_iter_beam_search = True + result = super().generate( + inputs, + generation_config, + logits_processor, + stopping_criteria, + prefix_allowed_tokens_fn, + synced_gpus, + assistant_model, + streamer, + negative_prompt_ids, + negative_prompt_attention_mask, + **kwargs, + ) + return result + def _get_past_length(self, past_key_values=None): if past_key_values is None: return 0 if self.stateful: return self._past_length - if self.config.model_type in MULTI_QUERY_ATTN_MODELS: + if self.config.model_type in MULTI_QUERY_ATTN_MODELS and not ( + self.config.model_type == "falcon" and self.config.new_decoder_architecture + ): return past_key_values[0].shape[-2] seq_length_dim = -2 if self.config.model_type == "chatglm": @@ -546,12 +681,20 @@ def _reorder_cache( if self.stateful: # TODO: Apply it differently based on model type # TODO: At least for bloom we need to replicate values for each attention head - self.next_beam_idx = np.array(beam_idx) # save beam_idx to be used as an input in the next iteration + self.next_beam_idx = ( + np.array(beam_idx) if not self._second_iter_beam_search else self.next_beam_idx + ) # save beam_idx to be used as an input in the next iteration + self._second_iter_beam_search = False return past_key_values else: - return tuple( - tuple(np.take(past_state, beam_idx, 0) for past_state in layer_past) for layer_past in past_key_values - ) + if self.config.model_type not in MULTI_QUERY_ATTN_MODELS or ( + self.config.model_type == "falcon" and self.config.new_decoder_architecture + ): + return tuple( + tuple(np.take(past_state, beam_idx, 0) for past_state in layer_past) + for layer_past in past_key_values + ) + return tuple(np.take(past_state, beam_idx, 0) for past_state in past_key_values) def can_generate(self): """Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate.""" @@ -562,10 +705,11 @@ def _from_pretrained( cls, model_id: Union[str, Path], config: PretrainedConfig, - use_auth_token: Optional[Union[bool, str, None]] = None, + use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, revision: Optional[Union[str, None]] = None, force_download: bool = False, - cache_dir: Optional[str] = None, + cache_dir: str = HUGGINGFACE_HUB_CACHE, file_name: Optional[str] = None, subfolder: str = "", from_onnx: bool = False, @@ -574,13 +718,22 @@ def _from_pretrained( quantization_config: Optional[Union[OVWeightQuantizationConfig, Dict]] = None, **kwargs, ): + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") + token = use_auth_token + model_path = Path(model_id) default_file_name = ONNX_WEIGHTS_NAME if from_onnx else OV_XML_FILE_NAME file_name = file_name or default_file_name model_cache_path = cls._cached_file( model_path=model_path, - use_auth_token=use_auth_token, + token=token, revision=revision, force_download=force_download, cache_dir=cache_dir, @@ -624,9 +777,8 @@ def _from_pretrained( raise ImportError( "Quantization of the weights requires nncf, please install it with `pip install nncf`" ) - import nncf - from .quantization import _weight_only_quantization + from optimum.intel.openvino.quantization import OVQuantizer default_config = _check_default_4bit_configs(config) @@ -635,18 +787,10 @@ def _from_pretrained( f"For the given model, we recommend the following `quantization_config` : {default_config}" ) - calibration_dataset = None - if isinstance(quantization_config.dataset, str): - tokenizer = quantization_config.tokenizer or AutoTokenizer.from_pretrained(model_id) - - from optimum.gptq.data import get_dataset, prepare_dataset - - nsamples = quantization_config.num_samples or 128 - dataset = get_dataset(quantization_config.dataset, tokenizer, seqlen=32, nsamples=nsamples) - dataset = prepare_dataset(dataset) - calibration_dataset = nncf.Dataset(dataset, lambda x: causal_model.prepare_inputs(**x)) - - _weight_only_quantization(model, quantization_config, calibration_dataset) + quantizer = OVQuantizer(causal_model) + quantization_config_copy = copy.deepcopy(quantization_config) + quantization_config_copy.tokenizer = quantization_config.tokenizer or model_id + quantizer.quantize(ov_config=OVConfig(quantization_config=quantization_config_copy)) return causal_model @@ -671,11 +815,12 @@ def _reorder_cache( This is required to match `past_key_values` with the correct beam_idx at every generation step. """ if self.stateful: - beam_idx = np.array(beam_idx) batch_size = beam_idx.shape[0] + beam_idx = np.array(beam_idx) if not self._second_iter_beam_search else self.next_beam_idx indices = np.array(range(batch_size * self.config.num_attention_heads)) indices = indices.reshape([batch_size, self.config.num_attention_heads]) self.next_beam_idx = np.take(indices, beam_idx, 0).flatten() + self._second_iter_beam_search = False return past_key_values else: standardized_past = self._convert_to_standard_cache(past_key_values, batch_size=len(beam_idx)) @@ -725,6 +870,24 @@ def _convert_to_standard_cache( for layer_past in past_key_value ) + def _expand_outputs_for_generation(self, indicies, logits: torch.Tensor, past_key_values: Tuple): + batch_size = logits.shape[0] + if indicies.shape[0] != 1: + logits = logits[indicies] + if past_key_values and not self.stateful: + pkv_standard = self._convert_to_standard_cache(past_key_values, batch_size) + pkv = tuple(tuple(past_state[indicies] for past_state in layer_past) for layer_past in pkv_standard) + past_key_values = self._convert_to_bloom_cache(pkv) + + if self.stateful: + self.next_beam_idx = ( + self.next_beam_idx[indicies] + if self.next_beam_idx is not None + else np.arange(batch_size, dtype=int)[indicies] + ) + self._second_iter_beam_search = True + return logits, past_key_values + class OVGPTBigCodeForCausalLM(OVModelForCausalLM): # Adapted from transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeForCausalLM._reorder_cache @@ -732,7 +895,9 @@ def _reorder_cache( self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor ) -> Tuple[Tuple[torch.Tensor]]: if self.stateful: - self.next_beam_idx = np.array(beam_idx) # save beam_idx to be used as an input in the next iteration + # save beam_idx to be used as an input in the next iteration + self.next_beam_idx = np.array(beam_idx) if not self._second_iter_beam_search else self.next_beam_idx + self._second_iter_beam_search = False return past_key_values else: return tuple(np.take(layer_past, beam_idx, 0) for layer_past in past_key_values) diff --git a/optimum/intel/openvino/modeling_diffusion.py b/optimum/intel/openvino/modeling_diffusion.py index fb9bec7a8e..1b880e736c 100644 --- a/optimum/intel/openvino/modeling_diffusion.py +++ b/optimum/intel/openvino/modeling_diffusion.py @@ -16,6 +16,7 @@ import logging import os import shutil +import warnings from copy import deepcopy from pathlib import Path from tempfile import TemporaryDirectory, gettempdir @@ -35,6 +36,7 @@ from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from diffusers.utils import CONFIG_NAME, is_invisible_watermark_available from huggingface_hub import snapshot_download +from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from openvino._offline_transformations import compress_model_transformation from openvino.runtime import Core from transformers import CLIPFeatureExtractor, CLIPTokenizer @@ -55,14 +57,13 @@ ) from ...exporters.openvino import main_export -from .configuration import OVConfig, OVWeightQuantizationConfig +from .configuration import OVConfig, OVQuantizationMethod, OVWeightQuantizationConfig from .loaders import OVTextualInversionLoaderMixin from .modeling_base import OVBaseModel from .utils import ( ONNX_WEIGHTS_NAME, OV_TO_NP_TYPE, OV_XML_FILE_NAME, - PREDEFINED_SD_DATASETS, _print_compiled_model_properties, ) @@ -207,8 +208,9 @@ def _from_pretrained( model_id: Union[str, Path], config: Dict[str, Any], use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, - cache_dir: Optional[str] = None, + cache_dir: str = HUGGINGFACE_HUB_CACHE, vae_decoder_file_name: Optional[str] = None, text_encoder_file_name: Optional[str] = None, unet_file_name: Optional[str] = None, @@ -221,6 +223,15 @@ def _from_pretrained( quantization_config: Union[OVWeightQuantizationConfig, Dict] = None, **kwargs, ): + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") + token = use_auth_token + default_file_name = ONNX_WEIGHTS_NAME if from_onnx else OV_XML_FILE_NAME vae_decoder_file_name = vae_decoder_file_name or default_file_name text_encoder_file_name = text_encoder_file_name or default_file_name @@ -259,7 +270,7 @@ def _from_pretrained( model_id, cache_dir=cache_dir, local_files_only=local_files_only, - use_auth_token=use_auth_token, + token=token, revision=revision, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, @@ -281,21 +292,7 @@ def _from_pretrained( else: kwargs[name] = load_method(new_model_save_dir) - quantization_config = cls._prepare_weight_quantization_config(quantization_config, load_in_8bit) - unet_path = new_model_save_dir / DIFFUSION_MODEL_UNET_SUBFOLDER / unet_file_name - if quantization_config is not None and quantization_config.dataset is not None: - # load the UNet model uncompressed to apply hybrid quantization further - unet = cls.load_model(unet_path) - # Apply weights compression to other `components` without dataset - weight_quantization_params = { - param: value for param, value in quantization_config.__dict__.items() if param != "dataset" - } - weight_quantization_config = OVWeightQuantizationConfig.from_dict(weight_quantization_params) - else: - weight_quantization_config = quantization_config - unet = cls.load_model(unet_path, weight_quantization_config) - components = { "vae_encoder": new_model_save_dir / DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER / vae_encoder_file_name, "vae_decoder": new_model_save_dir / DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER / vae_decoder_file_name, @@ -303,13 +300,19 @@ def _from_pretrained( "text_encoder_2": new_model_save_dir / DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER / text_encoder_2_file_name, } - for key, value in components.items(): - components[key] = cls.load_model(value, weight_quantization_config) if value.is_file() else None - if model_save_dir is None: model_save_dir = new_model_save_dir - if quantization_config is not None and quantization_config.dataset is not None: + quantization_config = cls._prepare_weight_quantization_config(quantization_config, load_in_8bit) + if quantization_config is None or quantization_config.dataset is None: + unet = cls.load_model(unet_path, quantization_config) + for key, value in components.items(): + components[key] = cls.load_model(value, quantization_config) if value.is_file() else None + else: + # Load uncompressed models to apply hybrid quantization further + unet = cls.load_model(unet_path) + for key, value in components.items(): + components[key] = cls.load_model(value) if value.is_file() else None sd_model = cls(unet=unet, config=config, model_save_dir=model_save_dir, **components, **kwargs) supported_pipelines = ( @@ -320,12 +323,14 @@ def _from_pretrained( if not isinstance(sd_model, supported_pipelines): raise NotImplementedError(f"Quantization in hybrid mode is not supported for {cls.__name__}") - nsamples = quantization_config.num_samples if quantization_config.num_samples else 200 - unet_inputs = sd_model._prepare_unet_inputs(quantization_config.dataset, nsamples) + from optimum.intel import OVQuantizer - from .quantization import _hybrid_quantization + hybrid_quantization_config = deepcopy(quantization_config) + hybrid_quantization_config.quant_method = OVQuantizationMethod.HYBRID + quantizer = OVQuantizer(sd_model) + quantizer.quantize(ov_config=OVConfig(quantization_config=hybrid_quantization_config)) - unet = _hybrid_quantization(sd_model.unet.model, weight_quantization_config, dataset=unet_inputs) + return sd_model return cls( unet=unet, @@ -336,71 +341,16 @@ def _from_pretrained( **kwargs, ) - def _prepare_unet_inputs( - self, - dataset: Union[str, List[Any]], - num_samples: int, - height: Optional[int] = None, - width: Optional[int] = None, - seed: Optional[int] = 42, - **kwargs, - ) -> Dict[str, Any]: - self.compile() - - size = self.unet.config.get("sample_size", 64) * self.vae_scale_factor - height = height or min(size, 512) - width = width or min(size, 512) - - if isinstance(dataset, str): - dataset = deepcopy(dataset) - available_datasets = PREDEFINED_SD_DATASETS.keys() - if dataset not in available_datasets: - raise ValueError( - f"""You have entered a string value for dataset. You can only choose between - {list(available_datasets)}, but the {dataset} was found""" - ) - - from datasets import load_dataset - - dataset_metadata = PREDEFINED_SD_DATASETS[dataset] - dataset = load_dataset(dataset, split=dataset_metadata["split"], streaming=True).shuffle(seed=seed) - input_names = dataset_metadata["inputs"] - dataset = dataset.select_columns(list(input_names.values())) - - def transform_fn(data_item): - return {inp_name: data_item[column] for inp_name, column in input_names.items()} - - else: - - def transform_fn(data_item): - return data_item if isinstance(data_item, (list, dict)) else [data_item] - - from .quantization import InferRequestWrapper - - calibration_data = [] - self.unet.request = InferRequestWrapper(self.unet.request, calibration_data) - - for inputs in dataset: - inputs = transform_fn(inputs) - if isinstance(inputs, dict): - self.__call__(**inputs, height=height, width=width) - else: - self.__call__(*inputs, height=height, width=width) - if len(calibration_data) >= num_samples: - break - - self.unet.request = self.unet.request.request - return calibration_data[:num_samples] - @classmethod def _from_transformers( cls, model_id: str, config: Dict[str, Any], use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, force_download: bool = False, - cache_dir: Optional[str] = None, + cache_dir: str = HUGGINGFACE_HUB_CACHE, local_files_only: bool = False, tokenizer: Optional["CLIPTokenizer"] = None, scheduler: Union["DDIMScheduler", "PNDMScheduler", "LMSDiscreteScheduler"] = None, @@ -410,6 +360,15 @@ def _from_transformers( quantization_config: Union[OVWeightQuantizationConfig, Dict] = None, **kwargs, ): + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") + token = use_auth_token + save_dir = TemporaryDirectory() save_dir_path = Path(save_dir.name) @@ -427,7 +386,7 @@ def _from_transformers( no_post_process=True, revision=revision, cache_dir=cache_dir, - use_auth_token=use_auth_token, + token=token, local_files_only=local_files_only, force_download=force_download, ov_config=ov_config, @@ -437,7 +396,7 @@ def _from_transformers( model_id=save_dir_path, config=config, from_onnx=False, - use_auth_token=use_auth_token, + token=token, revision=revision, force_download=force_download, cache_dir=cache_dir, diff --git a/optimum/intel/openvino/modeling_seq2seq.py b/optimum/intel/openvino/modeling_seq2seq.py index 0959a8c2f9..6d72dc7b0e 100644 --- a/optimum/intel/openvino/modeling_seq2seq.py +++ b/optimum/intel/openvino/modeling_seq2seq.py @@ -224,7 +224,7 @@ ```python >>> from transformers import {processor_class} - >>> from optimum.intel.openvino import {model_class} + >>> from optimum.intel import {model_class} >>> from datasets import load_dataset >>> processor = {processor_class}.from_pretrained("{checkpoint}") @@ -241,7 +241,7 @@ ```python >>> from transformers import {processor_class}, pipeline - >>> from optimum.intel.openvino import {model_class} + >>> from optimum.intel import {model_class} >>> from datasets import load_dataset >>> processor = {processor_class}.from_pretrained("{checkpoint}") diff --git a/optimum/intel/openvino/modeling_timm.py b/optimum/intel/openvino/modeling_timm.py index a84f80c9f7..f2566c8c4a 100644 --- a/optimum/intel/openvino/modeling_timm.py +++ b/optimum/intel/openvino/modeling_timm.py @@ -19,6 +19,7 @@ import numpy as np import timm import torch +from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from packaging import version from timm.layers.config import set_fused_attn from timm.models._hub import load_model_config_from_hf @@ -55,7 +56,7 @@ class TimmConfig(PretrainedConfig): def from_pretrained( cls, pretrained_model_name_or_path: Union[str, os.PathLike], - cache_dir: Optional[Union[str, os.PathLike]] = None, + cache_dir: str = HUGGINGFACE_HUB_CACHE, force_download: bool = False, local_files_only: bool = False, token: Optional[Union[str, bool]] = None, diff --git a/optimum/intel/openvino/quantization.py b/optimum/intel/openvino/quantization.py index 217e5e4056..43cf1dd93b 100644 --- a/optimum/intel/openvino/quantization.py +++ b/optimum/intel/openvino/quantization.py @@ -12,19 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +import collections.abc import copy import inspect import logging import os +import warnings from collections import deque from pathlib import Path from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union -import datasets import nncf import openvino import torch import transformers +from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from nncf import CompressWeightsMode, SensitivityMetric from nncf.quantization.advanced_parameters import AdvancedSmoothQuantParameters, OverflowFix from nncf.torch import register_module @@ -36,6 +38,7 @@ from transformers import AutoTokenizer, DataCollator, PreTrainedModel, default_data_collator from transformers.pytorch_utils import Conv1D from transformers.utils import is_accelerate_available +from transformers.utils.quantization_config import QuantizationMethod from optimum.exporters.onnx.convert import check_dummy_inputs_are_allowed from optimum.exporters.tasks import TasksManager @@ -45,15 +48,16 @@ from ...exporters.openvino.model_patcher import patch_model_with_bettertransformer from ...exporters.openvino.stateful import ensure_export_task_support_stateful, ensure_stateful_is_available from ..utils.constant import _TASK_ALIASES -from ..utils.import_utils import DATASETS_IMPORT_ERROR, is_datasets_available +from ..utils.import_utils import DATASETS_IMPORT_ERROR, is_datasets_available, is_diffusers_available from ..utils.modeling_utils import get_model_device -from .configuration import OVConfig, OVQuantizationConfig, OVWeightQuantizationConfig +from .configuration import OVConfig, OVQuantizationConfig, OVQuantizationMethod, OVWeightQuantizationConfig from .modeling_base import OVBaseModel from .utils import ( MAX_ONNX_OPSET, MIN_ONNX_QDQ_OPSET, ONNX_WEIGHTS_NAME, OV_XML_FILE_NAME, + PREDEFINED_SD_DATASETS, ) @@ -197,8 +201,8 @@ def from_pretrained(cls, model: PreTrainedModel, **kwargs): def quantize( self, - calibration_dataset: Optional[Union[datasets.Dataset, nncf.Dataset, Iterable]] = None, - save_directory: Union[str, Path] = None, + calibration_dataset: Optional[Union["Dataset", nncf.Dataset, Iterable]] = None, + save_directory: Optional[Union[str, Path]] = None, ov_config: OVConfig = None, file_name: Optional[str] = None, batch_size: int = 1, @@ -214,7 +218,7 @@ def quantize( calibration_dataset (`datasets.Dataset` or `nncf.Dataset` or `Iterable`, *optional*): A collection of data samples to use for quantization calibration. Is optional for weight-only quantization and is required for full quantization. - save_directory (`Union[str, Path]`): + save_directory (`Union[str, Path]`, *optional*): The directory where the quantized model should be saved. ov_config (`OVConfig`, *optional*): The configuration containing the parameters related to quantization. If not provided, 8-bit symmetric @@ -234,7 +238,7 @@ def quantize( Examples: ```python - >>> from optimum.intel.openvino import OVQuantizer, OVModelForCausalLM + >>> from optimum.intel import OVQuantizer, OVModelForCausalLM >>> from transformers import AutoModelForCausalLM >>> model = AutoModelForCausalLM.from_pretrained("databricks/dolly-v2-3b") >>> quantizer = OVQuantizer.from_pretrained(model, task="text-generation") @@ -244,7 +248,7 @@ def quantize( ``` ```python - >>> from optimum.intel.openvino import OVQuantizer, OVModelForSequenceClassification + >>> from optimum.intel import OVQuantizer, OVModelForSequenceClassification >>> from transformers import AutoModelForSequenceClassification >>> model = OVModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english", export=True) >>> # or @@ -262,10 +266,6 @@ def quantize( "as an instance of `OVWeightQuantizationConfig` for weight-only compression or as an instance of `OVQuantizationConfig` for full model quantization." ) - if save_directory is None: - # TODO : can be set to self.model.config.name_or_path for OVModels when not provided - raise ValueError("`save_directory` needs to be specified") - if ov_config is None: ov_config = OVConfig() if not isinstance(ov_config, OVConfig): @@ -318,61 +318,107 @@ def quantize( def _quantize_ovbasemodel( self, ov_config: OVConfig, - save_directory: Union[str, Path], - calibration_dataset: Optional[Union[datasets.Dataset, nncf.Dataset, Iterable]] = None, + save_directory: Union[str, Path] = None, + calibration_dataset: Optional[Union["Dataset", nncf.Dataset, Iterable]] = None, batch_size: int = 1, data_collator: Optional[DataCollator] = None, remove_unused_columns: bool = True, **kwargs, ): - save_directory = Path(save_directory) - save_directory.mkdir(parents=True, exist_ok=True) + if is_diffusers_available(): + from optimum.intel.openvino.modeling_diffusion import OVStableDiffusionPipelineBase + if save_directory is not None: + save_directory = Path(save_directory) + save_directory.mkdir(parents=True, exist_ok=True) quantization_config = ov_config.quantization_config + + if calibration_dataset is not None: + # Process custom calibration dataset + + if is_diffusers_available() and isinstance(self.model, OVStableDiffusionPipelineBase): + calibration_dataset = self._prepare_unet_dataset( + quantization_config.num_samples, dataset=calibration_dataset + ) + elif is_datasets_available() and isinstance(calibration_dataset, Dataset): + calibration_dataloader = self._get_calibration_dataloader( + calibration_dataset=calibration_dataset, + batch_size=batch_size, + remove_unused_columns=remove_unused_columns, + data_collator=data_collator, + ) + if self.model.export_feature == "text-generation" and self.model.use_cache: + calibration_dataset = self._prepare_text_generation_dataset( + quantization_config, calibration_dataloader + ) + else: + calibration_dataset = nncf.Dataset(calibration_dataloader) + elif isinstance(calibration_dataset, collections.abc.Iterable): + calibration_dataset = nncf.Dataset(calibration_dataset) + elif not isinstance(calibration_dataset, nncf.Dataset): + raise ValueError( + "`calibration_dataset` must be either an `Iterable` object or an instance of " + f"`nncf.Dataset` or `datasets.Dataset`. Found: {type(calibration_dataset)}." + ) + if isinstance(quantization_config, OVWeightQuantizationConfig): - _weight_only_quantization(self.model.model, quantization_config, calibration_dataset) - self.model.save_pretrained(save_directory) - ov_config.save_pretrained(save_directory) - return - if not isinstance(quantization_config, OVQuantizationConfig): - raise ValueError(f"Unsupported type of quantization config: {type(quantization_config)}") + if quantization_config.dataset is not None and calibration_dataset is not None: + logger.info( + "Both `quantization_config.dataset` and `calibration_dataset` were provided for weight only " + "quantization. Will rely on `calibration_dataset`." + ) - if isinstance(calibration_dataset, nncf.Dataset): - quantization_dataset = calibration_dataset - elif isinstance(calibration_dataset, datasets.Dataset): - calibration_dataloader = self._get_calibration_dataloader( - calibration_dataset=calibration_dataset, - batch_size=batch_size, - remove_unused_columns=remove_unused_columns, - data_collator=data_collator, - ) + if calibration_dataset is None and isinstance(quantization_config.dataset, str): + from optimum.intel import OVModelForCausalLM - if self.model.export_feature == "text-generation" and self.model.use_cache: - # Prefetch past_key_values - self.model.update_pkv_precision(True) - self.model.compile() - collected_inputs = [] + if isinstance(self.model, OVModelForCausalLM): + calibration_dataset = self._prepare_builtin_dataset(quantization_config) + elif is_diffusers_available() and isinstance(self.model, OVStableDiffusionPipelineBase): + calibration_dataset = self._prepare_unet_dataset( + quantization_config.num_samples, dataset_name=quantization_config.dataset + ) + else: + raise ValueError( + f"Can't create weight compression calibration dataset from string for {type(self.model)}" + ) - self.model.request = InferRequestWrapper(self.model.request, collected_inputs) - try: - for data in calibration_dataloader: - self.model.generate(**data, max_new_tokens=1) - if len(collected_inputs) >= quantization_config.num_samples: - break - finally: - self.model.request = self.model.request.request - quantization_dataset = nncf.Dataset(collected_inputs) + if quantization_config.quant_method == OVQuantizationMethod.HYBRID: + if calibration_dataset is None: + raise ValueError("Calibration dataset is required to run hybrid quantization.") + if is_diffusers_available() and isinstance(self.model, OVStableDiffusionPipelineBase): + # Apply weight-only quantization to all SD submodels except UNet + quantization_config_copy = copy.deepcopy(quantization_config) + quantization_config_copy.dataset = None + quantization_config_copy.quant_method = OVQuantizationMethod.DEFAULT + for sd_submodel_name in ["vae_encoder", "vae_decoder", "text_encoder", "text_encoder_2"]: + sd_submodel = getattr(self.model, sd_submodel_name) + if sd_submodel is not None: + _weight_only_quantization(sd_submodel.model, quantization_config_copy) + + # Apply hybrid quantization to UNet + self.model.unet.model = _hybrid_quantization( + self.model.unet.model, quantization_config, calibration_dataset + ) + else: + # The model may be for example OVModelForImageClassification, OVModelForAudioClassification, etc. + self.model.model = _hybrid_quantization(self.model.model, quantization_config, calibration_dataset) else: - quantization_dataset = nncf.Dataset(calibration_dataloader) - else: - if calibration_dataset is None: - raise ValueError("Calibration dataset is required to run quantization.") - quantization_dataset = nncf.Dataset(calibration_dataset) + _weight_only_quantization(self.model.model, quantization_config, calibration_dataset) + if save_directory is not None: + self.model.save_pretrained(save_directory) + ov_config.save_pretrained(save_directory) + return + + if not isinstance(quantization_config, OVQuantizationConfig): + raise ValueError(f"Unsupported type of quantization config: {type(quantization_config)}") + + if calibration_dataset is None: + raise ValueError("Calibration dataset is required to run quantization.") # Actual model quantization quantized_model = nncf.quantize( self.model.model, - quantization_dataset, + calibration_dataset, subset_size=quantization_config.num_samples, ignored_scope=quantization_config.get_ignored_scope_instance(), model_type=nncf.ModelType(quantization_config.model_type), @@ -383,21 +429,27 @@ def _quantize_ovbasemodel( ), **kwargs, ) + self.model.model = quantized_model - self.model.save_pretrained(save_directory) - ov_config.save_pretrained(save_directory) + if save_directory is not None: + self.model.save_pretrained(save_directory) + ov_config.save_pretrained(save_directory) def _quantize_torchmodel( self, ov_config: OVConfig, save_directory: Union[str, Path], - calibration_dataset: Optional[Union[datasets.Dataset, nncf.Dataset, Iterable]] = None, + calibration_dataset: Optional[Union["Dataset", nncf.Dataset, Iterable]] = None, file_name: Optional[str] = None, batch_size: int = 1, data_collator: Optional[DataCollator] = None, remove_unused_columns: bool = True, **kwargs, ): + if save_directory is None: + # TODO : can be set to self.model.config.name_or_path for OVModels when not provided + raise ValueError("`save_directory` needs to be specified") + self._set_task() save_directory = Path(save_directory) save_directory.mkdir(parents=True, exist_ok=True) @@ -458,7 +510,7 @@ def _quantize_torchmodel( if isinstance(calibration_dataset, nncf.Dataset): quantization_dataset = calibration_dataset - elif isinstance(calibration_dataset, datasets.Dataset): + elif isinstance(calibration_dataset, Dataset): calibration_dataloader = self._get_calibration_dataloader( calibration_dataset=calibration_dataset, batch_size=batch_size, @@ -476,9 +528,9 @@ def _quantize_torchmodel( subset_size=quantization_config.num_samples, ignored_scope=quantization_config.get_ignored_scope_instance(), model_type=nncf.ModelType(quantization_config.model_type), - preset=nncf.QuantizationPreset.PERFORMANCE - if quantization_config.sym - else nncf.QuantizationPreset.MIXED, + preset=( + nncf.QuantizationPreset.PERFORMANCE if quantization_config.sym else nncf.QuantizationPreset.MIXED + ), fast_bias_correction=quantization_config.fast_bias_correction, advanced_parameters=nncf.AdvancedQuantizationParameters( overflow_fix=OverflowFix(quantization_config.overflow_fix) @@ -540,9 +592,10 @@ def get_calibration_dataset( dataset_split: str = "train", preprocess_function: Optional[Callable] = None, preprocess_batch: bool = True, - use_auth_token: bool = False, - cache_dir: Optional[str] = None, - ) -> datasets.Dataset: + use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, + cache_dir: str = HUGGINGFACE_HUB_CACHE, + ) -> "Dataset": """ Create the calibration `datasets.Dataset` to use for the post-training static quantization calibration step. @@ -560,13 +613,25 @@ def get_calibration_dataset( Processing function to apply to each example after loading dataset. preprocess_batch (`bool`, defaults to `True`): Whether the `preprocess_function` should be batched. - use_auth_token (`bool`, defaults to `False`): - Whether to use the token generated when running `transformers-cli login`. + use_auth_token (Optional[Union[bool, str]], defaults to `None`): + Deprecated. Please use `token` instead. + token (Optional[Union[bool, str]], defaults to `None`): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `huggingface-cli login` (stored in `~/.huggingface`). cache_dir (`str`, *optional*): Caching directory for a calibration dataset. Returns: The calibration `datasets.Dataset` to use for the post-training static quantization calibration step. """ + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") + token = use_auth_token + if not is_datasets_available(): raise ValueError(DATASETS_IMPORT_ERROR.format("OVQuantizer.get_calibration_dataset")) from datasets import load_dataset @@ -575,7 +640,7 @@ def get_calibration_dataset( dataset_name, name=dataset_config_name, split=dataset_split, - use_auth_token=use_auth_token, + token=token, cache_dir=cache_dir, ) @@ -617,6 +682,104 @@ def _remove_unused_columns(self, dataset: "Dataset"): ignored_columns = list(set(dataset.column_names) - set(self._signature_columns)) return dataset.remove_columns(ignored_columns) + def _prepare_builtin_dataset(self, quantization_config: OVWeightQuantizationConfig): + from optimum.gptq.data import get_dataset, prepare_dataset + + tokenizer = AutoTokenizer.from_pretrained(quantization_config.tokenizer) + nsamples = quantization_config.num_samples if quantization_config.num_samples else 128 + calibration_dataset = get_dataset(quantization_config.dataset, tokenizer, seqlen=32, nsamples=nsamples) + calibration_dataset = prepare_dataset(calibration_dataset) + calibration_dataset = nncf.Dataset(calibration_dataset, lambda x: self.model.prepare_inputs(**x)) + + return calibration_dataset + + def _prepare_text_generation_dataset( + self, quantization_config: OVQuantizationConfig, calibration_dataloader: OVDataLoader + ) -> nncf.Dataset: + # Prefetch past_key_values + self.model.update_pkv_precision(True) + self.model.compile() + collected_inputs = [] + + num_samples = quantization_config.num_samples or 200 + + self.model.request = InferRequestWrapper(self.model.request, collected_inputs) + try: + for data in calibration_dataloader: + self.model.generate(**data, max_new_tokens=1) + if len(collected_inputs) >= num_samples: + break + finally: + self.model.request = self.model.request.request + calibration_dataset = nncf.Dataset(collected_inputs) + + return calibration_dataset + + def _prepare_unet_dataset( + self, + num_samples: Optional[int] = None, + dataset_name: Optional[str] = None, + dataset: Optional[Union[Iterable, "Dataset"]] = None, + ) -> nncf.Dataset: + self.model.compile() + + size = self.model.unet.config.get("sample_size", 64) * self.model.vae_scale_factor + height, width = 2 * (min(size, 512),) + num_samples = num_samples or 200 + + if dataset is not None: + if isinstance(dataset, nncf.Dataset): + return dataset + if is_datasets_available() and isinstance(dataset, Dataset): + dataset = dataset.select_columns(["caption"]) + + def transform_fn(data_item): + return data_item if isinstance(data_item, (list, dict)) else [data_item] + + elif isinstance(dataset_name, str): + available_datasets = PREDEFINED_SD_DATASETS.keys() + if dataset_name not in available_datasets: + raise ValueError( + f"""You have entered a string value for dataset. You can only choose between + {list(available_datasets)}, but the {dataset_name} was found""" + ) + + from datasets import load_dataset + + dataset_metadata = PREDEFINED_SD_DATASETS[dataset_name] + dataset = load_dataset(dataset_name, split=dataset_metadata["split"], streaming=True).shuffle( + seed=self.seed + ) + input_names = dataset_metadata["inputs"] + dataset = dataset.select_columns(list(input_names.values())) + + def transform_fn(data_item): + return {inp_name: data_item[column] for inp_name, column in input_names.items()} + + else: + raise ValueError( + "For UNet inputs collection either quantization_config.dataset or custom " + "calibration_dataset must be provided." + ) + + calibration_data = [] + try: + self.model.unet.request = InferRequestWrapper(self.model.unet.request, calibration_data) + + for inputs in dataset: + inputs = transform_fn(inputs) + if isinstance(inputs, dict): + self.model(**inputs, height=height, width=width) + else: + self.model(*inputs, height=height, width=width) + if len(calibration_data) >= num_samples: + break + finally: + self.model.unet.request = self.model.unet.request.request + + calibration_dataset = nncf.Dataset(calibration_data[:num_samples]) + return calibration_dataset + def _weight_only_quantization( model: openvino.runtime.Model, @@ -627,14 +790,9 @@ def _weight_only_quantization( if isinstance(config, dict): config = OVWeightQuantizationConfig.from_dict(quantization_config) - if config.dataset is not None and calibration_dataset is not None: - logger.info( - "Both `quantization_config.dataset` and `calibration_dataset` were provided for weight only " - "quantization. Will rely on `calibration_dataset`." - ) dataset = None if calibration_dataset is not None: - if isinstance(calibration_dataset, datasets.Dataset): + if is_datasets_available() and isinstance(calibration_dataset, Dataset): raise ValueError( "Providing calibration dataset as an instance of `datasets.Dataset` for OV weight-only " "quantization is not supported. Please provide it as `nncf.Dataset` or as iterable of " @@ -644,14 +802,6 @@ def _weight_only_quantization( dataset = calibration_dataset else: dataset = nncf.Dataset(calibration_dataset) - elif config.dataset is not None and isinstance(config.dataset, str): - tokenizer = AutoTokenizer.from_pretrained(config.tokenizer) - - from optimum.gptq.data import get_dataset, prepare_dataset - - nsamples = config.num_samples if config.num_samples else 128 - dataset = get_dataset(config.dataset, tokenizer, seqlen=32, nsamples=nsamples) - dataset = prepare_dataset(dataset) sensitivity_metric = None if isinstance(config.sensitivity_metric, str): @@ -669,10 +819,10 @@ def _weight_only_quantization( group_size=config.group_size, all_layers=config.all_layers, sensitivity_metric=sensitivity_metric, - # awq=config.quant_method == QuantizationMethod.AWQ, # TODO : enable from nncf v2.9.0 + awq=config.quant_method == QuantizationMethod.AWQ or None, ignored_scope=config.get_ignored_scope_instance(), dataset=dataset, - # subset_size=config.num_samples if config.num_samples else 128, # TODO : enable from nncf v2.9.0 + subset_size=config.num_samples if config.num_samples else 128, ) @@ -722,7 +872,7 @@ def _collect_ops_with_weights(model): def _hybrid_quantization( - model: openvino.runtime.Model, quantization_config: OVWeightQuantizationConfig, dataset: Dict[str, Any] + model: openvino.runtime.Model, quantization_config: OVWeightQuantizationConfig, dataset: nncf.Dataset ) -> openvino.runtime.Model: """ Quantize a model in hybrid mode with NNCF which means that we quantize: @@ -734,7 +884,7 @@ def _hybrid_quantization( The OpenVINO Runtime model for applying hybrid quantization. quantization_config (`OVWeightQuantizationConfig`): The configuration containing the parameters related to quantization. - dataset (`Dict[str, Any]`): + dataset (`nncf.Dataset`): The dataset used for hybrid quantization. Returns: The OpenVINO Runtime model with applied hybrid quantization. @@ -751,7 +901,7 @@ def _hybrid_quantization( subset_size = quantization_config.num_samples if quantization_config.num_samples else 200 quantized_model = nncf.quantize( model=compressed_model, - calibration_dataset=nncf.Dataset(dataset), + calibration_dataset=dataset, model_type=nncf.ModelType.TRANSFORMER, ignored_scope=ptq_ignored_scope, # SQ algo should be disabled for MatMul nodes because their weights are already compressed diff --git a/optimum/intel/openvino/trainer.py b/optimum/intel/openvino/trainer.py index 0745a1cd79..c8b29800fa 100644 --- a/optimum/intel/openvino/trainer.py +++ b/optimum/intel/openvino/trainer.py @@ -906,7 +906,7 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None): output_path = os.path.join(output_dir, OV_XML_FILE_NAME) self.compression_controller.prepare_for_export() model_type = self.model.config.model_type.replace("_", "-") - onnx_config_class = TasksManager.get_exporter_config_constructor( + exporter_config_class = TasksManager.get_exporter_config_constructor( exporter="onnx", model=self.model, task=self.task, @@ -914,9 +914,9 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None): ) if self.task == "text-generation": - onnx_config = onnx_config_class(self.model.config, use_past=self.model.config.use_cache) + onnx_config = exporter_config_class(self.model.config, use_past=self.model.config.use_cache) else: - onnx_config = onnx_config_class(self.model.config) + onnx_config = exporter_config_class(self.model.config) num_parameters = self.model.num_parameters() save_as_external_data = use_external_data_format(num_parameters) or self.ov_config.save_onnx_model diff --git a/optimum/intel/openvino/utils.py b/optimum/intel/openvino/utils.py index 4d1479f733..69a750fb65 100644 --- a/optimum/intel/openvino/utils.py +++ b/optimum/intel/openvino/utils.py @@ -17,10 +17,13 @@ import logging import os from glob import glob +from pathlib import Path +from typing import Tuple, Union import numpy as np from huggingface_hub import model_info from openvino.runtime import Core, Type, properties +from transformers import AutoTokenizer, CLIPTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast from transformers.onnx.utils import ParameterFormat, compute_serialized_parameters_size @@ -107,6 +110,24 @@ } +NEED_CONVERT_TO_FAST_TOKENIZER: Tuple[type(PreTrainedTokenizer)] = (CLIPTokenizer,) + + +def maybe_convert_tokenizer_to_fast( + hf_tokenizer: PreTrainedTokenizer, tokenizer_path: Path +) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + if isinstance(hf_tokenizer, PreTrainedTokenizerFast): + return hf_tokenizer + + if isinstance(hf_tokenizer, NEED_CONVERT_TO_FAST_TOKENIZER): + try: + return AutoTokenizer.from_pretrained(tokenizer_path) + except Exception: + return hf_tokenizer + + return hf_tokenizer + + def use_external_data_format(num_parameters: int) -> bool: """ Returns whether or not the model requires using external data format for the ONNX export diff --git a/optimum/intel/pipelines/__init__.py b/optimum/intel/pipelines/__init__.py new file mode 100644 index 0000000000..40a1e3ca56 --- /dev/null +++ b/optimum/intel/pipelines/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .pipeline_base import pipeline diff --git a/optimum/intel/pipelines/pipeline_base.py b/optimum/intel/pipelines/pipeline_base.py new file mode 100644 index 0000000000..65e6cfb782 --- /dev/null +++ b/optimum/intel/pipelines/pipeline_base.py @@ -0,0 +1,290 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path +from typing import TYPE_CHECKING, Any, Dict, Optional, Union + +import torch +from transformers import AutoConfig, AutoFeatureExtractor, AutoTokenizer +from transformers import pipeline as transformers_pipeline +from transformers.feature_extraction_utils import PreTrainedFeatureExtractor +from transformers.pipelines import ( + AudioClassificationPipeline, + FillMaskPipeline, + ImageClassificationPipeline, + QuestionAnsweringPipeline, + TextClassificationPipeline, + TextGenerationPipeline, + TokenClassificationPipeline, +) +from transformers.pipelines.base import Pipeline +from transformers.tokenization_utils import PreTrainedTokenizer +from transformers.utils import logging + +from optimum.intel.utils import is_ipex_available + + +if is_ipex_available(): + from ..ipex.modeling_base import ( + IPEXModel, + IPEXModelForAudioClassification, + IPEXModelForCausalLM, + IPEXModelForImageClassification, + IPEXModelForMaskedLM, + IPEXModelForQuestionAnswering, + IPEXModelForSequenceClassification, + IPEXModelForTokenClassification, + ) + + IPEX_SUPPORTED_TASKS = { + "text-generation": { + "impl": TextGenerationPipeline, + "class": (IPEXModelForCausalLM,), + "default": "gpt2", + "type": "text", + }, + "fill-mask": { + "impl": FillMaskPipeline, + "class": (IPEXModelForMaskedLM,), + "default": "bert-base-cased", + "type": "text", + }, + "question-answering": { + "impl": QuestionAnsweringPipeline, + "class": (IPEXModelForQuestionAnswering,), + "default": "distilbert-base-cased-distilled-squad", + "type": "text", + }, + "image-classification": { + "impl": ImageClassificationPipeline, + "class": (IPEXModelForImageClassification,), + "default": "google/vit-base-patch16-224", + "type": "image", + }, + "text-classification": { + "impl": TextClassificationPipeline, + "class": (IPEXModelForSequenceClassification,), + "default": "distilbert-base-uncased-finetuned-sst-2-english", + "type": "text", + }, + "token-classification": { + "impl": TokenClassificationPipeline, + "class": (IPEXModelForTokenClassification,), + "default": "dbmdz/bert-large-cased-finetuned-conll03-english", + "type": "text", + }, + "audio-classification": { + "impl": AudioClassificationPipeline, + "class": (IPEXModelForAudioClassification,), + "default": "superb/hubert-base-superb-ks", + "type": "audio", + }, + } +else: + IPEX_SUPPORTED_TASKS = {} + + +def load_ipex_model( + model, + targeted_task, + SUPPORTED_TASKS, + model_kwargs: Optional[Dict[str, Any]] = None, + hub_kwargs: Optional[Dict[str, Any]] = None, +): + if model_kwargs is None: + model_kwargs = {} + + ipex_model_class = SUPPORTED_TASKS[targeted_task]["class"][0] + + if model is None: + model_id = SUPPORTED_TASKS[targeted_task]["default"] + model = ipex_model_class.from_pretrained(model_id, export=True, **model_kwargs, **hub_kwargs) + elif isinstance(model, str): + model_id = model + try: + config = AutoConfig.from_pretrained(model) + export = not getattr(config, "torchscript", False) + except RuntimeError: + logger.warning("We will use IPEXModel with export=True to export the model") + export = True + model = ipex_model_class.from_pretrained(model, export=export, **model_kwargs, **hub_kwargs) + elif isinstance(model, IPEXModel): + model_id = getattr(model.config, "name_or_path", None) + else: + raise ValueError( + f"""Model {model} is not supported. Please provide a valid model name or path or a IPEXModel. + You can also provide non model then a default one will be used""" + ) + + return model, model_id + + +MAPPING_LOADING_FUNC = { + "ipex": load_ipex_model, +} + + +if TYPE_CHECKING: + from transformers.modeling_utils import PreTrainedModel + from transformers.tokenization_utils_fast import PreTrainedTokenizerFast + + +logger = logging.get_logger(__name__) + + +def pipeline( + task: str = None, + model: Optional[Union[str, "PreTrainedModel"]] = None, + tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, + feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, + use_fast: bool = True, + token: Optional[Union[str, bool]] = None, + accelerator: Optional[str] = "ort", + revision: Optional[str] = None, + trust_remote_code: Optional[bool] = None, + torch_dtype: Optional[Union[str, torch.dtype]] = None, + commit_hash: Optional[str] = None, + **model_kwargs, +) -> Pipeline: + """ + Utility factory method to build a [`Pipeline`]. + + Pipelines are made of: + + - A [tokenizer](tokenizer) in charge of mapping raw textual input to token. + - A [model](model) to make predictions from the inputs. + - Some (optional) post processing for enhancing model's output. + + Args: + task (`str`): + The task defining which pipeline will be returned. Currently accepted tasks are: + + - `"text-generation"`: will return a [`TextGenerationPipeline`]:. + + model (`str` or [`PreTrainedModel`], *optional*): + The model that will be used by the pipeline to make predictions. This can be a model identifier or an + actual instance of a pretrained model inheriting from [`PreTrainedModel`] (for PyTorch). + + If not provided, the default for the `task` will be loaded. + tokenizer (`str` or [`PreTrainedTokenizer`], *optional*): + The tokenizer that will be used by the pipeline to encode data for the model. This can be a model + identifier or an actual pretrained tokenizer inheriting from [`PreTrainedTokenizer`]. + + If not provided, the default tokenizer for the given `model` will be loaded (if it is a string). If `model` + is not specified or not a string, then the default tokenizer for `config` is loaded (if it is a string). + However, if `config` is also not given or not a string, then the default tokenizer for the given `task` + will be loaded. + accelerator (`str`, *optional*, defaults to `"ipex"`): + The optimization backends, choose from ["ipex", "inc", "openvino"]. + use_fast (`bool`, *optional*, defaults to `True`): + Whether or not to use a Fast tokenizer if possible (a [`PreTrainedTokenizerFast`]). + torch_dtype (`str` or `torch.dtype`, *optional*): + Sent directly as `model_kwargs` (just a simpler shortcut) to use the available precision for this model + (`torch.float16`, `torch.bfloat16`, ... or `"auto"`). + model_kwargs (`Dict[str, Any]`, *optional*): + Additional dictionary of keyword arguments passed along to the model's `from_pretrained(..., + **model_kwargs)` function. + + Returns: + [`Pipeline`]: A suitable pipeline for the task. + + Examples: + + ```python + >>> import torch + >>> from optimum.intel.pipelines import pipeline + + >>> pipe = pipeline('text-generation', 'gpt2', torch_dtype=torch.bfloat16) + >>> pipe("Describe a real-world application of AI in sustainable energy.") + ```""" + if model_kwargs is None: + model_kwargs = {} + + if task is None and model is None: + raise RuntimeError( + "Impossible to instantiate a pipeline without either a task or a model " + "being specified. " + "Please provide a task class or a model" + ) + + if model is None and tokenizer is not None: + raise RuntimeError( + "Impossible to instantiate a pipeline with tokenizer specified but not the model as the provided tokenizer" + " may not be compatible with the default model. Please provide a PreTrainedModel class or a" + " path/identifier to a pretrained model when providing tokenizer." + ) + + if accelerator not in MAPPING_LOADING_FUNC: + raise ValueError( + f'Accelerator {accelerator} is not supported. Supported accelerator is {", ".join(MAPPING_LOADING_FUNC)}.' + ) + + if accelerator == "ipex": + if task not in list(IPEX_SUPPORTED_TASKS.keys()): + raise ValueError( + f"Task {task} is not supported for the IPEX pipeline. Supported tasks are { list(IPEX_SUPPORTED_TASKS.keys())}" + ) + + supported_tasks = IPEX_SUPPORTED_TASKS if accelerator == "ipex" else None + + no_feature_extractor_tasks = set() + no_tokenizer_tasks = set() + for _task, values in supported_tasks.items(): + if values["type"] == "text": + no_feature_extractor_tasks.add(_task) + elif values["type"] in {"image", "video"}: + no_tokenizer_tasks.add(_task) + elif values["type"] in {"audio"}: + no_tokenizer_tasks.add(_task) + elif values["type"] not in ["multimodal", "audio", "video"]: + raise ValueError(f"SUPPORTED_TASK {_task} contains invalid type {values['type']}") + + load_tokenizer = task not in no_tokenizer_tasks + load_feature_extractor = task not in no_feature_extractor_tasks + + hub_kwargs = { + "revision": revision, + "token": token, + "trust_remote_code": trust_remote_code, + "_commit_hash": commit_hash, + } + + if isinstance(model, Path): + model = str(model) + + if torch_dtype is not None: + if "torch_dtype" in model_kwargs: + raise ValueError( + 'You cannot use both `pipeline(... torch_dtype=..., model_kwargs={"torch_dtype":...})` as those' + " arguments might conflict, use only one.)" + ) + model_kwargs["torch_dtype"] = torch_dtype + + # Load the correct model if possible + # Infer the framework from the model if not already defined + model, model_id = MAPPING_LOADING_FUNC[accelerator](model, task, supported_tasks, model_kwargs, hub_kwargs) + + if load_tokenizer and tokenizer is None: + tokenizer = AutoTokenizer.from_pretrained(model_id, **hub_kwargs, **model_kwargs) + if load_feature_extractor and feature_extractor is None: + feature_extractor = AutoFeatureExtractor.from_pretrained(model_id, **hub_kwargs, **model_kwargs) + + return transformers_pipeline( + task, + model=model, + tokenizer=tokenizer, + feature_extractor=feature_extractor, + use_fast=use_fast, + torch_dtype=torch_dtype, + ) diff --git a/optimum/intel/utils/import_utils.py b/optimum/intel/utils/import_utils.py index fcdf932a28..ac6306923d 100644 --- a/optimum/intel/utils/import_utils.py +++ b/optimum/intel/utils/import_utils.py @@ -33,6 +33,7 @@ STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt} _optimum_version = importlib_metadata.version("optimum") +_optimum_intel_version = importlib_metadata.version("optimum-intel") _transformers_available = importlib.util.find_spec("transformers") is not None _transformers_version = "N/A" @@ -61,14 +62,14 @@ _neural_compressor_available = False -_intel_extension_for_transformers_available = importlib.util.find_spec("intel_extension_for_transformers") is not None -_intel_extension_for_transformers_version = "N/A" -if _intel_extension_for_transformers_available: +_itrex_available = importlib.util.find_spec("intel_extension_for_transformers") is not None +_itrex_version = "N/A" +if _itrex_available: try: - _intel_extension_for_transformers_version = importlib_metadata.version("intel_extension_for_transformers") + _itrex_version = importlib_metadata.version("intel_extension_for_transformers") logging.warn("`transformers` version >= 4.31 is requirements by intel-extension-for-transformers.") except importlib_metadata.PackageNotFoundError: - _intel_extension_for_transformers_available = False + _itrex_available = False _ipex_available = importlib.util.find_spec("intel_extension_for_pytorch") is not None @@ -158,8 +159,8 @@ def is_neural_compressor_available(): return _neural_compressor_available -def is_intel_extension_for_transformers_available(): - return _intel_extension_for_transformers_available +def is_itrex_available(): + return _itrex_available def is_ipex_available(): @@ -314,13 +315,13 @@ def is_neural_compressor_version(operation: str, version: str): return compare_versions(parse(_neural_compressor_version), operation, version) -def is_intel_extension_for_transformers_version(operation: str, version: str): +def is_itrex_version(operation: str, version: str): """ Compare the current intel_extension_for_transformers version to a given reference with an operation. """ - if not _intel_extension_for_transformers_available: + if not _itrex_available: return False - return compare_versions(parse(_intel_extension_for_transformers_version), operation, version) + return compare_versions(parse(_itrex_version), operation, version) def is_openvino_version(operation: str, version: str): @@ -396,7 +397,7 @@ def is_timm_version(operation: str, version: str): `pip install neural-compressor`. Please note that you may need to restart your runtime after installation. """ -INTEL_EXTENSION_FOR_TRANSFORMERS_IMPORT_ERROR = """ +ITREX_IMPORT_ERROR = """ {0} requires the intel-extension-for-transformers library but it was not found in your environment. You can install it with pip: `pip install intel-extension-for-transformers` and `pip install peft`. Please note that you may need to restart your runtime after installation. """ @@ -418,10 +419,7 @@ def is_timm_version(operation: str, version: str): ("nncf", (is_nncf_available, NNCF_IMPORT_ERROR)), ("openvino", (is_openvino_available, OPENVINO_IMPORT_ERROR)), ("neural_compressor", (is_neural_compressor_available, NEURAL_COMPRESSOR_IMPORT_ERROR)), - ( - "intel_extension_for_transformers", - (is_intel_extension_for_transformers_available, INTEL_EXTENSION_FOR_TRANSFORMERS_IMPORT_ERROR), - ), + ("itrex", (is_itrex_available, ITREX_IMPORT_ERROR)), ("accelerate", (is_accelerate_available, ACCELERATE_IMPORT_ERROR)), ] ) diff --git a/optimum/intel/utils/modeling_utils.py b/optimum/intel/utils/modeling_utils.py index 99ad42aafa..a2cd728354 100644 --- a/optimum/intel/utils/modeling_utils.py +++ b/optimum/intel/utils/modeling_utils.py @@ -169,3 +169,16 @@ def get_model_device(model: torch.nn.Module) -> torch.device: # The model had no parameters at all, doesn't matter which device to choose device = torch.device("cpu") return device + + +def recursive_to_device(value, device): + """ + Recursivley move the tensor element in `value` to `device` + """ + if isinstance(value, (tuple, list)): + return type(value)(recursive_to_device(v, device) for v in value) + elif isinstance(value, dict): + return {k: recursive_to_device(v, device) for k, v in value.items()} + elif isinstance(value, torch.Tensor): + return value.to(device) + return value diff --git a/setup.py b/setup.py index 8b94c2e8ad..fb5807b4f8 100644 --- a/setup.py +++ b/setup.py @@ -28,7 +28,7 @@ INSTALL_REQUIRE = [ "torch>=1.11", - "transformers>=4.36.0,<4.40.0", + "transformers>=4.36.0,<4.41.0", "optimum~=1.19", "datasets>=1.4.0", "sentencepiece", @@ -38,7 +38,7 @@ TESTS_REQUIRE = [ "accelerate", - "pytest", + "pytest<8.2", "parameterized", "Pillow", "evaluate", @@ -60,8 +60,8 @@ EXTRAS_REQUIRE = { "neural-compressor": ["neural-compressor>=2.2.0", "onnxruntime<1.15.0", "accelerate"], - "openvino": ["openvino>=2023.3", "nncf>=2.8.1", "openvino-tokenizers[transformers]"], - "nncf": ["nncf>=2.8.1"], + "openvino": ["openvino>=2023.3", "nncf>=2.10.0", "openvino-tokenizers[transformers]"], + "nncf": ["nncf>=2.10.0"], "ipex": ["intel-extension-for-pytorch", "transformers>=4.36.0,<4.39.0"], "diffusers": ["diffusers"], "quality": QUALITY_REQUIRE, diff --git a/tests/generation/test_modeling.py b/tests/generation/test_modeling.py index 22a9cac661..20381aa92b 100644 --- a/tests/generation/test_modeling.py +++ b/tests/generation/test_modeling.py @@ -58,7 +58,7 @@ class ModelingIntegrationTest(unittest.TestCase): "mistral", "llama", "llama2", - # "gpt_bigcode", + "gpt_bigcode", ) GENERATION_LENGTH = 100 diff --git a/tests/ipex/test_inference.py b/tests/ipex/test_inference.py index e120514506..1a452fe408 100644 --- a/tests/ipex/test_inference.py +++ b/tests/ipex/test_inference.py @@ -16,8 +16,6 @@ import torch from parameterized import parameterized - -# TODO : add more tasks from transformers import ( AutoModelForCausalLM, AutoModelForQuestionAnswering, @@ -26,60 +24,51 @@ AutoTokenizer, pipeline, ) +from utils_tests import MODEL_NAMES from optimum.intel import inference_mode as ipex_inference_mode from optimum.intel.ipex.modeling_base import IPEXModel -MODEL_NAMES = { - "bert": "hf-internal-testing/tiny-random-bert", - "bloom": "hf-internal-testing/tiny-random-BloomModel", - "distilbert": "hf-internal-testing/tiny-random-distilbert", - "roberta": "hf-internal-testing/tiny-random-roberta", - "gptj": "hf-internal-testing/tiny-random-gptj", - "gpt2": "hf-internal-testing/tiny-random-gpt2", - "gpt_neo": "hf-internal-testing/tiny-random-GPTNeoModel", - "gpt_neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM", - "gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel", - "llama": "fxmarty/tiny-llama-fast-tokenizer", - "llama2": "Jiqing/tiny_random_llama2", - "opt": "hf-internal-testing/tiny-random-OPTModel", - "mpt": "hf-internal-testing/tiny-random-MptForCausalLM", -} - _CLASSIFICATION_TASK_TO_AUTOMODELS = { "text-classification": AutoModelForSequenceClassification, "token-classification": AutoModelForTokenClassification, } -class IPEXIntegrationTest(unittest.TestCase): - CLASSIFICATION_SUPPORTED_ARCHITECTURES = ( +class IPEXClassificationTest(unittest.TestCase): + SUPPORTED_ARCHITECTURES = ( "bert", "distilbert", "roberta", ) - TEXT_GENERATION_SUPPORTED_ARCHITECTURES = ( - "bloom", - "gptj", - "gpt2", - "gpt_neo", - # "gpt_bigcode", - "llama", - "llama2", - "opt", - "mpt", - ) + @parameterized.expand(SUPPORTED_ARCHITECTURES) + def test_pipeline_inference(self, model_arch): + model_id = MODEL_NAMES[model_arch] + tokenizer = AutoTokenizer.from_pretrained(model_id) + inputs = "This is a sample input" + for task, auto_model_class in _CLASSIFICATION_TASK_TO_AUTOMODELS.items(): + model = auto_model_class.from_pretrained(model_id, torch_dtype=torch.float32) + pipe = pipeline(task, model=model, tokenizer=tokenizer) - QA_SUPPORTED_ARCHITECTURES = ( + with torch.inference_mode(): + outputs = pipe(inputs) + with ipex_inference_mode(pipe, dtype=model.config.torch_dtype, verbose=False, jit=True) as ipex_pipe: + outputs_ipex = ipex_pipe(inputs) + self.assertTrue(isinstance(ipex_pipe.model._optimized.model, torch.jit.RecursiveScriptModule)) + self.assertEqual(outputs[0]["score"], outputs_ipex[0]["score"]) + + +class IPEXQuestionAnsweringTest(unittest.TestCase): + SUPPORTED_ARCHITECTURES = ( "bert", "distilbert", "roberta", ) - @parameterized.expand(QA_SUPPORTED_ARCHITECTURES) - def test_question_answering_pipeline_inference(self, model_arch): + @parameterized.expand(SUPPORTED_ARCHITECTURES) + def test_pipeline_inference(self, model_arch): model_id = MODEL_NAMES[model_arch] tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForQuestionAnswering.from_pretrained(model_id, torch_dtype=torch.float32) @@ -95,24 +84,22 @@ def test_question_answering_pipeline_inference(self, model_arch): self.assertEqual(outputs["start"], outputs_ipex["start"]) self.assertEqual(outputs["end"], outputs_ipex["end"]) - @parameterized.expand(CLASSIFICATION_SUPPORTED_ARCHITECTURES) - def test_classification_pipeline_inference(self, model_arch): - model_id = MODEL_NAMES[model_arch] - tokenizer = AutoTokenizer.from_pretrained(model_id) - inputs = "This is a sample input" - for task, auto_model_class in _CLASSIFICATION_TASK_TO_AUTOMODELS.items(): - model = auto_model_class.from_pretrained(model_id, torch_dtype=torch.float32) - pipe = pipeline(task, model=model, tokenizer=tokenizer) - with torch.inference_mode(): - outputs = pipe(inputs) - with ipex_inference_mode(pipe, dtype=model.config.torch_dtype, verbose=False, jit=True) as ipex_pipe: - outputs_ipex = ipex_pipe(inputs) - self.assertTrue(isinstance(ipex_pipe.model._optimized.model, torch.jit.RecursiveScriptModule)) - self.assertEqual(outputs[0]["score"], outputs_ipex[0]["score"]) +class IPEXTextGenerationTest(unittest.TestCase): + SUPPORTED_ARCHITECTURES = ( + "bloom", + "gptj", + "gpt2", + "gpt_neo", + "gpt_bigcode", + "llama", + "llama2", + "opt", + "mpt", + ) - @parameterized.expand(TEXT_GENERATION_SUPPORTED_ARCHITECTURES) - def test_text_generation_pipeline_inference(self, model_arch): + @parameterized.expand(SUPPORTED_ARCHITECTURES) + def test_pipeline_inference(self, model_arch): model_id = MODEL_NAMES[model_arch] model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32, return_dict=False) model = model.eval() diff --git a/tests/ipex/test_modeling.py b/tests/ipex/test_modeling.py index 94a5ca9e16..2a2f18f6f8 100644 --- a/tests/ipex/test_modeling.py +++ b/tests/ipex/test_modeling.py @@ -45,53 +45,11 @@ ) from optimum.intel.utils.import_utils import is_ipex_version from optimum.utils.testing_utils import grid_parameters +from utils_tests import MODEL_NAMES SEED = 42 -MODEL_NAMES = { - "albert": "hf-internal-testing/tiny-random-albert", - "beit": "hf-internal-testing/tiny-random-BeitForImageClassification", - "bert": "hf-internal-testing/tiny-random-bert", - "bart": "hf-internal-testing/tiny-random-bart", - "blenderbot-small": "hf-internal-testing/tiny-random-BlenderbotModel", - "blenderbot": "hf-internal-testing/tiny-random-BlenderbotModel", - "bloom": "hf-internal-testing/tiny-random-BloomModel", - "convbert": "hf-internal-testing/tiny-random-ConvBertForSequenceClassification", - "codegen": "hf-internal-testing/tiny-random-CodeGenForCausalLM", - "convnext": "hf-internal-testing/tiny-random-convnext", - "distilbert": "hf-internal-testing/tiny-random-distilbert", - "electra": "hf-internal-testing/tiny-random-electra", - "flaubert": "hf-internal-testing/tiny-random-flaubert", - "gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel", - "gpt2": "hf-internal-testing/tiny-random-gpt2", - "gpt_neo": "hf-internal-testing/tiny-random-GPTNeoModel", - "gpt_neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM", - "gptj": "hf-internal-testing/tiny-random-GPTJModel", - "levit": "hf-internal-testing/tiny-random-LevitModel", - "llama": "fxmarty/tiny-llama-fast-tokenizer", - "llama2": "Jiqing/tiny_random_llama2", - "marian": "sshleifer/tiny-marian-en-de", - "mbart": "hf-internal-testing/tiny-random-mbart", - "mistral": "echarlaix/tiny-random-mistral", - "mobilenet_v1": "google/mobilenet_v1_0.75_192", - "mobilenet_v2": "hf-internal-testing/tiny-random-MobileNetV2Model", - "mobilevit": "hf-internal-testing/tiny-random-mobilevit", - "mpt": "hf-internal-testing/tiny-random-MptForCausalLM", - "mt5": "stas/mt5-tiny-random", - "opt": "hf-internal-testing/tiny-random-OPTModel", - "phi": "echarlaix/tiny-random-PhiForCausalLM", - "resnet": "hf-internal-testing/tiny-random-resnet", - "roberta": "hf-internal-testing/tiny-random-roberta", - "roformer": "hf-internal-testing/tiny-random-roformer", - "squeezebert": "hf-internal-testing/tiny-random-squeezebert", - "t5": "hf-internal-testing/tiny-random-t5", - "unispeech": "hf-internal-testing/tiny-random-unispeech", - "vit": "hf-internal-testing/tiny-random-vit", - "wav2vec2": "anton-l/wav2vec2-random-tiny-classifier", - "xlm": "hf-internal-testing/tiny-random-xlm", -} - class Timer(object): def __enter__(self): diff --git a/tests/ipex/test_pipelines.py b/tests/ipex/test_pipelines.py new file mode 100644 index 0000000000..c4ae471a0f --- /dev/null +++ b/tests/ipex/test_pipelines.py @@ -0,0 +1,222 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from tempfile import TemporaryDirectory + +import numpy as np +import torch +from parameterized import parameterized +from transformers import AutoTokenizer +from transformers.pipelines import pipeline as transformers_pipeline +from utils_tests import MODEL_NAMES + +from optimum.intel.ipex.modeling_base import ( + IPEXModelForAudioClassification, + IPEXModelForCausalLM, + IPEXModelForImageClassification, + IPEXModelForMaskedLM, + IPEXModelForQuestionAnswering, + IPEXModelForSequenceClassification, + IPEXModelForTokenClassification, +) +from optimum.intel.pipelines import pipeline as ipex_pipeline + + +class PipelinesIntegrationTest(unittest.TestCase): + COMMON_SUPPORTED_ARCHITECTURES = ( + "albert", + "bert", + "distilbert", + "electra", + "flaubert", + "roberta", + "roformer", + "squeezebert", + "xlm", + ) + TEXT_GENERATION_SUPPORTED_ARCHITECTURES = ( + "bart", + "gpt_bigcode", + "blenderbot", + "blenderbot-small", + "bloom", + "codegen", + "gpt2", + "gpt_neo", + "gpt_neox", + "llama", + "llama2", + "mistral", + "mpt", + "opt", + ) + QUESTION_ANSWERING_SUPPORTED_ARCHITECTURES = ( + "bert", + "distilbert", + "roberta", + ) + AUDIO_CLASSIFICATION_SUPPORTED_ARCHITECTURES = ( + "unispeech", + "wav2vec2", + ) + IMAGE_CLASSIFICATION_SUPPORTED_ARCHITECTURES = ( + "beit", + "mobilenet_v1", + "mobilenet_v2", + "mobilevit", + "resnet", + "vit", + ) + + @parameterized.expand(COMMON_SUPPORTED_ARCHITECTURES) + def test_token_classification_pipeline_inference(self, model_arch): + model_id = MODEL_NAMES[model_arch] + transformers_generator = transformers_pipeline("token-classification", model_id) + ipex_generator = ipex_pipeline("token-classification", model_id, accelerator="ipex") + inputs = "Hello I'm Omar and I live in Zürich." + with torch.inference_mode(): + transformers_output = transformers_generator(inputs) + with torch.inference_mode(): + ipex_output = ipex_generator(inputs) + self.assertEqual(len(transformers_output), len(ipex_output)) + self.assertTrue(isinstance(ipex_generator.model, IPEXModelForTokenClassification)) + self.assertTrue(isinstance(ipex_generator.model.model, torch.jit.RecursiveScriptModule)) + for i in range(len(transformers_output)): + self.assertAlmostEqual(transformers_output[i]["score"], ipex_output[i]["score"], delta=1e-4) + + @parameterized.expand(COMMON_SUPPORTED_ARCHITECTURES) + def test_sequence_classification_pipeline_inference(self, model_arch): + model_id = MODEL_NAMES[model_arch] + transformers_generator = transformers_pipeline("text-classification", model_id) + ipex_generator = ipex_pipeline("text-classification", model_id, accelerator="ipex") + inputs = "This restaurant is awesome" + with torch.inference_mode(): + transformers_output = transformers_generator(inputs) + with torch.inference_mode(): + ipex_output = ipex_generator(inputs) + self.assertTrue(isinstance(ipex_generator.model, IPEXModelForSequenceClassification)) + self.assertTrue(isinstance(ipex_generator.model.model, torch.jit.RecursiveScriptModule)) + self.assertEqual(transformers_output[0]["label"], ipex_output[0]["label"]) + self.assertAlmostEqual(transformers_output[0]["score"], ipex_output[0]["score"], delta=1e-4) + + @parameterized.expand(COMMON_SUPPORTED_ARCHITECTURES) + def test_fill_mask_pipeline_inference(self, model_arch): + model_id = MODEL_NAMES[model_arch] + inputs = "The Milky Way is a galaxy." + transformers_generator = transformers_pipeline("fill-mask", model_id) + ipex_generator = ipex_pipeline("fill-mask", model_id, accelerator="ipex") + mask_token = transformers_generator.tokenizer.mask_token + inputs = inputs.replace("", mask_token) + with torch.inference_mode(): + transformers_output = transformers_generator(inputs) + with torch.inference_mode(): + ipex_output = ipex_generator(inputs) + self.assertEqual(len(transformers_output), len(ipex_output)) + self.assertTrue(isinstance(ipex_generator.model, IPEXModelForMaskedLM)) + self.assertTrue(isinstance(ipex_generator.model.model, torch.jit.RecursiveScriptModule)) + for i in range(len(transformers_output)): + self.assertEqual(transformers_output[i]["token"], ipex_output[i]["token"]) + self.assertAlmostEqual(transformers_output[i]["score"], ipex_output[i]["score"], delta=1e-4) + + @parameterized.expand(TEXT_GENERATION_SUPPORTED_ARCHITECTURES) + def test_text_generation_pipeline_inference(self, model_arch): + model_id = MODEL_NAMES[model_arch] + transformers_generator = transformers_pipeline("text-generation", model_id) + ipex_generator = ipex_pipeline("text-generation", model_id, accelerator="ipex") + inputs = "Describe a real-world application of AI." + with torch.inference_mode(): + transformers_output = transformers_generator(inputs) + with torch.inference_mode(): + ipex_output = ipex_generator(inputs) + self.assertTrue(isinstance(ipex_generator.model, IPEXModelForCausalLM)) + self.assertTrue(isinstance(ipex_generator.model.model, torch.jit.RecursiveScriptModule)) + self.assertEqual(transformers_output[0]["generated_text"], ipex_output[0]["generated_text"]) + + @parameterized.expand(QUESTION_ANSWERING_SUPPORTED_ARCHITECTURES) + def test_question_answering_pipeline_inference(self, model_arch): + model_id = MODEL_NAMES[model_arch] + transformers_generator = transformers_pipeline("question-answering", model_id) + ipex_generator = ipex_pipeline("question-answering", model_id, accelerator="ipex") + question = "How many programming languages does BLOOM support?" + context = "BLOOM has 176 billion parameters and can generate text in 46 languages natural languages and 13 programming languages." + with torch.inference_mode(): + transformers_output = transformers_generator(question=question, context=context) + with torch.inference_mode(): + ipex_output = ipex_generator(question=question, context=context) + self.assertTrue(isinstance(ipex_generator.model, IPEXModelForQuestionAnswering)) + self.assertTrue(isinstance(ipex_generator.model.model, torch.jit.RecursiveScriptModule)) + self.assertAlmostEqual(transformers_output["score"], ipex_output["score"], delta=1e-4) + self.assertEqual(transformers_output["start"], ipex_output["start"]) + self.assertEqual(transformers_output["end"], ipex_output["end"]) + + @parameterized.expand(AUDIO_CLASSIFICATION_SUPPORTED_ARCHITECTURES) + def test_audio_classification_pipeline_inference(self, model_arch): + model_id = MODEL_NAMES[model_arch] + transformers_generator = transformers_pipeline("audio-classification", model_id) + ipex_generator = ipex_pipeline("audio-classification", model_id, accelerator="ipex") + inputs = [np.random.random(16000)] + with torch.inference_mode(): + transformers_output = transformers_generator(inputs) + with torch.inference_mode(): + ipex_output = ipex_generator(inputs) + self.assertTrue(isinstance(ipex_generator.model, IPEXModelForAudioClassification)) + self.assertTrue(isinstance(ipex_generator.model.model, torch.jit.RecursiveScriptModule)) + self.assertAlmostEqual(transformers_output[0][0]["score"], ipex_output[0][0]["score"], delta=1e-2) + self.assertAlmostEqual(transformers_output[0][1]["score"], ipex_output[0][1]["score"], delta=1e-2) + + @parameterized.expand(IMAGE_CLASSIFICATION_SUPPORTED_ARCHITECTURES) + def test_image_classification_pipeline_inference(self, model_arch): + model_id = MODEL_NAMES[model_arch] + transformers_generator = transformers_pipeline("image-classification", model_id) + ipex_generator = ipex_pipeline("image-classification", model_id, accelerator="ipex") + inputs = "http://images.cocodataset.org/val2017/000000039769.jpg" + with torch.inference_mode(): + transformers_output = transformers_generator(inputs) + with torch.inference_mode(): + ipex_output = ipex_generator(inputs) + self.assertEqual(len(transformers_output), len(ipex_output)) + self.assertTrue(isinstance(ipex_generator.model, IPEXModelForImageClassification)) + self.assertTrue(isinstance(ipex_generator.model.model, torch.jit.RecursiveScriptModule)) + for i in range(len(transformers_output)): + self.assertEqual(transformers_output[i]["label"], ipex_output[i]["label"]) + self.assertAlmostEqual(transformers_output[i]["score"], ipex_output[i]["score"], delta=1e-4) + + @parameterized.expand(COMMON_SUPPORTED_ARCHITECTURES) + def test_pipeline_load_from_ipex_model(self, model_arch): + model_id = MODEL_NAMES[model_arch] + model = IPEXModelForSequenceClassification.from_pretrained(model_id, export=True) + tokenizer = AutoTokenizer.from_pretrained(model_id) + ipex_generator = ipex_pipeline("text-classification", model, tokenizer=tokenizer, accelerator="ipex") + inputs = "This restaurant is awesome" + with torch.inference_mode(): + ipex_output = ipex_generator(inputs) + self.assertTrue(isinstance(ipex_generator.model, IPEXModelForSequenceClassification)) + self.assertTrue(isinstance(ipex_generator.model.model, torch.jit.RecursiveScriptModule)) + self.assertGreaterEqual(ipex_output[0]["score"], 0.0) + + @parameterized.expand(COMMON_SUPPORTED_ARCHITECTURES) + def test_pipeline_load_from_jit_model(self, model_arch): + model_id = MODEL_NAMES[model_arch] + model = IPEXModelForSequenceClassification.from_pretrained(model_id, export=True) + save_dir = TemporaryDirectory().name + model.save_pretrained(save_dir) + tokenizer = AutoTokenizer.from_pretrained(model_id) + ipex_generator = ipex_pipeline("text-classification", save_dir, tokenizer=tokenizer, accelerator="ipex") + inputs = "This restaurant is awesome" + with torch.inference_mode(): + ipex_output = ipex_generator(inputs) + self.assertTrue(isinstance(ipex_generator.model, IPEXModelForSequenceClassification)) + self.assertTrue(isinstance(ipex_generator.model.model, torch.jit.RecursiveScriptModule)) + self.assertGreaterEqual(ipex_output[0]["score"], 0.0) diff --git a/tests/ipex/utils_tests.py b/tests/ipex/utils_tests.py new file mode 100644 index 0000000000..a14f0bf7ca --- /dev/null +++ b/tests/ipex/utils_tests.py @@ -0,0 +1,57 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +MODEL_NAMES = { + "albert": "hf-internal-testing/tiny-random-albert", + "beit": "hf-internal-testing/tiny-random-BeitForImageClassification", + "bert": "hf-internal-testing/tiny-random-bert", + "bart": "hf-internal-testing/tiny-random-bart", + "blenderbot-small": "hf-internal-testing/tiny-random-BlenderbotModel", + "blenderbot": "hf-internal-testing/tiny-random-BlenderbotModel", + "bloom": "hf-internal-testing/tiny-random-BloomModel", + "convbert": "hf-internal-testing/tiny-random-ConvBertForSequenceClassification", + "codegen": "hf-internal-testing/tiny-random-CodeGenForCausalLM", + "convnext": "hf-internal-testing/tiny-random-convnext", + "distilbert": "hf-internal-testing/tiny-random-distilbert", + "electra": "hf-internal-testing/tiny-random-electra", + "flaubert": "hf-internal-testing/tiny-random-flaubert", + "gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel", + "gpt2": "hf-internal-testing/tiny-random-gpt2", + "gpt_neo": "hf-internal-testing/tiny-random-GPTNeoModel", + "gpt_neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM", + "gptj": "hf-internal-testing/tiny-random-GPTJModel", + "levit": "hf-internal-testing/tiny-random-LevitModel", + "llama": "fxmarty/tiny-llama-fast-tokenizer", + "llama2": "Jiqing/tiny_random_llama2", + "marian": "sshleifer/tiny-marian-en-de", + "mbart": "hf-internal-testing/tiny-random-mbart", + "mistral": "echarlaix/tiny-random-mistral", + "mobilenet_v1": "google/mobilenet_v1_0.75_192", + "mobilenet_v2": "hf-internal-testing/tiny-random-MobileNetV2Model", + "mobilevit": "hf-internal-testing/tiny-random-mobilevit", + "mpt": "hf-internal-testing/tiny-random-MptForCausalLM", + "mt5": "stas/mt5-tiny-random", + "opt": "hf-internal-testing/tiny-random-OPTModel", + "phi": "echarlaix/tiny-random-PhiForCausalLM", + "resnet": "hf-internal-testing/tiny-random-resnet", + "roberta": "hf-internal-testing/tiny-random-roberta", + "roformer": "hf-internal-testing/tiny-random-roformer", + "squeezebert": "hf-internal-testing/tiny-random-squeezebert", + "t5": "hf-internal-testing/tiny-random-t5", + "unispeech": "hf-internal-testing/tiny-random-unispeech", + "vit": "hf-internal-testing/tiny-random-vit", + "wav2vec2": "anton-l/wav2vec2-random-tiny-classifier", + "xlm": "hf-internal-testing/tiny-random-xlm", +} diff --git a/tests/neural_compressor/test_modeling.py b/tests/neural_compressor/test_modeling.py index e6ce4763f2..0c3e60969b 100644 --- a/tests/neural_compressor/test_modeling.py +++ b/tests/neural_compressor/test_modeling.py @@ -16,10 +16,12 @@ import os import tempfile import unittest +from pathlib import Path import torch from parameterized import parameterized from transformers import AutoTokenizer, pipeline, set_seed +from transformers.utils import SAFE_WEIGHTS_NAME from optimum.exporters import TasksManager from optimum.intel import ( # noqa @@ -37,7 +39,8 @@ INCStableDiffusionPipeline, INCTrainer, ) -from optimum.intel.neural_compressor.utils import _HEAD_TO_AUTOMODELS, WEIGHTS_NAME +from optimum.intel.neural_compressor.utils import _HEAD_TO_AUTOMODELS, QUANTIZATION_CONFIG_NAME, WEIGHTS_NAME +from optimum.intel.utils.import_utils import is_itrex_available os.environ["CUDA_VISIBLE_DEVICES"] = "" @@ -52,7 +55,7 @@ MODEL_NAMES_TO_TASK = ( - ("hf-internal-testing/tiny-random-gpt2", "text-generation"), + ("hf-internal-testing/tiny-random-GPT2LMHeadModel", "text-generation"), ("hf-internal-testing/tiny-random-BertForMaskedLM", "fill-mask"), ("hf-internal-testing/tiny-random-DistilBertForSequenceClassification", "text-classification"), ("hf-internal-testing/tiny-random-DebertaV2Model", "feature-extraction"), @@ -86,7 +89,7 @@ def test_compare_to_transformers(self, model_id, task): outputs = inc_model(**model_inputs) with tempfile.TemporaryDirectory() as tmpdirname: inc_model.save_pretrained(tmpdirname) - loaded_model = model_class.from_pretrained(tmpdirname, file_name=WEIGHTS_NAME) + loaded_model = model_class.from_pretrained(tmpdirname) outputs_loaded = loaded_model(**model_inputs) if task == "feature-extraction": @@ -143,3 +146,57 @@ def test_compare_with_and_without_past_key_values(self): self.assertEqual(outputs_with_pkv.shape[1], self.GENERATION_LENGTH) self.assertEqual(outputs_without_pkv.shape[1], self.GENERATION_LENGTH) self.assertTrue(torch.equal(outputs_with_pkv, outputs_without_pkv)) + + @unittest.skipIf(not is_itrex_available(), reason="ITREX not available") + def test_saving_loading_woq_itrex_model(self): + model_name = "echarlaix/tiny-random-PhiForCausalLM" + subfolder = "itrex" + model = INCModelForCausalLM.from_pretrained(model_name, revision="itrex", subfolder=subfolder) + tokenizer = AutoTokenizer.from_pretrained(model_name, revision="itrex") + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + tokens = tokenizer("This is a sample output", return_tensors="pt") + + with tempfile.TemporaryDirectory() as tmp_dir: + model_save_dir = Path(tmp_dir) / subfolder + model.save_pretrained(model_save_dir) + folder_contents = os.listdir(model_save_dir) + self.assertIn(SAFE_WEIGHTS_NAME, folder_contents) + self.assertIn(QUANTIZATION_CONFIG_NAME, folder_contents) + loaded_model = INCModelForCausalLM.from_pretrained(tmp_dir, subfolder=subfolder) + + with torch.no_grad(): + outputs = model(**tokens) + loaded_outputs = loaded_model(**tokens) + + self.assertTrue("logits" in loaded_outputs) + self.assertIsInstance(loaded_outputs.logits, torch.Tensor) + self.assertTrue("past_key_values" in loaded_outputs) + self.assertIsInstance(loaded_outputs.past_key_values, tuple) + self.assertTrue(torch.allclose(outputs.logits, loaded_outputs.logits, atol=1e-5)) + + def test_saving_loading_inc_model(self): + model_name = "echarlaix/tiny-random-PhiForCausalLM" + subfolder = "inc" + model = INCModelForCausalLM.from_pretrained(model_name, revision="inc", subfolder=subfolder) + tokenizer = AutoTokenizer.from_pretrained(model_name, revision="inc") + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + tokens = tokenizer("This is a sample output", return_tensors="pt") + + with tempfile.TemporaryDirectory() as tmp_dir: + model_save_dir = Path(tmp_dir) / subfolder + model.save_pretrained(model_save_dir) + folder_contents = os.listdir(model_save_dir) + self.assertIn(WEIGHTS_NAME, folder_contents) + self.assertIn("inc_config.json", folder_contents) + loaded_model = INCModelForCausalLM.from_pretrained(tmp_dir, subfolder=subfolder) + self.assertIsInstance(loaded_model.inc_config, INCConfig) + + with torch.no_grad(): + outputs = model(**tokens) + loaded_outputs = loaded_model(**tokens) + + self.assertTrue("logits" in loaded_outputs) + self.assertIsInstance(loaded_outputs.logits, torch.Tensor) + self.assertTrue("past_key_values" in loaded_outputs) + self.assertIsInstance(loaded_outputs.past_key_values, tuple) + self.assertTrue(torch.allclose(outputs.logits, loaded_outputs.logits, atol=1e-5)) diff --git a/tests/neural_compressor/test_optimization.py b/tests/neural_compressor/test_optimization.py index e38ba8e327..56f2a5bac3 100644 --- a/tests/neural_compressor/test_optimization.py +++ b/tests/neural_compressor/test_optimization.py @@ -45,8 +45,7 @@ set_seed, ) from utils_tests import MODEL_NAMES, SEED, INCTestMixin, _generate_dataset -from optimum.intel.utils.import_utils import is_torch_version, is_intel_extension_for_transformers_available - +from optimum.intel.utils.import_utils import is_torch_version, is_itrex_available from optimum.intel import ( INCConfig, @@ -70,12 +69,13 @@ class QuantizationTest(INCTestMixin): - SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS = ( + SUPPORTED_ARCHITECTURES_STATIC = ( + ("text-generation", "gpt_neo", 17), ("text-classification", "bert", 21), - # ("text-generation", "bloom", 21), + ("text-generation", "bloom", 21), ) - SUPPORTED_ARCHITECTURES_DYNAMIC = SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS + ( + SUPPORTED_ARCHITECTURES_DYNAMIC = SUPPORTED_ARCHITECTURES_STATIC + ( ("fill-mask", "bert", 22), ("token-classification", "albert", 26), ) @@ -88,12 +88,14 @@ class QuantizationTest(INCTestMixin): @parameterized.expand(SUPPORTED_ARCHITECTURES_DYNAMIC) def test_dynamic_quantization(self, task, model_arch, expected_quantized_matmuls): model_name = MODEL_NAMES[model_arch] - quantization_config = PostTrainingQuantConfig(approach="dynamic") model_class = ORT_SUPPORTED_TASKS[task]["class"][0] tokenizer = AutoTokenizer.from_pretrained(model_name) - save_onnx_model = False + quantized_model = None + save_onnx_model = False model_kwargs = {"use_cache": False, "use_io_binding": False} if task == "text-generation" else {} + quantization_config = PostTrainingQuantConfig(approach="dynamic") + with tempfile.TemporaryDirectory() as tmp_dir: for backend in ["torch", "ort"]: if backend == "torch": @@ -104,8 +106,8 @@ def test_dynamic_quantization(self, task, model_arch, expected_quantized_matmuls quantizer = INCQuantizer.from_pretrained(model, task=task) quantizer.quantize( quantization_config=quantization_config, - save_directory=tmp_dir, save_onnx_model=save_onnx_model, + save_directory=tmp_dir, ) if backend == "torch": quantized_model = quantizer._quantized_model @@ -121,7 +123,7 @@ def test_dynamic_quantization(self, task, model_arch, expected_quantized_matmuls load_inc_model=True, ) - @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS) + @parameterized.expand(SUPPORTED_ARCHITECTURES_STATIC) def test_static_quantization(self, task, model_arch, expected_quantized_matmuls): num_samples = 10 model_name = MODEL_NAMES[model_arch] @@ -130,28 +132,26 @@ def test_static_quantization(self, task, model_arch, expected_quantized_matmuls) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token - save_onnx_model = False - op_type_dict = ( - {"Embedding": {"weight": {"dtype": ["fp32"]}, "activation": {"dtype": ["fp32"]}}} - if save_onnx_model - else None - ) - quantization_config = PostTrainingQuantConfig(approach="static", op_type_dict=op_type_dict) quantized_model = None + save_onnx_model = False + quantization_config = PostTrainingQuantConfig(approach="static") + model_kwargs = {"use_cache": False, "use_io_binding": False} if task == "text-generation" else {} with tempfile.TemporaryDirectory() as tmp_dir: for backend in ["torch", "ort"]: if backend == "torch": model = model_class.auto_model_class.from_pretrained(model_name) else: - model = model_class.from_pretrained(model_name, export=True) + model = model_class.from_pretrained(model_name, export=True, **model_kwargs) + quantizer = INCQuantizer.from_pretrained(model, task=task) calibration_dataset = _generate_dataset(quantizer, tokenizer, num_samples=num_samples) + quantizer.quantize( quantization_config=quantization_config, calibration_dataset=calibration_dataset, - save_directory=tmp_dir, save_onnx_model=save_onnx_model, + save_directory=tmp_dir, ) if backend == "torch": quantized_model = quantizer._quantized_model @@ -511,7 +511,7 @@ class WeightOnlyQuantizationTest(INCTestMixin): ) @parameterized.expand(WEIGHT_ONLY_CONFIG) - @unittest.skipIf(not is_intel_extension_for_transformers_available(), reason="ITREX not available") + @unittest.skipIf(not is_itrex_available(), reason="ITREX not available") def test_weight_only_quantization(self, methodology, weight_dtype): model_name = "hf-internal-testing/tiny-random-GPTNeoForCausalLM" diff --git a/tests/neural_compressor/utils_tests.py b/tests/neural_compressor/utils_tests.py index c91270355a..2106237589 100644 --- a/tests/neural_compressor/utils_tests.py +++ b/tests/neural_compressor/utils_tests.py @@ -47,6 +47,7 @@ from optimum.intel.utils.constant import ONNX_WEIGHTS_NAME from optimum.onnxruntime import ORTModelForCausalLM, ORTModelForSequenceClassification from optimum.pipelines import ORT_SUPPORTED_TASKS +from optimum.exporters.onnx import MODEL_TYPES_REQUIRING_POSITION_IDS if is_ipex_available(): from optimum.intel import ( @@ -80,7 +81,7 @@ "electra": "hf-internal-testing/tiny-random-electra", "flaubert": "hf-internal-testing/tiny-random-flaubert", "gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel", - "gpt2": "hf-internal-testing/tiny-random-gpt2", + "gpt2": "hf-internal-testing/tiny-random-GPT2LMHeadModel", "gpt_neo": "hf-internal-testing/tiny-random-GPTNeoModel", "gpt_neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM", "gptj": "hf-internal-testing/tiny-random-GPTJModel", @@ -135,6 +136,13 @@ def _generate_dataset(quantizer, tokenizer, num_samples=10): num_samples=num_samples, dataset_split="train", ) + model_type = quantizer._original_model.config.model_type.replace("_", "-") + if model_type in MODEL_TYPES_REQUIRING_POSITION_IDS: + dataset = dataset.map( + lambda x: { + "position_ids": np.arange(len(x["input_ids"])), + } + ) return dataset @@ -187,6 +195,9 @@ def check_model_outputs( self.assertEqual(expected_quantized_matmuls, num_quantized_matmul) ort_model = ORT_SUPPORTED_TASKS[task]["class"][0].from_pretrained(save_directory, **model_kwargs) + model_type = ort_model.config.model_type.replace("_", "-") + if model_type in MODEL_TYPES_REQUIRING_POSITION_IDS: + tokens["position_ids"] = torch.arange(len(tokens["input_ids"])).unsqueeze(0) ort_outputs = ort_model(**tokens) self.assertTrue("logits" in ort_outputs) # self.assertTrue(torch.allclose(ort_outputs.logits, outputs, atol=1e-2)) diff --git a/tests/openvino/test_export.py b/tests/openvino/test_export.py index 95e31b481b..8f61d9a36d 100644 --- a/tests/openvino/test_export.py +++ b/tests/openvino/test_export.py @@ -47,6 +47,7 @@ OVStableDiffusionXLPipeline, ) from optimum.intel.openvino.modeling_base import OVBaseModel +from optimum.intel.utils.import_utils import _transformers_version from optimum.utils.save_utils import maybe_load_preprocessors @@ -115,6 +116,9 @@ def _openvino_export( if task == "text-generation": self.assertEqual(ov_model.stateful, stateful and use_cache) + self.assertEqual( + ov_model.model.get_rt_info()["optimum"]["transformers_version"], _transformers_version + ) @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_export(self, model_type: str): diff --git a/tests/openvino/test_exporters_cli.py b/tests/openvino/test_exporters_cli.py index 09fad5d773..cce25bbae1 100644 --- a/tests/openvino/test_exporters_cli.py +++ b/tests/openvino/test_exporters_cli.py @@ -18,7 +18,6 @@ from parameterized import parameterized from utils_tests import ( - _ARCHITECTURES_TO_EXPECTED_INT4_INT8, _ARCHITECTURES_TO_EXPECTED_INT8, MODEL_NAMES, get_num_quantized_nodes, @@ -74,8 +73,8 @@ class OVCLIExportTestCase(unittest.TestCase): "wav2vec2": 0, # no tokenizer "bert": 1, # no detokenizer "blenderbot": 2, - "stable-diffusion": 0, # not supported - "stable-diffusion-xl": 0, # not supported + "stable-diffusion": 2, + "stable-diffusion-xl": 4, } SUPPORTED_SD_HYBRID_ARCHITECTURES = ( @@ -84,14 +83,13 @@ class OVCLIExportTestCase(unittest.TestCase): ("latent-consistency", 50, 135), ) - SUPPORTED_4BIT_ARCHITECTURES = (("text-generation-with-past", "opt125m"),) - - SUPPORTED_4BIT_OPTIONS = ["int4_sym_g128", "int4_asym_g128", "int4_sym_g64", "int4_asym_g64"] - - TEST_4BIT_CONFIGURATONS = [] - for arch in SUPPORTED_4BIT_ARCHITECTURES: - for option in SUPPORTED_4BIT_OPTIONS: - TEST_4BIT_CONFIGURATONS.append([arch[0], arch[1], option]) + TEST_4BIT_CONFIGURATONS = [ + ("text-generation-with-past", "opt125m", "int4_sym_g128", 62, 86), + ("text-generation-with-past", "opt125m", "int4_asym_g128", 62, 86), + ("text-generation-with-past", "opt125m", "int4_sym_g64", 62, 86), + ("text-generation-with-past", "opt125m", "int4_asym_g64", 62, 86), + ("text-generation-with-past", "llama_awq", "int4 --ratio 1.0 --sym --group-size 16 --all-layers", 0, 32), + ] def _openvino_export( self, model_name: str, task: str, compression_option: str = None, compression_ratio: float = None @@ -197,17 +195,16 @@ def test_exporters_cli_hybrid_quantization(self, model_type: str, exp_num_fq: in self.assertEqual(exp_num_fq, num_fq) @parameterized.expand(TEST_4BIT_CONFIGURATONS) - def test_exporters_cli_int4(self, task: str, model_type: str, option: str): + def test_exporters_cli_int4(self, task: str, model_type: str, option: str, expected_int8: int, expected_int4: int): with TemporaryDirectory() as tmpdir: subprocess.run( - f"optimum-cli export openvino --model {MODEL_NAMES[model_type]} --task {task} --weight-format {option} {tmpdir}", + f"optimum-cli export openvino --model {MODEL_NAMES[model_type]} --task {task} --weight-format {option} {tmpdir}", shell=True, check=True, ) model_kwargs = {"use_cache": task.endswith("with-past")} if "generation" in task else {} model = eval(_HEAD_TO_AUTOMODELS[task.replace("-with-past", "")]).from_pretrained(tmpdir, **model_kwargs) - expected_int8, expected_int4 = _ARCHITECTURES_TO_EXPECTED_INT4_INT8[model_type] _, num_int8, num_int4 = get_num_quantized_nodes(model) self.assertEqual(expected_int8, num_int8) self.assertEqual(expected_int4, num_int4) diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index f84cac8161..cb5ac52ed7 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -14,6 +14,7 @@ import gc import os +import subprocess import tempfile import time import unittest @@ -247,6 +248,15 @@ def test_load_from_hub_and_save_stable_diffusion_model(self): del pipeline gc.collect() + def test_load_model_from_hub_private_with_token(self): + subprocess.run("huggingface-cli logout", shell=True) + + # a fine-grained read-only token of private repo "IlyasMoutawwakil/test-hub-bert" + token = "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" + + loaded_model = OVModelForMaskedLM.from_pretrained("IlyasMoutawwakil/test-hub-bert", use_auth_token=token) + self.assertIsInstance(loaded_model.config, PretrainedConfig) + class OVModelForSequenceClassificationIntegrationTest(unittest.TestCase): SUPPORTED_ARCHITECTURES = ( @@ -510,12 +520,14 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase): SUPPORTED_ARCHITECTURES = ( "bart", "baichuan2", + "baichuan2-13b", "gpt_bigcode", "blenderbot", "blenderbot-small", "bloom", "chatglm", "codegen", + "codegen2", # "data2vec-text", # TODO : enable when enabled in exporters "gemma", "gpt2", @@ -528,6 +540,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase): "mistral", "mixtral", "mpt", + "olmo", "opt", "pegasus", "qwen", @@ -535,22 +548,50 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase): "stablelm", "starcoder2", "phi", + "phi3", "internlm2", "orion", "falcon", + "falcon-40b", + "persimmon", + "biogpt", + "gpt_neox_japanese", + "cohere", + "xglm", + "aquila", + "aquila2", + "xverse", + "internlm", + "dbrx", + "qwen2-moe", ) GENERATION_LENGTH = 100 - REMOTE_CODE_MODELS = ("chatglm", "minicpm", "baichuan2", "jais", "qwen", "internlm2", "olmo", "orion") + REMOTE_CODE_MODELS = ( + "chatglm", + "minicpm", + "baichuan2", + "baichuan2-13b", + "jais", + "qwen", + "internlm2", + "orion", + "phi3", + "aquila", + "aquila2", + "xverse", + "internlm", + "codegen2", + ) @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_compare_to_transformers(self, model_arch): model_id = MODEL_NAMES[model_arch] - not_stateful = ["gpt_bigcode"] + not_stateful = [] if is_openvino_version("<", "2024.0"): not_stateful.append("mixtral") if is_openvino_version("<", "2024.1"): - not_stateful.extend(["llama", "gemma"]) + not_stateful.extend(["llama", "gemma", "gpt_bigcode"]) if "gptq" in model_arch: self.skipTest("GPTQ model loading unsupported with AutoModelForCausalLM") @@ -567,6 +608,7 @@ def test_compare_to_transformers(self, model_arch): self.assertEqual(ov_model.stateful, ov_model.config.model_type not in not_stateful) tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS) tokens = tokenizer("This is a sample output", return_tensors="pt") + tokens.pop("token_type_ids", None) ov_outputs = ov_model(**tokens) self.assertTrue("logits" in ov_outputs) @@ -593,11 +635,15 @@ def test_compare_to_transformers(self, model_arch): if model_arch == "qwen": return - if model_arch != "chatglm": + if model_arch not in ["chatglm", "persimmon"]: tokenizer.pad_token_id = tokenizer.eos_token_id + + if model_arch == "persimmon": + tokenizer.pad_token_id = tokenizer.bos_token_id # Compare batched generation tokenizer.padding_side = "left" tokens = tokenizer(["Today is a nice day and I am longer", "This is me"], return_tensors="pt", padding=True) + tokens.pop("token_type_ids", None) ov_model.generation_config.eos_token_id = None transformers_model.generation_config.eos_token_id = None ov_model.config.eos_token_id = None @@ -754,6 +800,94 @@ def test_default_filling_attention_mask_and_position_ids(self): del model_with_cache gc.collect() + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @pytest.mark.run_slow + @slow + def test_beam_search(self, model_arch): + model_kwargs = {} + model_id = MODEL_NAMES[model_arch] + if model_arch in self.REMOTE_CODE_MODELS: + model_kwargs = { + "config": AutoConfig.from_pretrained(model_id, trust_remote_code=True), + "trust_remote_code": True, + } + # Qwen tokenizer does not support padding, chatgm testing model produces nan that incompatible with beam search + if model_arch in ["qwen", "chatglm"]: + return + + tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS) + if model_arch == "persimmon": + tokenizer.pad_token_id = tokenizer.bos_token_id + tokenizer.eos_token_id = tokenizer.bos_token_id + + beam_search_gen_config = GenerationConfig( + max_new_tokens=10, + min_new_tokens=10, + num_beams=4, + do_sample=False, + eos_token_id=None, + ) + beam_sample_gen_config = GenerationConfig( + max_new_tokens=10, + min_new_tokens=10, + num_beams=4, + do_sample=True, + eos_token_id=None, + top_k=1, + ) + + group_beam_search_gen_config = GenerationConfig( + max_new_tokens=10, + min_new_tokens=10, + num_beams=4, + do_sample=False, + eos_token_id=None, + num_beam_groups=2, + diversity_penalty=0.0000001, + ) + force_word = "cat" + force_words_ids = [tokenizer([force_word], add_special_tokens=False).input_ids] + constrained_beam_search_gen_config = GenerationConfig( + max_new_tokens=10, + min_new_tokens=10, + num_beams=4, + do_sample=False, + eos_token_id=None, + force_words_ids=force_words_ids, + ) + + gen_configs = [ + beam_search_gen_config, + beam_sample_gen_config, + group_beam_search_gen_config, + constrained_beam_search_gen_config, + ] + ov_model_stateful = OVModelForCausalLM.from_pretrained( + model_id, export=True, use_cache=True, stateful=True, **model_kwargs + ) + ov_model_stateless = OVModelForCausalLM.from_pretrained( + model_id, export=True, use_cache=True, stateful=False, **model_kwargs + ) + transformers_model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs) + tokenizer.pad_token_id = tokenizer.eos_token_id + tokens = tokenizer(["Today is a nice day and I am longer", "This is me"], return_tensors="pt", padding=True) + tokens.pop("token_type_ids", None) + ov_model_stateful.generation_config.eos_token_id = None + ov_model_stateless.generation_config.eos_token_id = None + transformers_model.generation_config.eos_token_id = None + ov_model_stateful.config.eos_token_id = None + ov_model_stateless.config.eos_token_id = None + transformers_model.config.eos_token_id = None + + for gen_config in gen_configs: + if gen_config.do_sample and model_arch in ["baichuan2-13b", "olmo"]: + continue + transformers_outputs = transformers_model.generate(**tokens, generation_config=gen_config) + ov_stateful_outputs = ov_model_stateful.generate(**tokens, generation_config=gen_config) + self.assertTrue(torch.allclose(ov_stateful_outputs, transformers_outputs)) + ov_stateless_outputs = ov_model_stateless.generate(**tokens, generation_config=gen_config) + self.assertTrue(torch.allclose(ov_stateless_outputs, transformers_outputs)) + class OVModelForMaskedLMIntegrationTest(unittest.TestCase): SUPPORTED_ARCHITECTURES = ( diff --git a/tests/openvino/test_quantization.py b/tests/openvino/test_quantization.py index 26dfc658a5..09b395ea12 100644 --- a/tests/openvino/test_quantization.py +++ b/tests/openvino/test_quantization.py @@ -21,21 +21,17 @@ from collections import defaultdict from enum import Enum from functools import partial -from typing import List, Union +from typing import Union import evaluate import numpy as np import torch from datasets import load_dataset -from nncf.quantization.advanced_parameters import OverflowFix from parameterized import parameterized -import openvino.runtime as ov import nncf from transformers import ( AutoModelForQuestionAnswering, AutoModelForSequenceClassification, - AutoModelForCausalLM, - AutoModelForTokenClassification, AutoTokenizer, AutoProcessor, TrainingArguments, @@ -77,12 +73,16 @@ class OVQuantizerTest(unittest.TestCase): - SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS = ( + SUPPORTED_ARCHITECTURES_TORCH_MODEL = ( (OVModelForSequenceClassification, "bert", 32, 35), - # (OVModelForCausalLM, "gpt2", 41, 23), + (OVModelForCausalLM, "gpt2", 41, 3), + ) + SUPPORTED_ARCHITECTURES_OV_MODEL = ( + (OVModelForSequenceClassification, "bert", 32, 35), + (OVModelForCausalLM, "gpt2", 31, 22), ) - @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS) + @parameterized.expand(SUPPORTED_ARCHITECTURES_TORCH_MODEL) def test_automodel_static_quantization(self, model_cls, model_name, expected_fake_quantize, expected_int8): model_id = MODEL_NAMES[model_name] task = model_cls.export_feature @@ -127,23 +127,21 @@ def preprocess_function(examples, tokenizer): loaded_config = OVConfig.from_pretrained(tmp_dir) self.assertEqual(ov_config.quantization_config.to_dict(), loaded_config.quantization_config.to_dict()) - @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS) + @parameterized.expand(SUPPORTED_ARCHITECTURES_OV_MODEL) def test_ovmodel_static_quantization(self, model_cls, model_name, expected_fake_quantize, expected_int8): model_id = MODEL_NAMES[model_name] task = model_cls.export_feature dataset_name, dataset_config_name, column_name = _TASK_TO_DATASET[task] - if "gpt2" in model_id: - expected_int8 -= 1 def preprocess_function(examples, tokenizer): return tokenizer(examples[column_name], padding="max_length", max_length=128, truncation=True) with tempfile.TemporaryDirectory() as tmp_dir: - transformers_model = model_cls.from_pretrained(model_id, export=True) + ov_model = model_cls.from_pretrained(model_id, export=True) tokenizer = AutoTokenizer.from_pretrained(model_id) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token - quantizer = OVQuantizer.from_pretrained(transformers_model, task=task) + quantizer = OVQuantizer.from_pretrained(ov_model, task=task) calibration_dataset = quantizer.get_calibration_dataset( dataset_name, @@ -221,17 +219,17 @@ class OVWeightCompressionTest(unittest.TestCase): ), ( OVModelForCausalLM, - "opt", + "llama_awq", dict( bits=4, sym=True, - group_size=-1, + group_size=16, ratio=0.8, sensitivity_metric="mean_activation_magnitude", dataset="ptb", quant_method=QuantizationMethod.AWQ, ), - 14, + 16, ), ) @@ -413,8 +411,12 @@ def test_ovmodel_hybrid_quantization_with_custom_dataset( model = model_cls.from_pretrained( model_id, export=True, - quantization_config=OVWeightQuantizationConfig(bits=8, dataset=dataset, num_samples=3), ) + quantizer = OVQuantizer(model) + quantization_config = OVWeightQuantizationConfig( + bits=8, num_samples=3, quant_method=OVQuantizationMethod.HYBRID + ) + quantizer.quantize(ov_config=OVConfig(quantization_config=quantization_config), calibration_dataset=dataset) num_fake_quantize, num_int8, num_int4 = get_num_quantized_nodes(model.unet) self.assertEqual(expected_num_fake_quantize, num_fake_quantize) self.assertEqual(expected_ov_int8, num_int8) @@ -452,6 +454,10 @@ def test_ovmodel_4bit_auto_compression_with_config( with tempfile.TemporaryDirectory() as tmp_dir: quantization_config = OVWeightQuantizationConfig.from_dict(quantization_config) model = model_cls.from_pretrained(model_id, export=True, quantization_config=quantization_config) + if quantization_config.quant_method == QuantizationMethod.AWQ: + # TODO: Check that AWQ was actually applied + pass + tokenizer = AutoTokenizer.from_pretrained(model_id) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token @@ -548,6 +554,8 @@ def test_ovmodel_load_large_model_with_additional_quantization_config(self): "sensitivity_metric": None, "dataset": None, "ignored_scope": nncf.IgnoredScope(), + "awq": None, + "subset_size": 128, } compress_weights_patch.assert_called_with(unittest.mock.ANY, **compression_params) @@ -657,7 +665,7 @@ def preprocess_function(examples, tokenizer): class OVTrainerTest(unittest.TestCase): - SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS = (("distilbert-base-uncased", 50, 38),) + SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS = (("distilbert-base-uncased", 49, 38),) @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS) def test_aware_training_quantization(self, model_name, expected_fake_quantize, expected_int8): @@ -709,7 +717,7 @@ class OVQuantizationConfigTest(unittest.TestCase): (OVWeightQuantizationConfig(bits=8, sym=True),), ( OVWeightQuantizationConfig( - dataset="wikitext", + dataset="wikitext2", bits=4, ignored_scope={"names": ["op_name"]}, sym=False, @@ -741,7 +749,7 @@ class OVQuantizationConfigTest(unittest.TestCase): (dict(bits=8, sym=True), OVWeightQuantizationConfig, None), ( dict( - dataset="wikitext", + dataset="wikitext2", bits=4, ignored_scope={"names": ["op_name"]}, sym=False, @@ -765,7 +773,7 @@ class OVQuantizationConfigTest(unittest.TestCase): (dict(num_samples=100), OVWeightQuantizationConfig, "Can't determine type of OV quantization config"), (dict(abc="def"), OVWeightQuantizationConfig, "Can't determine type of OV quantization config"), ( - dict(bits=8, fast_bias_correction=True, dataset="wikitext"), + dict(bits=8, fast_bias_correction=True, dataset="wikitext2"), OVWeightQuantizationConfig, "Can't determine type of OV quantization config", ), @@ -787,7 +795,7 @@ class OVQuantizationConfigTest(unittest.TestCase): (dict(abc="def", weight_only=False), OVQuantizationConfig, None), (dict(abc="def", weight_only=True), OVWeightQuantizationConfig, None), ( - dict(bits=8, fast_bias_correction=True, dataset="wikitext", weight_only=True), + dict(bits=8, fast_bias_correction=True, dataset="wikitext2", weight_only=True), OVWeightQuantizationConfig, None, ), diff --git a/tests/openvino/test_training.py b/tests/openvino/test_training.py index db443c6de2..c998d00d8b 100644 --- a/tests/openvino/test_training.py +++ b/tests/openvino/test_training.py @@ -322,7 +322,7 @@ def tearDown(self): "default_quantization": OVTrainerTestDescriptor( model_id="hf-internal-testing/tiny-random-bert", nncf_compression_config=DEFAULT_QUANTIZATION_CONFIG, - expected_fake_quantize=44, + expected_fake_quantize=34, expected_int8=32, compression_metrics=["compression_loss"], ), @@ -330,14 +330,14 @@ def tearDown(self): model_id="hf-internal-testing/tiny-random-bert", teacher_model_id="hf-internal-testing/tiny-random-bert", nncf_compression_config=DEFAULT_QUANTIZATION_CONFIG, - expected_fake_quantize=44, + expected_fake_quantize=34, expected_int8=32, compression_metrics=["compression_loss", "distillation_loss", "task_loss"], ), "customized_quantization": OVTrainerTestDescriptor( model_id="hf-internal-testing/tiny-random-bert", nncf_compression_config=CUSTOMIZED_QUANTIZATION_CONFIG, - expected_fake_quantize=44, + expected_fake_quantize=34, expected_int8=32, compression_metrics=["compression_loss"], ), @@ -345,7 +345,7 @@ def tearDown(self): model_id="hf-internal-testing/tiny-random-bert", teacher_model_id="hf-internal-testing/tiny-random-bert", nncf_compression_config=CUSTOMIZED_QUANTIZATION_CONFIG, - expected_fake_quantize=44, + expected_fake_quantize=34, expected_int8=32, compression_metrics=["compression_loss", "distillation_loss", "task_loss"], ), @@ -418,7 +418,7 @@ def tearDown(self): "default_quantization,unstructured_movement_sparsity": OVTrainerTestDescriptor( model_id="hf-internal-testing/tiny-random-bert", nncf_compression_config=[DEFAULT_QUANTIZATION_CONFIG, UNSTRUCTURED_MOVEMENT_SPARSITY_CONFIG_FOR_BERT], - expected_fake_quantize=44, + expected_fake_quantize=34, expected_int8=32, expected_binary_masks=60, compression_metrics=["compression_loss"], @@ -429,7 +429,7 @@ def tearDown(self): CUSTOMIZED_QUANTIZATION_CONFIG, UNSTRUCTURED_MOVEMENT_SPARSITY_CONFIG_FOR_BERT, ], - expected_fake_quantize=44, + expected_fake_quantize=34, expected_int8=32, expected_binary_masks=60, compression_metrics=["compression_loss"], @@ -438,7 +438,7 @@ def tearDown(self): model_id="hf-internal-testing/tiny-random-bert", teacher_model_id="hf-internal-testing/tiny-random-bert", nncf_compression_config=[DEFAULT_QUANTIZATION_CONFIG, UNSTRUCTURED_MOVEMENT_SPARSITY_CONFIG_FOR_BERT], - expected_fake_quantize=44, + expected_fake_quantize=34, expected_int8=32, expected_binary_masks=60, compression_metrics=["compression_loss", "distillation_loss", "task_loss"], @@ -450,7 +450,7 @@ def tearDown(self): CUSTOMIZED_QUANTIZATION_CONFIG, UNSTRUCTURED_MOVEMENT_SPARSITY_CONFIG_FOR_BERT, ], - expected_fake_quantize=44, + expected_fake_quantize=34, expected_int8=32, expected_binary_masks=60, compression_metrics=["compression_loss", "distillation_loss", "task_loss"], @@ -730,7 +730,7 @@ def check_ovmodel_reshaping(self, ovmodel: OVModel): "quantization": OVTrainerTestDescriptor( model_id="hf-internal-testing/tiny-random-Wav2Vec2Model", nncf_compression_config=[QUANTIZATION_CONFIG_FOR_WAV2VEC2], - expected_fake_quantize=48, + expected_fake_quantize=40, expected_int8=30, compression_metrics=["compression_loss"], ), @@ -757,7 +757,7 @@ def check_ovmodel_reshaping(self, ovmodel: OVModel): "quantization,unstructured_movement_sparsity": OVTrainerTestDescriptor( model_id="hf-internal-testing/tiny-random-Wav2Vec2Model", nncf_compression_config=[QUANTIZATION_CONFIG_FOR_WAV2VEC2, UNSTRUCTURED_MOVEMENT_SPARSITY_CONFIG_FOR_WAV2VEC2], - expected_fake_quantize=48, + expected_fake_quantize=40, expected_int8=30, expected_binary_masks=48, compression_metrics=["compression_loss"], @@ -775,7 +775,7 @@ def check_ovmodel_reshaping(self, ovmodel: OVModel): model_id="hf-internal-testing/tiny-random-Wav2Vec2Model", teacher_model_id="hf-internal-testing/tiny-random-Wav2Vec2Model", nncf_compression_config=[QUANTIZATION_CONFIG_FOR_WAV2VEC2, UNSTRUCTURED_MOVEMENT_SPARSITY_CONFIG_FOR_WAV2VEC2], - expected_fake_quantize=48, + expected_fake_quantize=40, expected_int8=30, expected_binary_masks=48, compression_metrics=["compression_loss", "distillation_loss", "task_loss"], diff --git a/tests/openvino/utils_tests.py b/tests/openvino/utils_tests.py index ca56f6d552..91500cfc63 100644 --- a/tests/openvino/utils_tests.py +++ b/tests/openvino/utils_tests.py @@ -18,23 +18,30 @@ MODEL_NAMES = { "albert": "hf-internal-testing/tiny-random-albert", + "aquila": "katuni4ka/tiny-random-aquilachat", + "aquila2": "katuni4ka/tiny-random-aquila2", "audio_spectrogram_transformer": "Ericwang/tiny-random-ast", "bge": "BAAI/bge-small-en-v1.5", "beit": "hf-internal-testing/tiny-random-BeitForImageClassification", "bert": "hf-internal-testing/tiny-random-bert", "bart": "hf-internal-testing/tiny-random-bart", "baichuan2": "katuni4ka/tiny-random-baichuan2", + "baichuan2-13b": "katuni4ka/tiny-random-baichuan2-13b", "bigbird_pegasus": "hf-internal-testing/tiny-random-bigbird_pegasus", + "biogpt": "hf-tiny-model-private/tiny-random-BioGptForCausalLM", "blenderbot-small": "hf-internal-testing/tiny-random-BlenderbotModel", "blenderbot": "hf-internal-testing/tiny-random-BlenderbotModel", "bloom": "hf-internal-testing/tiny-random-BloomModel", "camembert": "hf-internal-testing/tiny-random-camembert", "convbert": "hf-internal-testing/tiny-random-ConvBertForSequenceClassification", + "cohere": "hf-internal-testing/tiny-random-CohereForCausalLM", "chatglm": "katuni4ka/tiny-random-chatglm2", "codegen": "hf-internal-testing/tiny-random-CodeGenForCausalLM", + "codegen2": "katuni4ka/tiny-random-codegen2", "data2vec_text": "hf-internal-testing/tiny-random-Data2VecTextModel", "data2vec_vision": "hf-internal-testing/tiny-random-Data2VecVisionModel", "data2vec_audio": "hf-internal-testing/tiny-random-Data2VecAudioModel", + "dbrx": "katuni4ka/tiny-random-dbrx", "deberta": "hf-internal-testing/tiny-random-deberta", "deberta_v2": "hf-internal-testing/tiny-random-DebertaV2Model", "deit": "hf-internal-testing/tiny-random-deit", @@ -44,18 +51,22 @@ "electra": "hf-internal-testing/tiny-random-electra", "gemma": "fxmarty/tiny-random-GemmaForCausalLM", "falcon": "fxmarty/really-tiny-falcon-testing", + "falcon-40b": "katuni4ka/tiny-random-falcon-40b", "flaubert": "hf-internal-testing/tiny-random-flaubert", "gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel", "gpt2": "hf-internal-testing/tiny-random-gpt2", "gpt_neo": "hf-internal-testing/tiny-random-GPTNeoModel", "gpt_neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM", + "gpt_neox_japanese": "hf-internal-testing/tiny-random-GPTNeoXJapaneseForCausalLM", "gptj": "hf-internal-testing/tiny-random-GPTJModel", "hubert": "hf-internal-testing/tiny-random-HubertModel", "ibert": "hf-internal-testing/tiny-random-ibert", + "internlm": "katuni4ka/tiny-random-internlm", "internlm2": "katuni4ka/tiny-random-internlm2", "levit": "hf-internal-testing/tiny-random-LevitModel", "longt5": "hf-internal-testing/tiny-random-longt5", "llama": "fxmarty/tiny-llama-fast-tokenizer", + "llama_awq": "HuggingFaceH4/tiny-random-LlamaForCausalLM", "llama_gptq": "hf-internal-testing/TinyLlama-1.1B-Chat-v0.3-GPTQ", "m2m_100": "hf-internal-testing/tiny-random-m2m_100", "opt": "hf-internal-testing/tiny-random-OPTModel", @@ -72,14 +83,17 @@ "mpt": "hf-internal-testing/tiny-random-MptForCausalLM", "mt5": "stas/mt5-tiny-random", "nystromformer": "hf-internal-testing/tiny-random-NystromformerModel", - "olmo": "katuni4ka/tiny-random-olmo", + "olmo": "katuni4ka/tiny-random-olmo-hf", "orion": "katuni4ka/tiny-random-orion", "pegasus": "hf-internal-testing/tiny-random-pegasus", + "persimmon": "hf-internal-testing/tiny-random-PersimmonForCausalLM", "pix2struct": "fxmarty/pix2struct-tiny-random", "phi": "echarlaix/tiny-random-PhiForCausalLM", + "phi3": "katuni4ka/tiny-random-phi3", "poolformer": "hf-internal-testing/tiny-random-PoolFormerModel", "qwen": "katuni4ka/tiny-random-qwen", "qwen2": "Qwen/Qwen1.5-0.5B", + "qwen2-moe": "katuni4ka/tiny-random-qwen1.5-moe", "resnet": "hf-internal-testing/tiny-random-resnet", "roberta": "hf-internal-testing/tiny-random-roberta", "roformer": "hf-internal-testing/tiny-random-roformer", @@ -111,6 +125,8 @@ "whisper": "openai/whisper-tiny.en", "xlm": "hf-internal-testing/tiny-random-xlm", "xlm_roberta": "hf-internal-testing/tiny-xlm-roberta", + "xglm": "hf-internal-testing/tiny-random-XGLMForCausalLM", + "xverse": "katuni4ka/tiny-random-xverse", } @@ -136,8 +152,6 @@ "stable-diffusion-xl-refiner": (366, 34, 42, 66), } -_ARCHITECTURES_TO_EXPECTED_INT4_INT8 = {"opt125m": (62, 86)} - def get_num_quantized_nodes(ov_model): num_fake_quantize = 0