From b081e60bd0c32ed2cdf713c2c26c480895b6ccb7 Mon Sep 17 00:00:00 2001 From: Anna Grim <108307071+anna-grim@users.noreply.github.com> Date: Thu, 26 Sep 2024 15:04:03 -0700 Subject: [PATCH] Refactor gnn training (#248) * minor upds * refactor: training pipeline * feat: find gcs image path * feat: feature generation in trainer * feat: validation sets in training * bug: hgraph forward passes with missing edge types * refactor: hgnn trainer --------- Co-authored-by: anna-grim --- .../heterograph_feature_generation.py | 1 + .../machine_learning/heterograph_trainer.py | 149 +++--------------- 2 files changed, 23 insertions(+), 127 deletions(-) diff --git a/src/deep_neurographs/machine_learning/heterograph_feature_generation.py b/src/deep_neurographs/machine_learning/heterograph_feature_generation.py index 557a9ef..32fc446 100644 --- a/src/deep_neurographs/machine_learning/heterograph_feature_generation.py +++ b/src/deep_neurographs/machine_learning/heterograph_feature_generation.py @@ -21,6 +21,7 @@ NODE_PROFILE_DEPTH = 16 WINDOW = [5, 5, 5] + def generate_hgnn_features( neurograph, img, proposals_dict, radius, downsample_factor ): diff --git a/src/deep_neurographs/machine_learning/heterograph_trainer.py b/src/deep_neurographs/machine_learning/heterograph_trainer.py index f0a2230..8dc5866 100644 --- a/src/deep_neurographs/machine_learning/heterograph_trainer.py +++ b/src/deep_neurographs/machine_learning/heterograph_trainer.py @@ -10,7 +10,7 @@ """ from copy import deepcopy -from random import sample, shuffle +from random import shuffle import numpy as np import torch @@ -22,25 +22,16 @@ ) from torch.optim.lr_scheduler import StepLR from torch.utils.tensorboard import SummaryWriter -from torch_geometric.utils import subgraph from deep_neurographs.utils import gnn_util, ml_util from deep_neurographs.utils.gnn_util import toCPU -# Training -FEATURE_DTYPE = torch.float32 -MODEL_TYPE = "HeteroGNN" LR = 1e-3 N_EPOCHS = 200 SCHEDULER_GAMMA = 0.5 SCHEDULER_STEP_SIZE = 1000 -TEST_PERCENT = 0.15 WEIGHT_DECAY = 1e-3 -# Augmentation -MAX_PROPOSAL_DROPOUT = 0.1 -SCALING_FACTOR = 0.05 - class HeteroGraphTrainer: """ @@ -54,8 +45,6 @@ def __init__( criterion, lr=LR, n_epochs=N_EPOCHS, - max_proposal_dropout=MAX_PROPOSAL_DROPOUT, - scaling_factor=SCALING_FACTOR, weight_decay=WEIGHT_DECAY, ): """ @@ -90,10 +79,6 @@ def __init__( self.init_scheduler() self.writer = SummaryWriter() - # Augmentation - self.scaling_factor = scaling_factor - self.max_proposal_dropout = max_proposal_dropout - def init_scheduler(self): self.scheduler = StepLR( self.optimizer, @@ -101,16 +86,14 @@ def init_scheduler(self): gamma=SCHEDULER_GAMMA, ) - def run_on_graphs(self, datasets, augment=False): + def run(self, train_dataset_list, validation_dataset_list): """ Trains a graph neural network in the case where "datasets" is a dictionary of datasets such that each corresponds to a distinct graph. Parameters ---------- - datasets : dict - Dictionary where each key is a graph id and the value is the - corresponding graph dataset. + ... Returns ------- @@ -118,32 +101,36 @@ def run_on_graphs(self, datasets, augment=False): Graph neural network that has been fit onto "datasets". """ - # Initializations best_score = -np.inf best_ckpt = None - - # Main - train_ids, test_ids = train_test_split(list(datasets.keys())) for epoch in range(self.n_epochs): # Train y, hat_y = [], [] self.model.train() - for graph_id in train_ids: - print(graph_id) - y_i, hat_y_i = self.train( - datasets[graph_id], epoch, augment=augment - ) + for graph_dataset in train_dataset_list: + # Forward pass + hat_y_i, y_i = self.predict(graph_dataset.data) + loss = self.criterion(hat_y_i, y_i) + self.writer.add_scalar("loss", loss, epoch) + + # Backward pass + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + # Store predictions y.extend(toCPU(y_i)) hat_y.extend(toCPU(hat_y_i)) + self.compute_metrics(y, hat_y, "train", epoch) self.scheduler.step() - # Test + # Validate if epoch % 10 == 0: y, hat_y = [], [] self.model.eval() - for graph_id in test_ids: - y_i, hat_y_i = self.forward(datasets[graph_id].data) + for graph_dataset in validation_dataset_list: + hat_y_i, y_i = self.predict(graph_dataset.data) y.extend(toCPU(y_i)) hat_y.extend(toCPU(hat_y_i)) test_score = self.compute_metrics(y, hat_y, "val", epoch) @@ -155,52 +142,7 @@ def run_on_graphs(self, datasets, augment=False): self.model.load_state_dict(best_ckpt) return self.model - def run_on_graph(self): - """ - Trains a graph neural network in the case where "dataset" is a - graph that may contain multiple connected components. - - Parameters - ---------- - dataset : dict - Dictionary where each key is a graph id and the value is the - corresponding graph dataset. - - Returns - ------- - None - - """ - pass - - def train(self, dataset, epoch, augment=False): - """ - Performs the forward pass and backpropagation to update the model's - weights. - - Parameters - ---------- - data : GraphDataset - Graph dataset that corresponds to a single connected component. - epoch : int - Current epoch. - augment : bool, optional - Indication of whether to augment data. Default is False. - - Returns - ------- - y : torch.Tensor - Ground truth. - hat_y : torch.Tensor - Prediction. - - """ - # if augment: - y, hat_y = self.forward(dataset.data) - self.backpropagate(y, hat_y, epoch) - return y, hat_y - - def forward(self, data): + def predict(self, data): """ Runs "data" through "self.model" to generate a prediction. @@ -219,37 +161,13 @@ def forward(self, data): """ # Run model x_dict, edge_index_dict, edge_attr_dict = gnn_util.get_inputs( - data, MODEL_TYPE + data, "HeteroGNN" ) - self.optimizer.zero_grad() hat_y = self.model(x_dict, edge_index_dict, edge_attr_dict) # Output y = data["proposal"]["y"] - return y, truncate(hat_y, y) - - def backpropagate(self, y, hat_y, epoch): - """ - Runs backpropagation to update the model's weights. - - Parameters - ---------- - y : torch.Tensor - Ground truth. - hat_y : torch.Tensor - Prediction. - epoch : int - Current epoch. - - Returns - ------- - None - - """ - loss = self.criterion(hat_y, y) - loss.backward() - self.optimizer.step() - self.writer.add_scalar("loss", loss, epoch) + return truncate(hat_y, y), y def compute_metrics(self, y, hat_y, prefix, epoch): """ @@ -312,29 +230,6 @@ def shuffler(my_list): return my_list -def train_test_split(graph_ids): - """ - Split a list of graph IDs into training and testing sets. - - Parameters - ---------- - graph_ids : list[str] - A list containing unique identifiers (IDs) for graphs. - - Returns - ------- - list - A list containing IDs for the training set. - list - A list containing IDs for the testing set. - - """ - n_test_examples = int(len(graph_ids) * TEST_PERCENT) - test_ids = ["block_000", "block_002"] # sample(graph_ids, n_test_examples) - train_ids = list(set(graph_ids) - set(test_ids)) - return train_ids, test_ids - - def truncate(hat_y, y): """ Truncates "hat_y" so that this tensor has the same shape as "y". Note this