17
17
from tempfile import TemporaryDirectory
18
18
19
19
from parameterized import parameterized
20
- from transformers import AutoModelForCausalLM
20
+ from transformers import AutoModelForCausalLM , AutoTokenizer
21
21
from utils_tests import (
22
22
_ARCHITECTURES_TO_EXPECTED_INT8 ,
23
23
MODEL_NAMES ,
@@ -253,10 +253,12 @@ def test_exporters_cli_int4(self, task: str, model_type: str, option: str, expec
253
253
254
254
def test_exporters_cli_int4_with_local_model_and_default_config (self ):
255
255
with TemporaryDirectory () as tmpdir :
256
- pt_model = AutoModelForCausalLM .from_pretrained (MODEL_NAMES ["bloom" ])
257
- # overload for matching with default configuration
258
- pt_model .config ._name_or_path = "bigscience/bloomz-7b1"
256
+ model_id = "bigscience/bloomz-560m"
257
+ tokenizer = AutoTokenizer .from_pretrained (model_id )
258
+ pt_model = AutoModelForCausalLM .from_pretrained (model_id )
259
+ tokenizer .save_pretrained (tmpdir )
259
260
pt_model .save_pretrained (tmpdir )
261
+
260
262
subprocess .run (
261
263
f"optimum-cli export openvino --model { tmpdir } --task text-generation-with-past --weight-format int4 { tmpdir } " ,
262
264
shell = True ,
@@ -267,16 +269,23 @@ def test_exporters_cli_int4_with_local_model_and_default_config(self):
267
269
rt_info = model .model .get_rt_info ()
268
270
self .assertTrue ("nncf" in rt_info )
269
271
self .assertTrue ("weight_compression" in rt_info ["nncf" ])
270
- default_config = _DEFAULT_4BIT_CONFIGS ["bigscience/bloomz-7b1" ]
271
272
model_weight_compression_config = rt_info ["nncf" ]["weight_compression" ]
272
- sym = default_config .pop ("sym" , False )
273
+
274
+ default_config = _DEFAULT_4BIT_CONFIGS [model_id ]
273
275
bits = default_config .pop ("bits" , None )
274
276
self .assertEqual (bits , 4 )
275
277
276
- mode = f'int{ bits } _{ "sym" if sym else "asym" } '
277
- default_config ["mode" ] = mode
278
+ sym = default_config .pop ("sym" , False )
279
+ default_config ["mode" ] = f'int{ bits } _{ "sym" if sym else "asym" } '
280
+
281
+ quant_method = default_config .pop ("quant_method" , None )
282
+ default_config ["awq" ] = quant_method == "awq"
283
+ default_config ["gptq" ] = quant_method == "gptq"
284
+
285
+ default_config .pop ("dataset" , None )
286
+
278
287
for key , value in default_config .items ():
279
- self .assertTrue (key in model_weight_compression_config )
288
+ self .assertIn (key , model_weight_compression_config )
280
289
self .assertEqual (
281
290
model_weight_compression_config [key ].value ,
282
291
str (value ),
0 commit comments