Skip to content

Commit

Permalink
fix bauchan-13b
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Apr 22, 2024
1 parent 514f054 commit c674492
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 20 deletions.
21 changes: 11 additions & 10 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ def ts_patched_forward(*args, **kwargs):

with patcher:
check_dummy_inputs_are_allowed(model, dummy_inputs)
sig = inspect.signature(model.forward) if hasattr(model, "forward") else inspect.signature(model.call)
inputs = config.ordered_inputs(model)
input_names = list(inputs.keys())
output_names = list(config.outputs.keys())
Expand Down Expand Up @@ -376,7 +377,6 @@ def ts_patched_forward(*args, **kwargs):
ov_config=ov_config,
)

sig = inspect.signature(model.forward) if hasattr(model, "forward") else inspect.signature(model.call)
ordered_dummy_inputs = {param: dummy_inputs[param] for param in sig.parameters if param in dummy_inputs}
if not ordered_dummy_inputs:
ordered_dummy_inputs = dummy_inputs
Expand All @@ -388,15 +388,16 @@ def ts_patched_forward(*args, **kwargs):
out_tensor.get_tensor().set_names({output_names[idx]})

for idx, inp_tensor in enumerate(ov_model.inputs):
input_name = ordered_input_names[idx]
inp_tensor.get_tensor().set_names({input_name})
inp_data = flatten_inputs[idx]
static_shape = PartialShape(inp_data.shape)
dims = inputs[input_name]
for dim in dims:
static_shape[dim] = -1
inp_tensor.get_node().set_partial_shape(static_shape)
inp_tensor.get_node().set_element_type(get_element_type(inp_data.cpu().numpy().dtype))
if idx < len(ordered_input_names):
input_name = ordered_input_names[idx]
inp_tensor.get_tensor().set_names({input_name})
inp_data = flatten_inputs[idx]
static_shape = PartialShape(inp_data.shape)
dims = inputs.get(input_name, [])
for dim in dims:
static_shape[dim] = -1
inp_tensor.get_node().set_partial_shape(static_shape)
inp_tensor.get_node().set_element_type(get_element_type(inp_data.cpu().numpy().dtype))
ov_model.validate_nodes_and_infer_types()

if stateful:
Expand Down
4 changes: 2 additions & 2 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@
BaichuanModelPatcher,
ChatGLMModelPatcher,
GemmaModelPatcher,
InternLMPatcher,
LlamaModelPatcher,
MixtralModelPatcher,
QwenModelPatcher,
MPTModelPatcher,
InternLMPatcher,
QwenModelPatcher,
)


Expand Down
94 changes: 89 additions & 5 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import logging as log
import math
import types
Expand Down Expand Up @@ -328,9 +329,9 @@ def _llama_gemma_update_causal_mask(self, attention_mask, input_tensor, cache_po
offset = 0
mask_shape = attention_mask.shape
mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = (
mask_slice
)
causal_mask[
: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]
] = mask_slice

if (
self.config._attn_implementation == "sdpa"
Expand Down Expand Up @@ -601,6 +602,46 @@ def __exit__(self, exc_type, exc_value, traceback):
self._model.config.fp16 = self.original_fp16


def _baichuan13b_atten_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = True,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()

proj = self.W_pack(hidden_states)
proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2)
query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]

if past_key_value is not None:
# reuse k, v, self_attention
if attention_mask is not None:
attention_mask = attention_mask[:, :, -key_states.shape[-2] :, :]
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)

past_key_value = (key_states, value_states) if use_cache else None
attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask)
attn_output = attn_output.transpose(1, 2)
attn_weights = None
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)

if not output_attentions:
attn_weights = None

return attn_output, attn_weights, past_key_value


class BaichuanModelPatcher(DecoderModelPatcher):
def __init__(
self,
Expand All @@ -613,6 +654,50 @@ def __init__(
if hasattr(self._model.lm_head, "first_flag"):
self._model(torch.ones((1, 10), dtype=torch.int64), torch.ones((1, 10), dtype=torch.int64))

def __enter__(self):
super().__enter__()
# override signature to have position_ids
if "position_ids" not in inspect.signature(self._model.forward).parameters:
self._model._orig_forward = self._model.forward

def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True,
position_ids: Optional[torch.LongTensor] = None,
):
return self._orig_forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=past_key_values is not None,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=self.config.return_dict,
)

self._model.forward = types.MethodType(forward, self._model)
for layer in self._model.model.layers:
layer.self_attn._orig_forward = layer.self_attn.forward
layer.self_attn.forward = types.MethodType(_baichuan13b_atten_forward, layer.self_attn)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if hasattr(self._model, "_orig_forward"):
self._model.forward = self._model._orig_forward

for layer in self._model.model.layers:
layer.self_attn.forward = layer.self_attn._orig_forward


def _mpt_attention_forward(
self,
Expand Down Expand Up @@ -679,8 +764,7 @@ def _internlm_attention_forward(
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:

from transformers.models.llama.modeling_llama import repeat_kv, apply_rotary_pos_emb
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv

bsz, q_len, _ = hidden_states.size()

Expand Down
6 changes: 3 additions & 3 deletions optimum/intel/openvino/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,9 +484,9 @@ def _quantize_torchmodel(
subset_size=quantization_config.num_samples,
ignored_scope=quantization_config.get_ignored_scope_instance(),
model_type=nncf.ModelType(quantization_config.model_type),
preset=nncf.QuantizationPreset.PERFORMANCE
if quantization_config.sym
else nncf.QuantizationPreset.MIXED,
preset=(
nncf.QuantizationPreset.PERFORMANCE if quantization_config.sym else nncf.QuantizationPreset.MIXED
),
fast_bias_correction=quantization_config.fast_bias_correction,
advanced_parameters=nncf.AdvancedQuantizationParameters(
overflow_fix=OverflowFix(quantization_config.overflow_fix)
Expand Down

0 comments on commit c674492

Please sign in to comment.