Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pipeline extras and example configuration files #55

Merged
merged 10 commits into from
May 12, 2024
Merged

Conversation

otavioon
Copy link
Collaborator

@otavioon otavioon commented May 12, 2024

Pipelines and Examples

In this PR we add a simple yet powerful generic pipeline named SimpleLightningPipeline. This class is a subclass of Pipeline and it is designed to work with PyTorch Lightning models (train/test/predict and evaluate) and also with jsonargparse CLI module.

The SimpleLightningPipeline receives a model and a trainer as init_args. The entry-point (run method) receives the data and the task (fit, test, predict, or evaluate) and runs the corresponding method of the trainer.

  • The fit task trains the model (calls Trainer.fit)
  • The test task tests the model (calls Trainer.test)
  • The predict task runs inference on the model (calls Trainer.predict)
  • The evaluate task evaluates based on regression or clasification metrics, passed as init parameters as regression_metrics and classification_metrics, that uses torchmetrics Metric API. This method should be further customized to complex evaluation tasks that does not fit on torchmetrics API, such as performing overlapping of samples and plotting images, creating confusion matrix, etc.

The SimpleLightningPipeline can be called from standard Python code or from the command line using jsonargparse.

In the SimpleLightningPipeline we already added a CLI at main code that expose the class init arguments and the run method arguments. The CLI is generated using jsonargparse, and it is very flexible and powerful. It can be used to run the pipeline from the command line, and also to generate the command line help and documentation.

All the arguments of the SimpleLightningPipeline are exposed in the CLI, and the user can pass them as command line arguments or as a json/YAML file. In config folder, we added some configurations that are useful for the pipeline, such as the model, the trainer, the data, etc.

Example of training a model to compute seismic attributes

Suppose we have the original.zarr and envelope.zarr files that correspond to the original and envelope seismic F3 data, respectively. We want to train a model to compute the seismic attributes of the envelope data. We can use the SimpleLightningPipeline to train the model and evaluate it.
We can create a config file config.yaml with the following content:

model:
    # The model class path and init arguments
    class_path: minerva.models.nets.unet.UNet
    init_args:
        n_channels: 1

trainer:
    # The trainer class path and init arguments
    class_path: lightning.Trainer
    init_args:
        # Train for 1 epoch, using 1 GPU (devices/accelerator)
        max_epochs: 1
        accelerator: gpu
        devices: 1

# ----------- SimSimpleLightningPipeline RUN ARGS (entry-point)-----------
run:
    # Set the task to fit
    task: fit
    # The data configuration
    data: 
        # The data module class path and init arguments
        class_path: minerva.data.data_modules.f3_reconstruction.F3ReconstructionDataModule
        init_args:
          # Location of input and target data
          input_path: /workspaces/seismic/data/original.zarr
          target_path: /workspaces/seismic/data/envelope.zarr
          # Desired data shape
          data_shape: [1, 951, 462]
          # Default batch size
          batch_size: 16
          # The input and target transforms to apply
          input_transform:
              # We create a pipeline of transforms, with a single transform to cast the 
              # data to float32
              class_path: minerva.transforms.transform.TransformPipeline
              init_args:
                transforms:
                - class_path: minerva.transforms.transform.CastTo
                  init_args:
                    dtype: float32
          target_transform:
              class_path: minerva.transforms.transform.TransformPipeline
              init_args:
                transforms:
                - class_path: minerva.transforms.transform.CastTo
                  init_args:
                    dtype: float32

We can run the pipeline using the following command:

python minerva/pipelines/simple_lightning_pipeline.py --config config.yaml

Or if using the already configuration files, which is in a modular format:

# Train
python minerva/pipelines/simple_lightning_pipeline.py --config configs/pipelines/lightning_pipeline/unet_f3_reconstruct_train.yaml 

# Evaluate
python minerva/pipelines/simple_lightning_pipeline.py --config configs/pipelines/lightning_pipeline/unet_f3_reconstruct_evaluate.yaml 

NOTE: Paths in the config file are absolute. So you may need to change them to match your environment. This issue will be addressed in the future.

NOTE: Evaluation requires a checkpoint file to be passed as an argument. This checkpoint file is generated during training and is saved in the logs folder. The checkpoint file is a .ckpt file that contains the model weights and other information. This is passed as an argument to the SimpleLightningPipeline.evaluate method to load the model. You may need to change the path to the checkpoint file in the config file to match your environment. This issue (automatic checkpoint discovery) will be addressed in the future.

