-
Notifications
You must be signed in to change notification settings - Fork 2
/
train.py
166 lines (128 loc) · 5.19 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
# Import torch and torchvision
import torch
from torch.utils.data import DataLoader
import torchvision.transforms.v2 as transforms
# Import additional libraries
import argparse
import mplfonts
import numpy as np
import os
import random
import sys
from typing import Sequence
from dataset import CharacterDataset
from models import CharacterRecognitionCNN
from utils import *
# Set Noto Sans CJK JP font used to render plots
mplfonts.use_font("Noto Sans JP")
# Define and set random states
RANDOM_STATE_DEFAULT = 42
# Define default hyperparameters
EPOCH_DEFAULT = 100
LR_DEFAULT = 1e-3
def main():
# Parse the input arguments
parser = argparse.ArgumentParser()
# Path to the dataset image direcotry and label csv
parser.add_argument("--dir", required=True)
parser.add_argument("--csv", required=True)
parser.add_argument("--export", required=True)
# Define model architecture
parser.add_argument("--layers", nargs="+", type=int, required=True)
# Define image size
parser.add_argument("-x", required=True, type=int)
parser.add_argument('-y', required=True, type=int)
# Training hyperparameters
parser.add_argument("--epoch", type=int, default=EPOCH_DEFAULT)
parser.add_argument("--lr", type=float, default=LR_DEFAULT)
# Hardware acceleratin device
parser.add_argument("--device", default="auto", choices=["cpu", "cuda", "mps", "auto"])
# Augmentation pipeline options
parser.add_argument("--zoom", action="store_true", default=False)
g = parser.add_argument_group('Zoom options')
# Add the zoom-specific arguments to the group
g.add_argument("--zoom_min", type=float, required="--zoom" in sys.argv)
g.add_argument("--zoom_max", type=float, required="--zoom" in sys.argv)
g.add_argument("--zoom_ratio", type=float, required="--zoom" in sys.argv)
parser.add_argument("--rotation", type=float, default=False)
# Flag to visualize training history and evalutaion results, default `False`
parser.add_argument("--visualize", action="store_true", default=False)
# Random seed
parser.add_argument("--seed", type=int, default=RANDOM_STATE_DEFAULT)
# Parse arguments
args = parser.parse_args()
# Set the random seed
random_state = args.seed
torch.manual_seed(random_state)
np.random.seed(random_state)
random.seed(random_state)
# If --zoom is set and any required args are missing, raise error
if args.zoom and None in [args.zoom_min, args.zoom_max, args.zoom_ratio]:
parser.error("Augmentation zoom parameters not set")
# Automatically select hardware acceleration device based on availability
if args.device == "auto":
device = "cuda" if torch.cuda.is_available() \
else "mps" if torch.backends.mps.is_available() \
else "cpu"
else:
device = args.device
# Get paths
dir_path = os.path.expanduser(args.dir)
csv_path = os.path.expanduser(args.csv)
# Perform a 60-20-20 stratified data split on the dataset
stratified = stratify_dataset(csv_path, dir_path, random_state)
train_set = stratified["split"]["train"]
val_set = stratified["split"]["val"]
test_set = stratified["split"]["test"]
itos = stratified["map"]["itos"]
img_shape = (args.y, args.x)
# Determine the augmentation pipeline to use
pipeline = [
transforms.ColorJitter((0.8, 1.2), (0.8, 1.2), (0.8, 1.2), (-0.3, 0.3)),
transforms.RandomChannelPermutation(),
transforms.RandomInvert()
]
# Random resized crop
if args.zoom:
pipeline = pipeline + [
transforms.RandomResizedCrop(img_shape, (args.zoom_min, args.zoom_max),
ratio=(1.0, args.zoom_ratio), antialias=True)
]
# Rotation
if args.rotation:
pipeline = pipeline + [
transforms.RandomRotation(args.rotation)
]
# Perspective and noise
pipeline = pipeline + [
transforms.RandomPerspective(),
transforms.ToDtype(torch.float32, scale=True),
transforms.Lambda(lambda x: x + torch.normal(0, 1, x.size())),
]
print(f"Augmentation pipeline:")
for p in pipeline:
print("-", p)
# Check if --layer is a sequence
if not isinstance(args.layers, Sequence):
parser.error("Sequence of int expected for --layers argument")
# Define the model
num_cls = len(itos)
model = CharacterRecognitionCNN(3, num_cls, img_shape, args.layers, device, transforms.Compose(pipeline))
# Get dataloaders
batch_size = 8
train_dataloader = DataLoader(train_set, batch_size, shuffle=True)
val_dataloader = DataLoader(val_set, batch_size, shuffle=True)
# Train the model using the provided hyperparameters
model.train_loop(train_dataloader, val_dataloader, args.epoch, args.lr)
# Evalulate the model on the test set
result = evaluate_model(model, test_set)
# Export the model
export_path = os.path.abspath(args.export)
torch.save(model.state_dict(), export_path)
print(f"Model exported as {export_path}")
# Visualize the result
if args.visualize:
visualize_history(model)
visualize_result(result, itos, "Misclassified Instances")
if __name__ == "__main__":
main()