Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix olmo accuracy for bf16, add sdpa for persimmon, support jais #726

Merged
merged 4 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 35 additions & 1 deletion optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,19 @@

from .model_patcher import (
AquilaModelPatcher,
ArcticModelPatcher,
BaichuanModelPatcher,
ChatGLMModelPatcher,
CodeGenModelPatcher,
DBRXModelPatcher,
GemmaModelPatcher,
InternLM2Patcher,
InternLMModelPatcher,
JaisModelPatcher,
LlamaModelPatcher,
MixtralModelPatcher,
MPTModelPatcher,
PersimmonModelPatcher,
Phi3ModelPatcher,
QwenModelPatcher,
XverseModelPatcher,
Expand Down Expand Up @@ -473,7 +476,7 @@ class OrionOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):


@register_in_tasks_manager("olmo", *["text-generation", "text-generation-with-past"], library_name="transformers")
class OlmoOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
class OlmoOpenVINOConfig(LlamaOpenVINOConfig):
DEFAULT_ONNX_OPSET = 14
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig

Expand Down Expand Up @@ -630,6 +633,11 @@ class PersimmonOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
DEFAULT_ONNX_OPSET = 14
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return PersimmonModelPatcher(self, model, model_kwargs=model_kwargs)


@register_in_tasks_manager("biogpt", *["text-generation", "text-generation-with-past"], library_name="transformers")
class BioGPTOpenVINOConfig(TextDecoderOnnxConfig):
Expand Down Expand Up @@ -785,3 +793,29 @@ def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return DBRXModelPatcher(self, model, model_kwargs=model_kwargs)


@register_in_tasks_manager(
"jais",
*["text-generation", "text-generation-with-past"],
library_name="transformers",
)
class JaisOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
DEFAULT_ONNX_OPSET = 14

NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, DummyPastKeyValuesGenerator)
DUMMY_PKV_GENERATOR_CLASS = DummyPastKeyValuesGenerator

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return JaisModelPatcher(self, model, model_kwargs=model_kwargs)


@register_in_tasks_manager("arctic", *["text-generation", "text-generation-with-past"], library_name="transformers")
class ArcticOpenVINOConfig(MixtralOpenVINOConfig):
def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return ArcticModelPatcher(self, model, model_kwargs=model_kwargs)
230 changes: 230 additions & 0 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,17 @@ def __exit__(self, exc_type, exc_value, traceback):
layer.block_sparse_moe.forward = layer.block_sparse_moe._unpatched_forward


class ArcticModelPatcher(MixtralModelPatcher):
def __enter__(self):
# model initialize some weights for matrix multiplication in bfloat16, that lead to inconsistency of dtype
try:
self._model.to(torch.float32)
except Exception:
pass

super().__enter__()


def _chatglm_transformer_forward(
self,
input_ids,
Expand Down Expand Up @@ -1771,3 +1782,222 @@ def __exit__(self, exc_type, exc_value, traceback):
self._model.transformer._update_causal_mask = self._model.transformer._orig_update_causal_mask
for block in self._model.transformer.blocks:
block.ffn.experts.forward = block.ffn.experts._orig_forward


def _persimmon_self_attn_sdpa_forward(
eaidova marked this conversation as resolved.
Show resolved Hide resolved
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
from transformers.models.persimmon.modeling_persimmon import apply_rotary_pos_emb

if output_attentions:
return self._orig_forward(
hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache
)

bsz, q_len, _ = hidden_states.size()

# [batch_size, seq_length, 3 x hidden_size]
fused_qkv = self.query_key_value(hidden_states)

# 3 x [batch_size, seq_length, num_heads, head_dim]
(query_states, key_states, value_states) = self._split_heads(fused_qkv)

if self.qk_layernorm:
query_states = self.q_layernorm(query_states)
key_states = self.k_layernorm(key_states)

# [batch_size, num_heads, seq_length, head_dim] -> [batch_size, seq_length, num_heads, head_dim]
query_states = query_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)

# Partial rotary embedding
query_rot, query_pass = (
query_states[..., : self.rotary_emb.dim],
query_states[..., self.rotary_emb.dim :],
)
key_rot, key_pass = (
key_states[..., : self.rotary_emb.dim],
key_states[..., self.rotary_emb.dim :],
)
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)

# [batch_size, seq_length, num_heads, head_dim]
query_states = torch.cat((query_rot, query_pass), dim=-1)
key_states = torch.cat((key_rot, key_pass), dim=-1)

if past_key_value is not None:
# Specific to RoPE models with partial rotation
cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

attn_output = F.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attention_mask,
scale=1 / math.sqrt(self.head_dim),
dropout_p=self.attention_dropout.p,
)

attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

attn_output = self.dense(attn_output)

return attn_output, None, past_key_value


class PersimmonModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()
for layer in self._model.model.layers:
if is_torch_version(">=", "2.1.0"):
orig_self_attn_fwd = layer.self_attn.forward
layer.self_attn.forward = types.MethodType(_persimmon_self_attn_sdpa_forward, layer.self_attn)
layer.self_attn._orig_forward = orig_self_attn_fwd

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
for layer in self._model.model.layers:
if hasattr(layer.self_attn, "_orig_forward"):
layer.self_attn.forward = layer.self_attn._orig_forward


