diff --git a/source/0_general_summary.rst b/source/0_general_summary.rst index e51018f5..55fd6e45 100644 --- a/source/0_general_summary.rst +++ b/source/0_general_summary.rst @@ -20,9 +20,9 @@ Our repository includes: .. toctree:: :maxdepth: 3 - :caption: How to use DWI_ML + :caption: Detailed table of content: 1_A_model 2_A_creating_the_hdf5.rst - training - tracking \ No newline at end of file + 3_A_training + 4_tracking \ No newline at end of file diff --git a/source/2_A_creating_the_hdf5.rst b/source/2_A_creating_the_hdf5.rst index d44547c5..7400d171 100644 --- a/source/2_A_creating_the_hdf5.rst +++ b/source/2_A_creating_the_hdf5.rst @@ -144,6 +144,13 @@ Exemple of use: (See also please_copy_and_adapt/ALL_STEPS.sh) $dwi_ml_folder $hdf5_folder $config_file \ $training_subjs $validation_subjs $testing_subjs +.. toctree:: + :maxdepth: 1 + :caption: Detailed explanations for developers: + + 2_C_advanced_hdf5_organization + + P.S How to get data? ******************** diff --git a/source/devel/hdf5_organization.rst b/source/2_C_advanced_hdf5_organization.rst similarity index 100% rename from source/devel/hdf5_organization.rst rename to source/2_C_advanced_hdf5_organization.rst diff --git a/source/3_A_training.rst b/source/3_A_training.rst new file mode 100644 index 00000000..188734cf --- /dev/null +++ b/source/3_A_training.rst @@ -0,0 +1,51 @@ +3. Training your model +====================== + +Even tough training depends on your own model, we have prepared Trainers that can probably be used in any case. + +3.1. Our trainers +----------------- + +- They have a ``train_and_validate`` method that can be used to iterate on epochs (until a maximum number of iteration is reached, or a maximum number of bad epochs based on some loss). +- They save a checkpoint folder after each epoch, containing all information to resume the training any time. +- When a minimum loss value is reached, the model's parameters and states are save in a best_model folder. +- They save a good quantity of logs, both as numpy arrays (.npy logs) and online using Comet.ml. +- They know how to deal with the ``BatchSampler`` (which samples a list of streamlines to get for each batch) and with the ``BatchLoader`` (which gets data and performs data augmentation operations, if any). +- They prepare torch's optimizer (ex, Adam, SGD, RAdam), define the learning rate, etc. + +3.2. Our Batch samplers and loaders +----------------------------------- + +.. toctree:: + :maxdepth: 2 + + 3_B_MultisubjectDataset + 3_C_BatchSampler + 3_D_BatchLoader + + +3.3. Putting it all together +---------------------------- + +This class's main method is *train_and_validate()*: + +- Creates torch DataLoaders from the data_loaders. Collate_fn will be the sampler.load_batch() method, and the dataset will be sampler.source_data. + +- Trains each epoch by using compute_batch_loss, which should be implemented in each project's child class, on each batch. Saves the loss evolution and gradient norm in a log file. + +- Validates each epoch (also by using compute_batch_loss on each batch, but skipping the backpropagation step). Saves the loss evolution in a log file. + +After each epoch, a checkpoint is saved with current parameters. Training can be continued from a checkpoint using the script resume_training_from_checkpoint.py. + +3.4. Visualizing logs +--------------------- + +You can run "visualize_logs.py your_experiment" to see the evolution of the losses and gradient norm. + +You can also use COMET to save results (code to be improved). + +3.5. Trainer with generation +---------------------------- + +toDO + diff --git a/source/3_B_MultisubjectDataset.rst b/source/3_B_MultisubjectDataset.rst new file mode 100644 index 00000000..89ecea10 --- /dev/null +++ b/source/3_B_MultisubjectDataset.rst @@ -0,0 +1,11 @@ +MultisubjectDataset +=================== + + + + +.. toctree:: + :maxdepth: 1 + :caption: Detailed explanations for developers: + + 3_E_advanced_data_containers diff --git a/source/3_C_BatchSampler.rst b/source/3_C_BatchSampler.rst new file mode 100644 index 00000000..1a30a953 --- /dev/null +++ b/source/3_C_BatchSampler.rst @@ -0,0 +1,26 @@ + +Batch sampler +============= + +These classes defines how to sample the streamlines available in the +MultiSubjectData. + +**AbstractBatchSampler:** + +- Defines the __iter__ method: + + - Finds a list of streamlines ids and associated subj that you can later load in your favorite way. + +- Define the load_batch method: + + - Loads the streamlines associated to sampled ids. Can resample them. + + - Performs data augmentation (on-the-fly to avoid having to multiply data on disk) (ex: splitting, reversing, adding noise). + +Child class : **BatchStreamlinesSamplerOneInput:** + +- Redefines the load_batch method: + + - Now also loads the input data under each point of the streamline (and possibly its neighborhood), for one input volume. + +You are encouraged to contribute to dwi_ml by adding any child class here. diff --git a/source/3_D_BatchLoader.rst b/source/3_D_BatchLoader.rst new file mode 100644 index 00000000..a0bc1a7b --- /dev/null +++ b/source/3_D_BatchLoader.rst @@ -0,0 +1,2 @@ +Batch loader +============ diff --git a/source/data_containers.rst b/source/3_E_advanced_data_containers.rst similarity index 100% rename from source/data_containers.rst rename to source/3_E_advanced_data_containers.rst diff --git a/source/tracking.rst b/source/4_tracking.rst similarity index 100% rename from source/tracking.rst rename to source/4_tracking.rst diff --git a/source/adapting_dwiml.rst b/source/adapting_dwiml.rst deleted file mode 100644 index fa68637b..00000000 --- a/source/adapting_dwiml.rst +++ /dev/null @@ -1,9 +0,0 @@ -.. _ref_processing: - -Creating your own project based on DWI_ML -========================================= - -1. Create your own repository for your project. - -2. Copy our scripts from ``please_copy_and_adapt``. Adapt them based on your needs. Instructions are in each file. Generally, you will need to implement child classes of the abstract classes we have prepared in dwi_ml. - diff --git a/source/index.rst b/source/index.rst index 360d20ac..30e076cd 100644 --- a/source/index.rst +++ b/source/index.rst @@ -17,21 +17,8 @@ This website is a guide to the github repository from the SCIL-VITAL organisatio 0_general_summary 1_A_model 2_A_creating_the_hdf5 - training - tracking - -.. toctree:: - :maxdepth: 1 - :caption: Detailed explanations for developers: - - devel/hdf5_organization - -.. toctree:: - :maxdepth: 1 - :caption: Creating your own project: - - adapting_dwiml - data_containers + 3_A_training + 4_tracking .. toctree:: :maxdepth: 1 diff --git a/source/requirements_doc.txt b/source/requirements_doc.txt deleted file mode 100644 index 483a4e96..00000000 --- a/source/requirements_doc.txt +++ /dev/null @@ -1 +0,0 @@ -sphinx_rtd_theme diff --git a/source/training.rst b/source/training.rst deleted file mode 100644 index 3e19003b..00000000 --- a/source/training.rst +++ /dev/null @@ -1,57 +0,0 @@ -6. Training your model -====================== - -Even tough training depends on your own model, most of the necessary code has been prepared here to deal with the data in the hdf5 file and create a batch sampler that can get streamlines and their associated inputs. All you need to do now is implement a model and its forward method. - -The data from the hdf5 file created before will be loaded through the MultisubjectDataset. For more information on this, read page :ref:ref_data_containers. - -In the please_copy_and_adapt folder, adapt the train_model.py script. Choose or implement a child version of the classes described below to fit your needs. - -Batch sampler -------------- - -These classes defines how to sample the streamlines available in the -MultiSubjectData. - -**AbstractBatchSampler:** - -- Defines the __iter__ method: - - - Finds a list of streamlines ids and associated subj that you can later load in your favorite way. - -- Define the load_batch method: - - - Loads the streamlines associated to sampled ids. Can resample them. - - - Performs data augmentation (on-the-fly to avoid having to multiply data on disk) (ex: splitting, reversing, adding noise). - -Child class : **BatchStreamlinesSamplerOneInput:** - -- Redefines the load_batch method: - - - Now also loads the input data under each point of the streamline (and possibly its neighborhood), for one input volume. - -You are encouraged to contribute to dwi_ml by adding any child class here. - -Trainer -------- - -**DWIMLAbstractTrainer:** - -This class's main method is *train_and_validate()*: - -- Creates DataLoaders from the data_loaders. Collate_fn will be the sampler.load_batch() method, and the dataset will be sampler.source_data. - -- Trains each epoch by using compute_batch_loss, which should be implemented in each project's child class, on each batch. Saves the loss evolution and gradient norm in a log file. - -- Validates each epoch (also by using compute_batch_loss on each batch, but skipping the backpropagation step). Saves the loss evolution in a log file. - -After each epoch, a checkpoint is saved with current parameters. Training can be continued from a checkpoint using the script resume_training_from_checkpoint.py. - -Visualizing training --------------------- - -You can run "visualize_logs.py your_experiment" to see the evolution of the losses and gradient norm. - -You can also use COMET to save results (code to be improved). -