Skip to content

[OpenVINO] support ai-sage/GigaChat3-10B-A1.8B-bf16#1626

Open
Mohamed-Ashraf273 wants to merge 24 commits intohuggingface:mainfrom
Mohamed-Ashraf273:support_gigachat3
Open

[OpenVINO] support ai-sage/GigaChat3-10B-A1.8B-bf16#1626
Mohamed-Ashraf273 wants to merge 24 commits intohuggingface:mainfrom
Mohamed-Ashraf273:support_gigachat3

Conversation

@Mohamed-Ashraf273
Copy link
Contributor

@Mohamed-Ashraf273 Mohamed-Ashraf273 commented Feb 28, 2026

What does this PR do?

Conversion cmd-line for CohereLabs/tiny-aya-base:

optimum-cli export openvino -m ai-sage/GigaChat3-10B-A1.8B-bf16 ./output_dir --task text-generation-with-past

Inference of ai-sage/GigaChat3-10B-A1.8B-bf16 using OpenVINO backend:

import torch
from transformers import AutoTokenizer
from optimum.intel.openvino import OVModelForCausalLM

model_dir="output_dir"

tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = OVModelForCausalLM.from_pretrained(model_dir)

# Prepare input
prompt = "What is the capital of France?"
inputs = tokenizer(prompt, return_tensors="pt")

# Run inference
output_ids = model.generate(**inputs, max_new_tokens=10)
output_text = tokenizer.decode(output_ids[0])

print(output_text)

Solving Issue: #1608

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

@Mohamed-Ashraf273 Mohamed-Ashraf273 changed the title modify patcher [OpenVINO] support gigachat3 Feb 28, 2026
@Mohamed-Ashraf273 Mohamed-Ashraf273 marked this pull request as ready for review March 3, 2026 13:19
@Mohamed-Ashraf273
Copy link
Contributor Author

Hi @popovaan ,
Can you take a look?
Thanks!

@Mohamed-Ashraf273 Mohamed-Ashraf273 changed the title [OpenVINO] support gigachat3 [OpenVINO] support ai-sage/GigaChat3-10B-A1.8B-bf16 Mar 3, 2026
@popovaan
Copy link
Collaborator

popovaan commented Mar 3, 2026

Thanks for the PR! Please add tests for this model. For now, use a locally generated tiny model. I'm currently investigating whether we're allowed to invite GSoC contributors to the optimum-intel-internal-testing group so that you can publish the model there. If not, I’ll publish it myself and share the link.

@Mohamed-Ashraf273
Copy link
Contributor Author

Thanks for the PR! Please add tests for this model. For now, use a locally generated tiny model. I'm currently investigating whether we're allowed to invite GSoC contributors to the optimum-intel-internal-testing group so that you can publish the model there. If not, I’ll publish it myself and share the link.

Got it, thanks!
I’ll add the tests with a locally generated tiny model

@Mohamed-Ashraf273
Copy link
Contributor Author

Thanks for the PR! Please add tests for this model. For now, use a locally generated tiny model. I'm currently investigating whether we're allowed to invite GSoC contributors to the optimum-intel-internal-testing group so that you can publish the model there. If not, I’ll publish it myself and share the link.

Hi @popovaan, @rkazants,
I've added a tiny model along with the tests. Could you please take a look?
Thanks!

Copy link
Collaborator

@rkazants rkazants left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please also add export tests. The same test set that you have added for the previuos model.
Update documentation.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR aims to add OpenVINO export/inference support coverage for the ai-sage/GigaChat3-10B-A1.8B-bf16 family by extending OpenVINO test fixtures and adjusting DeepSeek patching logic used during export.

Changes:

  • Add a gigachat3 tiny-random model fixture and include it in OpenVINO decoder integration coverage.
  • Update decoder tests for gigachat3 (expected SDPA count, relaxed logits tolerance, and skip conditions for incompatible Transformers versions).
  • Refactor DeepSeek attention patching to use a versioned factory function and extend MoE patching to handle MLP blocks exposing experts but not moe_infer.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 5 comments.

