Skip to content

Commit

Permalink
Apply comments
Browse files Browse the repository at this point in the history
  • Loading branch information
l-bat committed Mar 11, 2024
1 parent 067c6d5 commit 2dc4087
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 19 deletions.
6 changes: 3 additions & 3 deletions docs/source/optimization_ov.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@ model = OVModelForCausalLM.from_pretrained(model_id, load_in_8bit=True)

## Hybrid quantization

Traditional optimization methods like post-training 8-bit quantization do not work for Stable Diffusion models because accuracy drops significantly. On the other hand, weight compression does not improve performance when applied to Stable Diffusion models, as the size of activations is comparable to weights.
Traditional optimization methods like post-training 8-bit quantization do not work well for Stable Diffusion models and can lead to poor generation results. On the other hand, weight compression does not improve performance significantly when applied to Stable Diffusion models, as the size of activations is comparable to weights.
The UNet model takes up most of the overall execution time of the pipeline. Thus, optimizing just one model brings substantial benefits in terms of inference speed while keeping acceptable accuracy without fine-tuning. Quantizing the rest of the diffusion pipeline does not significantly improve inference performance but could potentially lead to substantial degradation of accuracy.
Therefore, the proposal is to apply quantization in hybrid mode for the UNet model and weight-only quantization for other pipeline components. The hybrid mode involves the quantization of weights in MatMul and Embedding layers, and activations of other layers, facilitating accuracy preservation post-optimization while reducing the model size.
For optimizing the Stable Diffusion pipeline, utilize the `quantization_config` to define optimization parameters. To enable hybrid quantization, specify the quantization dataset in the `quantization_config`; otherwise, weight-only quantization in specified precisions will be applied to UNet.
Therefore, the proposal is to apply quantization in *hybrid mode* for the UNet model and weight-only quantization for the rest of the pipeline components. The hybrid mode involves the quantization of weights in MatMul and Embedding layers, and activations of other layers, facilitating accuracy preservation post-optimization while reducing the model size.
The `quantization_config` is utilized to define optimization parameters for optimizing the Stable Diffusion pipeline. To enable hybrid quantization, specify the quantization dataset in the `quantization_config`. Otherwise, weight-only quantization to a specified data type (8 tr 4 bits) is applied to UNet model.

