Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

QAT : TRT 8 compatible workflow #804

Draft
wants to merge 33 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
e66f210
made qat layer scriptable
SrivastavaKshitij Aug 24, 2022
38342b2
WIP - refactor qat library
SrivastavaKshitij Aug 25, 2022
9922d1c
fixed mapping function
SrivastavaKshitij Aug 25, 2022
44d0750
WIP
SrivastavaKshitij Sep 2, 2022
73e523b
WIP
SrivastavaKshitij Sep 7, 2022
5b1d6b9
added converter for quantconv
SrivastavaKshitij Sep 8, 2022
f4b52ee
working quant conv converter
SrivastavaKshitij Sep 9, 2022
39244d7
changing workflow
SrivastavaKshitij Sep 9, 2022
39b8f9f
removed redundant data
SrivastavaKshitij Sep 10, 2022
fa6712c
working maxpool layer
SrivastavaKshitij Sep 11, 2022
3d260ac
added working quantmaxpool2d laayer
SrivastavaKshitij Sep 11, 2022
7419d70
added support for AdaptiveAvgPool2d
SrivastavaKshitij Sep 12, 2022
c9c3f0f
fixed trt and non trt mode
SrivastavaKshitij Sep 12, 2022
4687282
added converter for quant adaptive avgpool 2d
SrivastavaKshitij Sep 12, 2022
c4e13ef
fixed adaptiveavgpool2d converter import
SrivastavaKshitij Sep 12, 2022
37f89fb
fixed pytorch fake quant ops
SrivastavaKshitij Sep 12, 2022
6e63c4f
added generic converter
SrivastavaKshitij Sep 12, 2022
179a9d9
fixed nn.conv2d
SrivastavaKshitij Sep 13, 2022
f31f67c
removed graphviz import
SrivastavaKshitij Sep 13, 2022
60cb7e6
fixed documentation
SrivastavaKshitij Sep 14, 2022
c0521f1
fixed build script
SrivastavaKshitij Sep 15, 2022
3f91fae
fixed parameter type for torch.fake_quantize per channel
SrivastavaKshitij Sep 15, 2022
1068c64
Fixed accuracy metrics
SrivastavaKshitij Sep 15, 2022
fe5a74d
added new patch file
SrivastavaKshitij Sep 15, 2022
5e3e65b
fixed import warning
SrivastavaKshitij Sep 19, 2022
0ef2df0
fixed precision of zero point
SrivastavaKshitij Sep 22, 2022
85cbc66
fixed logging issue for conv2d while scripting
SrivastavaKshitij Sep 23, 2022
5f3de32
stripped off tensor quantizer depedency, hopefully
SrivastavaKshitij Sep 23, 2022
e0742c1
added min reproducible file
SrivastavaKshitij Sep 23, 2022
3a3c4bd
fixed converters
SrivastavaKshitij Sep 24, 2022
e0c7e62
fixed loading
SrivastavaKshitij Sep 26, 2022
f994bad
fixed quant axis initial value
SrivastavaKshitij Sep 27, 2022
8317ffa
WIP
SrivastavaKshitij Oct 6, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ To install torch2trt with experimental community contributed features under ``to
```bash
git clone https://github.com/NVIDIA-AI-IOT/torch2trt
cd torch2trt/scripts
bash build_contrib.sh
bash build_qat.sh
```