File Description
tests/openvino/utils_tests.py Adds the gigachat3 test model mapping; adjusts which models are treated as remote-code in tests.
tests/openvino/test_decoder.py Adds gigachat3 to tested architectures and config expectations; tweaks tolerance/skip logic; adds debug output.
optimum/exporters/openvino/model_patcher.py Updates DeepSeek patcher to use a unified attention forward factory and broadens MoE patching behavior.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@Mohamed-Ashraf273
Copy link
Contributor Author

Mohamed-Ashraf273 commented Mar 4, 2026

please also add export tests. The same test set that you have added for the previuos model. Update documentation.

@rkazants
Thanks for your feedback!
I've added export tests and updated documentation for the newly added model!

@Mohamed-Ashraf273
Copy link
Contributor Author

Mohamed-Ashraf273 commented Mar 4, 2026

Thanks for the PR! Please add tests for this model. For now, use a locally generated tiny model. I'm currently investigating whether we're allowed to invite GSoC contributors to the optimum-intel-internal-testing group so that you can publish the model there. If not, I’ll publish it myself and share the link.

Hi @popovaan,

I’ve finished adding the tests and temporarily published tiny-random-gigachat3 on my Hugging Face profile (mohamedashraf273/tiny-random-gigachat3) until it can be moved to optimum-intel-internal-testing.

Would it be possible to invite me to the group so I can publish it there, or would you prefer to handle the publishing?

Please let me know if any changes are needed.
Thanks!

@savvadesogle
Copy link

Hi. Can I help test the model?
You are so great, thank you so much!♥️🔥😊

@Mohamed-Ashraf273
Copy link
Contributor Author

Hi. Can I help test the model?
You are so great, thank you so much!♥️🔥😊

Hi!
That would be great, thank you so much!
Please feel free to test it and let me know if you encounter any issues or unexpected behavior.

@savvadesogle
Copy link

savvadesogle commented Mar 6, 2026

Thank you, @Mohamed-Ashraf273 ❤️

Working with CPU

image image

Not working with GPU yet...

It loads endlessly to GPU, but it works with the CPU in OpenArс Tool (OVGenAI engine).
As in the screenshot below - RAM gradually grows, 10-20 minutes and nothing happens.

image

@Mohamed-Ashraf273
Copy link
Contributor Author

Mohamed-Ashraf273 commented Mar 6, 2026

Hi @savvadesogle,

I ran a demo test with a tiny GigaChat3 model on GPU and it worked correctly. I was able to successfully:

  • Export the model (on CPU)
  • Load/compile it on GPU
  • Run a forward pass
  • Run generate()
  • Run batched generation

All steps completed without issues and the GPU execution finished successfully.

From your description, it sounds like the GPU loading/compilation for the full model may simply require more time and RAM. The tiny model finishes quickly, but the real GigaChat3 model is significantly larger, so it would be expected that:

  • GPU loading/compilation takes longer than 20–30 minutes, and
  • RAM usage may keep increasing during the process before it stabilizes.

Since the same pipeline works correctly with the tiny model on GPU, the real model should also work, but it may just need more time and memory for the GPU compilation step.

For reference, here is the script I used for the GPU test:

import torch
from transformers import AutoTokenizer
from optimum.intel.openvino import OVModelForCausalLM
import openvino as ov

# ── 0. Check available devices ────────────────────────────────────────────────
core = ov.Core()
print("Available devices:", core.available_devices)
assert "GPU" in " ".join(core.available_devices), "No Intel GPU found!"

MODEL_DIR = "./tiny-random-gigachat3"

