From 5351f4a1d411b6813e8ed47a826ffdc667491518 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 23 May 2024 13:05:07 -0400 Subject: [PATCH 01/23] ipex 2.3 released --- optimum/exporters/ipex/model_patcher.py | 8 ++++---- optimum/exporters/ipex/modeling_utils.py | 17 ++++++++++------- optimum/intel/ipex/modeling_base.py | 2 +- tests/ipex/test_modeling.py | 16 ---------------- 4 files changed, 15 insertions(+), 28 deletions(-) diff --git a/optimum/exporters/ipex/model_patcher.py b/optimum/exporters/ipex/model_patcher.py index 60ff3b721b..3996c2b23f 100644 --- a/optimum/exporters/ipex/model_patcher.py +++ b/optimum/exporters/ipex/model_patcher.py @@ -62,10 +62,10 @@ def patch_op(m, target_m, new_op_name, new_op): def _patch_llama_model(model): - if is_ipex_version("<", "2.5.0"): - raise ImportError("Only ipex version > 2.3.0 supports RotaryEmbedding and IndirectAccessKVCache") + if is_ipex_version("<", "2.3.0"): + raise ImportError("Only ipex version >= 2.3.0 supports RotaryEmbedding and IndirectAccessKVCacheAttention") - from intel_extension_for_pytorch.llm.modules import IndirectAccessKVCache, RotaryEmbedding + from intel_extension_for_pytorch.llm.modules import IndirectAccessKVCacheAttention, RotaryEmbedding ipex_rope = RotaryEmbedding( model.config.max_position_embeddings, @@ -73,7 +73,7 @@ def _patch_llama_model(model): model.config.rope_theta, model.config.architectures[0], ) - ipex_scale_dot_product = IndirectAccessKVCache(text_max_length=model.config.max_position_embeddings) + ipex_scale_dot_product = IndirectAccessKVCacheAttention(text_max_length=model.config.max_position_embeddings) patch_op(model, LlamaAttention, "ipex_rope", ipex_rope) patch_op(model, LlamaAttention, "ipex_scale_dot_product", ipex_scale_dot_product) diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index f75e559eaf..9d53caf4fc 100644 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -219,7 +219,7 @@ def _llama_model_forward( # Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L694 class _IPEXLlamaDecoderLayerRef(nn.Module): def __init__(self, module, config, distributed=False): - if is_ipex_version("<", "2.5.0"): + if is_ipex_version("<", "2.3.0"): raise ImportError("Only ipex version > 2.3.0 supports Linear2SiluMul and LinearAdd") from intel_extension_for_pytorch.llm.modules import Linear2SiluMul, LinearAdd @@ -278,7 +278,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, ) - if not self.distributed: + if hasattr(self, "mha_linear_add"): hidden_states = self.mha_linear_add(hidden_states, residual) else: hidden_states = self.self_attn.o_proj(hidden_states) @@ -288,12 +288,15 @@ def forward( residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) - mlp_gate = self.linear_silu_mul(hidden_states) - - if not self.distributed: - hidden_states = self.mlp_linear_add(mlp_gate, residual) + if hasattr(self, "linear_silu_mul"): + mlp_gate = self.linear_silu_mul(hidden_states) + if hasattr(self, "mlp_linear_add"): + hidden_states = self.mlp_linear_add(mlp_gate, residual) + else: + hidden_states = self.mlp.down_proj(mlp_gate) + hidden_states = residual + hidden_states else: - hidden_states = self.mlp.down_proj(mlp_gate) + hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index e929a4ddb8..c9d43e3dc0 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -63,7 +63,7 @@ def _is_patched_with_ipex(model, task): - if is_ipex_version("<", "2.5.0"): + if is_ipex_version("<", "2.3.0"): return False if isinstance(model, torch.jit.ScriptModule): diff --git a/tests/ipex/test_modeling.py b/tests/ipex/test_modeling.py index 2a2f18f6f8..7eb34ef47c 100644 --- a/tests/ipex/test_modeling.py +++ b/tests/ipex/test_modeling.py @@ -219,22 +219,6 @@ def test_pipeline(self, model_arch): self.assertEqual(pipe.device, model.device) self.assertTrue(all("This is a sample" in item["generated_text"] for item in outputs)) - @parameterized.expand(SUPPORTED_ARCHITECTURES) - def test_assisted_decoding(self, model_arch): - model_id = MODEL_NAMES[model_arch] - tokenizer = AutoTokenizer.from_pretrained(model_id) - ipex_model = IPEXModelForCausalLM.from_pretrained(model_id, export=True) - transformers_model = AutoModelForCausalLM.from_pretrained(model_id) - tokens = tokenizer("This is a sample input", return_tensors="pt") - ipex_output = ipex_model.generate(**tokens, do_sample=False) - ipex_output_assisted = ipex_model.generate(**tokens, do_sample=False, assistant_model=transformers_model) - transformers_output = transformers_model.generate(**tokens, do_sample=False) - transformers_output_assisted = transformers_model.generate( - **tokens, do_sample=False, assistant_model=ipex_model - ) - self.assertTrue(torch.equal(ipex_output, ipex_output_assisted)) - self.assertTrue(torch.equal(transformers_output, transformers_output_assisted)) - @parameterized.expand( grid_parameters( { From d1d0ca0a50ed697271c912a266b10ac71e4b5892 Mon Sep 17 00:00:00 2001 From: "Lin, Fanli" Date: Sat, 25 May 2024 14:22:33 -0400 Subject: [PATCH 02/23] refactor IPEXLlamaAttention --- optimum/exporters/ipex/model_patcher.py | 18 +- optimum/exporters/ipex/modeling_utils.py | 295 ++++++++++++++--------- 2 files changed, 184 insertions(+), 129 deletions(-) diff --git a/optimum/exporters/ipex/model_patcher.py b/optimum/exporters/ipex/model_patcher.py index 3996c2b23f..7f99dbddd1 100644 --- a/optimum/exporters/ipex/model_patcher.py +++ b/optimum/exporters/ipex/model_patcher.py @@ -13,7 +13,6 @@ # limitations under the License. from transformers.models.llama.modeling_llama import ( - LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, @@ -24,7 +23,6 @@ from .modeling_utils import ( _IPEXLlamaDecoderLayerRef, - _llama_attn_forward, _llama_layer_norm_forward, _llama_model_forward, ) @@ -63,24 +61,10 @@ def patch_op(m, target_m, new_op_name, new_op): def _patch_llama_model(model): if is_ipex_version("<", "2.3.0"): - raise ImportError("Only ipex version >= 2.3.0 supports RotaryEmbedding and IndirectAccessKVCacheAttention") - - from intel_extension_for_pytorch.llm.modules import IndirectAccessKVCacheAttention, RotaryEmbedding - - ipex_rope = RotaryEmbedding( - model.config.max_position_embeddings, - model.config.hidden_size // model.config.num_attention_heads, - model.config.rope_theta, - model.config.architectures[0], - ) - ipex_scale_dot_product = IndirectAccessKVCacheAttention(text_max_length=model.config.max_position_embeddings) - patch_op(model, LlamaAttention, "ipex_rope", ipex_rope) - patch_op(model, LlamaAttention, "ipex_scale_dot_product", ipex_scale_dot_product) + raise ImportError("Only ipex version >= 2.3.0 supports llama model patching") convert_functions(model, LlamaModel, "forward", _llama_model_forward) - convert_functions(model, LlamaAttention, "forward", _llama_attn_forward) convert_functions(model, LlamaRMSNorm, "forward", _llama_layer_norm_forward) - convert_class(model, LlamaDecoderLayer, _IPEXLlamaDecoderLayerRef, model.config) return model diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index 9d53caf4fc..04a6fe5082 100644 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -29,90 +29,6 @@ def _llama_layer_norm_forward(self, hidden_states): return torch.ops.torch_ipex.rmsnorm(hidden_states, self.weight, self.variance_epsilon) -# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L321 -def _llama_attn_forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query = self.q_proj(hidden_states) - key = self.k_proj(hidden_states) - value = self.v_proj(hidden_states) - - kv_seq_len = q_len + past_key_value[0].size(-2) if past_key_value is not None else q_len - - query = query.view(bsz, q_len, self.num_heads, self.head_dim) - key = key.view(bsz, q_len, self.num_key_value_heads, self.head_dim) - value = value.view(bsz, q_len, self.num_key_value_heads, self.head_dim) - # Use ipex op to rotary position embedding more efficient. - key = self.ipex_rope( - key, - position_ids, - self.num_key_value_heads, - self.head_dim, - self.head_dim // 2, - self.head_dim, - kv_seq_len, - ) - query = self.ipex_rope( - query, - position_ids, - self.num_heads, - self.head_dim, - self.head_dim // 2, - self.head_dim, - kv_seq_len, - ) - - if use_cache: - # This ipex op pre-allocates buffers for past_key_values and use beam index history - # which to decide which beam should be used to make attention scale dot more efficient. - (attn_output, attn_weights, past_key_value) = self.ipex_scale_dot_product( - query, - key, - value, - math.sqrt(self.head_dim), - past_key_value, - None, - attention_mask, - ) - else: - value_states = value.transpose(1, 2) - query_states = query.transpose(1, 2) - key_states = key.transpose(1, 2) - kv_seq_len = key_states.shape[-2] - - past_key_value = None - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: - attn_weights = torch.tensor(attn_weights) + torch.tensor(attention_mask) - attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) - - attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - # Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L1130 def _llama_model_forward( self, @@ -216,12 +132,147 @@ def _llama_model_forward( ) -# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L694 -class _IPEXLlamaDecoderLayerRef(nn.Module): - def __init__(self, module, config, distributed=False): +# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L321 +class _IPEXLlamaAttentionRef(nn.Module): + def __init__(self, module, config, distributed=False) -> None: if is_ipex_version("<", "2.3.0"): - raise ImportError("Only ipex version > 2.3.0 supports Linear2SiluMul and LinearAdd") + raise ImportError( + "Only ipex version > 2.3.0 supports LinearAdd, IndirectAccessKVCacheAttention, RotaryEmbedding" + ) + from intel_extension_for_pytorch.llm.modules import IndirectAccessKVCacheAttention, LinearAdd, RotaryEmbedding + + super().__init__() + for k, v in module.__dict__.items(): + setattr(self, k, v) + for k, v in module.__class__.__dict__.items(): + if k.startswith("__") or k.startswith("forward"): + continue + setattr(self.__class__, k, getattr(module.__class__, k)) + self.config = config + self.distributed = distributed + if not self.distributed: + self.mha_linear_add = LinearAdd(self.o_proj) + del self.__dict__["_modules"]["o_proj"] + self.ipex_scale_dot_product = IndirectAccessKVCacheAttention( + text_max_length=module.config.max_position_embeddings + ) + self.ipex_rope = RotaryEmbedding( + module.config.max_position_embeddings, + module.config.hidden_size // module.config.num_attention_heads, + module.config.rope_theta, + module.config.architectures[0], + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + residual: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + residual (`torch.Tensor`): residual tensor to the layer of shape ` + """ + bsz, seq_len, _ = hidden_states.size() + + query = self.q_proj(hidden_states) + key = self.k_proj(hidden_states) + value = self.v_proj(hidden_states) + kv_seq_len = seq_len + past_key_value[0].size(-2) if past_key_value is not None else seq_len + + query = query.view(bsz, seq_len, self.num_key_value_heads, self.head_dim) + key = key.view(bsz, seq_len, self.num_key_value_heads, self.head_dim) + value = value.view(bsz, seq_len, self.num_key_value_heads, self.head_dim) + # Use ipex op to rotary position embedding more efficient. + key = self.ipex_rope( + key, + position_ids, + self.num_key_value_heads, + self.head_dim, + self.head_dim // 2, + self.head_dim, + kv_seq_len, + ) + query = self.ipex_rope( + query, + position_ids, + self.num_heads, + self.head_dim, + self.head_dim // 2, + self.head_dim, + kv_seq_len, + ) + + if use_cache: + # This ipex op pre-allocates buffers for past_key_values and use beam index history + # which to decide which beam should be used to make attention scale dot more efficient. + (attn_output, attn_weights, past_key_value) = self.ipex_scale_dot_product( + query, + key, + value, + math.sqrt(self.head_dim), + past_key_value, + None, + attention_mask, + ) + else: + value_states = value.transpose(1, 2) + query_states = query.transpose(1, 2) + key_states = key.transpose(1, 2) + kv_seq_len = key_states.shape[-2] + + past_key_value = None + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: + attn_weights = torch.tensor(attn_weights) + torch.tensor(attention_mask) + attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, seq_len, self.hidden_size) + + if hasattr(self, "mha_linear_add"): + attn_output = self.mha_linear_add(attn_output, residual) + else: + attn_output = self.o_proj(attn_output) + attn_output = residual + attn_output + + if not output_attentions: + attn_weights = None + + return attn_output, past_key_value, attn_weights + + +class _IPEXLlamaMLP(nn.Module): + def __init__(self, module, config, distributed=False) -> None: + if is_ipex_version("<", "2.3.0"): + raise ImportError("Only ipex version > 2.3.0 supports Linear2SiluMul and LinearAdd") from intel_extension_for_pytorch.llm.modules import Linear2SiluMul, LinearAdd super().__init__() @@ -231,15 +282,47 @@ def __init__(self, module, config, distributed=False): if k.startswith("__") or k.startswith("forward"): continue setattr(self.__class__, k, getattr(module.__class__, k)) + self.config = config self.distributed = distributed if not self.distributed: - self.mha_linear_add = LinearAdd(module.self_attn.o_proj) - self.mlp_linear_add = LinearAdd(module.mlp.down_proj) - del self.__dict__["_modules"]["self_attn"].o_proj - del self.__dict__["_modules"]["mlp"].down_proj - self.linear_silu_mul = Linear2SiluMul(module.mlp.gate_proj, module.mlp.up_proj) - del self.__dict__["_modules"]["mlp"].gate_proj - del self.__dict__["_modules"]["mlp"].up_proj + self.mlp_linear_add = LinearAdd(module.down_proj) + del self.__dict__["_modules"]["down_proj"] + self.linear_silu_mul = Linear2SiluMul(module.gate_proj, module.up_proj) + del self.__dict__["_modules"]["gate_proj"] + del self.__dict__["_modules"]["up_proj"] + + def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor = None, **kwargs): + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + """ + if hasattr(self, "linear_silu_mul"): + mlp_gate = self.linear_silu_mul(hidden_states) + if hasattr(self, "mlp_linear_add"): + hidden_states = self.mlp_linear_add(mlp_gate, residual) + else: + hidden_states = self.down_proj(mlp_gate) + hidden_states = residual + hidden_states + else: + hidden_states = self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states)) + hidden_states = residual + hidden_states + + return hidden_states + + +# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L694 +class _IPEXLlamaDecoderLayerRef(nn.Module): + def __init__(self, module, config, distributed=False): + super().__init__() + for k, v in module.__dict__.items(): + setattr(self, k, v) + for k, v in module.__class__.__dict__.items(): + if k.startswith("__") or k.startswith("forward"): + continue + setattr(self.__class__, k, getattr(module.__class__, k)) + self.distributed = distributed + self.self_attn = _IPEXLlamaAttentionRef(module.self_attn, config, distributed) + self.mlp = _IPEXLlamaMLP(module.mlp, config, distributed) def forward( self, @@ -270,34 +353,22 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, present_key_value, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, + cache_position=None, + residual=residual, + **kwargs, ) - if hasattr(self, "mha_linear_add"): - hidden_states = self.mha_linear_add(hidden_states, residual) - else: - hidden_states = self.self_attn.o_proj(hidden_states) - hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) - - if hasattr(self, "linear_silu_mul"): - mlp_gate = self.linear_silu_mul(hidden_states) - if hasattr(self, "mlp_linear_add"): - hidden_states = self.mlp_linear_add(mlp_gate, residual) - else: - hidden_states = self.mlp.down_proj(mlp_gate) - hidden_states = residual + hidden_states - else: - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states + hidden_states = self.mlp(hidden_states, residual, **kwargs) outputs = (hidden_states,) From 48b205edc05c2b699ddd06fb2355d19844b427d1 Mon Sep 17 00:00:00 2001 From: "Lin, Fanli" Date: Sun, 26 May 2024 06:18:23 -0400 Subject: [PATCH 03/23] change to Ref --- optimum/exporters/ipex/modeling_utils.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index 04a6fe5082..36d0972006 100644 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -139,8 +139,6 @@ def __init__(self, module, config, distributed=False) -> None: raise ImportError( "Only ipex version > 2.3.0 supports LinearAdd, IndirectAccessKVCacheAttention, RotaryEmbedding" ) - from intel_extension_for_pytorch.llm.modules import IndirectAccessKVCacheAttention, LinearAdd, RotaryEmbedding - super().__init__() for k, v in module.__dict__.items(): setattr(self, k, v) @@ -150,6 +148,8 @@ def __init__(self, module, config, distributed=False) -> None: setattr(self.__class__, k, getattr(module.__class__, k)) self.config = config self.distributed = distributed + from intel_extension_for_pytorch.llm.modules import IndirectAccessKVCacheAttention, LinearAdd, RotaryEmbedding + if not self.distributed: self.mha_linear_add = LinearAdd(self.o_proj) del self.__dict__["_modules"]["o_proj"] @@ -196,11 +196,11 @@ def forward( key = self.k_proj(hidden_states) value = self.v_proj(hidden_states) - kv_seq_len = seq_len + past_key_value[0].size(-2) if past_key_value is not None else seq_len - - query = query.view(bsz, seq_len, self.num_key_value_heads, self.head_dim) + query = query.view(bsz, seq_len, self.num_heads, self.head_dim) key = key.view(bsz, seq_len, self.num_key_value_heads, self.head_dim) value = value.view(bsz, seq_len, self.num_key_value_heads, self.head_dim) + + kv_seq_len = seq_len + past_key_value[0].size(-2) if past_key_value is not None else seq_len # Use ipex op to rotary position embedding more efficient. key = self.ipex_rope( key, @@ -269,11 +269,10 @@ def forward( return attn_output, past_key_value, attn_weights -class _IPEXLlamaMLP(nn.Module): +class _IPEXLlamaMLPRef(nn.Module): def __init__(self, module, config, distributed=False) -> None: if is_ipex_version("<", "2.3.0"): raise ImportError("Only ipex version > 2.3.0 supports Linear2SiluMul and LinearAdd") - from intel_extension_for_pytorch.llm.modules import Linear2SiluMul, LinearAdd super().__init__() for k, v in module.__dict__.items(): @@ -284,6 +283,8 @@ def __init__(self, module, config, distributed=False) -> None: setattr(self.__class__, k, getattr(module.__class__, k)) self.config = config self.distributed = distributed + from intel_extension_for_pytorch.llm.modules import Linear2SiluMul, LinearAdd + if not self.distributed: self.mlp_linear_add = LinearAdd(module.down_proj) del self.__dict__["_modules"]["down_proj"] @@ -322,7 +323,7 @@ def __init__(self, module, config, distributed=False): setattr(self.__class__, k, getattr(module.__class__, k)) self.distributed = distributed self.self_attn = _IPEXLlamaAttentionRef(module.self_attn, config, distributed) - self.mlp = _IPEXLlamaMLP(module.mlp, config, distributed) + self.mlp = _IPEXLlamaMLPRef(module.mlp, config, distributed) def forward( self, From 4ea8a478cea259f836b60e6df00eeb9db114fc7b Mon Sep 17 00:00:00 2001 From: "Lin, Fanli" Date: Mon, 27 May 2024 09:32:28 -0400 Subject: [PATCH 04/23] remove Ref --- optimum/exporters/ipex/model_patcher.py | 4 ++-- optimum/exporters/ipex/modeling_utils.py | 14 +++++++------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/optimum/exporters/ipex/model_patcher.py b/optimum/exporters/ipex/model_patcher.py index 7f99dbddd1..88212363e8 100644 --- a/optimum/exporters/ipex/model_patcher.py +++ b/optimum/exporters/ipex/model_patcher.py @@ -22,7 +22,7 @@ from optimum.intel.utils.import_utils import is_ipex_version from .modeling_utils import ( - _IPEXLlamaDecoderLayerRef, + _IPEXLlamaDecoderLayer, _llama_layer_norm_forward, _llama_model_forward, ) @@ -65,7 +65,7 @@ def _patch_llama_model(model): convert_functions(model, LlamaModel, "forward", _llama_model_forward) convert_functions(model, LlamaRMSNorm, "forward", _llama_layer_norm_forward) - convert_class(model, LlamaDecoderLayer, _IPEXLlamaDecoderLayerRef, model.config) + convert_class(model, LlamaDecoderLayer, _IPEXLlamaDecoderLayer, model.config) return model diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index 36d0972006..047a7b8d10 100644 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -133,7 +133,7 @@ def _llama_model_forward( # Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L321 -class _IPEXLlamaAttentionRef(nn.Module): +class _IPEXLlamaAttention(nn.Module): def __init__(self, module, config, distributed=False) -> None: if is_ipex_version("<", "2.3.0"): raise ImportError( @@ -266,10 +266,10 @@ def forward( if not output_attentions: attn_weights = None - return attn_output, past_key_value, attn_weights + return attn_output, attn_weights, past_key_value -class _IPEXLlamaMLPRef(nn.Module): +class _IPEXLlamaMLP(nn.Module): def __init__(self, module, config, distributed=False) -> None: if is_ipex_version("<", "2.3.0"): raise ImportError("Only ipex version > 2.3.0 supports Linear2SiluMul and LinearAdd") @@ -312,7 +312,7 @@ def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor = None, ** # Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L694 -class _IPEXLlamaDecoderLayerRef(nn.Module): +class _IPEXLlamaDecoderLayer(nn.Module): def __init__(self, module, config, distributed=False): super().__init__() for k, v in module.__dict__.items(): @@ -322,8 +322,8 @@ def __init__(self, module, config, distributed=False): continue setattr(self.__class__, k, getattr(module.__class__, k)) self.distributed = distributed - self.self_attn = _IPEXLlamaAttentionRef(module.self_attn, config, distributed) - self.mlp = _IPEXLlamaMLPRef(module.mlp, config, distributed) + self.self_attn = _IPEXLlamaAttention(module.self_attn, config, distributed) + self.mlp = _IPEXLlamaMLP(module.mlp, config, distributed) def forward( self, @@ -354,7 +354,7 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, present_key_value, self_attn_weights = self.self_attn( + hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, From 1f98d6d773cce4157e00a941b594db0be97696fd Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Mon, 27 May 2024 11:07:16 -0400 Subject: [PATCH 05/23] skip tests --- tests/ipex/test_modeling.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/ipex/test_modeling.py b/tests/ipex/test_modeling.py index 7eb34ef47c..2948d383f0 100644 --- a/tests/ipex/test_modeling.py +++ b/tests/ipex/test_modeling.py @@ -219,6 +219,23 @@ def test_pipeline(self, model_arch): self.assertEqual(pipe.device, model.device) self.assertTrue(all("This is a sample" in item["generated_text"] for item in outputs)) + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @unittest.skip("CPU IPEXModel does not support assisted decoding for now.") + def test_assisted_decoding(self, model_arch): + model_id = MODEL_NAMES[model_arch] + tokenizer = AutoTokenizer.from_pretrained(model_id) + ipex_model = IPEXModelForCausalLM.from_pretrained(model_id, export=True) + transformers_model = AutoModelForCausalLM.from_pretrained(model_id) + tokens = tokenizer("This is a sample input", return_tensors="pt") + ipex_output = ipex_model.generate(**tokens, do_sample=False) + ipex_output_assisted = ipex_model.generate(**tokens, do_sample=False, assistant_model=transformers_model) + transformers_output = transformers_model.generate(**tokens, do_sample=False) + transformers_output_assisted = transformers_model.generate( + **tokens, do_sample=False, assistant_model=ipex_model + ) + self.assertTrue(torch.equal(ipex_output, ipex_output_assisted)) + self.assertTrue(torch.equal(transformers_output, transformers_output_assisted)) + @parameterized.expand( grid_parameters( { From d3ce377d8694a87329e491c43557535baaacc920 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Mon, 27 May 2024 11:07:16 -0400 Subject: [PATCH 06/23] skip tests --- tests/ipex/test_modeling.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/ipex/test_modeling.py b/tests/ipex/test_modeling.py index 7eb34ef47c..2948d383f0 100644 --- a/tests/ipex/test_modeling.py +++ b/tests/ipex/test_modeling.py @@ -219,6 +219,23 @@ def test_pipeline(self, model_arch): self.assertEqual(pipe.device, model.device) self.assertTrue(all("This is a sample" in item["generated_text"] for item in outputs)) + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @unittest.skip("CPU IPEXModel does not support assisted decoding for now.") + def test_assisted_decoding(self, model_arch): + model_id = MODEL_NAMES[model_arch] + tokenizer = AutoTokenizer.from_pretrained(model_id) + ipex_model = IPEXModelForCausalLM.from_pretrained(model_id, export=True) + transformers_model = AutoModelForCausalLM.from_pretrained(model_id) + tokens = tokenizer("This is a sample input", return_tensors="pt") + ipex_output = ipex_model.generate(**tokens, do_sample=False) + ipex_output_assisted = ipex_model.generate(**tokens, do_sample=False, assistant_model=transformers_model) + transformers_output = transformers_model.generate(**tokens, do_sample=False) + transformers_output_assisted = transformers_model.generate( + **tokens, do_sample=False, assistant_model=ipex_model + ) + self.assertTrue(torch.equal(ipex_output, ipex_output_assisted)) + self.assertTrue(torch.equal(transformers_output, transformers_output_assisted)) + @parameterized.expand( grid_parameters( { From b2b93bb12b71a9bd0e6e4f3ed96261d05a664b2d Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Mon, 27 May 2024 11:35:25 -0400 Subject: [PATCH 07/23] skip testing without pkv --- tests/ipex/test_modeling.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/ipex/test_modeling.py b/tests/ipex/test_modeling.py index 2948d383f0..043d1e761c 100644 --- a/tests/ipex/test_modeling.py +++ b/tests/ipex/test_modeling.py @@ -244,7 +244,7 @@ def test_assisted_decoding(self, model_arch): } ) ) - @unittest.skipIf(is_ipex_version("<", "2.5.0"), reason="Only ipex version > 2.3.0 supports ipex model patching") + @unittest.skipIf(is_ipex_version("<", "2.3.0"), reason="Only ipex version >= 2.3.0 supports ipex model patching") def test_ipex_patching_beam_search(self, test_name, model_arch, use_cache): model_id = MODEL_NAMES[model_arch] set_seed(SEED) @@ -271,6 +271,7 @@ def test_ipex_patching_beam_search(self, test_name, model_arch, use_cache): self.assertIsInstance(outputs, torch.Tensor) self.assertEqual(outputs, transformers_outputs) + @unittest.skip("CPU IPEXModel only supports with past_key_values.") def test_compare_with_and_without_past_key_values(self): model_id = "echarlaix/tiny-random-gpt2-torchscript" tokenizer = AutoTokenizer.from_pretrained(model_id) From 64dcde4f6d42e1c14a797580760147abcb4eee8f Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Mon, 27 May 2024 12:11:59 -0400 Subject: [PATCH 08/23] add tests skip --- tests/ipex/test_modeling.py | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/tests/ipex/test_modeling.py b/tests/ipex/test_modeling.py index 043d1e761c..5646d59db1 100644 --- a/tests/ipex/test_modeling.py +++ b/tests/ipex/test_modeling.py @@ -220,7 +220,7 @@ def test_pipeline(self, model_arch): self.assertTrue(all("This is a sample" in item["generated_text"] for item in outputs)) @parameterized.expand(SUPPORTED_ARCHITECTURES) - @unittest.skip("CPU IPEXModel does not support assisted decoding for now.") + @unittest.skipIf(is_ipex_version(">=", "2.3.0"), reason="CPU IPEXModel does not support assisted decoding when ipex version >= 2.3.0") def test_assisted_decoding(self, model_arch): model_id = MODEL_NAMES[model_arch] tokenizer = AutoTokenizer.from_pretrained(model_id) @@ -236,20 +236,12 @@ def test_assisted_decoding(self, model_arch): self.assertTrue(torch.equal(ipex_output, ipex_output_assisted)) self.assertTrue(torch.equal(transformers_output, transformers_output_assisted)) - @parameterized.expand( - grid_parameters( - { - "model_arch": IPEX_PATCHED_SUPPORTED_ARCHITECTURES, - "use_cache": [True, False], - } - ) - ) + @parameterized.expand(IPEX_PATCHED_SUPPORTED_ARCHITECTURES) @unittest.skipIf(is_ipex_version("<", "2.3.0"), reason="Only ipex version >= 2.3.0 supports ipex model patching") - def test_ipex_patching_beam_search(self, test_name, model_arch, use_cache): + def test_ipex_patching_beam_search(self, model_arch): model_id = MODEL_NAMES[model_arch] set_seed(SEED) - model = IPEXModelForCausalLM.from_pretrained(model_id, export=True, use_cache=use_cache) - self.assertEqual(model.use_cache, use_cache) + model = IPEXModelForCausalLM.from_pretrained(model_id, export=True) trasnformers_model = AutoModelForCausalLM.from_pretrained(model_id) tokenizer = AutoTokenizer.from_pretrained(model_id) tokenizer.pad_token = tokenizer.eos_token @@ -260,7 +252,7 @@ def test_ipex_patching_beam_search(self, test_name, model_arch, use_cache): GenerationConfig(max_new_tokens=4, num_beams=4, do_sample=True), GenerationConfig(max_new_tokens=4, num_beams=8, do_sample=True), GenerationConfig(max_new_tokens=4, num_beams=32, do_sample=True), - GenerationConfig(max_new_tokens=4, do_sample=not use_cache, top_p=1.0, top_k=5, penalty_alpha=0.6), + GenerationConfig(max_new_tokens=4, do_sample=True, top_p=1.0, top_k=5, penalty_alpha=0.6), GenerationConfig(max_new_tokens=4, do_sample=True, top_p=0.9, top_k=0), ) for text in texts: @@ -271,7 +263,7 @@ def test_ipex_patching_beam_search(self, test_name, model_arch, use_cache): self.assertIsInstance(outputs, torch.Tensor) self.assertEqual(outputs, transformers_outputs) - @unittest.skip("CPU IPEXModel only supports with past_key_values.") + @unittest.skipIf(is_ipex_version(">=", "2.3.0"), reason="CPU IPEXModel only supports with past_key_values for ipex version >= 2.3.0") def test_compare_with_and_without_past_key_values(self): model_id = "echarlaix/tiny-random-gpt2-torchscript" tokenizer = AutoTokenizer.from_pretrained(model_id) From 945f6b6a958ad560fe32bf33b43b3bd4ec113625 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Mon, 27 May 2024 12:26:58 -0400 Subject: [PATCH 09/23] only llama2 with at least 64 head size support IAKV --- tests/ipex/test_modeling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/ipex/test_modeling.py b/tests/ipex/test_modeling.py index 5646d59db1..428e5d3da9 100644 --- a/tests/ipex/test_modeling.py +++ b/tests/ipex/test_modeling.py @@ -178,7 +178,7 @@ class IPEXModelForCausalLMTest(unittest.TestCase): "mpt", "opt", ) - IPEX_PATCHED_SUPPORTED_ARCHITECTURES = ("llama",) + IPEX_PATCHED_SUPPORTED_ARCHITECTURES = ("llama2",) GENERATION_LENGTH = 100 SPEEDUP_CACHE = 1.0 From c8922f3fd3c51fcea2e2e26b3234dc400ee39c38 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Mon, 27 May 2024 12:55:57 -0400 Subject: [PATCH 10/23] cannot assert same outputs cause do_sample=True --- tests/ipex/test_modeling.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/ipex/test_modeling.py b/tests/ipex/test_modeling.py index 428e5d3da9..975821eaf2 100644 --- a/tests/ipex/test_modeling.py +++ b/tests/ipex/test_modeling.py @@ -259,9 +259,7 @@ def test_ipex_patching_beam_search(self, model_arch): tokens = tokenizer(text, padding=True, return_tensors="pt") for generation_config in generation_configs: outputs = model.generate(**tokens, generation_config=generation_config) - transformers_outputs = trasnformers_model.generate(**tokens, generation_config=generation_config) self.assertIsInstance(outputs, torch.Tensor) - self.assertEqual(outputs, transformers_outputs) @unittest.skipIf(is_ipex_version(">=", "2.3.0"), reason="CPU IPEXModel only supports with past_key_values for ipex version >= 2.3.0") def test_compare_with_and_without_past_key_values(self): From 2ddfa7a679d96f10d1f8e05ea57f791bed75129a Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Mon, 27 May 2024 13:25:23 -0400 Subject: [PATCH 11/23] rm tiny-llama model testing cause it not work for IAKV --- tests/ipex/test_modeling.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/ipex/test_modeling.py b/tests/ipex/test_modeling.py index 975821eaf2..fddcb548eb 100644 --- a/tests/ipex/test_modeling.py +++ b/tests/ipex/test_modeling.py @@ -171,7 +171,6 @@ class IPEXModelForCausalLMTest(unittest.TestCase): "gpt2", "gpt_neo", "gpt_neox", - "llama", "llama2", "mistral", # "phi", @@ -242,7 +241,6 @@ def test_ipex_patching_beam_search(self, model_arch): model_id = MODEL_NAMES[model_arch] set_seed(SEED) model = IPEXModelForCausalLM.from_pretrained(model_id, export=True) - trasnformers_model = AutoModelForCausalLM.from_pretrained(model_id) tokenizer = AutoTokenizer.from_pretrained(model_id) tokenizer.pad_token = tokenizer.eos_token # Test with batch_size is 1 and 2. From f4e887d4089804c851260da3ab90545736c69ce0 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Tue, 28 May 2024 05:31:32 -0400 Subject: [PATCH 12/23] fix code style --- tests/ipex/test_modeling.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/ipex/test_modeling.py b/tests/ipex/test_modeling.py index fddcb548eb..aebebda101 100644 --- a/tests/ipex/test_modeling.py +++ b/tests/ipex/test_modeling.py @@ -219,7 +219,10 @@ def test_pipeline(self, model_arch): self.assertTrue(all("This is a sample" in item["generated_text"] for item in outputs)) @parameterized.expand(SUPPORTED_ARCHITECTURES) - @unittest.skipIf(is_ipex_version(">=", "2.3.0"), reason="CPU IPEXModel does not support assisted decoding when ipex version >= 2.3.0") + @unittest.skipIf( + is_ipex_version(">=", "2.3.0"), + reason="CPU IPEXModel does not support assisted decoding when ipex version >= 2.3.0", + ) def test_assisted_decoding(self, model_arch): model_id = MODEL_NAMES[model_arch] tokenizer = AutoTokenizer.from_pretrained(model_id) @@ -259,7 +262,10 @@ def test_ipex_patching_beam_search(self, model_arch): outputs = model.generate(**tokens, generation_config=generation_config) self.assertIsInstance(outputs, torch.Tensor) - @unittest.skipIf(is_ipex_version(">=", "2.3.0"), reason="CPU IPEXModel only supports with past_key_values for ipex version >= 2.3.0") + @unittest.skipIf( + is_ipex_version(">=", "2.3.0"), + reason="CPU IPEXModel only supports with past_key_values for ipex version >= 2.3.0", + ) def test_compare_with_and_without_past_key_values(self): model_id = "echarlaix/tiny-random-gpt2-torchscript" tokenizer = AutoTokenizer.from_pretrained(model_id) From 74f132ea59a0070e832fbd7e50da35680925acb3 Mon Sep 17 00:00:00 2001 From: "Lin, Fanli" Date: Tue, 28 May 2024 13:17:01 -0400 Subject: [PATCH 13/23] refine docstring --- optimum/exporters/ipex/modeling_utils.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index 047a7b8d10..e823966078 100644 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -179,16 +179,22 @@ def forward( Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + Attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, query_sequence_length, key_sequence_length)` if default attention is used. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.n_positions - 1]`. + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - residual (`torch.Tensor`): residual tensor to the layer of shape ` + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. + residual (`torch.Tensor`): residual tensor to the layer of shape (batch, seq_len, embed_dim)` """ bsz, seq_len, _ = hidden_states.size() @@ -296,6 +302,7 @@ def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor = None, ** """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + residual (`torch.Tensor`): residual tensor to the layer of shape (batch, seq_len, embed_dim)` """ if hasattr(self, "linear_silu_mul"): mlp_gate = self.linear_silu_mul(hidden_states) @@ -339,15 +346,17 @@ def forward( Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + Attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, query_sequence_length, key_sequence_length)` if default attention is used. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.n_positions - 1]`. + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states """ residual = hidden_states From e130345d4a033e07fa5ed36a54a488ef1332b42a Mon Sep 17 00:00:00 2001 From: "Lin, Fanli" Date: Wed, 29 May 2024 23:15:06 -0700 Subject: [PATCH 14/23] fix duplicted code --- optimum/exporters/ipex/modeling_utils.py | 22 ++++------------------ optimum/intel/utils/modeling_utils.py | 9 +++++++++ 2 files changed, 13 insertions(+), 18 deletions(-) diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index e823966078..55ed0dae85 100644 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -22,6 +22,7 @@ from transformers.models.llama.modeling_llama import repeat_kv from optimum.intel.utils.import_utils import is_ipex_version +from optimum.intel.utils.modeling_utils import setattr_from_module # Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L83 @@ -140,12 +141,7 @@ def __init__(self, module, config, distributed=False) -> None: "Only ipex version > 2.3.0 supports LinearAdd, IndirectAccessKVCacheAttention, RotaryEmbedding" ) super().__init__() - for k, v in module.__dict__.items(): - setattr(self, k, v) - for k, v in module.__class__.__dict__.items(): - if k.startswith("__") or k.startswith("forward"): - continue - setattr(self.__class__, k, getattr(module.__class__, k)) + setattr_from_module(self, module) self.config = config self.distributed = distributed from intel_extension_for_pytorch.llm.modules import IndirectAccessKVCacheAttention, LinearAdd, RotaryEmbedding @@ -281,12 +277,7 @@ def __init__(self, module, config, distributed=False) -> None: raise ImportError("Only ipex version > 2.3.0 supports Linear2SiluMul and LinearAdd") super().__init__() - for k, v in module.__dict__.items(): - setattr(self, k, v) - for k, v in module.__class__.__dict__.items(): - if k.startswith("__") or k.startswith("forward"): - continue - setattr(self.__class__, k, getattr(module.__class__, k)) + setattr_from_module(self, module) self.config = config self.distributed = distributed from intel_extension_for_pytorch.llm.modules import Linear2SiluMul, LinearAdd @@ -322,12 +313,7 @@ def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor = None, ** class _IPEXLlamaDecoderLayer(nn.Module): def __init__(self, module, config, distributed=False): super().__init__() - for k, v in module.__dict__.items(): - setattr(self, k, v) - for k, v in module.__class__.__dict__.items(): - if k.startswith("__") or k.startswith("forward"): - continue - setattr(self.__class__, k, getattr(module.__class__, k)) + setattr_from_module(self, module) self.distributed = distributed self.self_attn = _IPEXLlamaAttention(module.self_attn, config, distributed) self.mlp = _IPEXLlamaMLP(module.mlp, config, distributed) diff --git a/optimum/intel/utils/modeling_utils.py b/optimum/intel/utils/modeling_utils.py index a2cd728354..b213ea3107 100644 --- a/optimum/intel/utils/modeling_utils.py +++ b/optimum/intel/utils/modeling_utils.py @@ -182,3 +182,12 @@ def recursive_to_device(value, device): elif isinstance(value, torch.Tensor): return value.to(device) return value + + +def setattr_from_module(new_module, module): + for k, v in module.__dict__.items(): + setattr(new_module, k, v) + for k, v in module.__class__.__dict__.items(): + if k.startswith("__") or k.startswith("forward"): + continue + setattr(new_module.__class__, k, getattr(module.__class__, k)) From 14673dab2cc6a17caa8c40dcba51406eb91cbf6d Mon Sep 17 00:00:00 2001 From: "Lin, Fanli" Date: Mon, 3 Jun 2024 10:09:49 -0400 Subject: [PATCH 15/23] refactor attention forward --- optimum/exporters/ipex/modeling_utils.py | 89 ++++++++++++++---------- 1 file changed, 51 insertions(+), 38 deletions(-) diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index 55ed0dae85..518292993a 100644 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -159,39 +159,7 @@ def __init__(self, module, config, distributed=False) -> None: module.config.architectures[0], ) - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - residual: Optional[torch.Tensor] = None, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): - Attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.n_positions - 1]`. - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, - this tensor is not affected by padding. It is used to update the cache in the correct position and to infer - the complete sequence length. - residual (`torch.Tensor`): residual tensor to the layer of shape (batch, seq_len, embed_dim)` - """ + def qkv_gemm(self, hidden_states): bsz, seq_len, _ = hidden_states.size() query = self.q_proj(hidden_states) @@ -202,8 +170,9 @@ def forward( key = key.view(bsz, seq_len, self.num_key_value_heads, self.head_dim) value = value.view(bsz, seq_len, self.num_key_value_heads, self.head_dim) - kv_seq_len = seq_len + past_key_value[0].size(-2) if past_key_value is not None else seq_len - # Use ipex op to rotary position embedding more efficient. + return query, key, value + + def rope(self, query, key, kv_seq_len, position_ids): key = self.ipex_rope( key, position_ids, @@ -222,7 +191,9 @@ def forward( self.head_dim, kv_seq_len, ) + return query, key + def sdpa(self, query, key, value, past_key_value, attention_mask, use_cache): if use_cache: # This ipex op pre-allocates buffers for past_key_values and use beam index history # which to decide which beam should be used to make attention scale dot more efficient. @@ -239,7 +210,6 @@ def forward( value_states = value.transpose(1, 2) query_states = query.transpose(1, 2) key_states = key.transpose(1, 2) - kv_seq_len = key_states.shape[-2] past_key_value = None # repeat k/v heads if n_kv_heads < n_heads @@ -256,8 +226,51 @@ def forward( attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(bsz, seq_len, self.hidden_size) + return attn_output, past_key_value, attn_weights + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + residual: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + Attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.n_positions - 1]`. + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. + residual (`torch.Tensor`): residual tensor to the layer of shape (batch, seq_len, embed_dim)` + """ + bsz, seq_len, _ = hidden_states.size() + kv_seq_len = seq_len + past_key_value[0].size(-2) if past_key_value is not None else seq_len + + query, key, value = self.qkv_gemm(hidden_states) + query, key = self.rope(query, key, kv_seq_len, position_ids) + + attn_output, past_key_value, attn_weights = self.sdpa( + query, key, value, past_key_value, attention_mask, use_cache + ) + attn_output = attn_output.transpose(1, 2).view(bsz, seq_len, self.hidden_size) if hasattr(self, "mha_linear_add"): attn_output = self.mha_linear_add(attn_output, residual) From a2a969ec29f7402529e4df1006e7775ed8c82806 Mon Sep 17 00:00:00 2001 From: "Lin, Fanli" Date: Mon, 3 Jun 2024 10:20:27 -0400 Subject: [PATCH 16/23] add use_cache for rope --- optimum/exporters/ipex/modeling_utils.py | 41 ++++++++++++------------ 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index 518292993a..ea387eb41f 100644 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -172,25 +172,26 @@ def qkv_gemm(self, hidden_states): return query, key, value - def rope(self, query, key, kv_seq_len, position_ids): - key = self.ipex_rope( - key, - position_ids, - self.num_key_value_heads, - self.head_dim, - self.head_dim // 2, - self.head_dim, - kv_seq_len, - ) - query = self.ipex_rope( - query, - position_ids, - self.num_heads, - self.head_dim, - self.head_dim // 2, - self.head_dim, - kv_seq_len, - ) + def rope(self, query, key, kv_seq_len, position_ids, use_cache): + if use_cache: + key = self.ipex_rope( + key, + position_ids, + self.num_key_value_heads, + self.head_dim, + self.head_dim // 2, + self.head_dim, + kv_seq_len, + ) + query = self.ipex_rope( + query, + position_ids, + self.num_heads, + self.head_dim, + self.head_dim // 2, + self.head_dim, + kv_seq_len, + ) return query, key def sdpa(self, query, key, value, past_key_value, attention_mask, use_cache): @@ -265,7 +266,7 @@ def forward( kv_seq_len = seq_len + past_key_value[0].size(-2) if past_key_value is not None else seq_len query, key, value = self.qkv_gemm(hidden_states) - query, key = self.rope(query, key, kv_seq_len, position_ids) + query, key = self.rope(query, key, kv_seq_len, position_ids, use_cache) attn_output, past_key_value, attn_weights = self.sdpa( query, key, value, past_key_value, attention_mask, use_cache From 3abd79067950ad06e836c99b1eddee15eece4d0a Mon Sep 17 00:00:00 2001 From: "Lin, Fanli" Date: Mon, 3 Jun 2024 10:35:22 -0400 Subject: [PATCH 17/23] use with and without cache --- optimum/exporters/ipex/modeling_utils.py | 110 ++++++++++++----------- 1 file changed, 57 insertions(+), 53 deletions(-) diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index ea387eb41f..195c7c3829 100644 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -172,60 +172,60 @@ def qkv_gemm(self, hidden_states): return query, key, value - def rope(self, query, key, kv_seq_len, position_ids, use_cache): - if use_cache: - key = self.ipex_rope( - key, - position_ids, - self.num_key_value_heads, - self.head_dim, - self.head_dim // 2, - self.head_dim, - kv_seq_len, - ) - query = self.ipex_rope( - query, - position_ids, - self.num_heads, - self.head_dim, - self.head_dim // 2, - self.head_dim, - kv_seq_len, - ) + def rope(self, query, key, kv_seq_len, position_ids): + key = self.ipex_rope( + key, + position_ids, + self.num_key_value_heads, + self.head_dim, + self.head_dim // 2, + self.head_dim, + kv_seq_len, + ) + query = self.ipex_rope( + query, + position_ids, + self.num_heads, + self.head_dim, + self.head_dim // 2, + self.head_dim, + kv_seq_len, + ) return query, key - def sdpa(self, query, key, value, past_key_value, attention_mask, use_cache): - if use_cache: - # This ipex op pre-allocates buffers for past_key_values and use beam index history - # which to decide which beam should be used to make attention scale dot more efficient. - (attn_output, attn_weights, past_key_value) = self.ipex_scale_dot_product( - query, - key, - value, - math.sqrt(self.head_dim), - past_key_value, - None, - attention_mask, - ) - else: - value_states = value.transpose(1, 2) - query_states = query.transpose(1, 2) - key_states = key.transpose(1, 2) + def sdpa_with_cache(self, query, key, value, past_key_value, attention_mask): + # This ipex op pre-allocates buffers for past_key_values and use beam index history + # which to decide which beam should be used to make attention scale dot more efficient. + (attn_output, attn_weights, past_key_value) = self.ipex_scale_dot_product( + query, + key, + value, + math.sqrt(self.head_dim), + past_key_value, + None, + attention_mask, + ) + return attn_output, past_key_value, attn_weights + + def sdpa_without_cache(self, query, key, value, past_key_value, attention_mask): + value_states = value.transpose(1, 2) + query_states = query.transpose(1, 2) + key_states = key.transpose(1, 2) - past_key_value = None - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) + past_key_value = None + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - if attention_mask is not None: - attn_weights = torch.tensor(attn_weights) + torch.tensor(attention_mask) - attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) + if attention_mask is not None: + attn_weights = torch.tensor(attn_weights) + torch.tensor(attention_mask) + attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) return attn_output, past_key_value, attn_weights @@ -266,11 +266,15 @@ def forward( kv_seq_len = seq_len + past_key_value[0].size(-2) if past_key_value is not None else seq_len query, key, value = self.qkv_gemm(hidden_states) - query, key = self.rope(query, key, kv_seq_len, position_ids, use_cache) - - attn_output, past_key_value, attn_weights = self.sdpa( - query, key, value, past_key_value, attention_mask, use_cache - ) + if use_cache: + query, key = self.rope(query, key, kv_seq_len, position_ids) + attn_output, past_key_value, attn_weights = self.sdpa_with_cache( + query, key, value, past_key_value, attention_mask + ) + else: + attn_output, past_key_value, attn_weights = self.sdpa_without_cache( + query, key, value, past_key_value, attention_mask + ) attn_output = attn_output.transpose(1, 2).view(bsz, seq_len, self.hidden_size) if hasattr(self, "mha_linear_add"): From 82bd0c7e376a646e02a9db6c448e8ae6ae1376ac Mon Sep 17 00:00:00 2001 From: "Lin, Fanli" Date: Mon, 3 Jun 2024 10:42:43 -0400 Subject: [PATCH 18/23] refine code --- optimum/exporters/ipex/modeling_utils.py | 52 +++++++++++------------- 1 file changed, 24 insertions(+), 28 deletions(-) diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index 195c7c3829..2d885b1de6 100644 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -172,25 +172,26 @@ def qkv_gemm(self, hidden_states): return query, key, value - def rope(self, query, key, kv_seq_len, position_ids): - key = self.ipex_rope( - key, - position_ids, - self.num_key_value_heads, - self.head_dim, - self.head_dim // 2, - self.head_dim, - kv_seq_len, - ) - query = self.ipex_rope( - query, - position_ids, - self.num_heads, - self.head_dim, - self.head_dim // 2, - self.head_dim, - kv_seq_len, - ) + def rope(self, query, key, kv_seq_len, position_ids, use_cache): + if use_cache: + key = self.ipex_rope( + key, + position_ids, + self.num_key_value_heads, + self.head_dim, + self.head_dim // 2, + self.head_dim, + kv_seq_len, + ) + query = self.ipex_rope( + query, + position_ids, + self.num_heads, + self.head_dim, + self.head_dim // 2, + self.head_dim, + kv_seq_len, + ) return query, key def sdpa_with_cache(self, query, key, value, past_key_value, attention_mask): @@ -266,15 +267,10 @@ def forward( kv_seq_len = seq_len + past_key_value[0].size(-2) if past_key_value is not None else seq_len query, key, value = self.qkv_gemm(hidden_states) - if use_cache: - query, key = self.rope(query, key, kv_seq_len, position_ids) - attn_output, past_key_value, attn_weights = self.sdpa_with_cache( - query, key, value, past_key_value, attention_mask - ) - else: - attn_output, past_key_value, attn_weights = self.sdpa_without_cache( - query, key, value, past_key_value, attention_mask - ) + query, key = self.rope(query, key, kv_seq_len, position_ids, use_cache) + + sdpa = self.sdpa_with_cache if use_cache else self.sdpa_without_cache + attn_output, past_key_value, attn_weights = sdpa(query, key, value, past_key_value, attention_mask) attn_output = attn_output.transpose(1, 2).view(bsz, seq_len, self.hidden_size) if hasattr(self, "mha_linear_add"): From de2cc43718bd4c0f74f307e18ad96d84ff555d34 Mon Sep 17 00:00:00 2001 From: "Lin, Fanli" Date: Tue, 4 Jun 2024 02:50:23 -0700 Subject: [PATCH 19/23] add reference link --- optimum/exporters/ipex/modeling_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index 2d885b1de6..c4ed19f012 100644 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -208,6 +208,7 @@ def sdpa_with_cache(self, query, key, value, past_key_value, attention_mask): ) return attn_output, past_key_value, attn_weights + # Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L341 def sdpa_without_cache(self, query, key, value, past_key_value, attention_mask): value_states = value.transpose(1, 2) query_states = query.transpose(1, 2) @@ -285,6 +286,7 @@ def forward( return attn_output, attn_weights, past_key_value +# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L186 class _IPEXLlamaMLP(nn.Module): def __init__(self, module, config, distributed=False) -> None: if is_ipex_version("<", "2.3.0"): From 752aba6e854e13b375bb9005704ddffd46e183b2 Mon Sep 17 00:00:00 2001 From: "Lin, Fanli" Date: Thu, 6 Jun 2024 07:55:43 -0700 Subject: [PATCH 20/23] bug fix --- optimum/exporters/ipex/model_patcher.py | 4 +--- optimum/exporters/ipex/modeling_utils.py | 8 +++++--- tests/ipex/test_modeling.py | 1 + 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/optimum/exporters/ipex/model_patcher.py b/optimum/exporters/ipex/model_patcher.py index 51cdfe1b30..e0fad0d6c2 100644 --- a/optimum/exporters/ipex/model_patcher.py +++ b/optimum/exporters/ipex/model_patcher.py @@ -62,9 +62,7 @@ def patch_op(m, target_m, new_op_name, new_op): def _patch_llama_model(model): if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING): - raise ImportError( - f"Only ipex version >= {_IPEX_MINIMUM_VERSION_FOR_PATCHING} supports llama model patching" - ) + raise ImportError(f"Only ipex version >= {_IPEX_MINIMUM_VERSION_FOR_PATCHING} supports llama model patching") convert_functions(model, LlamaModel, "forward", _llama_model_forward) convert_functions(model, LlamaRMSNorm, "forward", _llama_layer_norm_forward) convert_class(model, LlamaDecoderLayer, _IPEXLlamaDecoderLayer, model.config) diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index f7393289b5..9251f3dec5 100644 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -19,11 +19,12 @@ from torch import nn from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from transformers.modeling_outputs import BaseModelOutputWithPast -from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv +from transformers.models.llama.modeling_llama import repeat_kv -from optimum.intel.utils.import_utils import is_ipex_version, is_transformers_version +from optimum.intel.utils.import_utils import is_ipex_version from optimum.intel.utils.modeling_utils import setattr_from_module + # Please also update in the setup.py and .github/workflows/test_ipex.yml if you change the transformers version _TRANSFORMERS_MIN_VERSION = "4.39.0" _TRANSFORMERS_MAX_VERSION = "4.41.2" @@ -34,7 +35,7 @@ def _llama_layer_norm_forward(self, hidden_states): return torch.ops.torch_ipex.rmsnorm(hidden_states, self.weight, self.variance_epsilon) - + # Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L1130 def _llama_model_forward( self, @@ -137,6 +138,7 @@ def _llama_model_forward( attentions=all_self_attns, ) + # Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L321 class _IPEXLlamaAttention(nn.Module): def __init__(self, module, config, distributed=False) -> None: diff --git a/tests/ipex/test_modeling.py b/tests/ipex/test_modeling.py index 40e0c5aca5..1ea653eedc 100644 --- a/tests/ipex/test_modeling.py +++ b/tests/ipex/test_modeling.py @@ -297,6 +297,7 @@ def test_ipex_patching_beam_search(self, test_name, model_arch, use_cache): tokens = tokenizer(text, padding=True, return_tensors="pt") for generation_config in generation_configs: outputs = model.generate(**tokens, generation_config=generation_config) + transformers_outputs = trasnformers_model.generate(**tokens, generation_config=generation_config) self.assertIsInstance(outputs, torch.Tensor) self.assertTrue(torch.equal(outputs, transformers_outputs)) From 1ef8d567ec615892b395bd7fa42400551aba045a Mon Sep 17 00:00:00 2001 From: "Lin, Fanli" Date: Thu, 6 Jun 2024 08:16:54 -0700 Subject: [PATCH 21/23] use reshape --- optimum/exporters/ipex/modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index 9251f3dec5..e7ea1edf3e 100644 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -278,7 +278,7 @@ def forward( sdpa = self.sdpa_with_cache if use_cache else self.sdpa_without_cache attn_output, past_key_value, attn_weights = sdpa(query, key, value, past_key_value, attention_mask) - attn_output = attn_output.transpose(1, 2).view(bsz, seq_len, self.hidden_size) + attn_output = attn_output.transpose(1, 2).reshape(bsz, seq_len, self.hidden_size) if hasattr(self, "mha_linear_add"): attn_output = self.mha_linear_add(attn_output, residual) From 5f5d2050283ec22667a5352b85d00bbb75c858f0 Mon Sep 17 00:00:00 2001 From: Fanli Lin Date: Fri, 7 Jun 2024 00:48:53 +0800 Subject: [PATCH 22/23] Apply suggestions from code review Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> --- optimum/intel/utils/modeling_utils.py | 2 +- tests/ipex/test_modeling.py | 8 -------- 2 files changed, 1 insertion(+), 9 deletions(-) diff --git a/optimum/intel/utils/modeling_utils.py b/optimum/intel/utils/modeling_utils.py index b213ea3107..3541f4f933 100644 --- a/optimum/intel/utils/modeling_utils.py +++ b/optimum/intel/utils/modeling_utils.py @@ -184,7 +184,7 @@ def recursive_to_device(value, device): return value -def setattr_from_module(new_module, module): +def _setattr_from_module(new_module, module): for k, v in module.__dict__.items(): setattr(new_module, k, v) for k, v in module.__class__.__dict__.items(): diff --git a/tests/ipex/test_modeling.py b/tests/ipex/test_modeling.py index 1ea653eedc..8664b99cee 100644 --- a/tests/ipex/test_modeling.py +++ b/tests/ipex/test_modeling.py @@ -242,10 +242,6 @@ def test_pipeline(self, model_arch): # High optimized model llama is not supported assisted decoding for now. @parameterized.expand(SUPPORTED_ARCHITECTURES) - @unittest.skipIf( - is_ipex_version(">=", "2.3.0"), - reason="CPU IPEXModel does not support assisted decoding when ipex version >= 2.3.0", - ) def test_assisted_decoding(self, model_arch): if model_arch == "llama2": return @@ -301,10 +297,6 @@ def test_ipex_patching_beam_search(self, test_name, model_arch, use_cache): self.assertIsInstance(outputs, torch.Tensor) self.assertTrue(torch.equal(outputs, transformers_outputs)) - @unittest.skipIf( - is_ipex_version(">=", "2.3.0"), - reason="CPU IPEXModel only supports with past_key_values for ipex version >= 2.3.0", - ) def test_compare_with_and_without_past_key_values(self): model_id = "echarlaix/tiny-random-gpt2-torchscript" tokenizer = AutoTokenizer.from_pretrained(model_id) From 22860f21d6705f4cb0051f7ddb38d21139d7d0bd Mon Sep 17 00:00:00 2001 From: "Lin, Fanli" Date: Thu, 6 Jun 2024 09:53:51 -0700 Subject: [PATCH 23/23] fix --- optimum/exporters/ipex/model_patcher.py | 12 ++++++++- optimum/exporters/ipex/modeling_utils.py | 34 +++++++++++++----------- 2 files changed, 30 insertions(+), 16 deletions(-) diff --git a/optimum/exporters/ipex/model_patcher.py b/optimum/exporters/ipex/model_patcher.py index e0fad0d6c2..c6ed33006e 100644 --- a/optimum/exporters/ipex/model_patcher.py +++ b/optimum/exporters/ipex/model_patcher.py @@ -19,7 +19,7 @@ LlamaRMSNorm, ) -from optimum.intel.utils.import_utils import is_ipex_version +from optimum.intel.utils.import_utils import is_ipex_version, is_transformers_version from .modeling_utils import ( _IPEX_MINIMUM_VERSION_FOR_PATCHING, @@ -29,6 +29,10 @@ ) +# Please also update in the setup.py and .github/workflows/test_ipex.yml if you change the transformers version +_TRANSFORMERS_MIN_VERSION = "4.39.0" +_TRANSFORMERS_MAX_VERSION = "4.41.2" + _IPEX_EXPORTED_ARCH = ("LlamaForCausalLM",) _IPEX_EXPORTED_TASK = ("text-generation",) @@ -63,6 +67,12 @@ def patch_op(m, target_m, new_op_name, new_op): def _patch_llama_model(model): if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING): raise ImportError(f"Only ipex version >= {_IPEX_MINIMUM_VERSION_FOR_PATCHING} supports llama model patching") + if is_transformers_version("<", _TRANSFORMERS_MIN_VERSION) or is_transformers_version( + ">", _TRANSFORMERS_MAX_VERSION + ): + raise ImportError( + f"Only transformers versions {_TRANSFORMERS_MIN_VERSION} ~ {_TRANSFORMERS_MAX_VERSION} are verified." + ) convert_functions(model, LlamaModel, "forward", _llama_model_forward) convert_functions(model, LlamaRMSNorm, "forward", _llama_layer_norm_forward) convert_class(model, LlamaDecoderLayer, _IPEXLlamaDecoderLayer, model.config) diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index e7ea1edf3e..5870a8c792 100644 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -19,15 +19,12 @@ from torch import nn from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from transformers.modeling_outputs import BaseModelOutputWithPast -from transformers.models.llama.modeling_llama import repeat_kv +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv from optimum.intel.utils.import_utils import is_ipex_version -from optimum.intel.utils.modeling_utils import setattr_from_module +from optimum.intel.utils.modeling_utils import _setattr_from_module -# Please also update in the setup.py and .github/workflows/test_ipex.yml if you change the transformers version -_TRANSFORMERS_MIN_VERSION = "4.39.0" -_TRANSFORMERS_MAX_VERSION = "4.41.2" _IPEX_MINIMUM_VERSION_FOR_PATCHING = "2.3.0" @@ -142,12 +139,12 @@ def _llama_model_forward( # Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L321 class _IPEXLlamaAttention(nn.Module): def __init__(self, module, config, distributed=False) -> None: - if is_ipex_version("<", "2.3.0"): + if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING): raise ImportError( - "Only ipex version > 2.3.0 supports LinearAdd, IndirectAccessKVCacheAttention, RotaryEmbedding" + f"Only ipex version > {_IPEX_MINIMUM_VERSION_FOR_PATCHING} supports IndirectAccessKVCacheAttention, LinearAdd, RotaryEmbedding" ) super().__init__() - setattr_from_module(self, module) + _setattr_from_module(self, module) self.config = config self.distributed = distributed from intel_extension_for_pytorch.llm.modules import IndirectAccessKVCacheAttention, LinearAdd, RotaryEmbedding @@ -200,7 +197,7 @@ def rope(self, query, key, kv_seq_len, position_ids, use_cache): ) return query, key - def sdpa_with_cache(self, query, key, value, past_key_value, attention_mask): + def sdpa_with_cache(self, query, key, value, past_key_value, attention_mask, position_ids): # This ipex op pre-allocates buffers for past_key_values and use beam index history # which to decide which beam should be used to make attention scale dot more efficient. (attn_output, attn_weights, past_key_value) = self.ipex_scale_dot_product( @@ -215,11 +212,14 @@ def sdpa_with_cache(self, query, key, value, past_key_value, attention_mask): return attn_output, past_key_value, attn_weights # Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L341 - def sdpa_without_cache(self, query, key, value, past_key_value, attention_mask): + def sdpa_without_cache(self, query, key, value, past_key_value, attention_mask, position_ids): value_states = value.transpose(1, 2) query_states = query.transpose(1, 2) key_states = key.transpose(1, 2) + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + past_key_value = None # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -277,7 +277,9 @@ def forward( query, key = self.rope(query, key, kv_seq_len, position_ids, use_cache) sdpa = self.sdpa_with_cache if use_cache else self.sdpa_without_cache - attn_output, past_key_value, attn_weights = sdpa(query, key, value, past_key_value, attention_mask) + attn_output, past_key_value, attn_weights = sdpa( + query, key, value, past_key_value, attention_mask, position_ids + ) attn_output = attn_output.transpose(1, 2).reshape(bsz, seq_len, self.hidden_size) if hasattr(self, "mha_linear_add"): @@ -295,10 +297,12 @@ def forward( # Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L186 class _IPEXLlamaMLP(nn.Module): def __init__(self, module, config, distributed=False) -> None: - if is_ipex_version("<", "2.3.0"): - raise ImportError("Only ipex version > 2.3.0 supports Linear2SiluMul and LinearAdd") + if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING): + raise ImportError( + f"Only ipex version > {_IPEX_MINIMUM_VERSION_FOR_PATCHING} supports Linear2SiluMul, LinearAdd" + ) super().__init__() - setattr_from_module(self, module) + _setattr_from_module(self, module) self.config = config self.distributed = distributed from intel_extension_for_pytorch.llm.modules import Linear2SiluMul, LinearAdd @@ -334,7 +338,7 @@ def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor = None, ** class _IPEXLlamaDecoderLayer(nn.Module): def __init__(self, module, config, distributed=False): super().__init__() - setattr_from_module(self, module) + _setattr_from_module(self, module) self.distributed = distributed self.self_attn = _IPEXLlamaAttention(module.self_attn, config, distributed) self.mlp = _IPEXLlamaMLP(module.mlp, config, distributed)