Skip to content

Commit

Permalink
add support granite and granitemoe models (#1099)
Browse files Browse the repository at this point in the history
* add support granite and granitemoe models

* add tests and docs

* add models to test cases
  • Loading branch information
eaidova authored Jan 8, 2025
1 parent bb1c68a commit 7d7de7c
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 0 deletions.
2 changes: 2 additions & 0 deletions docs/source/openvino/models.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ Here is the list of the supported architectures :
- GPT-NeoX-Japanese
- Gemma
- Gemma2
- Granite
- GraniteMoE
- Hubert
- IBert
- InternLM
Expand Down
28 changes: 28 additions & 0 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
GptNeoModelPatcher,
GptNeoxJapaneseModelPatcher,
GptNeoxModelPatcher,
GraniteMoEModelPatcher,
IBertModelPatcher,
InputEmbeddingPatcher,
InternLM2Patcher,
Expand Down Expand Up @@ -2554,3 +2555,30 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
)
class GLMOpenVINOConfig(LlamaOpenVINOConfig):
MIN_TRANSFORMERS_VERSION = "4.46.0"


@register_in_tasks_manager(
"granite",
*[
"feature-extraction",
"feature-extraction-with-past",
"text-generation",
"text-generation-with-past",
"text-classification",
],
library_name="transformers",
)
class GraniteOpenVINOConfig(LlamaOpenVINOConfig):
MIN_TRANSFORMERS_VERSION = "4.45.0"


@register_in_tasks_manager(
"granitemoe", *["text-generation", "text-generation-with-past"], library_name="transformers"
)
class GraniteMoEOpenVINOConfig(LlamaOpenVINOConfig):
MIN_TRANSFORMERS_VERSION = "4.45.0"

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> ModelPatcher:
return GraniteMoEModelPatcher(self, model, model_kwargs=model_kwargs)
69 changes: 69 additions & 0 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3581,3 +3581,72 @@ def __exit__(self, exc_type, exc_value, traceback):
for block in self._model.blocks:
block.forward = block._orig_forward
block.attn.forward = block.attn._orig_forward


# copied from https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/granitemoe/modeling_granitemoe.py#L321
def _granite_moe_topk_gating_forward(self, hidden_states):
# compute the top_k routing decision
logits = self.layer(hidden_states).float() # [batch_size x seq_len, num_experts]
top_k_logits, top_k_indices = logits.topk(self.top_k, dim=1) # [num_tokens, top_k]
top_k_gates = torch.softmax(top_k_logits, dim=1).type_as(hidden_states) # [num_tokens, top_k]

# compute number of input given to each expert
zeros = torch.zeros(
[top_k_gates.size(0), self.num_experts], dtype=top_k_gates.dtype, device=top_k_gates.device
) # [num_tokens, num_experts]
gates = zeros.scatter(1, top_k_indices, 1) # [num_tokens, num_experts]
expert_size = gates.long().sum(0) # [num_experts,]
# difference with original, removed expert_size = expert_size.tolist() due to incorrect tracing

# sort and group input tokens according to expert assignment
top_k_experts = top_k_indices.flatten() # [num_tokens * top_k]
_, index_sorted_experts = top_k_experts.sort(0) # [num_tokens * top_k]
batch_index = index_sorted_experts.div(self.top_k, rounding_mode="trunc") # [num_tokens * top_k]

# gather the gate values for grouped input tokens
top_k_gates = top_k_gates.flatten() # [num_tokens * top_k]
batch_gates = top_k_gates[index_sorted_experts] # [num_tokens * top_k]

return index_sorted_experts, batch_index, batch_gates, expert_size, logits


# copied from https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/granitemoe/modeling_granitemoe.py#L281
def _granite_moe_parallel_experts_forward(self, inputs, expert_size):
output_list = []
# difference with original
# 1) expert_size is tensor instead of list of ints after gating patching, that does not allow use original inputs.split(expert_size)
# 2) use index_start:next_index for obtaining expert inputs splits one by one instead of precomputed splits once before cycle
index_start = torch.tensor(0, dtype=torch.int64)
for i in range(self.num_experts):
next_index = index_start + expert_size[i]
output_list.append(F.linear(inputs[index_start:next_index], self.weight[i]))
index_start = next_index
results = torch.cat(output_list, dim=0)
return results


class GraniteMoEModelPatcher(LlamaModelPatcher):
def __enter__(self):
super().__enter__()
for layer in self._model.model.layers:
block_sparse_moe = layer.block_sparse_moe
block_sparse_moe.router._orig_forward = block_sparse_moe.router.forward
block_sparse_moe.router.forward = types.MethodType(
_granite_moe_topk_gating_forward, block_sparse_moe.router
)
block_sparse_moe.input_linear._orig_forward = block_sparse_moe.input_linear.forward
block_sparse_moe.input_linear.forward = types.MethodType(
_granite_moe_parallel_experts_forward, block_sparse_moe.input_linear
)
block_sparse_moe.output_linear._orig_forward = block_sparse_moe.output_linear.forward
block_sparse_moe.output_linear.forward = types.MethodType(
_granite_moe_parallel_experts_forward, block_sparse_moe.output_linear
)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
for layer in self._model.model.layers:
block_sparse_moe = layer.block_sparse_moe
block_sparse_moe.router.forward = block_sparse_moe.router._orig_forward
block_sparse_moe.input_linear.forward = block_sparse_moe.input_linear._orig_forward
block_sparse_moe.output_linear.forward = block_sparse_moe.output_linear._orig_forward
2 changes: 2 additions & 0 deletions tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -916,6 +916,8 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
"mistral-nemo",
"minicpm3",
"glm",
"granite",
"granite-moe",
)

# gptq and awq install disabled for windows test environment
Expand Down
2 changes: 2 additions & 0 deletions tests/openvino/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@
"gpt_neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM",
"gpt_neox_japanese": "hf-internal-testing/tiny-random-GPTNeoXJapaneseForCausalLM",
"gptj": "hf-internal-testing/tiny-random-GPTJModel",
"granite": "katuni4ka/tiny-random-granite",
"granite-moe": "katuni4ka/tiny-random-granite-moe",
"hubert": "hf-internal-testing/tiny-random-HubertModel",
"ibert": "hf-internal-testing/tiny-random-ibert",
"internlm": "katuni4ka/tiny-random-internlm",
Expand Down

0 comments on commit 7d7de7c

Please sign in to comment.