# ── 1. Export (CPU export, then load on GPU) ──────────────────────────────────
print("\n[1] Exporting tiny-random-gigachat3 to OpenVINO...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
ov_model = OVModelForCausalLM.from_pretrained(
    MODEL_DIR,
    export=True,
    device="GPU",        # compile directly on GPU after export
)
print("    Export + GPU compile: OK")

# ── 2. Basic forward pass ─────────────────────────────────────────────────────
print("\n[2] Running forward pass on GPU...")
prompt = "What is the capital of France?"
inputs = tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
    outputs = ov_model(**inputs)
logits = outputs.logits
print(f"    Logits shape : {logits.shape}")
print(f"    Logits dtype : {logits.dtype}")
print(f"    Logits sample: {logits[0, -1, :5].tolist()}")
assert logits.shape[0] == 1, "Batch size mismatch"
print("    Forward pass : OK")

# ── 3. Generation ─────────────────────────────────────────────────────────────
print("\n[3] Running generate() on GPU...")
ov_model.generation_config.eos_token_id = None   # avoid early stop on tiny model
output_ids = ov_model.generate(**inputs, max_new_tokens=10, do_sample=False)
decoded = tokenizer.decode(output_ids[0], skip_special_tokens=True)
print(f"    Output : {decoded!r}")
assert output_ids.shape[1] > inputs["input_ids"].shape[1], "No tokens generated"
print("    Generate     : OK")

# ── 4. Batch generation ───────────────────────────────────────────────────────
print("\n[4] Running batched generate() on GPU...")
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
prompts = ["Hello world", "The sky is"]
batch = tokenizer(prompts, return_tensors="pt", padding=True)
output_ids = ov_model.generate(**batch, max_new_tokens=5, do_sample=False)
for i, ids in enumerate(output_ids):
    print(f"    Batch[{i}]: {tokenizer.decode(ids, skip_special_tokens=True)!r}")
print("    Batched generate: OK")

print("\n✅ All GPU tests passed!")

Output:

(env) mohamed-ashraf@mohamed-ashraf-LOQ-15IRX9:~/Desktop/projects/GSoC26/optimum-intel$ python test_gpu.py 2>&1 | grep -v TracerWarning | grep -v "site-packages" | grep -v "^$"
Multiple distributions found for package optimum. Picked distribution: optimum-intel
`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.
  or not self.key_cache[layer_idx].numel()  # the layer has no cache
  if (padding_length := kv_length + kv_offset - attention_mask.shape[-1]) > 0:
  torch.tensor(0.0, device=mask.device, dtype=dtype),
  torch.tensor(torch.finfo(torch.float16).min, device=mask.device, dtype=dtype),
  not self.key_cache[layer_idx].numel()  # prefers not t.numel() to len(t) == 0 to export the model
  is_causal = query.shape[2] > 1 and attention_mask is None and getattr(module, "is_causal", True)
Available devices: ['CPU', 'GPU.0', 'GPU.1']
[1] Exporting tiny-random-gigachat3 to OpenVINO...
    Export + GPU compile: OK
[2] Running forward pass on GPU...
    Logits shape : torch.Size([1, 31, 32000])
    Logits dtype : torch.float32
    Logits sample: [-0.0056610107421875, -0.0082855224609375, -0.04266357421875, 0.06475830078125, -0.002960205078125]
    Forward pass : OK
[3] Running generate() on GPU...
    Output : 'What is the capital of France?qualurtheremon{}) chiqish\\])\\ир flat المو'
    Generate     : OK
[4] Running batched generate() on GPU...
    Batch[0]: 'Hello world_en дlinux aylandi[tahr'
    Batch[1]: 'The sky isшт г may extentimg'
    Batched generate: OK
✅ All GPU tests passed!

@savvadesogle
Copy link

savvadesogle commented Mar 6, 2026

  • GPU loading/compilation takes longer than 20–30 minutes, and

I didn't expect the process to take so long. I'll have to wait and see. The conversion happens very quickly, up to 3 minutes to a regular int4, without any additional parameters. I'll definitely give it a try. I have 128 GB of RAM, so that should be enough.

Other models load much faster on the GPU. I'll try waiting longer.
Thank you

@Mohamed-Ashraf273
Copy link
Contributor Author

Hi @popovaan, @rkazants, @IlyasMoutawwakil

I’ve fixed the remaining issues. Could you please take a look when you have time?
Thanks!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@Mohamed-Ashraf273
Copy link
Contributor Author

Hi @popovaan, @rkazants, @IlyasMoutawwakil,

All tests are now passing. I’d really appreciate it if you could take a final look.

Thanks!

Comment on lines +733 to +744
if orig_transformers_version is not None:
import json as _json
from pathlib import Path as _Path

gen_cfg_path = _Path(output) / "generation_config.json"
if gen_cfg_path.exists():
with open(gen_cfg_path, "r", encoding="utf-8") as _f:
_cfg = _json.load(_f)
if _cfg.get("transformers_version") != orig_transformers_version:
_cfg["transformers_version"] = orig_transformers_version
with open(gen_cfg_path, "w", encoding="utf-8") as _f:
_json.dump(_cfg, _f, indent=2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we avoid this change? This modifies the common code for all models, which is undesirable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your feedback!

I reverted the changes and set gen_config.do_sample = False specifically for deepseek in test_decoder.py.

Could you please take another look and let me know if anything else should be adjusted?

Thanks!

@Mohamed-Ashraf273
Copy link
Contributor Author

Hi @popovaan, @rkazants, @IlyasMoutawwakil,

I’d really appreciate it if you could take a look.

Thanks!

expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self.experts))
expert_mask = expert_mask.permute(2, 0, 1)

for expert_idx in range(len(self.experts)):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's kinda inefficient, especially during decoding. do we replace this with some optimized MoE operator in openvino later ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, this is inefficient for decoding. The current implementation intentionally runs all experts to avoid the data-dependent control flow in the original MoE (skipping experts with no tokens), which breaks torch.jit.trace required for OpenVINO export. So this change mainly serves as a temporary tracing workaround to produce a static, exportable graph.

Yes, the plan is to replace this with a custom OpenVINO MoE operator (similar to convert_recurrent_attention_cell used in Qwen3Next) so we can restore sparse execution and avoid loading all expert weights during decoding. This PR just unblocks model export for now, and the optimized operator is intended as a follow-up improvement.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then maybe minimizing and standardizing the experts forward at least, in some models we use the batching trick to make experts exportable with minimal graph layout (the loop results in very long graphs), see http://github.com/huggingface/optimum-intel/blob/439b6319368c1667f3119ef508812ef167b0fef5/optimum/exporters/openvino/model_patcher.py#L7562 @rkazants

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion!
I refactored the DeepSeek MoE patch to follow the same batching pattern used in the AFMoE implementation. The expert projections are now pre-stacked in the patcher, and the forward pass uses vectorized bmm operations instead of looping over experts, which helps keep the exported graph compact.

Copy link
Member

@IlyasMoutawwakil IlyasMoutawwakil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm! approved to merge if all tests are passing, however we still don't solve the real perf issue in exported MoEs which is having an efficient implementation of either:

  • torch.grouped_mm operator
  • entire MoE operator

@Mohamed-Ashraf273
Copy link
Contributor Author

Hi @popovaan , @rkazants
Could you please take a look?
Please let me know if anything else is needed.

@popovaan
Copy link
Collaborator

popovaan commented Mar 17, 2026

Could you please locally run OpenVINO GenAI WhoWhatBenchmark tool to check the accuracy of the full model (not the tiny one) and share the results?
https://github.com/openvinotoolkit/openvino.genai/tree/master/tools/who_what_benchmark

Here is the instruction: https://github.com/openvinotoolkit/openvino.genai/blob/master/tools/who_what_benchmark/README.md#compare-text-generation-models-llms

@savvadesogle
Copy link

the full model

Hello Anastasiia @popovaan

Sorry. I won't be able to run the full model (BF16->converted OpenVINO full) on the GPU;

изображение

I only have 16GB of memory, and I'm afraid it won't all fit on a single GPU. And there are some issues with Hetero (openvinotoolkit/openvino#33012 (comment)) (with two or more GPUs).
I can try to run it on CPU.

Or is it enough to test converted models in int8 and int4?

@Mohamed-Ashraf273
Copy link
Contributor Author

Could you please locally run OpenVINO GenAI WhoWhatBenchmark tool to check the accuracy of the full model (not the tiny one) and share the results? https://github.com/openvinotoolkit/openvino.genai/tree/master/tools/who_what_benchmark

Here is the instruction: https://github.com/openvinotoolkit/openvino.genai/blob/master/tools/who_what_benchmark/README.md#compare-text-generation-models-llms

Thanks for the suggestion! I'm running the WhoWhatBenchmark tool locally now to check the full model's accuracy. Will share results once it's done.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants