Skip to content

Commit

Permalink
Refactor gnn training (#248)
Browse files Browse the repository at this point in the history
* minor upds

* refactor: training pipeline

* feat: find gcs image path

* feat: feature generation in trainer

* feat: validation sets in training

* bug: hgraph forward passes with missing edge types

* refactor: hgnn trainer

---------

Co-authored-by: anna-grim <anna.grim@alleninstitute.org>
  • Loading branch information
anna-grim and anna-grim authored Sep 26, 2024
1 parent 94a52fc commit b081e60
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 127 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
NODE_PROFILE_DEPTH = 16
WINDOW = [5, 5, 5]


def generate_hgnn_features(
neurograph, img, proposals_dict, radius, downsample_factor
):
Expand Down
149 changes: 22 additions & 127 deletions src/deep_neurographs/machine_learning/heterograph_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"""

from copy import deepcopy
from random import sample, shuffle
from random import shuffle

import numpy as np
import torch
Expand All @@ -22,25 +22,16 @@
)
from torch.optim.lr_scheduler import StepLR
from torch.utils.tensorboard import SummaryWriter
from torch_geometric.utils import subgraph

from deep_neurographs.utils import gnn_util, ml_util
from deep_neurographs.utils.gnn_util import toCPU

# Training
FEATURE_DTYPE = torch.float32
MODEL_TYPE = "HeteroGNN"
LR = 1e-3
N_EPOCHS = 200
SCHEDULER_GAMMA = 0.5
SCHEDULER_STEP_SIZE = 1000
TEST_PERCENT = 0.15
WEIGHT_DECAY = 1e-3

# Augmentation
MAX_PROPOSAL_DROPOUT = 0.1
SCALING_FACTOR = 0.05


class HeteroGraphTrainer:
"""
Expand All @@ -54,8 +45,6 @@ def __init__(
criterion,
lr=LR,
n_epochs=N_EPOCHS,
max_proposal_dropout=MAX_PROPOSAL_DROPOUT,
scaling_factor=SCALING_FACTOR,
weight_decay=WEIGHT_DECAY,
):
"""
Expand Down Expand Up @@ -90,60 +79,58 @@ def __init__(
self.init_scheduler()
self.writer = SummaryWriter()

# Augmentation
self.scaling_factor = scaling_factor
self.max_proposal_dropout = max_proposal_dropout

def init_scheduler(self):
self.scheduler = StepLR(
self.optimizer,
step_size=SCHEDULER_STEP_SIZE,
gamma=SCHEDULER_GAMMA,
)

def run_on_graphs(self, datasets, augment=False):
def run(self, train_dataset_list, validation_dataset_list):
"""
Trains a graph neural network in the case where "datasets" is a
dictionary of datasets such that each corresponds to a distinct graph.
Parameters
----------
datasets : dict
Dictionary where each key is a graph id and the value is the
corresponding graph dataset.
...
Returns
-------
torch.nn.Module
Graph neural network that has been fit onto "datasets".
"""
# Initializations
best_score = -np.inf
best_ckpt = None

# Main
train_ids, test_ids = train_test_split(list(datasets.keys()))
for epoch in range(self.n_epochs):
# Train
y, hat_y = [], []
self.model.train()
for graph_id in train_ids:
print(graph_id)
y_i, hat_y_i = self.train(
datasets[graph_id], epoch, augment=augment
)
for graph_dataset in train_dataset_list:
# Forward pass
hat_y_i, y_i = self.predict(graph_dataset.data)
loss = self.criterion(hat_y_i, y_i)
self.writer.add_scalar("loss", loss, epoch)

# Backward pass
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()

# Store predictions
y.extend(toCPU(y_i))
hat_y.extend(toCPU(hat_y_i))

self.compute_metrics(y, hat_y, "train", epoch)
self.scheduler.step()

# Test
# Validate
if epoch % 10 == 0:
y, hat_y = [], []
self.model.eval()
for graph_id in test_ids:
y_i, hat_y_i = self.forward(datasets[graph_id].data)
for graph_dataset in validation_dataset_list:
hat_y_i, y_i = self.predict(graph_dataset.data)
y.extend(toCPU(y_i))
hat_y.extend(toCPU(hat_y_i))
test_score = self.compute_metrics(y, hat_y, "val", epoch)
Expand All @@ -155,52 +142,7 @@ def run_on_graphs(self, datasets, augment=False):
self.model.load_state_dict(best_ckpt)
return self.model

def run_on_graph(self):
"""
Trains a graph neural network in the case where "dataset" is a
graph that may contain multiple connected components.
Parameters
----------
dataset : dict
Dictionary where each key is a graph id and the value is the
corresponding graph dataset.
Returns
-------
None
"""
pass

def train(self, dataset, epoch, augment=False):
"""
Performs the forward pass and backpropagation to update the model's
weights.
Parameters
----------
data : GraphDataset
Graph dataset that corresponds to a single connected component.
epoch : int
Current epoch.
augment : bool, optional
Indication of whether to augment data. Default is False.
Returns
-------
y : torch.Tensor
Ground truth.
hat_y : torch.Tensor
Prediction.
"""
# if augment:
y, hat_y = self.forward(dataset.data)
self.backpropagate(y, hat_y, epoch)
return y, hat_y

def forward(self, data):
def predict(self, data):
"""
Runs "data" through "self.model" to generate a prediction.
Expand All @@ -219,37 +161,13 @@ def forward(self, data):
"""
# Run model
x_dict, edge_index_dict, edge_attr_dict = gnn_util.get_inputs(
data, MODEL_TYPE
data, "HeteroGNN"
)
self.optimizer.zero_grad()
hat_y = self.model(x_dict, edge_index_dict, edge_attr_dict)

# Output
y = data["proposal"]["y"]
return y, truncate(hat_y, y)

def backpropagate(self, y, hat_y, epoch):
"""
Runs backpropagation to update the model's weights.
Parameters
----------
y : torch.Tensor
Ground truth.
hat_y : torch.Tensor
Prediction.
epoch : int
Current epoch.
Returns
-------
None
"""
loss = self.criterion(hat_y, y)
loss.backward()
self.optimizer.step()
self.writer.add_scalar("loss", loss, epoch)
return truncate(hat_y, y), y

def compute_metrics(self, y, hat_y, prefix, epoch):
"""
Expand Down Expand Up @@ -312,29 +230,6 @@ def shuffler(my_list):
return my_list


def train_test_split(graph_ids):
"""
Split a list of graph IDs into training and testing sets.
Parameters
----------
graph_ids : list[str]
A list containing unique identifiers (IDs) for graphs.
Returns
-------
list
A list containing IDs for the training set.
list
A list containing IDs for the testing set.
"""
n_test_examples = int(len(graph_ids) * TEST_PERCENT)
test_ids = ["block_000", "block_002"] # sample(graph_ids, n_test_examples)
train_ids = list(set(graph_ids) - set(test_ids))
return train_ids, test_ids


def truncate(hat_y, y):
"""
Truncates "hat_y" so that this tensor has the same shape as "y". Note this
Expand Down

0 comments on commit b081e60

Please sign in to comment.