Skip to content

Commit

Permalink
Add openvino support glm4 (huggingface#776)
Browse files Browse the repository at this point in the history
* add support glm4

* add test

* Update optimum/exporters/openvino/model_configs.py

Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com>

* Update optimum/exporters/openvino/model_configs.py

Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com>

---------

Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com>
  • Loading branch information
eaidova and echarlaix authored Jun 21, 2024
1 parent b2d3af8 commit c19723e
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 35 deletions.
49 changes: 32 additions & 17 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,24 +167,27 @@ def __init__(
)
self.multi_query_group_num = normalized_config.multi_query_group_num
self.head_dim = normalized_config.kv_channels
self.standart_cache_layout = hasattr(normalized_config, "rope_ratio")

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
past_key_shape = (
self.sequence_length,
self.batch_size,
self.multi_query_group_num,
self.head_dim,
)
past_value_shape = (
self.sequence_length,
self.batch_size,
self.multi_query_group_num,
self.head_dim,
)
if not self.standart_cache_layout:
pkv_shape = (
self.sequence_length,
self.batch_size,
self.multi_query_group_num,
self.head_dim,
)
else:
pkv_shape = (
self.batch_size,
self.multi_query_group_num,
self.sequence_length,
self.head_dim,
)
return [
(
self.random_float_tensor(past_key_shape, framework=framework, dtype=float_dtype),
self.random_float_tensor(past_value_shape, framework=framework, dtype=float_dtype),
self.random_float_tensor(pkv_shape, framework=framework, dtype=float_dtype),
self.random_float_tensor(pkv_shape, framework=framework, dtype=float_dtype),
)
for _ in range(self.num_layers)
]
Expand Down Expand Up @@ -229,7 +232,10 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
and "attention_mask" in dummy_inputs
):
# Obtain the past sequence length from the value instead of the key (Bloom). ChatGLM has seq_len in 0 dim instead of -2
past_present_length = dummy_inputs["input_ids"].shape[1] + dummy_inputs["past_key_values"][0][1].shape[0]
seq_len_dim = 0 if not hasattr(self._normalized_config, "rope_ratio") else -2
past_present_length = (
dummy_inputs["input_ids"].shape[1] + dummy_inputs["past_key_values"][0][1].shape[seq_len_dim]
)

