Skip to content

Commit

Permalink
Fix transformers v4.42.0 compatibility (#793)
Browse files Browse the repository at this point in the history
* fix transformers v4.42.0 compatibility

* fix inc modeling

* update setup

* add missing argument

* fix patching

* format

* fix num quant op

* remove incompatible transformers generation

* udpate setup
  • Loading branch information
echarlaix committed Jul 8, 2024
1 parent 1b2c29b commit b1c1900
Show file tree
Hide file tree
Showing 12 changed files with 92 additions and 78 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test_openvino.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
fail-fast: false
matrix:
python-version: ["3.8", "3.12"]
transformers-version: ["4.36.0", "4.41.*"]
transformers-version: ["4.36.0", "4.42.*"]
os: [ubuntu-latest]

runs-on: ${{ matrix.os }}
Expand Down
37 changes: 23 additions & 14 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -1161,7 +1161,7 @@ def __exit__(self, exc_type, exc_value, traceback):
block.attention.forward = block.attention._orig_forward


# Adapted from https://github.com/huggingface/transformers/blob/ccdabc5642bf84849af93f591e207dc625c8e1e1/src/transformers/models/phi3/modeling_phi3.py#L426
# Adapted from https://github.com/huggingface/transformers/blob/ccdabc5642bf84849af93f591e207dc625c8e1e1/src/transformers/models/phi3/modeling_phi3.py#L729
def _phi3_self_attn_sdpa_forward(
self,
hidden_states: torch.Tensor,
Expand All @@ -1170,6 +1170,7 @@ def _phi3_self_attn_sdpa_forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
return self._orig_forward(
Expand All @@ -1181,10 +1182,9 @@ def _phi3_self_attn_sdpa_forward(
use_cache=use_cache,
)

# TO DO: remove llama imports when transformers with phi3 support will be released
try:
if is_transformers_version(">=", "4.41.0"):
from transformers.models.phi3.modeling_phi3 import apply_rotary_pos_emb, repeat_kv
except ImportError:
else:
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv

bsz, q_len, _ = hidden_states.size()
Expand All @@ -1206,17 +1206,15 @@ def _phi3_self_attn_sdpa_forward(
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

causal_mask = attention_mask
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]

# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
Expand All @@ -1229,7 +1227,7 @@ def _phi3_self_attn_sdpa_forward(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal=self.is_causal and attention_mask is None and q_len > 1,
Expand Down Expand Up @@ -1561,7 +1559,7 @@ def __exit__(self, exc_type, exc_value, traceback):
layer.attn._attn = layer.attn._orig_attn


# adapted from https://github.com/huggingface/transformers/blob/v4.40.2/src/transformers/models/dbrx/modeling_dbrx.py#L763
# Adapted from https://github.com/huggingface/transformers/blob/v4.40.2/src/transformers/models/dbrx/modeling_dbrx.py#L763
def _dbrx_experts_forward(
self, x: torch.Tensor, weights: torch.Tensor, top_weights: torch.Tensor, top_experts: torch.LongTensor
):
Expand Down Expand Up @@ -1606,7 +1604,7 @@ def _dbrx_experts_forward(
return out


# adapted from https://github.com/huggingface/transformers/blob/v4.40.2/src/transformers/models/dbrx/modeling_dbrx.py#L1228
# Adapted from https://github.com/huggingface/transformers/blob/v4.40.2/src/transformers/models/dbrx/modeling_dbrx.py#L1228
def _dbrx_update_causal_mask_legacy(
self, attention_mask: Optional[torch.Tensor], input_tensor: torch.Tensor, cache_position: torch.Tensor
) -> Optional[torch.Tensor]:
Expand Down Expand Up @@ -1803,6 +1801,7 @@ def __exit__(self, exc_type, exc_value, traceback):
block.ffn.experts.forward = block.ffn.experts._orig_forward


# Adapted from https://github.com/huggingface/transformers/blob/v4.41.0/src/transformers/models/persimmon/modeling_persimmon.py#L264
def _persimmon_self_attn_sdpa_forward(
self,
hidden_states: torch.Tensor,
Expand All @@ -1811,6 +1810,7 @@ def _persimmon_self_attn_sdpa_forward(
past_key_value: Optional["Cache"] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
from transformers.models.persimmon.modeling_persimmon import apply_rotary_pos_emb

Expand Down Expand Up @@ -1865,14 +1865,23 @@ def _persimmon_self_attn_sdpa_forward(

if past_key_value is not None:
# Specific to RoPE models with partial rotation
cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
cache_kwargs = {
"sin": sin,
"cos": cos,
"partial_rotation_size": self.rotary_emb.dim,
"cache_position": cache_position,
}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

causal_mask = attention_mask
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]

attn_output = F.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attention_mask,
causal_mask,
scale=1 / math.sqrt(self.head_dim),
dropout_p=self.attention_dropout.p,
)
Expand Down
1 change: 1 addition & 0 deletions optimum/intel/generation/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ class BaseModelForCausalLM(OptimizedModel, GenerationMixin):
export_feature = "text-generation"
main_input_name = "input_ids"
base_model_prefix = "torch_script_model"
_supports_cache_class = False

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions optimum/intel/ipex/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ class IPEXModel(OptimizedModel):
base_model_prefix = "ipex_model"
main_input_name = "input_ids"
output_name = "last_hidden_state"
_supports_cache_class = False

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions optimum/intel/neural_compressor/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class INCModel(OptimizedModel):
auto_model_class = AutoModel
export_feature = "feature-extraction"
base_model_prefix = "inc_model"
_supports_cache_class = False

def __init__(
self,
Expand Down
2 changes: 2 additions & 0 deletions optimum/intel/neural_compressor/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,7 @@ def _inner_training_loop(
def save_model(
self,
output_dir: Optional[str] = None,
_internal_call: bool = False,
save_onnx_model: bool = False,
):
"""
Expand All @@ -696,6 +697,7 @@ def save_model(
output_dir=output_dir,
save_onnx_model=save_onnx_model,
)
# TODO: push to hub if self.args.push_to_hub and not _internal_call

def _save(
self,
Expand Down
1 change: 1 addition & 0 deletions optimum/intel/openvino/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
class OVBaseModel(OptimizedModel):
auto_model_class = None
export_feature = None
_supports_cache_class = False

def __init__(
self,
Expand Down
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@

INSTALL_REQUIRE = [
"torch>=1.11",
"transformers>=4.36.0,<4.42.0",
"optimum~=1.20",
"transformers>=4.36.0,<4.43.0",
"optimum~=1.21",
# "optimum>=1.21.2,<1.22.0",
"datasets>=1.4.0",
"sentencepiece",
"setuptools",
Expand Down
1 change: 0 additions & 1 deletion tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,7 +697,6 @@ def test_compare_to_transformers(self, model_arch):
ov_model = OVModelForCausalLM.from_pretrained(model_id, export=True, ov_config=F32_CONFIG, **model_kwargs)
self.assertIsInstance(ov_model.config, PretrainedConfig)
self.assertTrue(ov_model.use_cache)
self.assertEqual(ov_model.stateful, ov_model.config.model_type not in not_stateful)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS)
tokens = tokenizer("This is a sample output", return_tensors="pt")
tokens.pop("token_type_ids", None)
Expand Down
7 changes: 2 additions & 5 deletions tests/openvino/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@

class OVQuantizerTest(unittest.TestCase):
SUPPORTED_ARCHITECTURES_TORCH_MODEL = (
(OVModelForSequenceClassification, "bert", 22, 35),
(OVModelForCausalLM, "gpt2", 41, 3),
(OVModelForSequenceClassification, "bert", 32 if is_transformers_version("<", "4.41.0") else 22, 35),
(OVModelForCausalLM, "gpt2", 41 if is_transformers_version("<", "4.42.0") else 21, 3),
)
SUPPORTED_ARCHITECTURES_OV_MODEL = (
(OVModelForSequenceClassification, "bert", 32, 35),
Expand All @@ -90,9 +90,6 @@ def test_automodel_static_quantization(self, model_cls, model_name, expected_fak
dataset_name, dataset_config_name, column_name = _TASK_TO_DATASET[task]
file_name = "openvino_quantized_model.xml"

if model_name == "bert" and is_transformers_version("<", "4.41.0"):
expected_fake_quantize = 32

def preprocess_function(examples, tokenizer):
return tokenizer(examples[column_name], padding="max_length", max_length=128, truncation=True)

Expand Down
Loading

0 comments on commit b1c1900

Please sign in to comment.