Skip to content

Commit 8bd7bbc

Browse files
committed
add regularizer
1 parent 126c30d commit 8bd7bbc

File tree

5 files changed

+39
-21
lines changed

5 files changed

+39
-21
lines changed

examples/quantization_aware_training/cifar10/basecase/main.py

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import time
66
import warnings
77
from enum import Enum
8+
import math
89

910
import torch
1011
import torch.nn as nn
@@ -27,7 +28,7 @@
2728
raise NotImplementedError("This example should run on a GPU device.") # 确定在GPU上运行
2829

2930

30-
config = "qconfig_lsq.yaml" # QAT配置文件——包括量化方式(dorefa/lsq),权重和激活值的量化bit数等
31+
config = "qconfig_lsq_dampen.yaml" # QAT配置文件——包括量化方式(dorefa/lsq),权重和激活值的量化bit数等
3132
workers = 4
3233
epochs = 200
3334
start_epoch = 0
@@ -38,8 +39,7 @@
3839
print_freq = 100
3940
pretrained = ""
4041
qconfig = parse_qconfig(config)
41-
is_pact = qconfig.A.QUANTIZER.TYPE == "pact"
42-
regularizer_lambda = 1e-4
42+
regularizer_schedule = "cosine" if qconfig.REGULARIZER.TYPE == "dampen" else "keep"
4343

4444
model = resnet20(num_classes=10) # 以resnet20作为基础模型
4545
if pretrained: # 可以采用pretrained中保存的模型参数
@@ -109,21 +109,8 @@
109109
optimizer, milestones=[100, 150], last_epoch=start_epoch - 1
110110
)
111111

112-
# PACT算法中对 alpha 增加 L2-regularization
113-
def get_pact_regularizer_loss(model):
114-
loss = 0
115-
for n, p in model.named_parameters():
116-
if "alpha" in n:
117-
loss += (p ** 2).sum()
118-
return loss
119-
120-
def get_regularizer_loss(model, scale=0):
121-
if is_pact:
122-
return get_pact_regularizer_loss(model) * scale
123-
else:
124-
return torch.tensor(0.).cuda()
125112

126-
def train(train_loader, model, criterion, optimizer, epoch):
113+
def train(train_loader, model, criterion, optimizer, epoch, schedule_value=1.0):
127114
batch_time = AverageMeter("Time", ":6.3f")
128115
data_time = AverageMeter("Data", ":6.3f")
129116
losses = AverageMeter("Loss", ":.4e")
@@ -151,7 +138,7 @@ def train(train_loader, model, criterion, optimizer, epoch):
151138
# compute output
152139
output = model(images)
153140
ce_loss = criterion(output, target)
154-
regular_loss = get_regularizer_loss(model, scale=regularizer_lambda)
141+
regular_loss = model.get_regularizer_loss() * schedule_value
155142
loss = ce_loss + regular_loss
156143

157144
# measure accuracy and record loss
@@ -311,12 +298,18 @@ def accuracy(output, target, topk=(1,)):
311298
best_acc1 = 0
312299
for epoch in range(start_epoch, epochs):
313300
# train for one epoch
301+
if regularizer_schedule == "cosine":
302+
coeff = (epoch - start_epoch + 1) / (epochs - start_epoch)
303+
schedule_value = 1 - 0.5 * (1.0 + math.cos(math.pi * coeff))
304+
else:
305+
schedule_value = 1.0
314306
train(
315307
trainloader,
316308
model,
317309
criterion,
318310
optimizer,
319311
epoch,
312+
schedule_value=schedule_value,
320313
)
321314

322315
# evaluate on validation set

examples/quantization_aware_training/cifar10/basecase/qconfig_pact.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,7 @@ A:
99
QUANTIZER:
1010
TYPE: pact
1111
BIT: 4
12+
REGULARIZER:
13+
ENABLE: True
14+
TYPE: pact
15+
LAMBDA: 0.0001

sparsebit/quantization/quant_config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@
3838
_C.A.OBSERVER.LAYOUT = "NCHW" # NCHW / NLC
3939
_C.A.SPECIFIC = []
4040

41+
_C.REGULARIZER = CN()
42+
_C.REGULARIZER.ENABLE = False
43+
_C.REGULARIZER.TYPE = ""
44+
_C.REGULARIZER.LAMBDA = 0.0
45+
4146

4247
def parse_qconfig(cfg_file):
4348
qconfig = _parse_config(cfg_file, default_cfg=_C)

sparsebit/quantization/quant_model.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from sparsebit.quantization.quantizers import Quantizer
2121
from sparsebit.quantization.tools import QuantizationErrorProfiler
2222
from sparsebit.quantization.converters import simplify, fuse_operations
23+
from sparsebit.quantization.regularizers import build_regularizer
2324

2425

2526
__all__ = ["QuantModel"]
@@ -34,6 +35,7 @@ def __init__(self, model: nn.Module, config):
3435
self._run_simplifiers()
3536
self._convert2quantmodule()
3637
self._build_quantizer()
38+
self._build_regularizer()
3739
self._run_fuse_operations()
3840

3941
def _convert2quantmodule(self):
@@ -119,11 +121,17 @@ def _sub_build(src, module_name):
119121
update_config(_config, "A", _sub_build(self.cfg.A, node.target))
120122
identity_module.build_quantizer(_config)
121123

124+
def _build_regularizer(self):
125+
if self.cfg.REGULARIZER.ENABLE:
126+
self.regularizer = build_regularizer(self.cfg)
127+
else:
128+
self.regularizer = None
129+
122130
def _run_simplifiers(self):
123131
self.model = simplify(self.model)
124132

125133
def _run_fuse_operations(self):
126-
if self.cfg.SCHEDULE.BN_TUNING: # first disable fuse bn
134+
if self.cfg.SCHEDULE.BN_TUNING: # first disable fuse bn
127135
update_config(self.cfg.SCHEDULE, "FUSE_BN", False)
128136
self.model = fuse_operations(self.model, self.cfg.SCHEDULE)
129137
self.model.graph.print_tabular()
@@ -144,7 +152,9 @@ def batchnorm_tuning(self):
144152
yield
145153
self.model.eval()
146154
update_config(self.cfg.SCHEDULE, "FUSE_BN", True)
147-
self.model = fuse_operations(self.model, self.cfg.SCHEDULE, custom_fuse_list=["fuse_bn"])
155+
self.model = fuse_operations(
156+
self.model, self.cfg.SCHEDULE, custom_fuse_list=["fuse_bn"]
157+
)
148158
self.set_quant(w_quant=False, a_quant=False)
149159

150160
def prepare_calibration(self):
@@ -210,6 +220,12 @@ def set_quant(self, w_quant=False, a_quant=False):
210220
if isinstance(m, QuantOpr):
211221
m.set_quant(w_quant, a_quant)
212222

223+
def get_regularizer_loss(self):
224+
if self.regularizer is None:
225+
return torch.tensor(0.).to(self.device)
226+
else:
227+
return self.regularizer(self.model)
228+
213229
def export_onnx(
214230
self,
215231
dummy_data,

sparsebit/quantization/regularizers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ def register_regularizer(regularizer):
77

88

99
from .base import Regularizer
10-
from . import dampen
10+
from . import dampen, pact
1111

1212

1313
def build_regularizer(config):

0 commit comments

Comments
 (0)