diff --git a/examples/transit_train.py b/examples/transit_train.py index 028eb4af..79054045 100644 --- a/examples/transit_train.py +++ b/examples/transit_train.py @@ -9,7 +9,7 @@ def main(): validation_light_curve_dataset = get_transit_validation_dataset() model = Hadryss.new() train_hyperparameter_configuration = TrainHyperparameterConfiguration.new( - batch_size=100, cycles=100, train_steps_per_cycle=100, validation_steps_per_cycle=10) + batch_size=100, cycles=20, train_steps_per_cycle=100, validation_steps_per_cycle=10) train_session(train_datasets=[train_light_curve_dataset], validation_datasets=[validation_light_curve_dataset], model=model, hyperparameter_configuration=train_hyperparameter_configuration)