Skip to content

Commit b2f72ea

Browse files
committed
Add option to pass num_workers for torch DataLoader
1 parent 280a995 commit b2f72ea

File tree

1 file changed

+14
-2
lines changed

1 file changed

+14
-2
lines changed

fast_plate_ocr/cli/train.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,13 @@
7171
type=int,
7272
help="Batch size for training.",
7373
)
74+
@click.option(
75+
"--num-workers",
76+
default=0,
77+
show_default=True,
78+
type=int,
79+
help="How many subprocesses to load data, used in the torch DataLoader.",
80+
)
7481
@click.option(
7582
"--output-dir",
7683
default="./trained-models",
@@ -120,6 +127,7 @@ def train(
120127
augmentation_path: pathlib.Path | None,
121128
lr: float,
122129
batch_size: int,
130+
num_workers: int,
123131
output_dir: pathlib.Path,
124132
epochs: int,
125133
tensorboard: bool,
@@ -139,14 +147,18 @@ def train(
139147
transform=train_augmentation,
140148
config=config,
141149
)
142-
train_dataloader = DataLoader(train_torch_dataset, batch_size=batch_size, shuffle=True)
150+
train_dataloader = DataLoader(
151+
train_torch_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True
152+
)
143153

144154
if val_annotations:
145155
val_torch_dataset = LicensePlateDataset(
146156
annotations_file=val_annotations,
147157
config=config,
148158
)
149-
val_dataloader = DataLoader(val_torch_dataset, batch_size=batch_size, shuffle=False)
159+
val_dataloader = DataLoader(
160+
val_torch_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False
161+
)
150162
else:
151163
val_dataloader = None
152164

0 commit comments

Comments
 (0)