Skip to content

Commit

Permalink
fix rebase bug
Browse files Browse the repository at this point in the history
  • Loading branch information
faaany committed Jun 7, 2024
1 parent 0a56b19 commit 548d83f
Show file tree
Hide file tree
Showing 2 changed files with 222 additions and 75 deletions.
285 changes: 216 additions & 69 deletions optimum/exporters/ipex/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv

from optimum.intel.utils.import_utils import is_ipex_version
from optimum.intel.utils.modeling_utils import _setattr_from_module


_IPEX_MINIMUM_VERSION_FOR_PATCHING = "2.3.0"


def matmul_add_add(attn_output, weight, bias=None, residual=None):
seq_len, bs, _ = attn_output.size()
if residual is None:
Expand Down Expand Up @@ -192,68 +195,124 @@ def __init__(self, module, config, distributed=False) -> None:
_setattr_from_module(self, module)
self.config = config
self.distributed = distributed
from intel_extension_for_pytorch.llm.modules import IndirectAccessKVCacheAttention, LinearAdd, RotaryEmbedding
self.module_device = next(module.parameters()).device.type
if self.module_device == "xpu":
from intel_extension_for_pytorch.transformers.models.xpu.fusions.mha_fusion import _IPEXRopeXPU

self.ipex_rope = _IPEXRopeXPU(
module.config.max_position_embeddings,
module.config.hidden_size // module.config.num_attention_heads,
module.config.rope_theta,
module.config.architectures[0],
)
self.port_parameters(module)
torch.xpu.empty_cache()
else:
from intel_extension_for_pytorch.llm.modules import (
IndirectAccessKVCacheAttention,
LinearAdd,
RotaryEmbedding,
)

if not self.distributed:
self.mha_linear_add = LinearAdd(self.o_proj)
del self.__dict__["_modules"]["o_proj"]
self.ipex_scale_dot_product = IndirectAccessKVCacheAttention(
text_max_length=module.config.max_position_embeddings
)
self.ipex_rope = RotaryEmbedding(
module.config.max_position_embeddings,
module.config.hidden_size // module.config.num_attention_heads,
module.config.rope_theta,
module.config.architectures[0],
)
if not self.distributed:
self.mha_linear_add = LinearAdd(self.o_proj)
del self.__dict__["_modules"]["o_proj"]
self.ipex_scale_dot_product = IndirectAccessKVCacheAttention(
text_max_length=module.config.max_position_embeddings
)
self.ipex_rope = RotaryEmbedding(
module.config.max_position_embeddings,
module.config.hidden_size // module.config.num_attention_heads,
module.config.rope_theta,
module.config.architectures[0],
)

def qkv_gemm(self, hidden_states):
bsz, seq_len, _ = hidden_states.size()

query = self.q_proj(hidden_states)
key = self.k_proj(hidden_states)
value = self.v_proj(hidden_states)
if self.module_device == "xpu":
query_shape = (bsz, seq_len, self.num_heads * self.head_dim)
kv_shape = (bsz, seq_len, self.num_key_value_heads * self.head_dim)
dtype = hidden_states.dtype
device = hidden_states.device
if self.num_key_value_heads == self.num_heads:
query = torch.empty(query_shape, dtype=dtype, device=device)
key = torch.empty(query_shape, dtype=dtype, device=device)
value = torch.empty(query_shape, dtype=dtype, device=device)
torch.ops.torch_ipex.mm_qkv_out(
hidden_states,
self.qkv_proj_weight,
self.qkv_proj_bias,
query,
key,
value,
)
else:
query = torch.empty(query_shape, dtype=dtype, device=device)
key = torch.empty(kv_shape, dtype=dtype, device=device)
value = torch.empty(kv_shape, dtype=dtype, device=device)
torch.ops.torch_ipex.mm_qkv_group_out(
hidden_states, self.qkv_proj_weight, self.qkv_proj_bias, query, key, value
)
else:
query = self.q_proj(hidden_states)
key = self.k_proj(hidden_states)
value = self.v_proj(hidden_states)

