Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor CPU llama inference code #728

Merged
merged 31 commits into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
5351f4a
ipex 2.3 released
jiqing-feng May 23, 2024
d1d0ca0
refactor IPEXLlamaAttention
faaany May 25, 2024
bd5706c
Merge branch 'huggingface:main' into ipex-cpu
faaany May 26, 2024
e61382b
Merge branch 'main' of https://github.com/faaany/optimum-intel into i…
faaany May 26, 2024
48b205e
change to Ref
faaany May 26, 2024
404486a
Merge branch 'ipex-cpu' of https://github.com/faaany/optimum-intel in…
faaany May 26, 2024
4ea8a47
remove Ref
faaany May 27, 2024
1f98d6d
skip tests
jiqing-feng May 27, 2024
d3ce377
skip tests
jiqing-feng May 27, 2024
b2b93bb
skip testing without pkv
jiqing-feng May 27, 2024
ec0f641
Merge branch 'rename' of https://github.com/jiqing-feng/optimum-intel…
faaany May 27, 2024
64dcde4
add tests skip
jiqing-feng May 27, 2024
945f6b6
only llama2 with at least 64 head size support IAKV
jiqing-feng May 27, 2024
0733625
Merge branch 'rename' of https://github.com/jiqing-feng/optimum-intel…
faaany May 27, 2024
c8922f3
cannot assert same outputs cause do_sample=True
jiqing-feng May 27, 2024
0ff1d7b
Merge branch 'rename' of https://github.com/jiqing-feng/optimum-intel…
faaany May 27, 2024
2ddfa7a
rm tiny-llama model testing cause it not work for IAKV
jiqing-feng May 27, 2024
f4e887d
fix code style
jiqing-feng May 28, 2024
923e233
Merge branch 'rename' of https://github.com/jiqing-feng/optimum-intel…
faaany May 28, 2024
74f132e
refine docstring
faaany May 28, 2024
e130345
fix duplicted code
faaany May 30, 2024
14673da
refactor attention forward
faaany Jun 3, 2024
a2a969e
add use_cache for rope
faaany Jun 3, 2024
3abd790
use with and without cache
faaany Jun 3, 2024
82bd0c7
refine code
faaany Jun 3, 2024
de2cc43
add reference link
faaany Jun 4, 2024
1385f97
Merge branch 'main' into ipex-cpu
faaany Jun 6, 2024
752aba6
bug fix
faaany Jun 6, 2024
1ef8d56
use reshape
faaany Jun 6, 2024
5f5d205
Apply suggestions from code review
faaany Jun 6, 2024
22860f2
fix
faaany Jun 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 2 additions & 18 deletions optimum/exporters/ipex/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaForCausalLM,
LlamaModel,
Expand All @@ -24,7 +23,6 @@

from .modeling_utils import (
_IPEXLlamaDecoderLayerRef,
_llama_attn_forward,
_llama_layer_norm_forward,
_llama_model_forward,
)
Expand Down Expand Up @@ -62,25 +60,11 @@ def patch_op(m, target_m, new_op_name, new_op):


def _patch_llama_model(model):
if is_ipex_version("<", "2.5.0"):
raise ImportError("Only ipex version > 2.3.0 supports RotaryEmbedding and IndirectAccessKVCache")

from intel_extension_for_pytorch.llm.modules import IndirectAccessKVCache, RotaryEmbedding

ipex_rope = RotaryEmbedding(
model.config.max_position_embeddings,
model.config.hidden_size // model.config.num_attention_heads,
model.config.rope_theta,
model.config.architectures[0],
)
ipex_scale_dot_product = IndirectAccessKVCache(text_max_length=model.config.max_position_embeddings)
patch_op(model, LlamaAttention, "ipex_rope", ipex_rope)
patch_op(model, LlamaAttention, "ipex_scale_dot_product", ipex_scale_dot_product)
if is_ipex_version("<", "2.3.0"):
raise ImportError("Only ipex version >= 2.3.0 supports llama model patching")

convert_functions(model, LlamaModel, "forward", _llama_model_forward)
convert_functions(model, LlamaAttention, "forward", _llama_attn_forward)
convert_functions(model, LlamaRMSNorm, "forward", _llama_layer_norm_forward)

convert_class(model, LlamaDecoderLayer, _IPEXLlamaDecoderLayerRef, model.config)
faaany marked this conversation as resolved.
Show resolved Hide resolved
return model

Expand Down
295 changes: 185 additions & 110 deletions optimum/exporters/ipex/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,90 +29,6 @@ def _llama_layer_norm_forward(self, hidden_states):
return torch.ops.torch_ipex.rmsnorm(hidden_states, self.weight, self.variance_epsilon)


# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L321
def _llama_attn_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()

query = self.q_proj(hidden_states)
key = self.k_proj(hidden_states)
value = self.v_proj(hidden_states)

kv_seq_len = q_len + past_key_value[0].size(-2) if past_key_value is not None else q_len

query = query.view(bsz, q_len, self.num_heads, self.head_dim)
key = key.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
value = value.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
# Use ipex op to rotary position embedding more efficient.
key = self.ipex_rope(
key,
position_ids,
self.num_key_value_heads,
self.head_dim,
self.head_dim // 2,
self.head_dim,
kv_seq_len,
)
query = self.ipex_rope(
query,
position_ids,
self.num_heads,
self.head_dim,
self.head_dim // 2,
self.head_dim,
kv_seq_len,
)

if use_cache:
# This ipex op pre-allocates buffers for past_key_values and use beam index history
# which to decide which beam should be used to make attention scale dot more efficient.
(attn_output, attn_weights, past_key_value) = self.ipex_scale_dot_product(
query,
key,
value,
math.sqrt(self.head_dim),
past_key_value,
None,
attention_mask,
)
else:
value_states = value.transpose(1, 2)
query_states = query.transpose(1, 2)
key_states = key.transpose(1, 2)
kv_seq_len = key_states.shape[-2]

past_key_value = None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

if attention_mask is not None:
attn_weights = torch.tensor(attn_weights) + torch.tensor(attention_mask)
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))

# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)

attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

if not output_attentions:
attn_weights = None

return attn_output, attn_weights, past_key_value


# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L1130
def _llama_model_forward(
self,
Expand Down Expand Up @@ -216,14 +132,188 @@ def _llama_model_forward(
)


# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L694
class _IPEXLlamaDecoderLayerRef(nn.Module):
def __init__(self, module, config, distributed=False):
if is_ipex_version("<", "2.5.0"):
# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L321
class _IPEXLlamaAttentionRef(nn.Module):
faaany marked this conversation as resolved.
Show resolved Hide resolved
faaany marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, module, config, distributed=False) -> None:
if is_ipex_version("<", "2.3.0"):
raise ImportError(
"Only ipex version > 2.3.0 supports LinearAdd, IndirectAccessKVCacheAttention, RotaryEmbedding"
)
super().__init__()
for k, v in module.__dict__.items():
setattr(self, k, v)
for k, v in module.__class__.__dict__.items():
if k.startswith("__") or k.startswith("forward"):
continue
setattr(self.__class__, k, getattr(module.__class__, k))
faaany marked this conversation as resolved.
Show resolved Hide resolved
self.config = config
self.distributed = distributed
from intel_extension_for_pytorch.llm.modules import IndirectAccessKVCacheAttention, LinearAdd, RotaryEmbedding

if not self.distributed:
self.mha_linear_add = LinearAdd(self.o_proj)
del self.__dict__["_modules"]["o_proj"]
self.ipex_scale_dot_product = IndirectAccessKVCacheAttention(
text_max_length=module.config.max_position_embeddings
)
self.ipex_rope = RotaryEmbedding(
module.config.max_position_embeddings,
module.config.hidden_size // module.config.num_attention_heads,
module.config.rope_theta,
module.config.architectures[0],
)