Configuration Files

The configuration files are very flexible and can be used to run the pipeline in different ways. We have structured the configuration files in a modular way, using the following directory structure:

  • configs/callbacks/: Contains default configurations for callbacks. This is used when instantiating the Trainer.
  • configs/data/: Contains default configurations for data modules. This is used when instantiating the DataModule, for each dataset/task.
  • configs/logger/: Contains default configurations for loggers. This is used when instantiating the Trainer.
  • configs/models/: Contains default configurations for models. This is used when instantiating the model.
  • configs/pipelines/: Contains configurations for the pipeline. This is used when instantiating the pipeline. Inside this folder, we have the following subfolders:
    • configs/pipelines/lightning_pipeline/: Contains configurations for the SimpleLightningPipeline. This is used when instantiating the pipeline.
    • configs/pipelines/other_pipeline/: Contains configurations for other pipelines. This is used when instantiating the pipeline.
  • configs/trainer/: Contains default configurations for trainers. This is used when instantiating the Trainer.

The configurations for SimpleLightningPipeline are in the configs/pipelines/lightning_pipeline/ folder. These are the configurations that are used when instantiating the SimpleLightningPipeline and usually contians all the CLI arguments that are passed to the SimpleLightningPipeline class. This you could simple run:

python minerva/pipelines/simple_lightning_pipeline.py --config configs/pipelines/lightning_pipeline/<YOUR PIPELINE CONFIG FILE>.yaml 

Others

  • Pipeline logs their outputs in order to be able to track the progress of the pipeline and to allow reproducibility.
  • Configuration files are modular. Inside each configuration file there are references to other configuration files. This allows the user to create a configuration file for each part of the pipeline (model, data, trainer, etc) and then combine them to create a full pipeline configuration file.
  • Usually, we will extend the SimpleLightningPipeline.evaluate method when torchmetrics API is not enough to evaluate the model. This method should be further customized to complex evaluation tasks that does not fit on torchmetrics API, such as performing overlapping of samples and plotting images, creating confusion matrix, etc.
  • The class_path and init_args are used to instantiate the classes. The class_path is the path to the class, and the init_args are the arguments that are passed to the class constructor. The init_args can contain references to other configuration files. For more information about the configuration files, see the jsonargparse documentation.
  • Pipelines variables should be typed using the typing module. This allows the user to know the expected type of the variables and also allows the IDE to provide code completion and type checking. Also, this allows seamless integration with the jsonargparse module, which uses the typing module to infer the types of the variables. In fact, the jsonargparse CLI will fail if the types of the variables are not correctly defined or assigned.
  • If you only need prections, consider use the predict task, that will run inference over the data. Note that predict and evaluate tasks requires thst the data module implement the predict_dataloader method, which returns a split of the dataset to make predictions (usually, the test part).

otavioon added 9 commits May 9, 2024 19:27
Signed-off-by: Otavio Napoli <otavio.napoli@gmail.com>
…utils

Signed-off-by: Otavio Napoli <otavio.napoli@gmail.com>
Signed-off-by: Otavio Napoli <otavio.napoli@gmail.com>
…culcate seismic attributes (train and eval)

Signed-off-by: Otavio Napoli <otavio.napoli@gmail.com>
@otavioon otavioon requested a review from GabrielBG0 May 12, 2024 04:55
@otavioon otavioon self-assigned this May 12, 2024
Signed-off-by: Otavio Napoli <otavio.napoli@gmail.com>
@otavioon
Copy link
Collaborator Author

De fato, os elementos relacionados aos experimentos não devem fazer parte da biblioteca. No último commit, Eu movi os arquivos de configuração para o repositório minerva-seismic. Lá, podemos manter os arquivos de configuração utilizados para realizar experimentos, resultados e saídas (caso necessário) e demais operações de customização especializadas em sísmica.

@otavioon
Copy link
Collaborator Author

Nota: eu movi todos os arquivos de configurações para o repositório minerva-seismic.

Copy link
Collaborator

@GabrielBG0 GabrielBG0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@GabrielBG0 GabrielBG0 merged commit f66618b into main May 12, 2024
1 check passed
@GabrielBG0 GabrielBG0 deleted the pipeline_extras branch May 12, 2024 20:50
@GabrielBG0 GabrielBG0 restored the pipeline_extras branch May 12, 2024 21:09
@otavioon otavioon deleted the pipeline_extras branch May 12, 2024 21:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants