Skip to content

Commit

Permalink
fix/feat(trainer): drop_last = False and num_workers = 2
Browse files Browse the repository at this point in the history
  • Loading branch information
HiroIshida committed Feb 22, 2024
1 parent 1929a0a commit dbb6ed9
Showing 1 changed file with 17 additions and 2 deletions.
19 changes: 17 additions & 2 deletions mohou/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,16 +455,31 @@ def train(
config: TrainConfig = TrainConfig(),
device: Optional[torch.device] = None,
is_stoppable: Optional[Callable[[TrainCache], bool]] = None,
num_workers: int = 2,
) -> None:
r"""
higher-level train function that auto create dataloader from the dataset
"""

dataset_train, dataset_validate = split_with_ratio(dataset, config.valid_data_ratio)

train_loader = DataLoader(dataset=dataset_train, batch_size=config.batch_size, shuffle=True)
if len(dataset_train) < config.batch_size:
message = "dataset size is smaller than batch_size. drop_last is set to False"
logger.warn(change_color_to_yellow(message))

# drop last is necessary for batch normalization
train_loader = DataLoader(
dataset=dataset_train,
batch_size=config.batch_size,
shuffle=True,
num_workers=num_workers,
drop_last=True,
)
validate_loader = DataLoader(
dataset=dataset_validate, batch_size=config.batch_size, shuffle=True
dataset=dataset_validate,
batch_size=config.batch_size,
shuffle=True,
num_workers=num_workers,
)
train_lower(
project_path,
Expand Down

0 comments on commit dbb6ed9

Please sign in to comment.