def _jais_attn_forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
position_bias: Optional[torch.FloatTensor] = None,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
if encoder_hidden_states is not None:
if not hasattr(self, "q_attn"):
raise ValueError(
"If class is used as cross attention, the weights `q_attn` have to be defined. "
"Please make sure to instantiate class with `JAISAttention(..., is_cross_attention=True)`."
)

query = self.q_attn(hidden_states)
key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
attention_mask = encoder_attention_mask
else:
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)

query = self._split_heads(query, self.num_heads, self.head_dim)
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)

if layer_past is not None:
past_key, past_value = layer_past
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)

if use_cache is True:
present = (key, value)
else:
present = None

if self.reorder_and_upcast_attn:
attn_output, attn_weights = self._upcast_and_reordered_attn(
query, key, value, attention_mask, head_mask, position_bias
)
else:
# Difference with original: override attn realization with sdpa if not output_attentions
if not output_attentions:
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask, position_bias)
else:
attn_output, attn_weights = self._orig_attn(query, key, value, attention_mask, head_mask, position_bias)

attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)

outputs = (attn_output, present)
if output_attentions:
outputs += (attn_weights,)

return outputs


def _jais_attn(self, query, key, value, attention_mask=None, head_mask=None, position_bias=None):
scale = 1.0
if self.scale_attn_weights:
scale = 1 / self.head_dim**self.attn_scale_power

# Layer-wise attention scaling
if self.scale_attn_by_inverse_layer_idx:
scale = scale / float(self.layer_idx + 1)

query_length = query.size(-2)
attention_mask_sdpa = torch.ones(
(query.shape[0], query.shape[1], query.shape[2], key.shape[2]),
dtype=query.dtype,
)

if not self.is_cross_attention:
# if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
mask_value = torch.finfo(torch.float16).min
attention_mask_sdpa.masked_fill_(~causal_mask, mask_value)

if attention_mask is not None:
# Apply the attention mask
attention_mask_sdpa = attention_mask_sdpa + attention_mask

if position_bias is not None:
attention_mask_sdpa += position_bias.type_as(attention_mask_sdpa).unsqueeze(0)

# Mask heads if we want to
if head_mask is not None:
attention_mask_sdpa = attention_mask_sdpa * head_mask

attn_output = F.scaled_dot_product_attention(
query, key, value, attention_mask_sdpa, dropout_p=self.attn_dropout.p, scale=scale
)

return attn_output, None


class JaisModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()

for layer in self._model.transformer.h:
if is_torch_version(">=", "2.1.0"):
orig_self_attn_fwd = layer.attn._attn
layer.attn._attn = types.MethodType(_jais_attn, layer.attn)
layer.attn._orig_attn = orig_self_attn_fwd
layer.attn._orig_forward = layer.attn.forward
layer.attn.forward = types.MethodType(_jais_attn_forward, layer.attn)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
for layer in self._model.transformer.h:
if hasattr(layer.attn, "_orig_attn"):
layer.attn._attn = layer.attn._orig_attn
layer.attn.forward = layer.attn._orig_forward
8 changes: 7 additions & 1 deletion tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,8 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
"internlm",
"dbrx",
"qwen2-moe",
"jais",
"arctic",
)
GENERATION_LENGTH = 100
REMOTE_CODE_MODELS = (
Expand All @@ -581,6 +583,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
"xverse",
"internlm",
"codegen2",
"arctic",
)

@parameterized.expand(SUPPORTED_ARCHITECTURES)
Expand Down Expand Up @@ -622,7 +625,7 @@ def test_compare_to_transformers(self, model_arch):

set_seed(SEED)
transformers_model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)
if model_arch == "qwen":
if model_arch in ["qwen", "arctic"]:
transformers_model.to(torch.float32)

with torch.no_grad():
Expand Down Expand Up @@ -869,6 +872,9 @@ def test_beam_search(self, model_arch):
model_id, export=True, use_cache=True, stateful=False, **model_kwargs
)
transformers_model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)

if model_arch == "arctic":
transformers_model.to(torch.float32)
tokenizer.pad_token_id = tokenizer.eos_token_id
tokens = tokenizer(["Today is a nice day and I am longer", "This is me"], return_tensors="pt", padding=True)
tokens.pop("token_type_ids", None)
Expand Down
2 changes: 2 additions & 0 deletions tests/openvino/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
"ibert": "hf-internal-testing/tiny-random-ibert",
"internlm": "katuni4ka/tiny-random-internlm",
"internlm2": "katuni4ka/tiny-random-internlm2",
"jais": "katuni4ka/tiny-random-jais",
"levit": "hf-internal-testing/tiny-random-LevitModel",
"longt5": "hf-internal-testing/tiny-random-longt5",
"llama": "fxmarty/tiny-llama-fast-tokenizer",
Expand Down Expand Up @@ -109,6 +110,7 @@
"latent-consistency": "echarlaix/tiny-random-latent-consistency",
"sew": "hf-internal-testing/tiny-random-SEWModel",
"sew_d": "asapp/sew-d-tiny-100k-ft-ls100h",
"arctic": "katuni4ka/tiny-random-snowflake",
"swin": "hf-internal-testing/tiny-random-SwinModel",
"t5": "hf-internal-testing/tiny-random-t5",
"trocr": "microsoft/trocr-small-handwritten",
Expand Down
Loading