Skip to content

Commit 6d22f96

Browse files
committed
Merge branch 'main' into ak/compression_options
2 parents 320e94e + ae36dda commit 6d22f96

File tree

4 files changed

+52
-36
lines changed

4 files changed

+52
-36
lines changed

.github/workflows/test_openvino.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,9 @@ jobs:
3636
- name: Test with Pytest
3737
run: |
3838
pytest tests/openvino/ --ignore test_modeling_basic
39+
- name: Test openvino-nightly import
40+
run: |
41+
pip uninstall -y openvino
42+
pip install openvino-nightly
43+
python -c "from optimum.intel import OVModelForCausalLM; OVModelForCausalLM.from_pretrained('hf-internal-testing/tiny-random-gpt2', export=True, compile=False)"
44+

optimum/exporters/openvino/__main__.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from typing import Any, Callable, Dict, Optional, Union
1919

2020
from requests.exceptions import ConnectionError as RequestsConnectionError
21-
from transformers import AutoTokenizer
21+
from transformers import AutoConfig, AutoTokenizer
2222

2323
from optimum.exporters import TasksManager
2424
from optimum.exporters.onnx import __main__ as optimum_main
@@ -140,6 +140,41 @@ def main_export(
140140
original_task = task
141141
task = TasksManager.map_from_synonym(task)
142142

143+
# Patch the modules to export of GPTQ models w/o GPU
144+
do_gptq_patching = False
145+
try:
146+
config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=trust_remote_code)
147+
config_dict = config.to_dict()
148+
quantization_config = config_dict.get("quantization_config", None)
149+
do_gptq_patching = quantization_config and quantization_config["quant_method"] == "gptq"
150+
except Exception:
151+
pass
152+
153+
if do_gptq_patching:
154+
import torch
155+
156+
torch.set_default_dtype(torch.float32)
157+
orig_cuda_check = torch.cuda.is_available
158+
torch.cuda.is_available = lambda: True
159+
160+
from optimum.gptq import GPTQQuantizer
161+
162+
orig_post_init_model = GPTQQuantizer.post_init_model
163+
164+
def post_init_model(self, model):
165+
from auto_gptq import exllama_set_max_input_length
166+
167+
class StoreAttr(object):
168+
pass
169+
170+
model.quantize_config = StoreAttr()
171+
model.quantize_config.desc_act = self.desc_act
172+
if self.desc_act and not self.disable_exllama and self.max_input_length is not None:
173+
model = exllama_set_max_input_length(model, self.max_input_length)
174+
return model
175+
176+
GPTQQuantizer.post_init_model = post_init_model
177+
143178
framework = TasksManager.determine_framework(model_name_or_path, subfolder=subfolder, framework=framework)
144179

145180
# get the shapes to be used to generate dummy inputs
@@ -326,3 +361,8 @@ def main_export(
326361
compression_ratio=compression_ratio,
327362
model_kwargs=model_kwargs,
328363
)
364+
365+
# Unpatch modules after GPTQ export
366+
if do_gptq_patching:
367+
torch.cuda.is_available = orig_cuda_check
368+
GPTQQuantizer.post_init_model = orig_post_init_model

optimum/intel/openvino/modeling_decoder.py

Lines changed: 1 addition & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -229,34 +229,6 @@ def _from_transformers(
229229
if use_cache:
230230
task = task + "-with-past"
231231

232-
# Patch the modules to export of GPTQ models w/o GPU
233-
do_gptq_patching = False
234-
config_dict = config.to_dict()
235-
quantization_config = config_dict.get("quantization_config", None)
236-
do_gptq_patching = quantization_config and quantization_config["quant_method"] == "gptq"
237-
if do_gptq_patching:
238-
torch.set_default_dtype(torch.float32)
239-
orig_cuda_check = torch.cuda.is_available
240-
torch.cuda.is_available = lambda: True
241-
242-
from optimum.gptq import GPTQQuantizer
243-
244-
orig_post_init_model = GPTQQuantizer.post_init_model
245-
246-
def post_init_model(self, model):
247-
from auto_gptq import exllama_set_max_input_length
248-
249-
class StoreAttr(object):
250-
pass
251-
252-
model.quantize_config = StoreAttr()
253-
model.quantize_config.desc_act = self.desc_act
254-
if self.desc_act and not self.disable_exllama and self.max_input_length is not None:
255-
model = exllama_set_max_input_length(model, self.max_input_length)
256-
return model
257-
258-
GPTQQuantizer.post_init_model = post_init_model
259-
260232
main_export(
261233
model_name_or_path=model_id,
262234
output=save_dir_path,
@@ -271,11 +243,6 @@ class StoreAttr(object):
271243
compression_option="i8" if load_in_8bit else None,
272244
)
273245

274-
# Unpatch modules after GPTQ export
275-
if do_gptq_patching:
276-
torch.cuda.is_available = orig_cuda_check
277-
GPTQQuantizer.post_init_model = orig_post_init_model
278-
279246
config.is_decoder = True
280247
config.is_encoder_decoder = False
281248
config.save_pretrained(save_dir_path)
@@ -504,7 +471,7 @@ def _from_pretrained(
504471
elif model_type == "gpt-bigcode":
505472
init_cls = OVGPTBigCodeForCausalLM
506473
else:
507-
init_cls = OVModelForCausalLM
474+
init_cls = cls
508475

509476
return init_cls(model=model, config=config, model_save_dir=model_cache_path.parent, **kwargs)
510477

optimum/intel/utils/import_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,10 @@
7171
try:
7272
_openvino_version = importlib_metadata.version("openvino")
7373
except importlib_metadata.PackageNotFoundError:
74-
_openvino_available = False
74+
try:
75+
_openvino_version = importlib_metadata.version("openvino-nightly")
76+
except importlib_metadata.PackageNotFoundError:
77+
_openvino_available = False
7578

7679

7780
_nncf_available = importlib.util.find_spec("nncf") is not None

0 commit comments

Comments
 (0)