Skip to content

Latest commit

 

History

History
150 lines (103 loc) · 9.72 KB

trainer.md

File metadata and controls

150 lines (103 loc) · 9.72 KB

Trainer

The Trainer component reads the outputs of split generator (which paths are specified in the frozen config), and trains a GNN model on the training set, early stops on the performance of the validation set, and finally evaluates on the test set. The training logic is implemented with PyTorch Distributed Data Parallel (DDP) Training, which enables distributed training on multiple GPU cards across multiple worker nodes.

Input

  • job_name (AppliedTaskIdentifier): which uniquely identifies an end-to-end task.
  • task_config_uri (Uri): Path which points to a "template" GbmlConfig proto yaml file.
  • resource_config_uri (Uri): Path which points to a GiGLResourceConfig yaml

What does it do?

The whole model training contains two main components: (i) the Trainer, which that sets up the environment, and (ii) a user-defined instance of BaseTrainer that contains the actual training loop w.r.t. the given task. For example, for node anchor-based link prediction, we have NodeAnchorBasedLinkPredictionModelingTaskSpec. Model training involves the following steps:

  • The Trainer sets up the (optionally distributed) Torch training environment.

  • The Trainer reads GraphMetadata that was generated by the Data Preprocessor.

  • The Trainer initializes the BaseTrainer instance (instance specified at the trainerClsPath field in the trainerConfig section of the frozen GbmlConfig, and with arguments at trainerArgs) and initializes the GNN model.

  • We start model training as indicated by the BaseTrainer instance. This may look something like:

    • We initialize training and validation dataloaders (See: NodeAnchorBasedLinkPredictionDatasetDataloaders in dataset_metadata_utils.py)
    • Follow a standard distributed training scheme: each worker loads a batch of data and performs the normal forward and backward passes for model training in a distributed way.
    • Every fixed number of training batches(val_every_num_batches), we evaluate the current model on the validation set with a fixed number of validation batches (num_val_batches)
    • We follow a standard early-stopping strategy on the validation performances on offline metrics, with a configurable patience parameter (early_stop_patience) or see EarlyStopper utility class in early_stop.py
    • When early-stopping is triggered to end the training process, we reload the saved model at the best validation batch, and run evaluation (test) it with a fixed number of test batches (num_test_batches).
    • At the end, we return the model and its test performance (offline metrics) back to the Trainer.
  • The Trainer persists output metadata like model parameters and offline metrics (see Output).

How do I run it?

Import GiGL

from gigl.src.split_generator.split_generator import SplitGenerator
from gigl.common import UriFactory
from gigl.src.common.types import AppliedTaskIdentifier

trainer = Trainer()

trainer.run(
    applied_task_identifier=AppliedTaskIdentifier("my_gigl_job_name"),
    task_config_uri=UriFactory.create_uri("gs://my-temp-assets-bucket/task_config.yaml"),
    resource_config_uri=UriFactory.create_uri("gs://my-temp-assets-bucket/resource_config.yaml")
)

Note: If you are training on VertexAI and using a custom class, you will have to provide a docker image (Either cuda_docker_uri for GPU training or cpu_docker_uri for CPU training.)

Command Line

python -m \
    gigl.src.training.trainer \
    --job_name my_gigl_job_name \
    --task_config_uri "gs://my-temp-assets-bucket/task_config.yaml"
    --resource_config_uri="gs://my-temp-assets-bucket/resource_config.yaml"

Output

Ater the training process finishes:

  • The Trainer saves the trained model’s state_dict at specified location (trainedModelUri field of sharedConfig.trainedModelMetadata).

  • The trainer logs training metrics to trainingLogsUri field of sharedConfig.trainedModelMetadata. To view the metrics on your local, you can run the command: tensorboard --logdir gs://tensorboard_logs_uri_here

Custom Usage

The Trainer is designed to be task-agnostic, with the detailed model and training logics specified in the user-provided BaseTrainer instance. Modifying the BaseTrainer instance allows maximal flexibility in changing model architecture and training parameters.

Other

Torch Profiler

You can profile trainer performance metrics, such as gpu/cpu utilization by adding below to task_config.yaml