dummy_inputs["attention_mask"] = DummyInputGenerator.pad_input_on_dim(
dummy_inputs["attention_mask"],
Expand Down Expand Up @@ -260,9 +266,18 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire
decoder_sequence_name = "past_sequence_length + present_lenght"
name = "present"

is_v4 = hasattr(self._normalized_config, "rope_ratio")
for i in range(self._normalized_config.num_layers):
inputs_or_outputs[f"{name}.{i}.key"] = {1: "batch_size", 0: decoder_sequence_name}
inputs_or_outputs[f"{name}.{i}.value"] = {1: "batch_size", 0: decoder_sequence_name}
inputs_or_outputs[f"{name}.{i}.key"] = (
{1: "batch_size", 0: decoder_sequence_name}
if not is_v4
else {0: "batch_size", 2: decoder_sequence_name}
)
inputs_or_outputs[f"{name}.{i}.value"] = (
{1: "batch_size", 0: decoder_sequence_name}
if not is_v4
else {0: "batch_size", 2: decoder_sequence_name}
)

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
Expand Down
27 changes: 21 additions & 6 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def _chatglm_transformer_forward(
if inputs_embeds is None:
inputs_embeds = self.embedding(input_ids)

if self.pre_seq_len is not None:
if getattr(self, "pre_seq_len", None) is not None:
if past_key_values is None:
past_key_values = self.get_prompt(
batch_size=batch_size,
Expand Down Expand Up @@ -285,6 +285,17 @@ def _chatglm2_core_attention_forward(self, query_layer, key_layer, value_layer,
return context_layer


def _glm4_core_attention_forward(self, query_layer, key_layer, value_layer, attention_mask):
attention_mask = ~attention_mask
context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer, key_layer, value_layer, attention_mask.to(torch.float32)
)
context_layer = context_layer.transpose(1, 2).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
context_layer = context_layer.reshape(*new_context_layer_shape)
return context_layer


class ChatGLMModelPatcher(DecoderModelPatcher):
def __init__(
self,
Expand All @@ -293,21 +304,25 @@ def __init__(
model_kwargs: Dict[str, Any],
):
super().__init__(config, model, model_kwargs)

self.original_chatglm_transformer_forward = model.transformer.forward
self.is_v4 = hasattr(self._model.config, "rope_ratio")

def __enter__(self):
super().__enter__()
self._model.transformer.forward = types.MethodType(_chatglm_transformer_forward, self._model.transformer)

if not self.is_v4:
self._model.transformer._orig_forward = self._model.transformer.forward
self._model.transformer.forward = types.MethodType(_chatglm_transformer_forward, self._model.transformer)
for block in self._model.transformer.encoder.layers:
block.self_attention.core_attention._orig_forward = block.self_attention.core_attention.forward
block.self_attention.core_attention.forward = types.MethodType(
_chatglm2_core_attention_forward, block.self_attention.core_attention
_chatglm2_core_attention_forward if not self.is_v4 else _glm4_core_attention_forward,
block.self_attention.core_attention,
)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._model.transformer.forward = self.original_chatglm_transformer_forward
if hasattr(self._model.transformer, "_orig_forward"):
self._model.transformer.forward = self._model.transformer._orig_forward
for block in self._model.transformer.encoder.layers:
block.self_attention.core_attention.forward = block.self_attention.core_attention._orig_forward

Expand Down
2 changes: 1 addition & 1 deletion optimum/exporters/openvino/stateful.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def patch_stateful(config: PretrainedConfig, ov_model: ov.Model):

# By default, batch is the 0-th but chatglm uses 1-st dimension as batch
# TODO: Deduce from a model via ordinal reshape (?) and topology
batch_dim = 1 if config.model_type == "chatglm" else 0
batch_dim = 1 if config.model_type == "chatglm" and not hasattr(config, "rope_ratio") else 0

fuse_cache_reorder(ov_model, not_kv_inputs, key_value_input_names, batch_dim)
num_attention_heads = config.num_attention_heads if config.model_type == "bloom" else 1
Expand Down
26 changes: 17 additions & 9 deletions optimum/intel/openvino/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,9 +328,9 @@ def _reshape(
shapes[inputs][0] = -1
input_name = inputs.get_any_name()
if input_name.startswith("past_key_values"):
if (
len(inputs.partial_shape) == 3 and input_name.endswith("value")
) or self.config.model_type == "chatglm":
if (len(inputs.partial_shape) == 3 and input_name.endswith("value")) or (
self.config.model_type == "chatglm" and not hasattr(self.config, "rope_ratio")
):
shapes[inputs][1] = -1
else:
shapes[inputs][2] = -1
Expand Down Expand Up @@ -421,7 +421,7 @@ def prepare_inputs(
model_inputs = self.model.input(input_name)
dtype = OV_TO_NP_TYPE[model_inputs.get_element_type().get_type_name()]
shape = model_inputs.get_partial_shape()
if self.config.model_type == "chatglm":
if self.config.model_type == "chatglm" and not hasattr(self.config, "rope_ratio"):
shape[0] = 0
shape[1] = batch_size
else:
Expand Down Expand Up @@ -571,9 +571,11 @@ def _expand_outputs_for_generation(self, indicies, logits: torch.Tensor, past_ke
):
past_key_values = tuple(
tuple(
past_state[indicies]
if not self.config.model_type == "chatglm"
else past_state[:, indicies, ...]
(
past_state[indicies]
if not (self.config.model_type == "chatglm" and not hasattr(self.config, "rope_ratio"))
else past_state[:, indicies, ...]
)
for past_state in layer_past
)
for layer_past in past_key_values
Expand Down Expand Up @@ -605,7 +607,13 @@ def _deduplicate_inputs(self, model_inputs: Dict):
upd_batch_size = indicies.shape[0]
if self.config.model_type == "bloom":
upd_batch_size *= self.config.num_attention_heads
shape[0 if not self.config.model_type == "chatglm" else 1] = upd_batch_size
shape[
(
0
if not (self.config.model_type == "chatglm" and not hasattr(self.config, "rope_ratio"))
else 1
)
] = upd_batch_size
upd_model_inputs[input_name] = Tensor(dtype, shape)
upd_model_inputs["input_ids"] = unique_input_ids
if "beam_idx" in model_inputs:
Expand Down Expand Up @@ -673,7 +681,7 @@ def _get_past_length(self, past_key_values=None):
):
return past_key_values[0].shape[-2]
seq_length_dim = -2
if self.config.model_type == "chatglm":
if self.config.model_type == "chatglm" and not hasattr(self.config, "rope_ratio"):
seq_length_dim = 0
elif self.config.model_type == "qwen":
seq_length_dim = 1
Expand Down
6 changes: 4 additions & 2 deletions tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
"xverse",
"internlm",
"jais",
"glm4",
)

if is_transformers_version(">=", "4.40.0"):
Expand Down Expand Up @@ -675,6 +676,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
"internlm",
"codegen2",
"arctic",
"glm4",
)

@parameterized.expand(SUPPORTED_ARCHITECTURES)
Expand Down Expand Up @@ -716,7 +718,7 @@ def test_compare_to_transformers(self, model_arch):

set_seed(SEED)
transformers_model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)
if model_arch in ["qwen", "arctic"]:
if model_arch in ["qwen", "arctic", "glm4"]:
transformers_model.to(torch.float32)

with torch.no_grad():
Expand All @@ -729,7 +731,7 @@ def test_compare_to_transformers(self, model_arch):
if model_arch == "qwen":
return

if model_arch not in ["chatglm", "persimmon"]:
if model_arch not in ["chatglm", "glm4", "persimmon"]:
tokenizer.pad_token_id = tokenizer.eos_token_id

if model_arch == "persimmon":
Expand Down
1 change: 1 addition & 0 deletions tests/openvino/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@
"xlm_roberta": "hf-internal-testing/tiny-xlm-roberta",
"xglm": "hf-internal-testing/tiny-random-XGLMForCausalLM",
"xverse": "katuni4ka/tiny-random-xverse",
"glm4": "katuni4ka/tiny-random-glm4",
}


Expand Down

0 comments on commit c19723e

Please sign in to comment.