Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support granite and granitemoe models #1099

Merged
merged 3 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/openvino/models.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ Here is the list of the supported architectures :
- GPT-NeoX-Japanese
- Gemma
- Gemma2
- Granite
- GraniteMoE
- Hubert
- IBert
- InternLM
Expand Down
28 changes: 28 additions & 0 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
GptNeoModelPatcher,
GptNeoxJapaneseModelPatcher,
GptNeoxModelPatcher,
GraniteMoEModelPatcher,
IBertModelPatcher,
InputEmbeddingPatcher,
InternLM2Patcher,
Expand Down Expand Up @@ -2554,3 +2555,30 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
)
class GLMOpenVINOConfig(LlamaOpenVINOConfig):
MIN_TRANSFORMERS_VERSION = "4.46.0"


@register_in_tasks_manager(
"granite",
*[
"feature-extraction",
"feature-extraction-with-past",
"text-generation",
"text-generation-with-past",
"text-classification",
],
library_name="transformers",
)
class GraniteOpenVINOConfig(LlamaOpenVINOConfig):
MIN_TRANSFORMERS_VERSION = "4.45.0"


@register_in_tasks_manager(
"granitemoe", *["text-generation", "text-generation-with-past"], library_name="transformers"
)
class GraniteMoEOpenVINOConfig(LlamaOpenVINOConfig):
MIN_TRANSFORMERS_VERSION = "4.45.0"

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> ModelPatcher:
return GraniteMoEModelPatcher(self, model, model_kwargs=model_kwargs)
69 changes: 69 additions & 0 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3601,3 +3601,72 @@ def __exit__(self, exc_type, exc_value, traceback):
for block in self._model.blocks:
block.forward = block._orig_forward
block.attn.forward = block.attn._orig_forward


# copied from https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/granitemoe/modeling_granitemoe.py#L321
def _granite_moe_topk_gating_forward(self, hidden_states):
# compute the top_k routing decision
logits = self.layer(hidden_states).float() # [batch_size x seq_len, num_experts]
top_k_logits, top_k_indices = logits.topk(self.top_k, dim=1) # [num_tokens, top_k]
top_k_gates = torch.softmax(top_k_logits, dim=1).type_as(hidden_states) # [num_tokens, top_k]

# compute number of input given to each expert
zeros = torch.zeros(
[top_k_gates.size(0), self.num_experts], dtype=top_k_gates.dtype, device=top_k_gates.device
) # [num_tokens, num_experts]
gates = zeros.scatter(1, top_k_indices, 1) # [num_tokens, num_experts]
expert_size = gates.long().sum(0) # [num_experts,]
# difference with original, removed expert_size = expert_size.tolist() due to incorrect tracing

# sort and group input tokens according to expert assignment
top_k_experts = top_k_indices.flatten() # [num_tokens * top_k]
_, index_sorted_experts = top_k_experts.sort(0) # [num_tokens * top_k]
batch_index = index_sorted_experts.div(self.top_k, rounding_mode="trunc") # [num_tokens * top_k]

# gather the gate values for grouped input tokens
top_k_gates = top_k_gates.flatten() # [num_tokens * top_k]
batch_gates = top_k_gates[index_sorted_experts] # [num_tokens * top_k]

return index_sorted_experts, batch_index, batch_gates, expert_size, logits


# copied from https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/granitemoe/modeling_granitemoe.py#L281
def _granite_moe_parallel_experts_forward(self, inputs, expert_size):
output_list = []
# difference with original
# 1) expert_size is tensor instead of list of ints after gating patching, that does not allow use original inputs.split(expert_size)
# 2) use index_start:next_index for obtaining expert inputs splits one by one instead of precomputed splits once before cycle
Comment on lines +3633 to +3638
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super helpful, thanks!

index_start = torch.tensor(0, dtype=torch.int64)
for i in range(self.num_experts):
next_index = index_start + expert_size[i]
output_list.append(F.linear(inputs[index_start:next_index], self.weight[i]))
index_start = next_index
results = torch.cat(output_list, dim=0)
return results


class GraniteMoEModelPatcher(LlamaModelPatcher):
def __enter__(self):
super().__enter__()
for layer in self._model.model.layers:
block_sparse_moe = layer.block_sparse_moe
block_sparse_moe.router._orig_forward = block_sparse_moe.router.forward
block_sparse_moe.router.forward = types.MethodType(
_granite_moe_topk_gating_forward, block_sparse_moe.router
)
block_sparse_moe.input_linear._orig_forward = block_sparse_moe.input_linear.forward
block_sparse_moe.input_linear.forward = types.MethodType(
_granite_moe_parallel_experts_forward, block_sparse_moe.input_linear
)
block_sparse_moe.output_linear._orig_forward = block_sparse_moe.output_linear.forward
block_sparse_moe.output_linear.forward = types.MethodType(
_granite_moe_parallel_experts_forward, block_sparse_moe.output_linear
)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
for layer in self._model.model.layers:
block_sparse_moe = layer.block_sparse_moe
block_sparse_moe.router.forward = block_sparse_moe.router._orig_forward
block_sparse_moe.input_linear.forward = block_sparse_moe.input_linear._orig_forward
block_sparse_moe.output_linear.forward = block_sparse_moe.output_linear._orig_forward
2 changes: 2 additions & 0 deletions tests/openvino/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@
"gpt_neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM",
"gpt_neox_japanese": "hf-internal-testing/tiny-random-GPTNeoXJapaneseForCausalLM",
"gptj": "hf-internal-testing/tiny-random-GPTJModel",
"granite": "katuni4ka/tiny-random-granite",
eaidova marked this conversation as resolved.
Show resolved Hide resolved
"granite-moe": "katuni4ka/tiny-random-granite-moe",
"hubert": "hf-internal-testing/tiny-random-HubertModel",
"ibert": "hf-internal-testing/tiny-random-ibert",
"internlm": "katuni4ka/tiny-random-internlm",
Expand Down