query = query.view(bsz, seq_len, self.num_heads, self.head_dim)
key = key.view(bsz, seq_len, self.num_key_value_heads, self.head_dim)
value = value.view(bsz, seq_len, self.num_key_value_heads, self.head_dim)

return query, key, value

def rope(self, query, key, kv_seq_len, position_ids, use_cache):
if use_cache:
key = self.ipex_rope(
key,
position_ids,
self.num_key_value_heads,
self.head_dim,
self.head_dim // 2,
self.head_dim,
kv_seq_len,
)
query = self.ipex_rope(
query,
position_ids,
self.num_heads,
self.head_dim,
self.head_dim // 2,
self.head_dim,
kv_seq_len,
)
def rope(self, query, key, kv_seq_len, position_ids, use_cache, **kwargs):
if self.module_device == "xpu":
sin = kwargs.pop("sin", None)
cos = kwargs.pop("cos", None)
self.ipex_rope.apply_embedding(query, sin, cos, self.head_dim // 2, key)
else:
if use_cache:
key = self.ipex_rope(
key,
position_ids,
self.num_key_value_heads,
self.head_dim,
self.head_dim // 2,
self.head_dim,
kv_seq_len,
)
query = self.ipex_rope(
query,
position_ids,
self.num_heads,
self.head_dim,
self.head_dim // 2,
self.head_dim,
kv_seq_len,
)
return query, key

def sdpa_with_cache(self, query, key, value, past_key_value, attention_mask, position_ids):
# This ipex op pre-allocates buffers for past_key_values and use beam index history
# which to decide which beam should be used to make attention scale dot more efficient.
(attn_output, attn_weights, past_key_value) = self.ipex_scale_dot_product(
query,
key,
value,
math.sqrt(self.head_dim),
past_key_value,
None,
attention_mask,
)

if self.module_device == "xpu":
scale = 1.0 / math.sqrt(self.head_dim)
is_causal = False
attn_output = torch.xpu.IpexSDP(
query, key, value, None, attention_mask, None, scale, 1.0, 0.0, is_causal, False
)
attn_weights = None
past_key_value = (key, value)
else:
# This ipex op pre-allocates buffers for past_key_values and use beam index history
# which to decide which beam should be used to make attention scale dot more efficient.
(attn_output, attn_weights, past_key_value) = self.ipex_scale_dot_product(
query,
key,
value,
math.sqrt(self.head_dim),
past_key_value,
None,
attention_mask,
)
return attn_output, past_key_value, attn_weights

# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L341
Expand Down Expand Up @@ -319,25 +378,91 @@ def forward(
kv_seq_len = seq_len + past_key_value[0].size(-2) if past_key_value is not None else seq_len

query, key, value = self.qkv_gemm(hidden_states)
query, key = self.rope(query, key, kv_seq_len, position_ids, use_cache)
query, key = self.rope(query, key, kv_seq_len, position_ids, use_cache, **kwargs)

if self.module_device == "xpu":
if past_key_value is not None:
key = torch.cat([past_key_value[0].transpose(1, 2), key], dim=1)
value = torch.cat([past_key_value[1].transpose(1, 2), value], dim=1)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)

sdpa = self.sdpa_with_cache if use_cache else self.sdpa_without_cache
attn_output, past_key_value, attn_weights = sdpa(
query, key, value, past_key_value, attention_mask, position_ids
)
attn_output = attn_output.transpose(1, 2).reshape(bsz, seq_len, self.hidden_size)

if hasattr(self, "mha_linear_add"):
attn_output = self.mha_linear_add(attn_output, residual)
if self.module_device == "xpu":
attn_output = matmul_add_add(attn_output, self.o_proj_weight, self.o_proj_bias, residual).view(
[bsz, seq_len, self.hidden_size]
)
else:
attn_output = self.o_proj(attn_output)
attn_output = residual + attn_output
if hasattr(self, "mha_linear_add"):
attn_output = self.mha_linear_add(attn_output, residual)
else:
attn_output = self.o_proj(attn_output)
attn_output = residual + attn_output

if not output_attentions:
attn_weights = None

return attn_output, attn_weights, past_key_value

def port_parameters(self, module):
self.qkv_proj_bias = None
self.qkv_proj_weight = None
if self.num_heads == self.num_key_value_heads:
q_proj = module.q_proj.weight.transpose(0, 1)
k_proj = module.k_proj.weight.transpose(0, 1)
v_proj = module.v_proj.weight.transpose(0, 1)
self.qkv_proj_weight = torch.stack([q_proj, k_proj, v_proj]).contiguous().view([3, -1, q_proj.shape[-1]])
module.q_proj.weight.data = self.qkv_proj_weight[0, :, :].transpose(0, 1)
module.k_proj.weight.data = self.qkv_proj_weight[1, :, :].transpose(0, 1)
module.v_proj.weight.data = self.qkv_proj_weight[2, :, :].transpose(0, 1)
if module.q_proj.bias is not None:
self.qkv_proj_bias = (
torch.stack([module.q_proj.bias, module.k_proj.bias, module.v_proj.bias])
.contiguous()
.view([3, -1])
)
module.q_proj.bias.data = self.qkv_proj_bias[0]
module.k_proj.bias.data = self.qkv_proj_bias[1]
module.v_proj.bias.data = self.qkv_proj_bias[2]
else:
q_proj = module.q_proj.weight.view(
self.num_key_value_heads, self.num_key_value_groups, self.head_dim, self.hidden_size
)
k_proj = module.k_proj.weight.view(self.num_key_value_heads, 1, self.head_dim, self.hidden_size)
v_proj = module.v_proj.weight.view(self.num_key_value_heads, 1, self.head_dim, self.hidden_size)
self.qkv_proj_weight = torch.cat([q_proj, k_proj, v_proj], dim=1).view(
[self.num_key_value_heads, self.num_key_value_groups + 2, self.head_dim, self.hidden_size]
)
module.q_proj.data = self.qkv_proj_weight[:, : self.num_key_value_groups, :, :].reshape(
[self.num_key_value_heads * self.num_key_value_groups * self.head_dim, self.hidden_size]
)
module.k_proj.data = self.qkv_proj_weight[:, self.num_key_value_groups, :, :].reshape(
[self.num_key_value_heads * self.head_dim, self.hidden_size]
)
module.v_proj.data = self.qkv_proj_weight[:, self.num_key_value_groups + 1, :, :].reshape(
[self.num_key_value_heads * self.head_dim, self.hidden_size]
)
self.qkv_proj_weight = self.qkv_proj_weight.permute(3, 0, 1, 2).contiguous()
if module.q_proj.bias is not None:
q_bias = module.q_proj.bias.view(self.num_key_value_heads, self.num_key_value_groups, self.head_dim)
k_bias = module.k_proj.bias.view(self.num_key_value_heads, 1, self.head_dim)
v_bias = module.v_proj.bias.view(self.num_key_value_heads, 1, self.head_dim)
self.qkv_proj_bias = torch.cat([q_bias, k_bias, v_bias], dim=1).view(
[self.num_key_value_heads, self.num_key_value_groups + 2, self.head_dim]
)
module.q_proj.bias.data = self.qkv_proj_bias[:, : self.num_key_value_groups, self.head_dim].view(-1)
module.k_proj.bias.data = self.qkv_proj_bias[:, self.num_key_value_groups, self.head_dim].view(-1)
module.v_proj.bias.data = self.qkv_proj_bias[:, self.num_key_value_groups + 1, self.head_dim].view(-1)
self.o_proj_weight = module.o_proj.weight.transpose(0, 1).contiguous()
module.o_proj.weight.data = self.o_proj_weight.transpose(0, 1)
self.o_proj_bias = module.o_proj.bias


# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L186
class _IPEXLlamaMLP(nn.Module):
Expand All @@ -350,34 +475,56 @@ def __init__(self, module, config, distributed=False) -> None:
_setattr_from_module(self, module)
self.config = config
self.distributed = distributed
from intel_extension_for_pytorch.llm.modules import Linear2SiluMul, LinearAdd
self.module_device = next(module.parameters()).device.type
if self.module_device == "xpu":
self.port_parameter(module)
torch.xpu.empty_cache()
else:
from intel_extension_for_pytorch.llm.modules import Linear2SiluMul, LinearAdd

if not self.distributed:
self.mlp_linear_add = LinearAdd(module.down_proj)
del self.__dict__["_modules"]["down_proj"]
self.linear_silu_mul = Linear2SiluMul(module.gate_proj, module.up_proj)
del self.__dict__["_modules"]["gate_proj"]
del self.__dict__["_modules"]["up_proj"]
if not self.distributed:
self.mlp_linear_add = LinearAdd(module.down_proj)
del self.__dict__["_modules"]["down_proj"]
self.linear_silu_mul = Linear2SiluMul(module.gate_proj, module.up_proj)
del self.__dict__["_modules"]["gate_proj"]
del self.__dict__["_modules"]["up_proj"]

def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor = None, **kwargs):
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
residual (`torch.Tensor`): residual tensor to the layer of shape (batch, seq_len, embed_dim)`
"""
if hasattr(self, "linear_silu_mul"):
mlp_gate = self.linear_silu_mul(hidden_states)
if hasattr(self, "mlp_linear_add"):
hidden_states = self.mlp_linear_add(mlp_gate, residual)
if self.module_device == "xpu":
up = torch.ops.torch_ipex.mm_silu(hidden_states, self.gate_proj_weight)
hidden_states = torch.ops.torch_ipex.mm_resmul(hidden_states, self.up_proj_weight, up)
hidden_states = matmul_add_add(hidden_states, self.down_proj_weight, self.down_proj_bias, residual)
else:
if hasattr(self, "linear_silu_mul"):
mlp_gate = self.linear_silu_mul(hidden_states)
if hasattr(self, "mlp_linear_add"):
hidden_states = self.mlp_linear_add(mlp_gate, residual)
else:
hidden_states = self.down_proj(mlp_gate)
hidden_states = residual + hidden_states
else:
hidden_states = self.down_proj(mlp_gate)
hidden_states = self.down_proj(
self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states)
)
hidden_states = residual + hidden_states
else:
hidden_states = self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))
hidden_states = residual + hidden_states

return hidden_states

def port_parameter(self, module):
self.up_proj_weight = module.up_proj.weight.transpose(0, 1).contiguous()
module.up_proj.weight.data = self.up_proj_weight.transpose(0, 1)
self.gate_proj_weight = module.gate_proj.weight.transpose(0, 1).contiguous()
module.gate_proj.weight.data = self.gate_proj_weight.transpose(0, 1)
self.down_proj_weight = module.down_proj.weight.transpose(0, 1).contiguous()
module.down_proj.weight.data = self.down_proj_weight.transpose(0, 1)
self.up_proj_bias = module.up_proj.bias
self.gate_proj_bias = module.gate_proj.bias
self.down_proj_bias = module.down_proj.bias


# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L694
class _IPEXLlamaDecoderLayer(nn.Module):
Expand Down
12 changes: 6 additions & 6 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,9 +364,9 @@ def _llama_gemma_update_causal_mask_legacy(self, attention_mask, input_tensor, c
offset = 0
mask_shape = attention_mask.shape
mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
causal_mask[
: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]
] = mask_slice
causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = (
mask_slice
)

if (
self.config._attn_implementation == "sdpa"
Expand Down Expand Up @@ -1637,9 +1637,9 @@ def _dbrx_update_causal_mask_legacy(
offset = 0
mask_shape = attention_mask.shape
mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
causal_mask[
: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]
] = mask_slice
causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = (
mask_slice
)

if (
self.config._attn_implementation == "sdpa"
Expand Down

0 comments on commit 548d83f

Please sign in to comment.