The Trainer component reads the outputs of split generator (which paths are specified in the frozen config), and trains a GNN model on the training set, early stops on the performance of the validation set, and finally evaluates on the test set. The training logic is implemented with PyTorch Distributed Data Parallel (DDP) Training, which enables distributed training on multiple GPU cards across multiple worker nodes.
- job_name (AppliedTaskIdentifier): which uniquely identifies an end-to-end task.
- task_config_uri (Uri): Path which points to a "template"
GbmlConfig
proto yaml file. - resource_config_uri (Uri): Path which points to a
GiGLResourceConfig
yaml
The whole model training contains two main components: (i) the Trainer, which that sets up the environment, and (ii) a user-defined instance of BaseTrainer
that contains the actual training loop w.r.t. the given task. For example, for node anchor-based link prediction, we have NodeAnchorBasedLinkPredictionModelingTaskSpec
. Model training involves the following steps:
-
The Trainer sets up the (optionally distributed) Torch training environment.
-
The Trainer reads
GraphMetadata
that was generated by the Data Preprocessor. -
The Trainer initializes the
BaseTrainer
instance (instance specified at thetrainerClsPath
field in thetrainerConfig
section of the frozenGbmlConfig
, and with arguments attrainerArgs
) and initializes the GNN model. -
We start model training as indicated by the
BaseTrainer
instance. This may look something like:- We initialize training and validation dataloaders (See:
NodeAnchorBasedLinkPredictionDatasetDataloaders
in dataset_metadata_utils.py) - Follow a standard distributed training scheme: each worker loads a batch of data and performs the normal forward and backward passes for model training in a distributed way.
- Every fixed number of training batches(
val_every_num_batches
), we evaluate the current model on the validation set with a fixed number of validation batches (num_val_batches
) - We follow a standard early-stopping strategy on the validation performances on offline metrics, with a configurable patience parameter (
early_stop_patience
) or seeEarlyStopper
utility class in early_stop.py - When early-stopping is triggered to end the training process, we reload the saved model at the best validation batch, and run evaluation (test) it with a fixed number of test batches (
num_test_batches
). - At the end, we return the model and its test performance (offline metrics) back to the Trainer.
- We initialize training and validation dataloaders (See:
-
The Trainer persists output metadata like model parameters and offline metrics (see Output).
Import GiGL
from gigl.src.split_generator.split_generator import SplitGenerator
from gigl.common import UriFactory
from gigl.src.common.types import AppliedTaskIdentifier
trainer = Trainer()
trainer.run(
applied_task_identifier=AppliedTaskIdentifier("my_gigl_job_name"),
task_config_uri=UriFactory.create_uri("gs://my-temp-assets-bucket/task_config.yaml"),
resource_config_uri=UriFactory.create_uri("gs://my-temp-assets-bucket/resource_config.yaml")
)
Note: If you are training on VertexAI and using a custom class, you will have to provide a docker image (Either cuda_docker_uri
for GPU training or cpu_docker_uri
for CPU training.)
Command Line
python -m \
gigl.src.training.trainer \
--job_name my_gigl_job_name \
--task_config_uri "gs://my-temp-assets-bucket/task_config.yaml"
--resource_config_uri="gs://my-temp-assets-bucket/resource_config.yaml"
Ater the training process finishes:
-
The Trainer saves the trained model’s
state_dict
at specified location (trainedModelUri
field ofsharedConfig.trainedModelMetadata
). -
The trainer logs training metrics to
trainingLogsUri
field ofsharedConfig.trainedModelMetadata
. To view the metrics on your local, you can run the command:tensorboard --logdir gs://tensorboard_logs_uri_here
The Trainer is designed to be task-agnostic, with the detailed model and training logics specified in the user-provided BaseTrainer
instance. Modifying the BaseTrainer
instance allows maximal flexibility in changing model architecture and training parameters.
You can profile trainer performance metrics, such as gpu/cpu utilization by adding below to task_config.yaml
profilerConfig:
should_enable_profiler: true
profiler_log_dir: gs://path_to_my_bucket (or a local dir)
profiler_args:
wait:'0'
with_stack: 'True'
Once the trainer component starts, the training process can be monitored via the gcloud console under Vertex AI Custom Jobs (https://console.cloud.google.com/vertex-ai/training/custom-jobs?project=<project_name_here>
). You can also view the job name, status, jobspec, and more using gcloud ai custom-jobs list --project <project_name_here>
On the Vertex AI UI, you can see all the information like machine/acceleratior information, CPU Utilization, GPU utiliization, Network data etc. Here, you will also find the "View logs" tab, which will open the Stackdriver for your job which logs everything from your modeling task spec as the training progresses in real time.
If you would like to view the logs locally, you can also use: gcloud ai custom-jobs stream-logs <custom job ID> --project=<project_name_here> --region=<region here>
.
Following are all of the Trainer parameters that can be configured within the config yaml, along with short explanations.
-
Training environment parameters (number of workers for different dataloaders)
- train_main_sample_num_workers
- train_random_sample_num_workers
- val_main_sample_num_workers
- val_random_sample_num_workers
- test_main_sample_num_workers
- test_random_sample_num_workers
Note that training involves multiple dataloaders simultaneously. Take care to specify these parameters in a way which avoids overburdening your machine. It is recommended to specify
(train_main_sample_num_workers + train_random_sample_num_workers + val_main_sample_num_workers + val_random_sample_num_workers < num_cpus)
, and(test_main_sample_num_workers + test_random_sample_num_workers < num_cpus)
to avoid training stalling due to contention. -
Training parameters:
- margin: margin for the margin loss
- loss_function: choice of training loss function, options
margin
andsoftmax
- softmax_temperature: temperature parameter in the
softmax
loss - optim_lr: learning rate of the optimizer
- optim_weight_decay: weight decay of the optimizer
- val_every_num_batches: validation frequence per training batches
- num_val_batches: number of validation batches
- num_test_batches: number of testing batches
- Early_stop_patience: patience for earlystopping
- main_sample_batch_size: training batch size
- random_negative_sample_batch_size: random negative sample batch size for training
- random_negative_sample_batch_size_for_evaluation: random negative sample batch size for evaluation
-
Model parameters:
- should_l2_normalize_output: whether apply L2 normalization on the output embeddings
- num_layers: number of layers in the GNN (this should be the same as numHops under subgraphSamplerConfig)
- in_dim: dimension of the input node feature
- hid_dim: dimension of the hidden layers
- out_dim: dimension of the output embeddings
-
Modifying the GNN model:
- Currently the GNN models are defined here and initialized in the
init_model
function in ModelingTaskSpec. When trying different GNN models, it is recommended to also include the new GNN architectures under the same file and declare them as is currently done. This cannot currently be done from the defaultGbmlConfig
yaml.
- Currently the GNN models are defined here and initialized in the
Trainer currently uses PyTorch distributed training abstractions to enable multi-node and multi-GPU training. Some useful terminology and links to learn about these abstractions below.
-
WORLD: Group of processes/workers that are used for distributed training.
-
WORLD_SIZE: The number of processes/workers in the distributed training WORLD.
-
RANK: The unique id (usually index) of the process/worker in the distributed training WORLD.
-
Data loader worker: A worker used specifically for loading data; if the dataloader worker is utilizing the same thread/process as a worker in distributed training WORLD, then we may incur blocking execution of training, resulting in slowdowns.
-
Distributed Data Parallel: Pytorch's version of Data parallalism across different processes (could even be processes on different machines), to speed up traiing on large datasets.
-
TORCH.DISTRIBUTED package: A torch package containing tools for distributed communication and trainings.
- Defines backends for distributed communication like
gloo
andnccl
- as a ML practitioner you should not worry about how these work, but important to know what devices and collective functions they support. - Contains "Collective functions" like
torch.distributed.broadcast
,torch.distributed.all_gather
, et al. which allow communication of tensors across the WORLD.
- Defines backends for distributed communication like