-
Notifications
You must be signed in to change notification settings - Fork 2
/
Predict.py
35 lines (32 loc) · 1.15 KB
/
Predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import Mydataset
import torch
import torchvision
import torchvision.transforms as transforms
from cnnModel import CNN
from Mydataset import get_test_data_loader
import matplotlib.pyplot as plt
from PIL import Image
PATH="E:\LearningStuff\DLcode\Pytorch\Mnist\Trained_models"
def predict(image):
if not torch.is_tensor(image):
image=image.resize((28,28))
'''将图像转为tensor'''
loader=transforms.Compose([transforms.ToTensor(),transforms.Normalize(0.1307,0.3081)])
image=loader(image).unsqueeze(dim=0)
''' 预测 '''
network=CNN()
network.eval()
network.load_state_dict(torch.load(PATH+"\Model1.pkl"))
pred=network(image).argmax(dim=1)
return pred
if __name__=='__main__':
Load_from_file=True
if Load_from_file:
image=Image.open('Images/new2.jpg').convert('L')
pred=predict(image)
else :
batch=get_test_data_loader(batch_size=1)
image,label=next(iter(batch))
pred=predict(image)
print("the Prediction of ",str(label.numpy())," is:",str(pred.numpy()))
print('The predition is:{}'.format(pred.item()))