File tree Expand file tree Collapse file tree 1 file changed +14
-2
lines changed Expand file tree Collapse file tree 1 file changed +14
-2
lines changed Original file line number Diff line number Diff line change 71
71
type = int ,
72
72
help = "Batch size for training." ,
73
73
)
74
+ @click .option (
75
+ "--num-workers" ,
76
+ default = 0 ,
77
+ show_default = True ,
78
+ type = int ,
79
+ help = "How many subprocesses to load data, used in the torch DataLoader." ,
80
+ )
74
81
@click .option (
75
82
"--output-dir" ,
76
83
default = "./trained-models" ,
@@ -120,6 +127,7 @@ def train(
120
127
augmentation_path : pathlib .Path | None ,
121
128
lr : float ,
122
129
batch_size : int ,
130
+ num_workers : int ,
123
131
output_dir : pathlib .Path ,
124
132
epochs : int ,
125
133
tensorboard : bool ,
@@ -139,14 +147,18 @@ def train(
139
147
transform = train_augmentation ,
140
148
config = config ,
141
149
)
142
- train_dataloader = DataLoader (train_torch_dataset , batch_size = batch_size , shuffle = True )
150
+ train_dataloader = DataLoader (
151
+ train_torch_dataset , batch_size = batch_size , num_workers = num_workers , shuffle = True
152
+ )
143
153
144
154
if val_annotations :
145
155
val_torch_dataset = LicensePlateDataset (
146
156
annotations_file = val_annotations ,
147
157
config = config ,
148
158
)
149
- val_dataloader = DataLoader (val_torch_dataset , batch_size = batch_size , shuffle = False )
159
+ val_dataloader = DataLoader (
160
+ val_torch_dataset , batch_size = batch_size , num_workers = num_workers , shuffle = False
161
+ )
150
162
else :
151
163
val_dataloader = None
152
164
You can’t perform that action at this time.
0 commit comments