Skip to content

Commit 92fe39f

Browse files
authored
Removing attention mask patching (#791)
* Removing attention mask patching * ruff
1 parent eeb1df0 commit 92fe39f

File tree

2 files changed

+2
-110
lines changed

2 files changed

+2
-110
lines changed

.github/workflows/check_code_quality.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,4 @@ jobs:
5151
- name: Check style with ruff
5252
run: |
5353
source venv/bin/activate
54-
ruff .
54+
ruff check .

optimum/intel/utils/modeling_utils.py

Lines changed: 1 addition & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import re
1616
from pathlib import Path
17-
from typing import List, Optional, Tuple, Union
17+
from typing import List, Optional, Union
1818

1919
import torch
2020
from huggingface_hub import HfApi, HfFolder
@@ -23,114 +23,6 @@
2323
MULTI_QUERY_ATTN_MODELS = {"falcon", "gpt_bigcode"}
2424

2525

26-
# Modified from transformers.models.bloom.modeling_bloom._make_causal_mask
27-
def _make_causal_mask(
28-
input_ids_shape: torch.Size,
29-
device: torch.device,
30-
past_key_values_length: int,
31-
dtype: torch.dtype = torch.bool,
32-
) -> torch.BoolTensor:
33-
"""
34-
Make causal mask used for bi-directional self-attention.
35-
"""
36-
batch_size, target_length = input_ids_shape
37-
mask = torch.zeros((target_length, target_length + past_key_values_length), dtype=dtype, device=device)
38-
seq_ids = torch.arange(target_length, device=device)
39-
40-
mask[:, past_key_values_length:] = (
41-
(seq_ids[:, None] < seq_ids[None, :]) * torch.finfo(dtype).min
42-
if torch.is_floating_point(mask)
43-
else seq_ids[:, None] < seq_ids[None, :]
44-
)
45-
46-
return mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
47-
48-
49-
# Modified from transformers.models..bloom.modeling_bloom._prepare_attn_mask
50-
def _prepare_attn_mask(
51-
attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
52-
) -> torch.BoolTensor:
53-
from transformers.models.bloom.modeling_bloom import _expand_mask
54-
55-
# create causal mask
56-
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
57-
combined_attention_mask = None
58-
device = attention_mask.device
59-
_, src_length = input_shape
60-
61-
combined_attention_mask = _make_causal_mask(
62-
input_shape, device=device, past_key_values_length=past_key_values_length
63-
)
64-
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]_prepare_decoder_attention_mask
65-
expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
66-
combined_attention_mask = (
67-
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
68-
)
69-
70-
return combined_attention_mask
71-
72-
73-
# Modified from transformers.models.llama.modeling_llama._prepare_decoder_attention_mask
74-
def _prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds, past_key_values_length):
75-
from transformers.models.llama.modeling_llama import _expand_mask
76-
77-
# create causal mask
78-
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
79-
combined_attention_mask = None
80-
81-
combined_attention_mask = _make_causal_mask(
82-
input_shape,
83-
device=inputs_embeds.device,
84-
past_key_values_length=past_key_values_length,
85-
dtype=inputs_embeds.dtype,
86-
)
87-
88-
if attention_mask is not None:
89-
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
90-
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
91-
inputs_embeds.device
92-
)
93-
combined_attention_mask = (
94-
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
95-
)
96-
97-
return combined_attention_mask
98-
99-
100-
# Modified from transformers.models.mistral.modeling_mistral._prepare_decoder_sliding_window_attention_mask
101-
def _prepare_decoder_sliding_window_attention_mask(
102-
attention_mask: torch.Tensor,
103-
input_shape: Tuple[int, int],
104-
inputs_embeds: torch.Tensor,
105-
past_key_values_length: int,
106-
sliding_window: int,
107-
):
108-
from transformers.models.mistral.modeling_mistral import _expand_mask, _make_sliding_window_causal_mask
109-
110-
# create causal mask
111-
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
112-
combined_attention_mask = None
113-
114-
combined_attention_mask = _make_sliding_window_causal_mask(
115-
input_shape,
116-
device=inputs_embeds.device,
117-
dtype=inputs_embeds.dtype,
118-
past_key_values_length=past_key_values_length,
119-
sliding_window=sliding_window,
120-
)
121-
122-
if attention_mask is not None:
123-
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
124-
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
125-
inputs_embeds.device
126-
)
127-
combined_attention_mask = (
128-
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
129-
)
130-
131-
return combined_attention_mask
132-
133-
13426
def get_model_device(model: torch.nn.Module) -> torch.device:
13527
"""
13628
Determines the device on which a PyTorch model is currently residing.

0 commit comments

Comments
 (0)