Skip to content

Commit

Permalink
Fix bf16 inference accuracy for mistral, phi3, dbrx (#833)
Browse files Browse the repository at this point in the history
* Fix bf16 inference accuracy for mistral, phi3, dbrx

* reuse inv_freq

* Apply suggestions from code review

Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com>

* make dim and base optional

* fix model patcher for dbrx and add bitwise fix for mistral

---------

Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com>
  • Loading branch information
2 people authored and IlyasMoutawwakil committed Aug 6, 2024
1 parent 5c1c81c commit 55fec6a
Show file tree
Hide file tree
Showing 2 changed files with 207 additions and 20 deletions.
20 changes: 20 additions & 0 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
FalconOnnxConfig,
GemmaOnnxConfig,
LlamaOnnxConfig,
MistralOnnxConfig,
MPTOnnxConfig,
PhiOnnxConfig,
UNetOnnxConfig,
Expand Down Expand Up @@ -53,6 +54,7 @@
InternLMModelPatcher,
JaisModelPatcher,
LlamaModelPatcher,
MistralModelPatcher,
MixtralModelPatcher,
MPTModelPatcher,
PersimmonModelPatcher,
Expand Down Expand Up @@ -839,3 +841,21 @@ def patch_model_for_export(
)

return ArcticModelPatcher(self, model, model_kwargs=model_kwargs)


@register_in_tasks_manager(
"mistral",
*[
"feature-extraction",
"feature-extraction-with-past",
"text-generation",
"text-generation-with-past",
"text-classification",
],
library_name="transformers",
)
class MistralOpenVINOConfig(MistralOnnxConfig):
def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return MistralModelPatcher(self, model, model_kwargs=model_kwargs)
207 changes: 187 additions & 20 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,39 @@ def llama_gemma_rotary_emb_forward(self, x, position_ids, seq_len=None):
return cos, sin


def create_sinusoidal_positions(num_pos: int, dim: int, base: int = 10000, inv_freq=None) -> torch.Tensor:
# adopted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L101
if inv_freq is None:
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64) / dim))

sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=torch.int64).float(), inv_freq).float()
emb = torch.cat((sinusoid_inp, sinusoid_inp), dim=-1)
return torch.cat((torch.sin(emb), torch.cos(emb)), dim=1)


def register_sin_cos_buffer(model):
max_positions = model.config.max_position_embeddings

# cos/sin for rotary position embeddings also having issues with bf16 and efficiency due to calculation on each step
# use precomputed

rotary_emb = model.model.layers[0].self_attn.rotary_emb
dim, base = None, None
inv_freq = getattr(rotary_emb, "inv_freq", None)
if inv_freq is None:
base = rotary_emb.base
dim = rotary_emb.dim
embed_positions = create_sinusoidal_positions(max_positions, dim, base, inv_freq)

for layer in model.model.layers:
layer.self_attn.rotary_emb.register_buffer("embed_positions", embed_positions)
layer.self_attn.rotary_emb._orig_forward = layer.self_attn.rotary_emb.forward

layer.self_attn.rotary_emb.forward = types.MethodType(
llama_gemma_rotary_emb_forward, layer.self_attn.rotary_emb
)


class LlamaModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()
Expand All @@ -521,39 +554,148 @@ def __enter__(self):
self._model.model._update_causal_mask = types.MethodType(
_llama_gemma_update_causal_mask, self._model.model
)
register_sin_cos_buffer(self._model)

max_positions = self._model.config.max_position_embeddings
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if hasattr(self._model.model, "_orig_update_causal_mask"):
self._model.model._update_causal_mask = self._model.model._orig_update_causal_mask

# cos/sin for rotary position embeddings also having issues with bf16 and efficiency due to calculation on each step
# use precomputed
def create_sinusoidal_positions(num_pos: int, dim: int, base: int = 10000) -> torch.Tensor:
# adopted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L101
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64) / dim))
for layer in self._model.model.layers:
layer.self_attn.rotary_emb.forward = layer.self_attn.rotary_emb._orig_forward

sinusoid_inp = torch.einsum(
"i , j -> i j", torch.arange(num_pos, dtype=torch.int64).float(), inv_freq
).float()
emb = torch.cat((sinusoid_inp, sinusoid_inp), dim=-1)
return torch.cat((torch.sin(emb), torch.cos(emb)), dim=1)

base = self._model.model.layers[0].self_attn.rotary_emb.base
dim = self._model.model.layers[0].self_attn.rotary_emb.dim
embed_positions = create_sinusoidal_positions(max_positions, dim, base)
# copied from https://github.com/huggingface/transformers/commit/57d7594a79a9f5d835abf2d4d384db0e4818e548 to unblock export with transformers 4.42
def _mistral_update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: "Cache",
use_cache: bool,
output_attentions: bool,
):
from transformers.cache_utils import SlidingWindowCache, StaticCache
from transformers.modeling_attn_mask_utils import AttentionMaskConverter

for layer in self._model.model.layers:
layer.self_attn.rotary_emb.register_buffer("embed_positions", embed_positions)
layer.self_attn.rotary_emb._orig_forward = layer.self_attn.rotary_emb.forward
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114

if self._attn_implementation == "flash_attention_2":
if attention_mask is not None and use_cache:
is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
if is_padding_right:
raise ValueError(
"You are attempting to perform batched generation with padding_side='right'"
" this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to "
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
)
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None

# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.

