Skip to content

Commit

Permalink
Remove useless transformers version check
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Feb 8, 2024
1 parent 1c14957 commit db66c13
Show file tree
Hide file tree
Showing 6 changed files with 10 additions and 37 deletions.
17 changes: 4 additions & 13 deletions optimum/intel/openvino/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,18 @@
from huggingface_hub import hf_hub_download
from openvino import Core, convert_model
from openvino._offline_transformations import apply_moc_transformations, compress_model_transformation
from transformers import PretrainedConfig
from transformers import GenerationConfig, PretrainedConfig
from transformers.file_utils import add_start_docstrings
from transformers.generation import GenerationMixin

from optimum.exporters.onnx import OnnxConfig
from optimum.modeling_base import OptimizedModel

from ...exporters.openvino import export, main_export
from ..utils.import_utils import is_nncf_available, is_transformers_version
from ..utils.import_utils import is_nncf_available
from .utils import ONNX_WEIGHTS_NAME, OV_XML_FILE_NAME, _print_compiled_model_properties


if is_transformers_version("<", "4.25.0"):
from transformers.generation_utils import GenerationMixin
else:
from transformers.generation import GenerationMixin

core = Core()

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -92,12 +88,7 @@ def __init__(
if enable_compilation:
self.compile()

if is_transformers_version("<=", "4.25.1"):
self.generation_config = None
else:
from transformers import GenerationConfig

self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None

@staticmethod
def load_model(file_name: Union[str, Path], load_in_8bit: bool = False):
Expand Down
11 changes: 2 additions & 9 deletions optimum/intel/openvino/modeling_base_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,10 @@
import openvino
from huggingface_hub import hf_hub_download
from openvino._offline_transformations import apply_moc_transformations, compress_model_transformation
from transformers import PretrainedConfig
from transformers import GenerationConfig, PretrainedConfig
from transformers.file_utils import add_start_docstrings

from ...exporters.openvino import main_export
from ..utils.import_utils import is_transformers_version
from .modeling_base import OVBaseModel
from .utils import (
ONNX_DECODER_NAME,
Expand Down Expand Up @@ -75,13 +74,7 @@ def __init__(
self.encoder_model = encoder
self.decoder_model = decoder
self.decoder_with_past_model = decoder_with_past

if is_transformers_version("<=", "4.25.1"):
self.generation_config = None
else:
from transformers import GenerationConfig

self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None

def _save_pretrained(self, save_directory: Union[str, Path]):
"""
Expand Down
8 changes: 1 addition & 7 deletions optimum/intel/openvino/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,25 +25,19 @@
from openvino.runtime import Core, Tensor, Type
from transformers import AutoModelForCausalLM, PretrainedConfig
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from transformers.generation import GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithPast

from optimum.utils.normalized_config import NormalizedConfigManager

from ...exporters.openvino import ensure_stateful_is_available, main_export, patch_stateful
from ...exporters.openvino.stateful import model_has_state
from ..utils.import_utils import is_transformers_version
from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS
from .modeling import _TOKENIZER_FOR_DOC, INPUTS_DOCSTRING, MODEL_START_DOCSTRING, OVModel
from .utils import ONNX_WEIGHTS_NAME, OV_XML_FILE_NAME, STR_TO_OV_TYPE
from .weight_quantization import OVWeightQuantizationConfig, compress_decoder_weights


if is_transformers_version("<", "4.25.0"):
from transformers.generation_utils import GenerationMixin
else:
from transformers.generation import GenerationMixin


logger = logging.getLogger(__name__)

core = Core()
Expand Down
7 changes: 1 addition & 6 deletions optimum/intel/openvino/modeling_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,15 @@
WhisperForConditionalGeneration,
)
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from transformers.generation import GenerationMixin
from transformers.generation.logits_process import WhisperTimeStampLogitsProcessor
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
from transformers.models.whisper.tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE

from ..utils.import_utils import is_transformers_version
from .modeling_base_seq2seq import OVBaseModelForSeq2SeqLM
from .utils import _print_compiled_model_properties


if is_transformers_version("<", "4.25.0"):
from transformers.generation_utils import GenerationMixin
else:
from transformers.generation import GenerationMixin

if TYPE_CHECKING:
from transformers import PretrainedConfig

Expand Down
2 changes: 1 addition & 1 deletion tests/ipex/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class IPEXIntegrationTest(unittest.TestCase):
"gptj",
"gpt2",
"gpt_neo",
"gpt_bigcode",
# "gpt_bigcode",
"llama",
"opt",
"mpt",
Expand Down
2 changes: 1 addition & 1 deletion tests/ipex/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ class IPEXModelForCausalLMTest(unittest.TestCase):
"opt",
)
GENERATION_LENGTH = 100
SPEEDUP_CACHE = 1.1
SPEEDUP_CACHE = 1.0

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_compare_to_transformers(self, model_arch):
Expand Down

0 comments on commit db66c13

Please sign in to comment.