diff --git a/examples/quantization_aware_training/cifar10/basecase/main.py b/examples/quantization_aware_training/cifar10/basecase/main.py index a57b659..63339e9 100644 --- a/examples/quantization_aware_training/cifar10/basecase/main.py +++ b/examples/quantization_aware_training/cifar10/basecase/main.py @@ -5,6 +5,7 @@ import time import warnings from enum import Enum +import math import torch import torch.nn as nn @@ -21,6 +22,7 @@ from model import resnet20 from sparsebit.quantization import QuantModel, parse_qconfig +from sparsebit.quantization.regularizers import build_regularizer parser = argparse.ArgumentParser(description="PyTorch Cifar Training") @@ -147,8 +149,6 @@ def main(): qconfig = parse_qconfig(args.config) - is_pact = qconfig.A.QUANTIZER.TYPE == "pact" - qmodel = QuantModel(model, qconfig).cuda() # 将model转化为量化模型,以支持后续QAT的各种量化操作 # set head and tail of model is 8bit @@ -181,6 +181,11 @@ def main(): optimizer, milestones=[100, 150], last_epoch=args.start_epoch - 1 ) + if qconfig.REGULARIZER.ENABLE: + regularizer = build_regularizer(qconfig) + else: + regularizer = None + best_acc1 = 0 for epoch in range(args.start_epoch, args.epochs): # train for one epoch @@ -190,7 +195,7 @@ def main(): criterion, optimizer, epoch, - is_pact, + regularizer, args.regularizer_lambda, args.print_freq, ) @@ -225,18 +230,9 @@ def main(): ) -# PACT算法中对 alpha 增加 L2-regularization -def get_pact_regularizer_loss(model): - loss = 0 - for n, p in model.named_parameters(): - if "alpha" in n: - loss += (p**2).sum() - return loss - - -def get_regularizer_loss(model, is_pact, scale=0): - if is_pact: - return get_pact_regularizer_loss(model) * scale +def get_regularizer_loss(model, regularizer, _lambda): + if regularizer is not None: + return regularizer(model) * _lambda else: return torch.tensor(0.0).cuda() @@ -247,7 +243,7 @@ def train( criterion, optimizer, epoch, - is_pact, + regularizer, regularizer_lambda, print_freq, ): @@ -278,7 +274,7 @@ def train( # compute output output = model(images) ce_loss = criterion(output, target) - regular_loss = get_regularizer_loss(model, is_pact, scale=regularizer_lambda) + regular_loss = get_regularizer_loss(model, regularizer, regularizer_lambda) loss = ce_loss + regular_loss # measure accuracy and record loss diff --git a/examples/quantization_aware_training/cifar10/basecase/qconfig_lsq_dampen.yaml b/examples/quantization_aware_training/cifar10/basecase/qconfig_lsq_dampen.yaml new file mode 100644 index 0000000..93c6127 --- /dev/null +++ b/examples/quantization_aware_training/cifar10/basecase/qconfig_lsq_dampen.yaml @@ -0,0 +1,14 @@ +BACKEND: virtual +W: + QSCHEME: per-channel-symmetric + QUANTIZER: + TYPE: lsq + BIT: 4 +A: + QSCHEME: per-tensor-affine + QUANTIZER: + TYPE: lsq + BIT: 4 +REGULARIZER: + ENABLE: True + TYPE: dampen diff --git a/examples/quantization_aware_training/cifar10/basecase/qconfig_pact.yaml b/examples/quantization_aware_training/cifar10/basecase/qconfig_pact.yaml index a191746..72b370e 100644 --- a/examples/quantization_aware_training/cifar10/basecase/qconfig_pact.yaml +++ b/examples/quantization_aware_training/cifar10/basecase/qconfig_pact.yaml @@ -9,3 +9,6 @@ A: QUANTIZER: TYPE: pact BIT: 4 +REGULARIZER: + ENABLE: True + TYPE: pact diff --git a/sparsebit/quantization/quant_config.py b/sparsebit/quantization/quant_config.py index b46532d..61204a8 100644 --- a/sparsebit/quantization/quant_config.py +++ b/sparsebit/quantization/quant_config.py @@ -47,6 +47,10 @@ _C.A.QADD.ENABLE_QUANT = False _C.A.SPECIFIC = [] +_C.REGULARIZER = CN() +_C.REGULARIZER.ENABLE = False +_C.REGULARIZER.TYPE = "" + def parse_qconfig(cfg_file): qconfig = _parse_config(cfg_file, default_cfg=_C) diff --git a/sparsebit/quantization/regularizers/__init__.py b/sparsebit/quantization/regularizers/__init__.py new file mode 100644 index 0000000..2fda594 --- /dev/null +++ b/sparsebit/quantization/regularizers/__init__.py @@ -0,0 +1,15 @@ +REGULARIZERS_MAP = {} + + +def register_regularizer(regularizer): + REGULARIZERS_MAP[regularizer.TYPE.lower()] = regularizer + return regularizer + + +from .base import Regularizer +from . import dampen, pact + + +def build_regularizer(config): + regularizer = REGULARIZERS_MAP[config.REGULARIZER.TYPE.lower()](config) + return regularizer diff --git a/sparsebit/quantization/regularizers/base.py b/sparsebit/quantization/regularizers/base.py new file mode 100644 index 0000000..45e3cb9 --- /dev/null +++ b/sparsebit/quantization/regularizers/base.py @@ -0,0 +1,6 @@ +class Regularizer(object): + def __init__(self, config): + self.config = config + + def __call__(self): + pass diff --git a/sparsebit/quantization/regularizers/dampen.py b/sparsebit/quantization/regularizers/dampen.py new file mode 100644 index 0000000..c121d0f --- /dev/null +++ b/sparsebit/quantization/regularizers/dampen.py @@ -0,0 +1,47 @@ +import torch + +from sparsebit.quantization.regularizers import Regularizer as BaseRegularizer +from sparsebit.quantization.regularizers import register_regularizer +from sparsebit.quantization.quantizers.quant_tensor import fake_qrange_factory + + +@register_regularizer +class Regularizer(BaseRegularizer): + TYPE = "Dampen" + + def __init__(self, config): + super(Regularizer, self).__init__(config) + self.config = config + + def _get_loss(self, x, quantizer): + + x_q = quantizer(x) + + scale, zero_point = quantizer._qparams_preprocess(x) + + min_val, max_val = fake_qrange_factory[quantizer.backend]( + scale, zero_point, quantizer.qdesc + ) + + min_val = min_val.detach() + max_val = max_val.detach() + + x_c = torch.min(torch.max(x, min_val), max_val) + + loss = (x_q - x_c) ** 2 + + loss = loss.sum() + + return loss + + def __call__(self, model): + loss = 0.0 + for n, m in model.named_modules(): + if ( + hasattr(m, "weight") + and hasattr(m, "weight_quantizer") + and m.weight_quantizer + and m.weight_quantizer.is_enable + ): + loss += self._get_loss(m.weight, m.weight_quantizer) + return loss diff --git a/sparsebit/quantization/regularizers/pact.py b/sparsebit/quantization/regularizers/pact.py new file mode 100644 index 0000000..89b7cba --- /dev/null +++ b/sparsebit/quantization/regularizers/pact.py @@ -0,0 +1,20 @@ +import torch + +from sparsebit.quantization.regularizers import Regularizer as BaseRegularizer +from sparsebit.quantization.regularizers import register_regularizer + + +@register_regularizer +class Regularizer(BaseRegularizer): + TYPE = "Pact" + + def __init__(self, config): + super(Regularizer, self).__init__(config) + self.config = config + + def __call__(self, model): + loss = 0.0 + for n, p in model.named_parameters(): + if "alpha" in n: + loss += (p**2).sum() + return loss