From b751655bb94a2ef0e2388fd4c686ca68989db33c Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Mon, 30 Oct 2023 11:49:23 -0400 Subject: [PATCH 01/11] Remove complicated lr_decrease_params --- dwi_ml/training/trainers.py | 33 +++---------------- dwi_ml/training/utils/trainer.py | 8 ----- scripts_python/l2t_train_model.py | 3 +- .../tests/test_all_steps_learn2track.py | 1 - scripts_python/tt_train_model.py | 3 +- 5 files changed, 6 insertions(+), 42 deletions(-) diff --git a/dwi_ml/training/trainers.py b/dwi_ml/training/trainers.py index ecb2558a..c6ecb440 100644 --- a/dwi_ml/training/trainers.py +++ b/dwi_ml/training/trainers.py @@ -54,7 +54,6 @@ def __init__(self, experiment_name: str, batch_sampler: DWIMLBatchIDSampler, batch_loader: DWIMLAbstractBatchLoader, learning_rates: Union[List, float] = None, - lr_decrease_params: Tuple[float, float] = None, weight_decay: float = 0.01, optimizer: str = 'Adam', max_epochs: int = 10, max_batches_per_epoch_training: int = 1000, @@ -87,12 +86,6 @@ def __init__(self, torch's default, 0.001). A list [0.01, 0.01, 0.001], for instance, would use these values for the first 3 epochs, and keep the final value for remaining epochs. - lr_decrease_params: Tuple[float, float] - Parameters [E, L] to set the learning rate an exponential decreasing - curve. The final curve will be init_lr * exp(-x / r). The rate of - decrease, r, is defined in order to ensure that the learning rate - curve will hit value L at epoch E. - learning_rates must be a single float value. weight_decay: float Add a weight decay penalty on the parameters. Default: 0.01. (torch's default). @@ -152,22 +145,9 @@ def __init__(self, self.comet_project = comet_project self.space = 'vox' self.origin = 'corner' - self.lr_decrease_params = lr_decrease_params self.clip_grad = clip_grad # Learning rate: - if lr_decrease_params is not None: - assert isinstance(learning_rates, float), \ - "To use lr_decrease_params, the learning_rate cannot be a " \ - "list of learning rates. Expecting a single float value, but " \ - "got {}".format(learning_rates) - self.initial_lr = learning_rates # Initial value - x, y = lr_decrease_params - assert x.is_integer(), \ - "First value of lr_decrease_params should be an epoch " \ - "(integer), but got {}".format(x) - self.lr_decrease_rate = -x / np.log(y / self.initial_lr) - if learning_rates is None: self.learning_rates = [0.001] elif isinstance(learning_rates, float): @@ -360,7 +340,6 @@ def params_for_checkpoint(self): # user to increase the patience when running again. params = { 'learning_rates': self.learning_rates, - 'lr_decrease_params': self.lr_decrease_params, 'weight_decay': self.weight_decay, 'max_epochs': self.max_epochs, 'max_batches_per_epoch_training': self.max_batches_per_epochs_train, @@ -689,16 +668,12 @@ def train_and_validate(self): .format(epoch, epoch + 1)) # Computing learning rate - if self.lr_decrease_params is not None: - # Exponential decrease - current_lr = self.initial_lr * np.exp(-epoch/self.lr_decrease_rate) - else: - # User-given values - current_lr = self.learning_rates[ - min(self.current_epoch, len(self.learning_rates) - 1)] + current_lr = self.learning_rates[ + min(self.current_epoch, len(self.learning_rates) - 1)] logger.info("Learning rate = {}".format(current_lr)) if self.comet_exp: - self.comet_exp.log_metric("learning_rate", current_lr, step=epoch) + self.comet_exp.log_metric("learning_rate", current_lr, + step=epoch) for g in self.optimizer.param_groups: g['lr'] = current_lr diff --git a/dwi_ml/training/utils/trainer.py b/dwi_ml/training/utils/trainer.py index e3c32e78..8b90a70c 100644 --- a/dwi_ml/training/utils/trainer.py +++ b/dwi_ml/training/utils/trainer.py @@ -18,14 +18,6 @@ def add_training_args(p: argparse.ArgumentParser, "Ex: '--learning_rate 0.001*3 0.0001' would set the lr to 0.001 " "for the first \n3 epochs, and 0.0001 for the remaining epochs.\n" "(torch's default = 0.001)") - training_group.add_argument( - '--lr_decrease_params', metavar='E L', nargs=2, type=float, - help="Parameters [E, L] to set the learning rate an exponential " - "decreasing curve. \nThe final curve will be " - "init_lr * exp(-x / r). The rate of \ndecrease, r, is defined in " - "order to ensure that the learning rate curve will hit \nvalue L " - "at epoch E.\n" - "learning_rate must be a single float value.") training_group.add_argument( '--weight_decay', type=float, default=0.01, metavar='v', help="Add a weight decay penalty on the parameters (regularization " diff --git a/scripts_python/l2t_train_model.py b/scripts_python/l2t_train_model.py index b6f9821d..50fa6b69 100755 --- a/scripts_python/l2t_train_model.py +++ b/scripts_python/l2t_train_model.py @@ -119,8 +119,7 @@ def init_from_args(args, sub_loggers_level): comet_project=args.comet_project, comet_workspace=args.comet_workspace, # TRAINING - learning_rates=lr, lr_decrease_params=args.lr_decrease_params, - weight_decay=args.weight_decay, + learning_rates=lr, weight_decay=args.weight_decay, optimizer=args.optimizer, max_epochs=args.max_epochs, max_batches_per_epoch_training=args.max_batches_per_epoch_training, max_batches_per_epoch_validation=args.max_batches_per_epoch_validation, diff --git a/scripts_python/tests/test_all_steps_learn2track.py b/scripts_python/tests/test_all_steps_learn2track.py index 20f6346d..e586cb12 100644 --- a/scripts_python/tests/test_all_steps_learn2track.py +++ b/scripts_python/tests/test_all_steps_learn2track.py @@ -63,7 +63,6 @@ def test_training(script_runner, experiments_path): input_group_name, streamline_group_name, '--max_epochs', '1', '--step_size', '0.5', '--learning_rate', '0.002', - '--lr_decrease_params', '50', '0.001', '--batch_size_training', '5', '--batch_size_validation', '5', '--batch_size_units', 'nb_streamlines', diff --git a/scripts_python/tt_train_model.py b/scripts_python/tt_train_model.py index 008c898f..6c639032 100755 --- a/scripts_python/tt_train_model.py +++ b/scripts_python/tt_train_model.py @@ -137,8 +137,7 @@ def init_from_args(args, sub_loggers_level): comet_project=args.comet_project, comet_workspace=args.comet_workspace, # TRAINING - learning_rates=lr, lr_decrease_params=args.lr_decrease_params, - weight_decay=args.weight_decay, + learning_rates=lr, weight_decay=args.weight_decay, optimizer=args.optimizer, max_epochs=args.max_epochs, max_batches_per_epoch_training=args.max_batches_per_epoch_training, max_batches_per_epoch_validation=args.max_batches_per_epoch_validation, From 35dc5f4f247a7d936dcd11dd5fb01d846b84e2c1 Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Mon, 30 Oct 2023 11:50:34 -0400 Subject: [PATCH 02/11] Not supporting deprecated neighborhood params in checkpoint anymore. --- dwi_ml/models/main_models.py | 41 ------------------------------------ 1 file changed, 41 deletions(-) diff --git a/dwi_ml/models/main_models.py b/dwi_ml/models/main_models.py index 5e15b7dd..3fc8bf05 100644 --- a/dwi_ml/models/main_models.py +++ b/dwi_ml/models/main_models.py @@ -249,47 +249,6 @@ def __init__(self, neighborhood_type: str = None, super().__init__(**kw) - @classmethod - def _load_params(cls, model_dir): - params = super()._load_params(model_dir) - - # Will eventually be deprecated: - if 'neighborhood_radius' in params and \ - 'neighborhood_resolution' not in params: - logging.warning( - "Model trained with a deprecated neighborhood management. " - "Fixing.") - r = params['neighborhood_radius'] - if params['neighborhood_type'] == 'grid': - res = 1 - - if isinstance(r, list): - assert len(r) == 1 - rad = r[0] - assert int(rad) == rad, \ - "Failed. Cannot interpret float radius anymore." - rad = int(rad) - else: - rad = 1 - else: - if isinstance(r, list): - res = r[0] - rad = len(r) - assert np.all(np.diff(r) == res), \ - "Failed. Cannot use that type of neighborhood anymore. " \ - "Resolution must be the same between each layer of " \ - "neighborhood." - else: - res = r - rad = 1 - - logging.warning("Guessed values are: resolution {}, radius {}" - .format(res, rad)) - params['neighborhood_resolution'] = float(res) - params['neighborhood_radius'] = rad - - return params - def move_to(self, device): super().move_to(device) if self.neighborhood_vectors is not None: From 020120ee2cbaecc40eac461347b1c8fc2c647885 Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Mon, 30 Oct 2023 11:55:09 -0400 Subject: [PATCH 03/11] Not supporting deprecated batch loader params anymore --- dwi_ml/training/batch_loaders.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/dwi_ml/training/batch_loaders.py b/dwi_ml/training/batch_loaders.py index 2161b410..0173caaf 100644 --- a/dwi_ml/training/batch_loaders.py +++ b/dwi_ml/training/batch_loaders.py @@ -161,12 +161,6 @@ def params_for_checkpoint(self): @classmethod def init_from_checkpoint(cls, dataset, model, checkpoint_state, new_log_level): - # Adding noise_gaussian_size_loss for deprecated batch loaders - if 'noise_gaussian_size_loss' not in checkpoint_state: - logging.warning("Deprecated batch loader. Did not contain a " - "noise_gaussian_size_loss value. Setting to 0.0.") - checkpoint_state['noise_gaussian_size_loss'] = 0.0 - batch_loader = cls(dataset=dataset, model=model, log_level=new_log_level, **checkpoint_state) return batch_loader From b214f62d4ae68e968376d4ad05083a586db83b2a Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Mon, 30 Oct 2023 11:56:16 -0400 Subject: [PATCH 04/11] Not supporting deprecated Transformer params anymore --- dwi_ml/models/projects/transformer_models.py | 23 -------------------- 1 file changed, 23 deletions(-) diff --git a/dwi_ml/models/projects/transformer_models.py b/dwi_ml/models/projects/transformer_models.py index 7aba783b..eb1a418a 100644 --- a/dwi_ml/models/projects/transformer_models.py +++ b/dwi_ml/models/projects/transformer_models.py @@ -257,13 +257,6 @@ def params_for_checkpoint(self): def _load_params(cls, model_dir): params = super()._load_params(model_dir) - # Fix deprecated value - if 'embedding_key_x' in params: - logging.warning("Deprecated model. Variable 'embedding_key_x' " - "now called input_embedding_key. Renaming.") - params['input_embedding_key'] = params['embedding_key_x'] - del params['embedding_key_x'] - # d_model now a property method. if 'd_model' in params: if isinstance(cls, TransformerSrcOnlyModel): @@ -273,22 +266,6 @@ def _load_params(cls, model_dir): return params - @classmethod - def _load_state(cls, model_dir): - model_state = super()._load_state(model_dir) - - if 'embedding_layer_x.linear.weight' in model_state: - logging.warning("Deprecated variable name embedding_layer_x. Now " - "called input_embedding_layer. Fixing model " - "state at loading.") - model_state['input_embedding_layer.linear.weight'] = \ - model_state['embedding_layer_x.linear.weight'] - model_state['input_embedding_layer.linear.bias'] = \ - model_state['embedding_layer_x.linear.bias'] - del model_state['embedding_layer_x.linear.weight'] - del model_state['embedding_layer_x.linear.bias'] - return model_state - def set_context(self, context): assert context in ['training', 'validation', 'tracking', 'visu'] self._context = context From cd795e721e0aeea0bc448e67882eee8d014da63c Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Mon, 30 Oct 2023 11:57:27 -0400 Subject: [PATCH 05/11] Not supporting deprecated Trainer params anymore --- dwi_ml/training/trainers.py | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/dwi_ml/training/trainers.py b/dwi_ml/training/trainers.py index c6ecb440..297cdeeb 100644 --- a/dwi_ml/training/trainers.py +++ b/dwi_ml/training/trainers.py @@ -457,16 +457,6 @@ def init_from_checkpoint( """ trainer_params = checkpoint_state['params_for_init'] - # Will eventually be deprecated: - if 'tracking_phase_nb_steps_init' in trainer_params: - logging.warning( - "Model trained with an older version of dwi_ml. Param " - "tracking_phase_nb_steps_init will soon be deprecated. Now " - "called tracking_phase_nb_segments_init, with value one less.") - val = trainer_params['tracking_phase_nb_steps_init'] - del trainer_params['tracking_phase_nb_steps_init'] - trainer_params['tracking_phase_nb_segments_init'] = val - 1 - trainer = cls(model=model, experiments_path=experiments_path, experiment_name=experiment_name, batch_sampler=batch_sampler, @@ -538,12 +528,7 @@ def _update_states_from_checkpoint(self, current_states): # F. Monitors for monitor in self.monitors: - if (monitor.name == 'unclipped_grad_norm_monitor' and - 'unclipped_grad_norm_monitor_state' not in current_states): - logging.warning("Deprecated trainer. Did not contain an " - "unclipped grad monitor. Starting as new.") - else: - monitor.set_state(current_states[monitor.name + '_state']) + monitor.set_state(current_states[monitor.name + '_state']) def _init_comet(self): """ From a125b668a8f7431c0f6c4f2ae76fa496cb96814e Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Mon, 30 Oct 2023 11:58:27 -0400 Subject: [PATCH 06/11] Not supporting deprecated Models params anymore --- dwi_ml/models/main_models.py | 44 ------------------------------------ 1 file changed, 44 deletions(-) diff --git a/dwi_ml/models/main_models.py b/dwi_ml/models/main_models.py index 3fc8bf05..14576222 100644 --- a/dwi_ml/models/main_models.py +++ b/dwi_ml/models/main_models.py @@ -400,20 +400,6 @@ def add_args_model_with_pd(p): help="If true, normalize the previous directions (before the " "embedding layer,\n if any, and before adding to the input.") - @classmethod - def _load_params(cls, model_dir): - params = super()._load_params(model_dir) - - # Will eventually be deprecated: - if 'prev_dirs_embedding_size' in params: - logging.warning( - "Deprecated param prev_dirs_embedding_size. Now called " - "prev_dirs_embedded_size. Changing") - params['prev_dirs_embedded_size'] = params['prev_dirs_embedding_size'] - del params['prev_dirs_embedding_size'] - - return params - @property def params_for_checkpoint(self): p = super().params_for_checkpoint @@ -642,36 +628,6 @@ def instantiate_nn_embedding(self): nb_features_in=input_size, nb_features_out=self.computed_input_embedded_size) - @classmethod - def _load_params(cls, model_dir): - params = super()._load_params(model_dir) - - # Will eventually be deprecated: - if 'input_embedding_size' in params: - logging.warning( - "Deprecated param input_embedding_size. Now called " - "input_embedded_size. Changing") - params['input_embedded_size'] = params['input_embedding_size'] - del params['input_embedding_size'] - - if 'input_embedding_size_ratio' in params: - if params['input_embedding_size_ratio'] is None: - logging.warning( - "Deprecated params 'input_embedding_size_ratio', but was " - "None. Ignoring") - del params['input_embedding_size_ratio'] - else: - raise ValueError("Deprecated use of " - "'input_embedding_size_ratio'. Cannot proceed.") - - # These values did not exist in older models. - if 'nb_cnn_filters' not in params: - params['nb_cnn_filters'] = None - if 'kernel_size' not in params: - params['kernel_size'] = None - - return params - @property def params_for_checkpoint(self): # Every parameter necessary to build the different layers again. From da3816ac0a2e98657549d62b7bbc746a27dcde2b Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Mon, 30 Oct 2023 11:59:02 -0400 Subject: [PATCH 07/11] Not supporting deprecated Learn2track params anymore --- dwi_ml/models/projects/learn2track_model.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/dwi_ml/models/projects/learn2track_model.py b/dwi_ml/models/projects/learn2track_model.py index 27a36608..4e49e792 100644 --- a/dwi_ml/models/projects/learn2track_model.py +++ b/dwi_ml/models/projects/learn2track_model.py @@ -219,23 +219,6 @@ def params_for_checkpoint(self): return params - @classmethod - def _load_state(cls, model_dir): - model_state = super()._load_state(model_dir) - - if 'input_embedding.linear.weight' in model_state: - logging.warning("Deprecated variable name input_embedding. Now " - "called input_embedding_layer. Fixing model " - "state at loading.") - model_state['input_embedding_layer.linear.weight'] = \ - model_state['input_embedding.linear.weight'] - model_state['input_embedding_layer.linear.bias'] = \ - model_state['input_embedding.linear.bias'] - del model_state['input_embedding.linear.weight'] - del model_state['input_embedding.linear.bias'] - - return model_state - @property def computed_params_for_display(self): p = super().computed_params_for_display From b17217b241dac77b9294da2395c2a84cceae97db Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Mon, 30 Oct 2023 11:59:35 -0400 Subject: [PATCH 08/11] Not supporting deprecated Monitoring params anymore --- dwi_ml/training/utils/monitoring.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/dwi_ml/training/utils/monitoring.py b/dwi_ml/training/utils/monitoring.py index 5d98121b..62e6dba0 100644 --- a/dwi_ml/training/utils/monitoring.py +++ b/dwi_ml/training/utils/monitoring.py @@ -130,13 +130,8 @@ def get_state(self): def set_state(self, state): self.average_per_epoch = state['average_per_epoch'] self.current_epoch = state['current_epoch'] - - if 'ever_max' in state: - self.ever_max = state['ever_max'] - self.ever_min = state['ever_min'] - else: - logging.info("Deprecated model. Ever_max and ever_min values not " - "set in {}. Ignoring".format(self.name)) + self.ever_max = state['ever_max'] + self.ever_min = state['ever_min'] class BestEpochMonitor(object): From 1faa3059e8d82c3ce30f1e0b3900ccf6b7aba4b0 Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Mon, 30 Oct 2023 12:00:18 -0400 Subject: [PATCH 09/11] Not supporting deprecated Hdf5 params anymore --- dwi_ml/data/dataset/multi_subject_containers.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/dwi_ml/data/dataset/multi_subject_containers.py b/dwi_ml/data/dataset/multi_subject_containers.py index 719d9055..701c3684 100644 --- a/dwi_ml/data/dataset/multi_subject_containers.py +++ b/dwi_ml/data/dataset/multi_subject_containers.py @@ -427,16 +427,7 @@ def load_data(self, load_training=True, load_validation=True, # Load main attributes from hdf file, but each process calling # the collate_fn must open its own hdf_file step_size = hdf_handle.attrs['step_size'] - if 'compress' in hdf_handle.attrs: - compress = hdf_handle.attrs['compress'] - else: - # Fix deprecated usages - logger.warning( - "Using an old version of hdf database. Compression rate " - "information was not saved. This only means that if you " - "use --compress option anywhere, we will perform it " - "again.") - compress = None + compress = hdf_handle.attrs['compress'] # Can't save None in hdf5, saved a string instead. Converting. if step_size == 'Not defined by user': From 1da42d8e52d3ee4e4eedb40daf5ae7a63e5c825d Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Mon, 30 Oct 2023 13:22:29 -0400 Subject: [PATCH 10/11] Sneak in a fix for Antoine's docstring comment --- dwi_ml/data/processing/space/neighborhood.py | 2 +- dwi_ml/unit_tests/utils/expected_values.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dwi_ml/data/processing/space/neighborhood.py b/dwi_ml/data/processing/space/neighborhood.py index 3b187daa..e040a6ec 100644 --- a/dwi_ml/data/processing/space/neighborhood.py +++ b/dwi_ml/data/processing/space/neighborhood.py @@ -80,7 +80,7 @@ def get_neighborhood_vectors_axes(radius: int, resolution: float): neighborhood_vectors : tensor of shape (N, 3) A list of vectors with last dimension = 3 (x,y,z coordinate for each neighbour per respect to the origin). The current point (0,0,0) is - included. + NOT included. """ tmp_axes = np.identity(3) unit_axes = np.concatenate((tmp_axes, -tmp_axes)) diff --git a/dwi_ml/unit_tests/utils/expected_values.py b/dwi_ml/unit_tests/utils/expected_values.py index 90afec6f..98448847 100644 --- a/dwi_ml/unit_tests/utils/expected_values.py +++ b/dwi_ml/unit_tests/utils/expected_values.py @@ -21,7 +21,7 @@ └── hdf5_file.hdf5 """ -# Values corresponding to the testing data, as of Nov 22, 22. +# Values corresponding to the testing data, as of Oct 30, 23. # Created from Emmanuelle's preprocessed WithReversed data. TEST_EXPECTED_VOLUME_GROUPS = ['input', 'wm_mask'] From ab64af747a67e390ad8a18c4fa006ad6f0c2e2cc Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Mon, 30 Oct 2023 14:23:40 -0400 Subject: [PATCH 11/11] Md5 sum updated for new test data --- dwi_ml/unit_tests/utils/data_and_models_for_tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dwi_ml/unit_tests/utils/data_and_models_for_tests.py b/dwi_ml/unit_tests/utils/data_and_models_for_tests.py index 610f5fac..0467bf60 100644 --- a/dwi_ml/unit_tests/utils/data_and_models_for_tests.py +++ b/dwi_ml/unit_tests/utils/data_and_models_for_tests.py @@ -32,7 +32,7 @@ def fetch_testing_data(): name_as_dict = { 'data_for_tests_dwi_ml.zip': ['1beRWAorhaINCncttgwqVAP2rNOfx842Q', - '8bdbf051877ec5c70aace21c9dab9bb7']} + 'da6c94fbef7ac13029acdb8b94325096']} fetch_data(name_as_dict) return testing_data_dir