Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mvafin committed Dec 5, 2024
1 parent 64f64b0 commit 86d9328
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 8 deletions.
7 changes: 5 additions & 2 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,8 +447,11 @@ def ts_patched_forward(*args, **kwargs):
if patch_16bit_model:
from openvino.frontend.pytorch.patch_model import unpatch_model

unpatch_model(model, "_openvino_module_extension_patch_orig_forward")
model.to(torch.float32)
unpatch_model(model, "_openvino_module_extension_patch_orig_forward")
for m in model.modules():
if (any(p.dtype in [torch.float16, torch.bfloat16] for p in m.parameters())
or any(b.dtype in [torch.float16, torch.bfloat16] for b in m.buffers())):
m.float()

return export_pytorch_via_onnx(
model,
Expand Down
60 changes: 55 additions & 5 deletions tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,13 +872,14 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
"gpt_neo",
"gpt_neox",
"llama",
# "llama_gptq",
"marian",
"minicpm",
"mistral",
"mixtral",
"mixtral_awq",
"mpt",
"opt",
"opt_gptq",
"pegasus",
"qwen",
"phi",
Expand Down Expand Up @@ -949,9 +950,6 @@ def test_compare_to_transformers(self, model_arch):
if is_openvino_version("<", "2024.1"):
not_stateful.extend(["llama", "gemma", "gpt_bigcode"])

if "gptq" in model_arch:
self.skipTest("GPTQ model loading unsupported with AutoModelForCausalLM")

set_seed(SEED)

model_kwargs = {}
Expand All @@ -978,6 +976,46 @@ def test_compare_to_transformers(self, model_arch):
if is_stateful:
self.assertTrue(len(ov_outputs.past_key_values) == 1 and len(ov_outputs.past_key_values[0]) == 0)

if "awq" in model_arch or "gptq" in model_arch:
orig_cuda_is_available = torch.cuda.is_available
torch.cuda.is_available = lambda: True
# infer in FP32
model_kwargs["torch_dtype"] = torch.float32

if "awq" in model_arch:
# patch GEMM module to allow inference without CUDA GPU
from awq.modules.linear.gemm import WQLinearMMFunction
from awq.utils.packing_utils import dequantize_gemm

def new_forward(
ctx,
x,
qweight,
qzeros,
scales,
w_bit=4,
group_size=128,
bias=None,
out_features=0,
):
ctx.out_features = out_features

out_shape = x.shape[:-1] + (out_features,)
x = x.to(torch.float16)

out = dequantize_gemm(qweight, qzeros, scales, w_bit, group_size)
out = torch.matmul(x, out)

out = out + bias if bias is not None else out
out = out.reshape(out_shape)

if len(out.shape) == 2:
out = out.unsqueeze(0)
return out

orig_gemm_forward = WQLinearMMFunction.forward
WQLinearMMFunction.forward = new_forward

set_seed(SEED)
transformers_model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)
if model_arch in ["qwen", "arctic", "glm4"]:
Expand All @@ -988,10 +1026,14 @@ def test_compare_to_transformers(self, model_arch):

# Compare tensor outputs
atol = 1e-3 if model_arch == "minicpm" else 1e-4
# quantized models have higher tolerance
if "awq" in model_arch:
atol = 1e-2
elif "gptq" in model_arch:
atol = 0.6
self.assertTrue(torch.allclose(ov_outputs.logits, transformers_outputs.logits, equal_nan=True, atol=atol))

# Qwen tokenizer does not support padding

if model_arch in ["qwen"]:
return

Expand Down Expand Up @@ -1026,11 +1068,19 @@ def test_compare_to_transformers(self, model_arch):

additional_inputs = {"past_key_values": DynamicCache()}
transformers_outputs = transformers_model.generate(**tokens, generation_config=gen_config, **additional_inputs)
print(f"ov_outputs: {ov_outputs}")
print(f"transformers_outputs: {transformers_outputs}")
self.assertTrue(
torch.allclose(ov_outputs, transformers_outputs),
"OV output {ov_outputs}\nTransformers output {transformers_output}",
)

if "awq" in model_arch:
WQLinearMMFunction.forward = orig_gemm_forward

if "awq" in model_arch or "gptq" in model_arch:
torch.cuda.is_available = orig_cuda_is_available

del transformers_model
del ov_model
gc.collect()
Expand Down
3 changes: 2 additions & 1 deletion tests/openvino/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,12 @@
"longt5": "hf-internal-testing/tiny-random-longt5",
"llama": "HuggingFaceM4/tiny-random-LlamaForCausalLM",
"llama_awq": "HuggingFaceH4/tiny-random-LlamaForCausalLM",
"llama_gptq": "hf-internal-testing/TinyLlama-1.1B-Chat-v0.3-GPTQ",
"llava": "katuni4ka/tiny-random-llava",
"llava_next": "katuni4ka/tiny-random-llava-next",
"m2m_100": "hf-internal-testing/tiny-random-m2m_100",
"opt": "hf-internal-testing/tiny-random-OPTModel",
"opt125m": "facebook/opt-125m",
"opt_gptq": "katuni4ka/opt-125m-gptq",
"marian": "sshleifer/tiny-marian-en-de",
"mbart": "hf-internal-testing/tiny-random-mbart",
"minicpm": "katuni4ka/tiny-random-minicpm",
Expand All @@ -91,6 +91,7 @@
"mistral": "echarlaix/tiny-random-mistral",
"mistral-nemo": "katuni4ka/tiny-random-mistral-nemo",
"mixtral": "TitanML/tiny-mixtral",
"mixtral_awq": "TitanML/tiny-mixtral-AWQ-4bit",
"mobilebert": "hf-internal-testing/tiny-random-MobileBertModel",
"mobilenet_v1": "google/mobilenet_v1_0.75_192",
"mobilenet_v2": "hf-internal-testing/tiny-random-MobileNetV2Model",
Expand Down

0 comments on commit 86d9328

Please sign in to comment.