Skip to content

Commit

Permalink
Added Quantization modules for dynamic and static quantization. Addit…
Browse files Browse the repository at this point in the history
…ionally, added abstraction for quantizing input data
  • Loading branch information
LovePelmeni committed Apr 1, 2024
1 parent 450f092 commit 4f6dc72
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 39 deletions.
Empty file added src/quantization/base.py
Empty file.
16 changes: 16 additions & 0 deletions src/quantization/quan_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import torch

def get_observer_by_name(observer_name: str):
if observer_name.lower() == "percentile":
return torch.ao.quantization.observer.PercentileObserver

if observer_name.lower() == "histogram":
return torch.ao.quantization.observer.HistogramObserver

if observer_name.lower() == "min_max":
return torch.ao.quantization.observer.MinMaxObserver

if observer_name.lower() == "moving_min_max":
return torch.ao.quantization.observer.MovingAverageMinMaxObserver
else:
raise NotImplemented()
140 changes: 101 additions & 39 deletions src/quantization/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,14 @@
from torch.utils import data
import torch
import typing
import logging
from src.quantization import base

class InputQuantizer(object):
quan_logger = logging.getLogger(__name__)
file_handler = logging.FileHandler(filename='quantization_logs.log')
quan_logger.addHandler(file_handler)

class InputTensorQuantizer(base.BaseInputQuantizer):
"""
Inference quantizer. Application string
container string branch application string validator.
Expand All @@ -18,49 +24,98 @@ class InputQuantizer(object):
quan_zero - quantization zero point
"""
def __init__(self,
quantization_type,
quan_scale: typing.Union[torch.Tensor, float],
quan_zero: typing.Union[torch.Tensor, int]
q_type,
pretrained_observer: torch.ao.quantization.observer.ObserverBase
):
self.quantization_type = quantization_type
self.quantization_scale = quan_scale
self.quantization_zero = quan_zero
self.quantization_type = q_type
self.observer = pretrained_observer

def compute_statistics(self, input_image: torch.Tensor):
scale, zero_point = self.observer.calculate_qparams(input_image)
return scale, zero_point

def __call__(self, input_img: numpy.ndarray):
scale, zero_point = self.compute_statistics(input_image=input_img)
return torch.quantize_per_tensor(
input_img,
self.quantization_scale,
self.quantization_zero,
scale,
zero_point,
self.quantization_type,
)
class InputChannelQuantizer(base.BaseInputQuantizer):
"""
Module for quantizing input
images per channel.
Parameters:
-----------
quan_type - type to use for input quantization. Typically (torch.qint8, torch.qint16).
input_observer - pretrained observer for computing quantization statistics.
"""
def __init__(self,
quan_type: torch.dtype,
input_observer: torch.ao.quantization.observer.ObserverBase
):
self.quan_type = quan_type
self.input_observer = input_observer

def compute_statistics(self, input_image: torch.Tensor):
return self.input_observer.compute_qparams(input_image)

def quantize(self, input_image: torch.Tensor):
scales = []
zero_points = []
for ch in range(input_image.shape[-1]):
ch_scale, ch_zero_point = self.compute_statistics(
input_image[:, :, ch])
scales.append(ch_scale)
zero_points.append(ch_zero_point)

return torch.quantize_per_channel(
input_image,
scales=scales,
zero_points=zero_points,
dtype=self.quan_type,
)


class StaticNetworkQuantizer(object):
"""
Base module for performing static quantization
of the network.
"""
def __init__(self, quan_type):
self.quan_type = quan_type
def __init__(self, q_activation_type, q_weight_type):
self.q_weight_type = q_weight_type
self.q_activation_type = q_activation_type
self.q_activation_type = q_activation_type
self.q_weight_type = q_weight_type
self.calibrator = NetworkCalibrator()

def quantize(self,
input_model: nn.Module,
calibration_dataset: data.Dataset,
calib_batch_size: int
):
calibration_loader = self.calibrator.configure_calibration_loader(
calibration_dataset=calibration_dataset,
calibration_batch_size=calib_batch_size,
loader_workers=2
)
# perform calibration
stat_network = self.calibrator.calibrate(
input_model,
loader=calibration_loader,
q_type=self.quan_type
)
quantized_model = torch.quantization.convert(stat_network)
return quantized_model
try:
calibration_loader = self.calibrator.configure_calibration_loader(
calibration_dataset=calibration_dataset,
calibration_batch_size=calib_batch_size,
loader_workers=2
)
# perform calibration
stat_network = self.calibrator.calibrate(
input_model,
loader=calibration_loader,
q_type=self.quan_type,
weight_q_type=self.q_weight_type,
activation_q_type=self.q_activation_type
)
quantized_model = torch.quantization.convert(stat_network)
return quantized_model

except(Exception) as err:
quan_logger.error(err)
return None

class NetworkCalibrator(object):

Expand All @@ -73,7 +128,7 @@ class NetworkCalibrator(object):
def configure_calibration_loader(self,
calibration_dataset: data.Dataset,
calibration_batch_size: int,
loader_workers: int = 1
loader_workers: int = 0
):
"""
Configures data loader for
Expand All @@ -89,8 +144,11 @@ def configure_calibration_loader(self,
def calibrate(self,
network: nn.Module,
loader: data.DataLoader,
q_type
):
activation_q_type: torch.dtype,
weight_q_type: torch.dtype,
weight_observer: torch.ao.quantization.observer.ObserverBase,
activation_observer: torch.ao.quantization.observer.ObserverBase,
) -> typing.Union[nn.Module, None]:
"""
Calibrates given network
for finding optimal quantization
Expand All @@ -102,17 +160,21 @@ def calibrate(self,
the training set.
"""
network.eval()
# Specify the quantization configuration
qconfig = torch.ao.quantization.QConfig(
activation=torch.ao.quantization.observer.MinMaxObserver.with_args(dtype=q_type),
weight=torch.ao.quantization.observer.MinMaxObserver.with_args(dtype=q_type)
)
# Apply the quantization configuration to the model
network.qconfig = qconfig
stat_network = torch.ao.quantization.prepare(network)
try:
# Specify the quantization configuration
qconfig = torch.ao.quantization.QConfig(
activation=activation_observer.with_args(dtype=activation_q_type),
weight=weight_observer.with_args(dtype=weight_q_type)
)
# Apply the quantization configuration to the model
network.qconfig = qconfig
stat_network = torch.ao.quantization.prepare(network)

# performing calibration
for images, _ in loader:
stat_network.forward(images).cpu()
# performing calibration
for images, _ in loader:
stat_network.forward(images).cpu()

return stat_network
return stat_network
except(Exception) as err:
quan_logger.debug(err)
return None

0 comments on commit 4f6dc72

Please sign in to comment.