The CIFAR-10
dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images.
below is the 6 random images with their respective label:
There is a package of python called torchvision
, that has data loaders for CIFAR10
and data transformers for images using torch.utils.data.DataLoader
.
Below an example of how to load CIFAR10
dataset using torchvision
:
import torch
import torchvision
## load data CIFAR10
train_dataset = torchvision.datasets.CIFAR10(root='./train_data', train=True, download=True)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
- Python>=3.6
- PyTorch >=1.4
- Library are mentioned in
requirenments.txt
I used pretrained resnet18
for model training. you can use any other pretrained model according to you problem.
import torchvision.models as models
alexnet = models.alexnet()
vgg16 = models.vgg16()
densenet = models.densenet161()
inception = models.inception_v3()
There are two things for pytorch model training:
- Notebook - you can just download and play with it
- python scripts:
# Start training with: python main.py # You can manually pass the attributes for the training: python main.py --lr=0.01 --epoch 20 --model_path './cifar_model.pth' # Start infrence with: python3.6 prediction.py --model_path './cifar_model.pth'