def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
residual: Optional[torch.Tensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*):
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
query_sequence_length, key_sequence_length)` if default attention is used.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
residual (`torch.Tensor`): residual tensor to the layer of shape `
"""
bsz, seq_len, _ = hidden_states.size()

query = self.q_proj(hidden_states)
key = self.k_proj(hidden_states)
value = self.v_proj(hidden_states)

query = query.view(bsz, seq_len, self.num_heads, self.head_dim)
key = key.view(bsz, seq_len, self.num_key_value_heads, self.head_dim)
value = value.view(bsz, seq_len, self.num_key_value_heads, self.head_dim)

kv_seq_len = seq_len + past_key_value[0].size(-2) if past_key_value is not None else seq_len
# Use ipex op to rotary position embedding more efficient.
key = self.ipex_rope(
key,
position_ids,
self.num_key_value_heads,
self.head_dim,
self.head_dim // 2,
self.head_dim,
kv_seq_len,
)
query = self.ipex_rope(
query,
position_ids,
self.num_heads,
self.head_dim,
self.head_dim // 2,
self.head_dim,
kv_seq_len,
)

if use_cache:
# This ipex op pre-allocates buffers for past_key_values and use beam index history
# which to decide which beam should be used to make attention scale dot more efficient.
(attn_output, attn_weights, past_key_value) = self.ipex_scale_dot_product(
query,
key,
value,
math.sqrt(self.head_dim),
past_key_value,
None,
attention_mask,
)
else:
value_states = value.transpose(1, 2)
query_states = query.transpose(1, 2)
key_states = key.transpose(1, 2)
kv_seq_len = key_states.shape[-2]

past_key_value = None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

if attention_mask is not None:
attn_weights = torch.tensor(attn_weights) + torch.tensor(attention_mask)
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))

# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)

attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, seq_len, self.hidden_size)

if hasattr(self, "mha_linear_add"):
attn_output = self.mha_linear_add(attn_output, residual)
else:
attn_output = self.o_proj(attn_output)
attn_output = residual + attn_output

if not output_attentions:
attn_weights = None

return attn_output, past_key_value, attn_weights


class _IPEXLlamaMLPRef(nn.Module):
faaany marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, module, config, distributed=False) -> None:
if is_ipex_version("<", "2.3.0"):
raise ImportError("Only ipex version > 2.3.0 supports Linear2SiluMul and LinearAdd")

super().__init__()
for k, v in module.__dict__.items():
setattr(self, k, v)
for k, v in module.__class__.__dict__.items():
if k.startswith("__") or k.startswith("forward"):
continue
setattr(self.__class__, k, getattr(module.__class__, k))
self.config = config
self.distributed = distributed
from intel_extension_for_pytorch.llm.modules import Linear2SiluMul, LinearAdd

if not self.distributed:
self.mlp_linear_add = LinearAdd(module.down_proj)
del self.__dict__["_modules"]["down_proj"]
self.linear_silu_mul = Linear2SiluMul(module.gate_proj, module.up_proj)
del self.__dict__["_modules"]["gate_proj"]
del self.__dict__["_modules"]["up_proj"]

def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor = None, **kwargs):
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
"""
if hasattr(self, "linear_silu_mul"):
mlp_gate = self.linear_silu_mul(hidden_states)
if hasattr(self, "mlp_linear_add"):
hidden_states = self.mlp_linear_add(mlp_gate, residual)
else:
hidden_states = self.down_proj(mlp_gate)
hidden_states = residual + hidden_states
else:
hidden_states = self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))
hidden_states = residual + hidden_states

return hidden_states


# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L694
class _IPEXLlamaDecoderLayerRef(nn.Module):
def __init__(self, module, config, distributed=False):
super().__init__()
for k, v in module.__dict__.items():
setattr(self, k, v)
Expand All @@ -232,14 +322,8 @@ def __init__(self, module, config, distributed=False):
continue
setattr(self.__class__, k, getattr(module.__class__, k))
self.distributed = distributed
if not self.distributed:
self.mha_linear_add = LinearAdd(module.self_attn.o_proj)
self.mlp_linear_add = LinearAdd(module.mlp.down_proj)
del self.__dict__["_modules"]["self_attn"].o_proj
del self.__dict__["_modules"]["mlp"].down_proj
self.linear_silu_mul = Linear2SiluMul(module.mlp.gate_proj, module.mlp.up_proj)
del self.__dict__["_modules"]["mlp"].gate_proj
del self.__dict__["_modules"]["mlp"].up_proj
faaany marked this conversation as resolved.
Show resolved Hide resolved
self.self_attn = _IPEXLlamaAttentionRef(module.self_attn, config, distributed)
self.mlp = _IPEXLlamaMLPRef(module.mlp, config, distributed)

def forward(
self,
Expand Down Expand Up @@ -270,31 +354,22 @@ def forward(
hidden_states = self.input_layernorm(hidden_states)

# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states, present_key_value, self_attn_weights = self.self_attn(
faaany marked this conversation as resolved.
Show resolved Hide resolved
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=None,
residual=residual,
**kwargs,
)
if not self.distributed:
hidden_states = self.mha_linear_add(hidden_states, residual)
else:
hidden_states = self.self_attn.o_proj(hidden_states)
hidden_states = residual + hidden_states

# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)

mlp_gate = self.linear_silu_mul(hidden_states)

if not self.distributed:
hidden_states = self.mlp_linear_add(mlp_gate, residual)
else:
hidden_states = self.mlp.down_proj(mlp_gate)
hidden_states = residual + hidden_states
hidden_states = self.mlp(hidden_states, residual, **kwargs)

outputs = (hidden_states,)

Expand Down
2 changes: 1 addition & 1 deletion optimum/intel/ipex/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@


def _is_patched_with_ipex(model, task):
if is_ipex_version("<", "2.5.0"):
if is_ipex_version("<", "2.3.0"):
return False

if isinstance(model, torch.jit.ScriptModule):
Expand Down
Loading
Loading