-
Notifications
You must be signed in to change notification settings - Fork 0
/
classifier.py
78 lines (60 loc) · 2.7 KB
/
classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import ast
import torchvision as torchvision
from torchvision.models import resnet18, ResNet18_Weights, alexnet, AlexNet_Weights, vgg16, VGG16_Weights
from PIL import Image
import torchvision.transforms as transforms
from torch.autograd import Variable
import torchvision.models as models
from torch import __version__
resnet18 = models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
alexnet = models.alexnet(weights=AlexNet_Weights.IMAGENET1K_V1)
vgg16 = models.vgg16(weights=VGG16_Weights.IMAGENET1K_V1)
models = {'resnet': resnet18, 'alexnet': alexnet, 'vgg': vgg16}
# obtain ImageNet labels
with open('imagenet1000_clsid_to_human.txt') as imagenet_classes_file:
imagenet_classes_dict = ast.literal_eval(imagenet_classes_file.read())
def classifier(img_path, model_name):
# load the image
img_pil = Image.open(img_path)
# define transforms
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# preprocess the image
img_tensor = preprocess(img_pil)
# resize the tensor (add dimension for batch)
img_tensor.unsqueeze_(0)
# wrap input in variable, wrap input in variable - no longer needed for
# v 0.4 & higher code changed 04/26/2018 by Jennifer S. to handle PyTorch upgrade
pytorch_ver = __version__.split('.')
# pytorch versions 0.4 & hihger - Variable depreciated so that it returns
# a tensor. So to address tensor as output (not wrapper) and to mimic the
# affect of setting volatile = True (because we are using pretrained models
# for inference) we can set requires_gradient to False. Here we just set
# requires_grad_ to False on our tensor
if int(pytorch_ver[0]) > 0 or int(pytorch_ver[1]) >= 4:
img_tensor.requires_grad_(False)
# pytorch versions less than 0.4 - uses Variable because not-depreciated
else:
# apply model to input
# wrap input in variable
data = Variable(img_tensor, volatile = True)
# apply model to input
model = models[model_name]
# puts model in evaluation mode
# instead of (default)training mode
model = model.eval()
# apply data to model - adjusted based upon version to account for
# operating on a Tensor for version 0.4 & higher.
if int(pytorch_ver[0]) > 0 or int(pytorch_ver[1]) >= 4:
output = model(img_tensor)
# pytorch versions less than 0.4
else:
# apply data to model
output = model(data)
# return index corresponding to predicted class
pred_idx = output.data.numpy().argmax()
return imagenet_classes_dict[pred_idx]