This enables you to run the QAT example located [here](examples/contrib/quantization_aware_training).
Expand Down
44 changes: 9 additions & 35 deletions examples/contrib/quantization_aware_training/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,34 +13,6 @@ This example is using QAT library open sourced by nvidia. [Github link](https://

2. Usually, nvidia quantization library doesn't provide control per layer for quantization. Custom wrapper under `utils/utilities.py` helps us in quantization selective layers in our model.

## Environment

**Filename** : pytorch_ngc_container_20.09

```
FROM nvcr.io/nvidia/pytorch:20.09-py3
RUN apt-get update && apt-get install -y software-properties-common && apt-get update
RUN add-apt-repository ppa:git-core/ppa && \
apt install -y git

RUN pip install termcolor graphviz

RUN git clone https://github.com/NVIDIA-AI-IOT/torch2trt.git /sw/torch2trt/ && \
cd /sw/torch2trt/scripts && \
bash build_contrib.sh

```

Docker build: `docker build -f pytorch_ngc_container_20.09 -t pytorch_ngc_container_20.09 .`

`docker_image=pytorch_ngc_container_20.09`

Docker run : `docker run -e NVIDIA_VISIBLE_DEVICES=0 --gpus 0 -it --shm-size=1g --ulimit memlock=-1 --rm -v $PWD:/workspace/work $docker_image`

**Important Notes** :

- Sparse checkout helps us in checking out a part of the github repo.
- Patch file can be found under `examples/quantization_aware_training/utils`

## Workflow

Expand All @@ -49,27 +21,29 @@ Workflow consists of three parts.

Here pretrained weights from imagenet are used.

`python train.py --m resnet34-tl / resnet18-tl --num_epochs 45 --test_trt --FP16 --INT8PTC`
`python train.py --m <model> --pretrain --num_epochs <num_epochs> --test_trt --FP16 --INT8PTC`

2. Train with quantization (weights are mapped using a custom function to make sure that each weight is loaded correctly)

`python train.py --m resnet34/ resnet18 --netqat --partial_ckpt --tl --load_ckpt /tmp/pytorch_exp/{} --num_epochs 25 --lr 1e-4 --lrdt 10`
`python train.py --m <model> --quantize --partial_ckpt --load_ckpt /tmp/pytorch_exp/{} --num_epochs <num_epochs> --lr 1e-4 --lrdt 10`

3. Infer with and without TRT

`python infer.py --m resnet34/resnet18 --load_ckpt /tmp/pytorch_exp_1/ckpt_{} --netqat --INT8QAT`
`python infer.py --m <model> --load_ckpt /tmp/pytorch_exp_1/ckpt_{} --quantize --INT8QAT`


## Accuracy Results

| Model | FP32 | FP16 | INT8 (QAT) | INT(PTC) |
| Model | FP32 | FP16 | INT8 (QAT) | INT8(PTC) |
|-------|------|------|------------|----------|
| Resnet18 | 83.08 | 83.12 | 83.12 | 83.06 |
| Resnet34 | 84.65 | 84.65 | 83.26 | 84.5 |
| Resnet18 | 83.78 | 83.77 | 83.78 | 83.78 |
| Resnet34 | 85.13 | 85.11 | 84.99 | 84.95 |
| Resnet50 | 87.56|87.54 |87.49 |87.38 |

Models were intially trained for 40 epochs and then fine tuned with QAT on for 10 epochs.

**Please note that the idea behind these experiments is to see if TRT conversion is working properly rather than achieving industry standard accuracy results**

## Future Work

- Add results for Resnet50, EfficientNet and Mobilenet
- Add results for EfficientNet and Mobilenet
10 changes: 6 additions & 4 deletions examples/contrib/quantization_aware_training/datasets/cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,16 @@ def __init__(self, data_dir='/tmp/cifar10', download=True, batch_size=128, pin_m
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

def train_loader(self,shuffle=True):
def train_loader(self,shuffle=True,batch_size=None):
if batch_size == None:
batch_size = self.batch_size
trainset = torchvision.datasets.CIFAR10(root=self.data_dir, train=True, download=True, transform=self.train_transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=self.batch_size, shuffle=shuffle, num_workers=self.num_workers, pin_memory=self.pin_memory)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=shuffle, num_workers=self.num_workers, pin_memory=self.pin_memory)
return trainloader

def test_loader(self,shuffle=False):
def test_loader(self,shuffle=False,batch_size=1):
testset = torchvision.datasets.CIFAR10(root=self.data_dir, train=False, download=True, transform=self.test_transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=self.batch_size, shuffle=shuffle, num_workers=self.num_workers, pin_memory=self.pin_memory)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=shuffle, num_workers=self.num_workers, pin_memory=self.pin_memory)
return testloader