# cache_position must be valid here no matter which cache we use
past_seen_tokens = cache_position[0] if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)

if (
self.config._attn_implementation == "sdpa"
and not (using_static_cache or using_sliding_window_cache)
and not output_attentions
):
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
sliding_window=self.config.sliding_window,
is_training=self.training,
):
return None

layer.self_attn.rotary_emb.forward = types.MethodType(
llama_gemma_rotary_emb_forward, layer.self_attn.rotary_emb
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
# SlidingWindowCache
if using_sliding_window_cache:
target_length = max(sequence_length, self.config.sliding_window)
# StaticCache
elif using_static_cache:
target_length = past_key_values.get_max_length()
# DynamicCache or no cache
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)

if attention_mask is not None and attention_mask.dim() == 4:
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
if attention_mask.max() != 0:
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
causal_mask = attention_mask
else:
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
exclude_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
if self.config.sliding_window is not None:
if not using_sliding_window_cache or sequence_length > self.config.sliding_window:
exclude_mask = exclude_mask.bitwise_or(
torch.arange(target_length, device=device)
<= (cache_position.reshape(-1, 1) - self.config.sliding_window)
)
causal_mask *= exclude_mask
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
if attention_mask.dim() == 2:
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)

if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)

return causal_mask


class MistralModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()
if is_transformers_version(">=", "4.42.0"):
# apply fix https://github.com/huggingface/transformers/commit/57d7594a79a9f5d835abf2d4d384db0e4818e548
self._model.model._orig_update_causal_mask = self._model.model._update_causal_mask
self._model.model._update_causal_mask = types.MethodType(_mistral_update_causal_mask, self._model.model)

# mistral has some accuracy issues with bf16 with transformers >= 4.42
# prefill rotary emb sin/cos for avoid this issue
register_sin_cos_buffer(self._model)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)

if hasattr(self._model.model, "_orig_update_causal_mask"):
self._model.model._update_causal_mask = self._model.model._orig_update_causal_mask

for layer in self._model.model.layers:
for layer in self._model.model.layers:
if hasattr(layer.self_attn.rotary_emb, "_orig_forward"):
layer.self_attn.rotary_emb.forward = layer.self_attn.rotary_emb._orig_forward


Expand Down Expand Up @@ -1283,11 +1425,15 @@ def __enter__(self):
rotary_emb.base ** (torch.arange(0, rotary_emb.dim, 2, dtype=torch.int64).float() / rotary_emb.dim)
)

# phi3 has issue with bf16 inference, precollect sin/cos for rotary_position_embedding for avoid accuracy issues
register_sin_cos_buffer(self._model)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
for layer in self._model.model.layers:
if hasattr(layer.self_attn, "_orig_forward"):
layer.self_attn.forward = layer.self_attn._orig_forward
layer.self_attn.rotary_emb.forward = layer.self_attn.rotary_emb._orig_forward


def _aquila_self_attn_sdpa_forward(
Expand Down Expand Up @@ -1807,6 +1953,18 @@ def __enter__(self):
_dbrx_update_causal_mask, self._model.transformer
)

# starting from transformers 4.41 issue also observable for calculation sin/cos for rotary_emb
patch_rope_sin_cos = is_transformers_version(">=", "4.41.0")

inv_freq = getattr(self._model.transformer.blocks[0].norm_attn_norm.attn.rotary_emb, "inv_freq")
dim, base = None, None
if inv_freq is None:
dim = self._model.transformer.blocks[0].norm_attn_norm.attn.rotary_emb.dim
base = self._model.transformer.blocks[0].norm_attn_norm.attn.rotary_emb.base
max_positions = self._model.config.max_seq_len
if patch_rope_sin_cos:
embed_positions = create_sinusoidal_positions(max_positions, dim, base, inv_freq)

for block in self._model.transformer.blocks:
rotary_emb = block.norm_attn_norm.attn.rotary_emb
# initialize inv_freq for torchscript tracing
Expand All @@ -1815,6 +1973,12 @@ def __enter__(self):
rotary_emb.base ** (torch.arange(0, rotary_emb.dim, 2, dtype=torch.int64).float() / rotary_emb.dim)
)
rotary_emb.inv_freq = inv_freq

if patch_rope_sin_cos:
rotary_emb.register_buffer("embed_positions", embed_positions)
rotary_emb._orig_forward = rotary_emb.forward
rotary_emb.forward = types.MethodType(llama_gemma_rotary_emb_forward, rotary_emb)

# remove continue-operator from iteration loop over experts
block.ffn.experts._orig_forward = block.ffn.experts.forward
block.ffn.experts.forward = types.MethodType(_dbrx_experts_forward, block.ffn.experts)
Expand All @@ -1825,6 +1989,9 @@ def __exit__(self, exc_type, exc_value, traceback):
for block in self._model.transformer.blocks:
block.ffn.experts.forward = block.ffn.experts._orig_forward

if hasattr(block.norm_attn_norm.attn.rotary_emb, "_orig_forward"):
block.norm_attn_norm.attn.rotary_emb.forward = block.norm_attn_norm.attn.rotary_emb._orig_forward


# Adapted from https://github.com/huggingface/transformers/blob/v4.41.0/src/transformers/models/persimmon/modeling_persimmon.py#L264
def _persimmon_self_attn_sdpa_forward(
Expand Down

0 comments on commit 55fec6a

Please sign in to comment.