```python
from optimum.intel import OVStableDiffusionPipeline, OVWeightQuantizationConfig
Expand Down
6 changes: 3 additions & 3 deletions optimum/intel/openvino/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,8 @@ class OVWeightQuantizationConfig(QuantizationConfigMixin):
using the [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
dataset (`str or List[str]`, *optional*):
The dataset used for data-aware compression or quantization with NNCF. You can provide your own dataset
in a list of string or just use the the one from the list ['wikitext2','c4','c4-new','ptb','ptb-new'] for LLLMs
or ['conceptual_captions','laion/220k-GPT4Vision-captions-from-LIVIS','laion/filtered-wit'] for SD models.
in a list of strings or just use the one from the list ['wikitext2','c4','c4-new','ptb','ptb-new'] for LLLMs
or ['conceptual_captions','laion/220k-GPT4Vision-captions-from-LIVIS','laion/filtered-wit'] for diffusion models.
ratio (`float`, defaults to 1.0):
The ratio between baseline and backup precisions (e.g. 0.9 means 90% of layers quantized to INT4_ASYM
and the rest to INT8_ASYM).
Expand Down Expand Up @@ -243,7 +243,7 @@ def post_init(self):
if self.dataset not in llm_datasets + stable_diffusion_datasets:
raise ValueError(
f"""You have entered a string value for dataset. You can only choose between
{llm_datasets} for LLLMs or {stable_diffusion_datasets} for SD models, but we found {self.dataset}"""
{llm_datasets} for LLLMs or {stable_diffusion_datasets} for diffusion models, but we found {self.dataset}"""
)

if self.bits not in [4, 8]:
Expand Down
26 changes: 15 additions & 11 deletions optimum/intel/openvino/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,16 +282,17 @@ def _from_pretrained(

quantization_config = cls._prepare_weight_quantization_config(quantization_config, load_in_8bit)

dataset = None
unet_path = new_model_save_dir / DIFFUSION_MODEL_UNET_SUBFOLDER / unet_file_name
if quantization_config is not None and quantization_config.dataset is not None:
dataset = quantization_config.dataset
# load the UNet model uncompressed to apply hybrid quantization further
unet = cls.load_model(unet_path)
# Apply weights compression to other `components` without dataset
quantization_config.dataset = None
q_config_params = quantization_config.__dict__
wc_params = {param: value for param, value in q_config_params.items() if param != "dataset"}
wc_quantization_config = OVWeightQuantizationConfig.from_dict(wc_params)
else:
unet = cls.load_model(unet_path, quantization_config)
wc_quantization_config = quantization_config
unet = cls.load_model(unet_path, wc_quantization_config)

components = {
"vae_encoder": new_model_save_dir / DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER / vae_encoder_file_name,
Expand All @@ -301,12 +302,12 @@ def _from_pretrained(
}

for key, value in components.items():
components[key] = cls.load_model(value, quantization_config) if value.is_file() else None
components[key] = cls.load_model(value, wc_quantization_config) if value.is_file() else None

if model_save_dir is None:
model_save_dir = new_model_save_dir

if dataset is not None:
if quantization_config is not None and quantization_config.dataset is not None:
sd_model = cls(unet=unet, config=config, model_save_dir=model_save_dir, **components, **kwargs)

supported_pipelines = (
Expand All @@ -318,12 +319,11 @@ def _from_pretrained(
raise NotImplementedError(f"Quantization in hybrid mode is not supported for {cls.__name__}")

nsamples = quantization_config.num_samples if quantization_config.num_samples else 200
unet_inputs = sd_model._prepare_unet_inputs(dataset, nsamples)
unet_inputs = sd_model._prepare_unet_inputs(quantization_config.dataset, nsamples)

from .quantization import _hybrid_quantization

unet = _hybrid_quantization(sd_model.unet.model, quantization_config, dataset=unet_inputs)
quantization_config.dataset = dataset
unet = _hybrid_quantization(sd_model.unet.model, wc_quantization_config, dataset=unet_inputs)

return cls(
unet=unet,
Expand All @@ -338,13 +338,17 @@ def _prepare_unet_inputs(
self,
dataset: Union[str, List[Any]],
num_samples: int,
height: Optional[int] = 512,
width: Optional[int] = 512,
height: Optional[int] = None,
width: Optional[int] = None,
seed: Optional[int] = 42,
**kwargs,
) -> Dict[str, Any]:
self.compile()

size = self.unet.config.get("sample_size", 64) * self.vae_scale_factor
height = height or min(size, 512)
width = width or min(size, 512)

if isinstance(dataset, str):
dataset = deepcopy(dataset)
available_datasets = PREDEFINED_SD_DATASETS.keys()
Expand Down
5 changes: 3 additions & 2 deletions tests/openvino/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ class OVWeightCompressionTest(unittest.TestCase):
)

SUPPORTED_ARCHITECTURES_WITH_EXPECTED_4BIT_COMPRESSED_MATMULS = ((OVModelForCausalLM, "opt125m", 62, 86),)
SUPPORTED_ARCHITECTURES_WITH_EXPECTED_4BIT_AUTOCOMPRESSED_MATMULS = ((OVModelForCausalLM, "opt125m", 0, 150),)
SUPPORTED_ARCHITECTURES_WITH_EXPECTED_4BIT_AUTOCOMPRESSED_MATMULS = ((OVModelForCausalLM, "opt125m", 0, 148),)
SUPPORTED_ARCHITECTURES_WITH_EXPECTED_4BIT_AUTO_COMPRESSED_MATMULS = (
(OVModelForCausalLM, "hf-internal-testing/tiny-random-OPTForCausalLM", 14, 50),
)
Expand Down Expand Up @@ -236,6 +236,7 @@ class OVWeightCompressionTest(unittest.TestCase):

SUPPORTED_ARCHITECTURES_WITH_HYBRID_QUANTIZATION = (
(OVStableDiffusionPipeline, "stable-diffusion", 72, 195),
(OVStableDiffusionXLPipeline, "stable-diffusion-xl", 84, 331),
(OVLatentConsistencyModelPipeline, "latent-consistency", 50, 135),
)

Expand Down Expand Up @@ -372,7 +373,7 @@ def test_ovmodel_hybrid_quantization(self, model_cls, model_type, expected_num_f

model.save_pretrained(tmp_dir)

@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_HYBRID_QUANTIZATION)
@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_HYBRID_QUANTIZATION[-1:])
def test_ovmodel_hybrid_quantization_with_custom_dataset(
self, model_cls, model_type, expected_num_fake_quantize, expected_ov_int8
):
Expand Down

0 comments on commit 2dc4087

Please sign in to comment.