Expand Down
38 changes: 28 additions & 10 deletions examples/contrib/quantization_aware_training/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
import os,sys
from datasets.cifar10 import Cifar10Loaders
from utils.utilities import calculate_accuracy, timeGraph,printStats
from models.resnet import resnet18,resnet34
from models.resnet import resnet18,resnet34,resnet50
from parser import parse_args
from torch2trt import torch2trt
import tensorrt as trt
torch.set_printoptions(precision=5)
import torch2trt.contrib.qat.layers as quant_nn

def main():
args = parse_args()
Expand All @@ -22,37 +23,51 @@ def main():
if args.cuda:
torch.backends.cudnn.benchmark = True
torch.cuda.manual_seed(args.seed)


## Loading dataset

loaders = Cifar10Loaders()
train_loader = loaders.train_loader()
test_loader = loaders.test_loader()

## Loading model

if args.m == "resnet18":
if args.netqat:
model=resnet18(qat_mode=True,infer=True)
if args.quantize:
model=resnet18(qat_mode=True)
else:
model=resnet18()
elif args.m == "resnet34":
if args.netqat:
model=resnet34(qat_mode=True,infer=True)
if args.quantize:
model=resnet34(qat_mode=True)
else:
model=resnet34()
elif args.m == "resnet50":
if args.quantize:
model=resnet50(qat_mode=True)
else:
model=resnet50()
else:
raise NotImplementedError("{} model not found".format(args.m))


model = model.cuda().eval()
rand_in = torch.randn([128,3,32,32],dtype=torch.float32).cuda()
model = model.cuda().train()

## Single dummy run to instantiate quant metrics
out = model(rand_in)
for k,v in model.state_dict().items():
print(k,v.shape)
if args.load_ckpt:
checkpoint = torch.load(args.load_ckpt)
if not args.netqat:
if not args.quantize:
checkpoint = mapping_names_resnets(checkpoint)
model.load_state_dict(checkpoint['model_state_dict'],strict=True)
print("===>>> Checkpoint loaded successfully from {} ".format(args.load_ckpt))

model=model.eval()
test_accuracy = calculate_accuracy(model,test_loader)
print(" Test accuracy for Pytorch model: {0} ".format(test_accuracy))
rand_in = torch.randn([128,3,32,32],dtype=torch.float32).cuda()

#Converting the model to TRT
if args.FP16:
Expand All @@ -61,7 +76,9 @@ def main():
print(" TRT test accuracy at FP16: {0}".format(test_accuracy))

if args.INT8QAT:
trt_model_int8 = torch2trt(model,[rand_in],log_level=trt.Logger.INFO,fp16_mode=True,int8_mode=True,max_batch_size=128,qat_mode=True)
quant_nn.HelperFunction.export_trt = True
model = model.eval()
trt_model_int8 = torch2trt(model,[rand_in],log_level=trt.Logger.INFO,fp16_mode=True,int8_mode=True,max_batch_size=128,qat_mode=True,strict_type_constraints=False)
test_accuracy = calculate_accuracy(trt_model_int8,test_loader)
print(" TRT test accuracy at INT8 QAT: {0}".format(test_accuracy))

Expand All @@ -77,5 +94,6 @@ def main():
test_accuracy = calculate_accuracy(trt_model_calib_int8,test_loader)
print(" TRT test accuracy at INT8 PTC: {0}".format(test_accuracy))


if __name__ == "__main__":
main()
36 changes: 0 additions & 36 deletions examples/contrib/quantization_aware_training/models/models.py

This file was deleted.

Loading