Skip to content

Commit

Permalink
feat: gnn inference and documentation (#119)
Browse files Browse the repository at this point in the history
Co-authored-by: anna-grim <anna.grim@alleninstitute.org>
  • Loading branch information
anna-grim and anna-grim authored Apr 13, 2024
1 parent cfeb75c commit b9f1b9d
Show file tree
Hide file tree
Showing 4 changed files with 242 additions and 8 deletions.
1 change: 0 additions & 1 deletion src/deep_neurographs/machine_learning/graph_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ def upd_idxs(idxs, shift):
Updated index transform dictinoary.
"""
idxs["block_to_idxs"] = upd_set(idxs["block_to_idxs"], shift)
idxs["idx_to_edge"] = upd_dict(idxs["idx_to_edge"], shift)
return idxs

Expand Down
203 changes: 198 additions & 5 deletions src/deep_neurographs/machine_learning/graph_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from torch.nn.functional import sigmoid
from torch.utils.tensorboard import SummaryWriter
from deep_neurographs.machine_learning import ml_utils


LR = 1e-3
Expand All @@ -24,6 +25,10 @@


class GraphTrainer:
"""
Custom class that trains graph neural networks.
"""
def __init__(
self,
model,
Expand All @@ -32,6 +37,28 @@ def __init__(
n_epochs=N_EPOCHS,
weight_decay=WEIGHT_DECAY,
):
"""
Constructs a GraphTrainer object.
Parameters
----------
model : torch.nn.Module
Graph neural network.
criterion : torch.nn.Module._Loss
Loss function.
lr : float, optional
Learning rate. The default is the global variable LR.
n_epochs : int
Number of epochs. The default is the global variable N_EPOCHS.
weight_decay : float
Weight decay used in optimizer. The default is the global variable
WEIGHT_DECAY.
Returns
-------
None.
"""
self.model = model.to("cuda:0")
self.criterion = criterion
self.n_epochs = n_epochs
Expand All @@ -41,6 +68,22 @@ def __init__(
self.writer = SummaryWriter()

def run_on_graphs(self, graph_datasets):
"""
Trains a graph neural network in the case where "graph_datasets" is a
dictionary of datasets such that each corresponds to a distinct graph.
Parameters
----------
graph_datasets : dict
Dictionary where each key is a graph id and the value is the
corresponding graph dataset.
Returns
-------
model : torch.nn.Module
Graph neural network that has been fit onto "graph_datasets".
"""
# Initializations
best_score = -np.inf
best_ckpt = None
Expand Down Expand Up @@ -73,24 +116,114 @@ def run_on_graphs(self, graph_datasets):
best_ckpt = deepcopy(self.model.state_dict())
return self.model.load_state_dict(best_ckpt)

def run_on_graph(self):
"""
Trains a graph neural network in the case where "graph_dataset" is a
graph that may contain multiple connected components.
Parameters
----------
graph_dataset : dict
Dictionary where each key is a graph id and the value is the
corresponding graph dataset.
Returns
-------
None
"""
pass

def train(self, graph_data, epoch):
"""
Performs the forward pass and backpropagation to update the model's
weights.
Parameters
----------
graph_data : GraphDataset
Graph dataset that corresponds to a single connected component.
epoch : int
Current epoch.
Returns
-------
y : torch.Tensor
Ground truth.
hat_y : torch.Tensor
Prediction.
"""
y, hat_y = self.forward(graph_data)
self.backpropagate(y, hat_y, epoch)
return y, hat_y

def forward(self, graph_data):
"""
Runs "graph_data" through "self.model" to generate a prediction.
Parameters
----------
graph_data : GraphDataset
Graph dataset that corresponds to a single connected component.
Returns
-------
y : torch.Tensor
Ground truth.
hat_y : torch.Tensor
Prediction.
"""
self.optimizer.zero_grad()
x, y, edge_index = toGPU(graph_data)
hat_y = self.model(x, edge_index)
return y, truncate(hat_y, y)

def backpropagate(self, y, hat_y, epoch):
"""
Runs backpropagation to update the model's weights.
Parameters
----------
y : torch.Tensor
Ground truth.
hat_y : torch.Tensor
Prediction.
epoch : int
Current epoch.
Returns
-------
None
"""
loss = self.criterion(hat_y, y)
loss.backward()
self.optimizer.step()
self.writer.add_scalar("loss", loss, epoch)

def compute_metrics(self, y, hat_y, prefix, epoch):
"""
Computes and logs evaluation metrics for binary classification.
Parameters
----------
y : torch.Tensor
Ground truth.
hat_y : torch.Tensor
Prediction.
prefix : str
Prefix to be added to the metric names when logging.
epoch : int
Current epoch.
Returns
-------
f1 : float
F1 score.
"""
# Initializations
y = np.array(y, dtype=int).tolist()
hat_y = get_predictions(hat_y)
Expand All @@ -108,7 +241,7 @@ def compute_metrics(self, y, hat_y, prefix, epoch):
self.writer.add_scalar(prefix + '_precision:', precision, epoch)
self.writer.add_scalar(prefix + '_recall:', recall, epoch)
self.writer.add_scalar(prefix + '_f1:', f1, epoch)
return accuracy_dif
return f1


# -- utils --
Expand All @@ -132,17 +265,65 @@ def shuffler(my_list):


def train_test_split(graph_ids):
n_test_examples = 1 # int(len(graph_ids) * TEST_PERCENT)
"""
Split a list of graph IDs into training and testing sets.
Parameters
----------
graph_ids : list[str]
A list containing unique identifiers (IDs) for graphs.
Returns
-------
train_ids : list
A list containing IDs for the training set.
test_ids : list
A list containing IDs for the testing set.
"""
n_test_examples = int(len(graph_ids) * TEST_PERCENT)
test_ids = sample(graph_ids, n_test_examples)
train_ids = list(set(graph_ids) - set(test_ids))
return train_ids, test_ids


def toCPU(tensor):
"""
Moves "tensor" from GPU to CPU.
Parameters
----------
tensor : torch.Tensor
Dataset to be moved to GPU.
Returns
-------
numpy.ndarray
Array.
"""
return np.array(tensor.detach().cpu()).tolist()


def toGPU(graph_data):
"""
Moves "graph_data" from CPU to GPU.
Parameters
----------
graph_data : GraphDataset
Dataset to be moved to GPU.
Returns
-------
x : torch.Tensor
Matrix of node feature vectors.
y : torch.Tensor
Ground truth.
edge_idx : torch.Tensor
Tensor containing edges in graph.
"""
x = graph_data.x.to("cuda:0", dtype=torch.float32)
y = graph_data.y.to("cuda:0", dtype=torch.float32)
edge_index = graph_data.edge_index.to("cuda:0")
Expand Down Expand Up @@ -172,8 +353,20 @@ def truncate(hat_y, y):


def get_predictions(hat_y, threshold=0.5):
return (sigmoid(np.array(hat_y)) > threshold).tolist()
"""
Generate binary predictions based on the input probabilities.
Parameters
----------
hat_y : torch.Tensor
Predicted probabilities generated by "self.model".
threshold : float, optional
The threshold value for binary classification. The default is 0.5.
Returns
-------
list[int]
Binary predictions based on the given threshold.
def sigmoid(x):
return 1.0/(1.0 + np.exp(-x))
"""
return (ml_utils.sigmoid(np.array(hat_y)) > threshold).tolist()
25 changes: 23 additions & 2 deletions src/deep_neurographs/machine_learning/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
CHUNK_SHAPE = (256, 256, 256)


# -- Whole Brain Inference --
def run(
neurograph,
model_type,
Expand Down Expand Up @@ -171,6 +172,7 @@ def predict(
return accepts, graph


# -- Whole Brain Seed-Based Inference --
def build_from_soma(
neurograph, labels_path, chunk_origin, chunk_shape=CHUNK_SHAPE, n_hops=1
):
Expand Down Expand Up @@ -250,11 +252,14 @@ def ingest_subgraph(neurograph_1, neurograph_2, node_subset):
return neurograph_2


# -- Inference --
def run_model(dataset, model, model_type):
data = dataset["dataset"]
if "Net" in model_type:
if "Graph" in model_type:
return run_graph_model(dataset, model)
elif "Net" in model_type:
model.eval()
hat_y = []
data = dataset["dataset"]
for batch in DataLoader(data, batch_size=32, shuffle=False):
# Run model
with torch.no_grad():
Expand All @@ -269,6 +274,17 @@ def run_model(dataset, model, model_type):
return np.array(hat_y)


def run_graph_model(graph_data, model):
# Run model
x, edge_index = toGPU(graph_data.data)
hat_y = model(x, edge_index)

# Reformat pred
idx = graph_data.n_proposals
hat_y = ml_utils.toCPU(hat_y[0:idx, 0])
return ml_utils.sigmoid(hat_y)


# Utils
def report_progress(current, total, chunk_size, cnt, t0, t1):
eta = get_eta(current, total, chunk_size, t1)
Expand All @@ -290,3 +306,8 @@ def get_runtime(current, total, chunk_size, t0, t1):
total_runtime = time() - t0 + eta
t, unit = utils.time_writer(total_runtime)
return f"{round(t, 4)} {unit}"

def toGPU(graph_data):
x = graph_data.x.to("cuda:0", dtype=torch.float32)
edge_index = graph_data.edge_index.to("cuda:0")
return x, edge_index
21 changes: 21 additions & 0 deletions src/deep_neurographs/machine_learning/ml_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,24 @@ def get_lengths(neurograph):
for edge in neurograph.proposals.keys():
lengths.append(neurograph.proposal_length(edge))
return lengths


def toCPU(tensor):
return np.array(tensor.detach().cpu())


def sigmoid(x):
"""
Sigmoid function.
Parameters
----------
x : numpy.ndarray
Input to sigmoid.
Return
------
Sigmoid applied to "x".
"""
return 1.0/(1.0 + np.exp(-x))

0 comments on commit b9f1b9d

Please sign in to comment.