-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_model.py
37 lines (30 loc) · 1.09 KB
/
test_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from model import MiniResNet
def test(model, batch_size=1000):
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
test_dataset = datasets.MNIST('./data', train=False, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=batch_size)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.eval()
correct = 0
total = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
_, predicted = output.max(1)
total += target.size(0)
correct += predicted.eq(target).sum().item()
accuracy = 100. * correct / total
print(f'Test Accuracy: {accuracy:.2f}%')
return accuracy
if __name__ == '__main__':
model = MiniResNet()
model.load_state_dict(torch.load('model.pth'))
test(model)