-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_submission_torch.py
43 lines (37 loc) · 1.62 KB
/
test_submission_torch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import sys
import pathlib
from PIL import Image
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from model import ResNetUNet
def main():
# Load the classes
data_dir = pathlib.Path('./data/tiny-imagenet-200/train/')
CLASSES = sorted([item.name for item in data_dir.glob('*')])
im_height, im_width = 64, 64
model = ResNetUNet(len(CLASSES))
model.load_state_dict(torch.load("./weights/best/ces_weights.pt"), strict=True)
model.eval()
data_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0, 0, 0), tuple(np.sqrt((255, 255, 255)))),
])
# Loop through the CSV file and make a prediction for each line
with open('eval_classified.csv', 'w') as eval_output_file: # Open the evaluation CSV file for writing
for line in pathlib.Path(sys.argv[1]).open(): # Open the input CSV file for reading
image_id, image_path, image_height, image_width, image_channels = line.strip().split(
',') # Extract CSV info
print(image_id, image_path, image_height, image_width, image_channels)
with open(image_path, 'rb') as f:
img = Image.open(f).convert('RGB')
img = data_transforms(img)[None, :]
y_hat = model(img)
probs = F.softmax(y_hat, dim=1)
_, predicted = probs.max(1)
# Write the prediction to the output file
eval_output_file.write('{},{}\n'.format(image_id, CLASSES[predicted]))
#USAGE: python3 test_submission_torch.py eval.py
if __name__ == '__main__':
main()