Skip to content

Commit

Permalink
add
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Nov 3, 2023
1 parent 598cba4 commit 4273932
Showing 1 changed file with 58 additions and 2 deletions.
60 changes: 58 additions & 2 deletions optimum/intel/neural_compressor/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import copy
import inspect
import logging
import warnings
from enum import Enum
from itertools import chain
from pathlib import Path
Expand All @@ -30,16 +31,25 @@
from neural_compressor.quantization import fit
from torch.utils.data import DataLoader, RandomSampler
from transformers import (
AutoModelForCausalLM,
AutoModelForMaskedLM,
AutoModelForMultipleChoice,
AutoModelForQuestionAnswering,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
AutoModelForVision2Seq,
DataCollator,
PretrainedConfig,
PreTrainedModel,
XLNetLMHeadModel,
default_data_collator,
)

from optimum.exporters import TasksManager
from optimum.exporters.onnx import OnnxConfig
from optimum.onnxruntime import ORTModel
from optimum.onnxruntime.modeling_decoder import ORTModelDecoder
from optimum.onnxruntime.modeling_decoder import ORTModelForCausalLM
from optimum.onnxruntime.modeling_seq2seq import ORTModelForConditionalGeneration
from optimum.onnxruntime.utils import ONNX_DECODER_NAME
from optimum.quantization_base import OptimumQuantizer
Expand Down Expand Up @@ -256,7 +266,7 @@ def quantize(
if isinstance(self._original_model, ORTModelForConditionalGeneration):
raise RuntimeError("ORTModelForConditionalGeneration not supported for quantization")

if isinstance(self._original_model, ORTModelDecoder):
if isinstance(self._original_model, ORTModelForCausalLM):
model_or_path = self._original_model.onnx_paths
if len(model_or_path) > 1:
raise RuntimeError(
Expand Down Expand Up @@ -528,3 +538,49 @@ def _apply_quantization_from_config(q_config: Dict, model: torch.nn.Module) -> t
q_model = convert(q_model, mapping=q_mapping, inplace=True)

return q_model


class IncQuantizedModel(INCModel):
@classmethod
def from_pretrained(cls, *args, **kwargs):
warnings.warn(
f"The class `{cls.__name__}` has been depreciated and will be removed in optimum-intel v1.12, please use "
f"`{cls.__name__.replace('IncQuantized', 'INC')}` instead."
)
return super().from_pretrained(*args, **kwargs)


class IncQuantizedModelForQuestionAnswering(IncQuantizedModel):
auto_model_class = AutoModelForQuestionAnswering


class IncQuantizedModelForSequenceClassification(IncQuantizedModel):
auto_model_class = AutoModelForSequenceClassification


class IncQuantizedModelForTokenClassification(IncQuantizedModel):
auto_model_class = AutoModelForTokenClassification


class IncQuantizedModelForMultipleChoice(IncQuantizedModel):
auto_model_class = AutoModelForMultipleChoice


class IncQuantizedModelForSeq2SeqLM(IncQuantizedModel):
auto_model_class = AutoModelForSeq2SeqLM


class IncQuantizedModelForCausalLM(IncQuantizedModel):
auto_model_class = AutoModelForCausalLM


class IncQuantizedModelForMaskedLM(IncQuantizedModel):
auto_model_class = AutoModelForMaskedLM


class IncQuantizedModelForXLNetLM(IncQuantizedModel):
auto_model_class = XLNetLMHeadModel


class IncQuantizedModelForVision2Seq(IncQuantizedModel):
auto_model_class = AutoModelForVision2Seq

0 comments on commit 4273932

Please sign in to comment.