profilerConfig:
    should_enable_profiler: true
    profiler_log_dir: gs://path_to_my_bucket  (or a local dir)
    profiler_args:
        wait:'0'
        with_stack: 'True'

Monitoring and logging

Once the trainer component starts, the training process can be monitored via the gcloud console under Vertex AI Custom Jobs (https://console.cloud.google.com/vertex-ai/training/custom-jobs?project=<project_name_here>). You can also view the job name, status, jobspec, and more using gcloud ai custom-jobs list --project <project_name_here>

On the Vertex AI UI, you can see all the information like machine/acceleratior information, CPU Utilization, GPU utiliization, Network data etc. Here, you will also find the "View logs" tab, which will open the Stackdriver for your job which logs everything from your modeling task spec as the training progresses in real time.

If you would like to view the logs locally, you can also use: gcloud ai custom-jobs stream-logs <custom job ID> --project=<project_name_here> --region=<region here>.

Parameters

Following are all of the Trainer parameters that can be configured within the config yaml, along with short explanations.

  • Training environment parameters (number of workers for different dataloaders)

    • train_main_sample_num_workers
    • train_random_sample_num_workers
    • val_main_sample_num_workers
    • val_random_sample_num_workers
    • test_main_sample_num_workers
    • test_random_sample_num_workers

    Note that training involves multiple dataloaders simultaneously. Take care to specify these parameters in a way which avoids overburdening your machine. It is recommended to specify (train_main_sample_num_workers + train_random_sample_num_workers + val_main_sample_num_workers + val_random_sample_num_workers < num_cpus), and (test_main_sample_num_workers + test_random_sample_num_workers < num_cpus) to avoid training stalling due to contention.

  • Training parameters:

    • margin: margin for the margin loss
    • loss_function: choice of training loss function, options margin and softmax
    • softmax_temperature: temperature parameter in the softmax loss
    • optim_lr: learning rate of the optimizer
    • optim_weight_decay: weight decay of the optimizer
    • val_every_num_batches: validation frequence per training batches
    • num_val_batches: number of validation batches
    • num_test_batches: number of testing batches
    • Early_stop_patience: patience for earlystopping
    • main_sample_batch_size: training batch size
    • random_negative_sample_batch_size: random negative sample batch size for training
    • random_negative_sample_batch_size_for_evaluation: random negative sample batch size for evaluation
  • Model parameters:

    • should_l2_normalize_output: whether apply L2 normalization on the output embeddings
    • num_layers: number of layers in the GNN (this should be the same as numHops under subgraphSamplerConfig)
    • in_dim: dimension of the input node feature
    • hid_dim: dimension of the hidden layers
    • out_dim: dimension of the output embeddings
  • Modifying the GNN model:

    • Currently the GNN models are defined here and initialized in the init_model function in ModelingTaskSpec. When trying different GNN models, it is recommended to also include the new GNN architectures under the same file and declare them as is currently done. This cannot currently be done from the default GbmlConfig yaml.

Background for distributed training

Trainer currently uses PyTorch distributed training abstractions to enable multi-node and multi-GPU training. Some useful terminology and links to learn about these abstractions below.

  • WORLD: Group of processes/workers that are used for distributed training.

  • WORLD_SIZE: The number of processes/workers in the distributed training WORLD.

  • RANK: The unique id (usually index) of the process/worker in the distributed training WORLD.

  • Data loader worker: A worker used specifically for loading data; if the dataloader worker is utilizing the same thread/process as a worker in distributed training WORLD, then we may incur blocking execution of training, resulting in slowdowns.

  • Distributed Data Parallel: Pytorch's version of Data parallalism across different processes (could even be processes on different machines), to speed up traiing on large datasets.

  • TORCH.DISTRIBUTED package: A torch package containing tools for distributed communication and trainings.

    • Defines backends for distributed communication like gloo and nccl - as a ML practitioner you should not worry about how these work, but important to know what devices and collective functions they support.
    • Contains "Collective functions" like torch.distributed.broadcast, torch.distributed.all_gather, et al. which allow communication of tensors across the WORLD.