Skip to content

Commit

Permalink
add patching for update_causal_mask to falcon for >= 4.45 (#989)
Browse files Browse the repository at this point in the history
* add patching for update_causal_mask to falcon and gpt-like models for >=4.45

* fix falcon

* enable codegen2 back

* Apply suggestions from code review

Co-authored-by: Nikita Savelyev <nikita.savelyev@intel.com>

* Update optimum/exporters/openvino/model_patcher.py

---------

Co-authored-by: Nikita Savelyev <nikita.savelyev@intel.com>
  • Loading branch information
eaidova and nikita-savelyevv authored Nov 12, 2024
1 parent 12783ee commit 0447ae2
Show file tree
Hide file tree
Showing 3 changed files with 237 additions and 16 deletions.
20 changes: 20 additions & 0 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
CodeGenOnnxConfig,
FalconOnnxConfig,
GemmaOnnxConfig,
GPTJOnnxConfig,
GPTNeoXOnnxConfig,
IBertOnnxConfig,
LlamaOnnxConfig,
Expand Down Expand Up @@ -66,6 +67,7 @@
FalconModelPatcher,
FluxTransfromerModelPatcher,
Gemma2ModelPatcher,
GptJModelPatcher,
GptNeoxJapaneseModelPatcher,
GptNeoxModelPatcher,
IBertModelPatcher,
Expand Down Expand Up @@ -726,6 +728,24 @@ def patch_model_for_export(
return GptNeoxJapaneseModelPatcher(self, model, model_kwargs=model_kwargs)


@register_in_tasks_manager(
"gptj",
*[
"feature-extraction",
"feature-extraction-with-past",
"text-generation",
"text-generation-with-past",
"text-classification",
],
library_name="transformers",
)
class GPTJOpenVINOConfig(GPTJOnnxConfig):
def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return GptJModelPatcher(self, model, model_kwargs=model_kwargs)


@register_in_tasks_manager(
"cohere",
*[
Expand Down
228 changes: 216 additions & 12 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,20 @@ def patch_model_with_bettertransformer(model):
return model


def patch_update_causal_mask(model, transformers_version):
def patch_update_causal_mask(model, transformers_version, inner_model_name="model", patch_fn=None):
if is_transformers_version(">=", transformers_version):
inner_model = getattr(model, "model", getattr(model, "transformer", None))
inner_model = getattr(model, inner_model_name, None)
if inner_model is not None:
inner_model._update_causal_mask = types.MethodType(_llama_gemma_update_causal_mask, inner_model)
if hasattr(inner_model, "_update_causal_mask"):
inner_model._orig_update_causal_mask = inner_model._update_causal_mask
patch_fn = patch_fn or _llama_gemma_update_causal_mask
inner_model._update_causal_mask = types.MethodType(patch_fn, inner_model)


def unpatch_update_causal_mask(model, inner_model_name="model"):
inner_model = getattr(model, inner_model_name, None)
if inner_model is not None and hasattr(inner_model, "._orig_update_causal_mask"):
inner_model._update_causal_mask = inner_model._orig_update_causal_mask


# initialization of sin/cos cached in bf16/fp16 leads to accuracy loss
Expand Down Expand Up @@ -579,13 +588,11 @@ def __enter__(self):

# llama/gemma has some accuracy issues with bf16 with transformers >= 4.39
# fill causal mask in slightly different way for avoid overflow on some platforms
patch_update_causal_mask(self._model, "4.39.0")
patch_update_causal_mask(self._model, "4.39.0", "model" if hasattr(self._model, "model") else "transformer")

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
inner_model = getattr(self._model, "model", getattr(self._model, "transformer", None))
if hasattr(inner_model, "_orig_update_causal_mask"):
inner_model._update_causal_mask = inner_model._orig_update_causal_mask
unpatch_update_causal_mask(self._model, "model" if hasattr(self._model, "model") else "transformer")


# copied from https://github.com/huggingface/transformers/commit/57d7594a79a9f5d835abf2d4d384db0e4818e548 to unblock export with transformers 4.42
Expand Down Expand Up @@ -1865,6 +1872,67 @@ def __exit__(self, exc_type, exc_value, traceback):
layer.self_attn.forward = layer.self_attn._orig_forward


# copied from https://github.com/huggingface/optimum/blob/2112e99122d7f23a1da1a9d263fef64301050ea7/optimum/bettertransformer/models/attention.py#L168
# for preserving backward compatibility between outdated codegen remote code and new transformers
def _codegen_wrapped_scaled_dot_product_legacy(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
):
from optimum.bettertransformer.models.attention import raise_on_head_mask

raise_on_head_mask(head_mask)
batch_size = query.shape[0]
mask_value = torch.finfo(value.dtype).min
mask_value = torch.full([], mask_value, dtype=value.dtype)

if batch_size == 1 and attention_mask is not None and attention_mask[0, 0, -1, -1] < -1:
raise ValueError("BetterTransformer does not support padding='max_length' with a batch size of 1.")

# in codegen the query and key are always in fp32 regardless of the dtype of the model
# https://github.com/huggingface/transformers/blob/5b28b7833297adf65c5160a685425ddb1eee5ce2/src/transformers/models/codegen/modeling_codegen.py#L226
query = query.to(value.dtype)
key = key.to(value.dtype)

dropout_p = self.dropout_prob_attn if self.training else 0.0
if batch_size == 1 or self.training:
if query.shape[2] > 1:
# first step of the decoding
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=True
)
else:
# in this case, which is the later decoding steps, the `causal_mask`` in
# https://github.com/huggingface/transformers/blob/ae54e3c3b18bac0832ad62ea9b896dfd52a09850/src/transformers/models/gpt2/modeling_gpt2.py#L195
# is [True, ..., True] so actually not causal
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=False
)
else:
query_length, key_length = query.size(-2), key.size(-2)

# causal_mask is always [True, ..., True] otherwise, so executing this is unnecessary
if query_length > 1:
causal_mask = self.causal_mask[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)

causal_mask = torch.where(causal_mask, 0, mask_value)

# torch.Tensor.expand does no memory copy
causal_mask = causal_mask.expand(batch_size, -1, -1, -1)

# we use torch.min to avoid having tensor(-inf)
attention_mask = torch.min(causal_mask, attention_mask)

sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=dropout_p, is_causal=False
)

return sdpa_result, None


class CodeGenModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()
Expand All @@ -1873,14 +1941,23 @@ def __enter__(self):
# For avoiding breaking model on tracing stage, we reduce area of bettertransformer patch only for _attn.
from optimum.bettertransformer.models.attention import codegen_wrapped_scaled_dot_product

attn_fn = codegen_wrapped_scaled_dot_product
if is_torch_version(">=", "2.1.0") and is_transformers_version(">=", "4.45"):
# in transformers 4.45 causal_mask const buffer was removed from the model
# if it still exists, it means legacy remote code was loaded
if hasattr(self._model.transformer.h[0].attn, "causal_mask"):
attn_fn = _codegen_wrapped_scaled_dot_product_legacy

for layer in self._model.transformer.h:
if is_torch_version(">=", "2.1.0") and not self._model.config.output_attentions:
orig_self_attn_fwd = layer.attn._attn
layer.attn._attn = types.MethodType(codegen_wrapped_scaled_dot_product, layer.attn)
layer.attn._attn = types.MethodType(attn_fn, layer.attn)
layer.attn._orig_attn = orig_self_attn_fwd
patch_update_causal_mask(self._model, "4.45.0", "transformer")

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
unpatch_update_causal_mask(self._model, "transformer")
for layer in self._model.transformer.h:
if hasattr(layer.attn, "_orig_attn"):
layer.attn._attn = layer.attn._orig_attn
Expand Down Expand Up @@ -2275,8 +2352,7 @@ def __enter__(self):

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if hasattr(self._model.model, "_orig_update_causal_mask"):
self._model.model._update_causal_mask = self._model.model._orig_update_causal_mask
unpatch_update_causal_mask(self._model)
for layer in self._model.model.layers:
if hasattr(layer.self_attn, "_orig_forward"):
layer.self_attn.forward = layer.self_attn._orig_forward
Expand Down Expand Up @@ -2413,8 +2489,7 @@ def __enter__(self):

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if hasattr(self._model.model, "_orig_update_causal_mask"):
self._model.model._update_causal_mask = self._model.model._orig_update_causal_mask
unpatch_update_causal_mask(self._model)


class RotaryEmbPatcher(DecoderModelPatcher):
Expand All @@ -2425,12 +2500,119 @@ def __enter__(self):
_reinitialize_cos_sin_cached_fp32(layer.self_attn.rotary_emb)


def _falcon_update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: "Cache",
output_attentions: bool,
head_mask: torch.Tensor,
alibi: torch.Tensor,
):
# copied from https://github.com/huggingface/transformers/blob/a30c865f991dfec9452cc64bd9a97bfbb96be036/src/transformers/models/falcon/modeling_falcon.py#L1130
from transformers.cache_utils import StaticCache
from transformers.modeling_attn_mask_utils import AttentionMaskConverter

# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114

if hasattr(self, "_prepare_4d_causal_attention_mask_with_cache_position"):
_prepare_4d_causal_attention_mask_with_cache_position = (
self._prepare_4d_causal_attention_mask_with_cache_position
)
else:
from transformers.models.falcon.modeling_falcon import _prepare_4d_causal_attention_mask_with_cache_position

if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None

# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)

# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if (
self.config._attn_implementation == "sdpa"
and not using_static_cache
and not output_attentions
and head_mask is None
and alibi is None
):
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None

dtype, device = input_tensor.dtype, input_tensor.device
# difference from original, replace torch.finfo(dtype).min to float16 for prevent overflow for fp16/bf16 execution
min_dtype = torch.finfo(torch.float16).min
batch_size, sequence_length, _ = input_tensor.shape
if using_static_cache:
target_length = past_key_values.get_max_length()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length
)

# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
min_dtype=min_dtype,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)

# We take care to integrate alibi bias in the causal_mask here
if head_mask is None and alibi is not None:
alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:])
causal_mask = torch.masked_fill(
alibi / math.sqrt(self.config.hidden_size // self.num_heads),
causal_mask < -1,
min_dtype,
)

if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)

return causal_mask


class FalconModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()
if is_transformers_version("<", "4.44.99"):
for layer in self._model.transformer.h:
_reinitialize_cos_sin_cached_fp32(layer.self_attention.rotary_emb)
else:
patch_update_causal_mask(self._model, "4.45.0", "transformer", _falcon_update_causal_mask)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
unpatch_update_causal_mask(self._model, "transformer")


class GptNeoxModelPatcher(DecoderModelPatcher):
Expand All @@ -2439,6 +2621,22 @@ def __enter__(self):
if is_transformers_version("<", "4.44.99"):
for layer in self._model.gpt_neox.layers:
_reinitialize_cos_sin_cached_fp32(layer.attention.rotary_emb)
else:
patch_update_causal_mask(self._model, "4.45.0", "gpt_neox")

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
unpatch_update_causal_mask(self._model, "gpt_neox")


class GptJModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()
patch_update_causal_mask(self._model, "4.45.0", "transformer")

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
unpatch_update_causal_mask(self._model, "transformer")


class GptNeoxJapaneseModelPatcher(DecoderModelPatcher):
Expand All @@ -2447,6 +2645,12 @@ def __enter__(self):
if is_transformers_version("<", "4.44.99"):
for layer in self._model.gpt_neox_japanese.layers:
_reinitialize_cos_sin_cached_fp32(layer.attention.rotary_emb)
else:
patch_update_causal_mask(self._model, "4.45.0", "gpt_neox_japanese")

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
unpatch_update_causal_mask(self._model, "gpt_neox_japanese")


class Gemma2ModelPatcher(LlamaModelPatcher):
Expand Down
5 changes: 1 addition & 4 deletions tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,6 +773,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
"bloom",
"chatglm",
"codegen",
"codegen2",
"gpt2",
"gpt_neo",
"gpt_neox",
Expand Down Expand Up @@ -821,10 +822,6 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
"mistral-nemo",
)

# custom modeling defined in https://huggingface.co/katuni4ka/tiny-random-codegen2 differs from transformers after v4.45 resulting in unadapted patching
if is_transformers_version("<", "4.45.0"):
SUPPORTED_ARCHITECTURES += ("codegen2",)

GENERATION_LENGTH = 100
REMOTE_CODE_MODELS = (
"chatglm",
Expand Down

0 comments on commit 0447ae2

Please sign in to comment.