From b9f1b9d7bb8bcf3fe65a5296749120af133fa9e9 Mon Sep 17 00:00:00 2001 From: Anna Grim <108307071+anna-grim@users.noreply.github.com> Date: Fri, 12 Apr 2024 21:29:22 -0700 Subject: [PATCH] feat: gnn inference and documentation (#119) Co-authored-by: anna-grim --- .../machine_learning/graph_datasets.py | 1 - .../machine_learning/graph_trainer.py | 203 +++++++++++++++++- .../machine_learning/inference.py | 25 ++- .../machine_learning/ml_utils.py | 21 ++ 4 files changed, 242 insertions(+), 8 deletions(-) diff --git a/src/deep_neurographs/machine_learning/graph_datasets.py b/src/deep_neurographs/machine_learning/graph_datasets.py index bf22ef7..2cbc4a3 100644 --- a/src/deep_neurographs/machine_learning/graph_datasets.py +++ b/src/deep_neurographs/machine_learning/graph_datasets.py @@ -145,7 +145,6 @@ def upd_idxs(idxs, shift): Updated index transform dictinoary. """ - idxs["block_to_idxs"] = upd_set(idxs["block_to_idxs"], shift) idxs["idx_to_edge"] = upd_dict(idxs["idx_to_edge"], shift) return idxs diff --git a/src/deep_neurographs/machine_learning/graph_trainer.py b/src/deep_neurographs/machine_learning/graph_trainer.py index a7f864d..0eba638 100644 --- a/src/deep_neurographs/machine_learning/graph_trainer.py +++ b/src/deep_neurographs/machine_learning/graph_trainer.py @@ -15,6 +15,7 @@ from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score from torch.nn.functional import sigmoid from torch.utils.tensorboard import SummaryWriter +from deep_neurographs.machine_learning import ml_utils LR = 1e-3 @@ -24,6 +25,10 @@ class GraphTrainer: + """ + Custom class that trains graph neural networks. + + """ def __init__( self, model, @@ -32,6 +37,28 @@ def __init__( n_epochs=N_EPOCHS, weight_decay=WEIGHT_DECAY, ): + """ + Constructs a GraphTrainer object. + + Parameters + ---------- + model : torch.nn.Module + Graph neural network. + criterion : torch.nn.Module._Loss + Loss function. + lr : float, optional + Learning rate. The default is the global variable LR. + n_epochs : int + Number of epochs. The default is the global variable N_EPOCHS. + weight_decay : float + Weight decay used in optimizer. The default is the global variable + WEIGHT_DECAY. + + Returns + ------- + None. + + """ self.model = model.to("cuda:0") self.criterion = criterion self.n_epochs = n_epochs @@ -41,6 +68,22 @@ def __init__( self.writer = SummaryWriter() def run_on_graphs(self, graph_datasets): + """ + Trains a graph neural network in the case where "graph_datasets" is a + dictionary of datasets such that each corresponds to a distinct graph. + + Parameters + ---------- + graph_datasets : dict + Dictionary where each key is a graph id and the value is the + corresponding graph dataset. + + Returns + ------- + model : torch.nn.Module + Graph neural network that has been fit onto "graph_datasets". + + """ # Initializations best_score = -np.inf best_ckpt = None @@ -73,24 +116,114 @@ def run_on_graphs(self, graph_datasets): best_ckpt = deepcopy(self.model.state_dict()) return self.model.load_state_dict(best_ckpt) + def run_on_graph(self): + """ + Trains a graph neural network in the case where "graph_dataset" is a + graph that may contain multiple connected components. + + Parameters + ---------- + graph_dataset : dict + Dictionary where each key is a graph id and the value is the + corresponding graph dataset. + + Returns + ------- + None + + """ + pass + def train(self, graph_data, epoch): + """ + Performs the forward pass and backpropagation to update the model's + weights. + + Parameters + ---------- + graph_data : GraphDataset + Graph dataset that corresponds to a single connected component. + epoch : int + Current epoch. + + Returns + ------- + y : torch.Tensor + Ground truth. + hat_y : torch.Tensor + Prediction. + + """ y, hat_y = self.forward(graph_data) self.backpropagate(y, hat_y, epoch) return y, hat_y def forward(self, graph_data): + """ + Runs "graph_data" through "self.model" to generate a prediction. + + Parameters + ---------- + graph_data : GraphDataset + Graph dataset that corresponds to a single connected component. + + Returns + ------- + y : torch.Tensor + Ground truth. + hat_y : torch.Tensor + Prediction. + + """ self.optimizer.zero_grad() x, y, edge_index = toGPU(graph_data) hat_y = self.model(x, edge_index) return y, truncate(hat_y, y) def backpropagate(self, y, hat_y, epoch): + """ + Runs backpropagation to update the model's weights. + + Parameters + ---------- + y : torch.Tensor + Ground truth. + hat_y : torch.Tensor + Prediction. + epoch : int + Current epoch. + + Returns + ------- + None + + """ loss = self.criterion(hat_y, y) loss.backward() self.optimizer.step() self.writer.add_scalar("loss", loss, epoch) def compute_metrics(self, y, hat_y, prefix, epoch): + """ + Computes and logs evaluation metrics for binary classification. + + Parameters + ---------- + y : torch.Tensor + Ground truth. + hat_y : torch.Tensor + Prediction. + prefix : str + Prefix to be added to the metric names when logging. + epoch : int + Current epoch. + + Returns + ------- + f1 : float + F1 score. + + """ # Initializations y = np.array(y, dtype=int).tolist() hat_y = get_predictions(hat_y) @@ -108,7 +241,7 @@ def compute_metrics(self, y, hat_y, prefix, epoch): self.writer.add_scalar(prefix + '_precision:', precision, epoch) self.writer.add_scalar(prefix + '_recall:', recall, epoch) self.writer.add_scalar(prefix + '_f1:', f1, epoch) - return accuracy_dif + return f1 # -- utils -- @@ -132,17 +265,65 @@ def shuffler(my_list): def train_test_split(graph_ids): - n_test_examples = 1 # int(len(graph_ids) * TEST_PERCENT) + """ + Split a list of graph IDs into training and testing sets. + + Parameters + ---------- + graph_ids : list[str] + A list containing unique identifiers (IDs) for graphs. + + Returns + ------- + train_ids : list + A list containing IDs for the training set. + test_ids : list + A list containing IDs for the testing set. + + """ + n_test_examples = int(len(graph_ids) * TEST_PERCENT) test_ids = sample(graph_ids, n_test_examples) train_ids = list(set(graph_ids) - set(test_ids)) return train_ids, test_ids def toCPU(tensor): + """ + Moves "tensor" from GPU to CPU. + + Parameters + ---------- + tensor : torch.Tensor + Dataset to be moved to GPU. + + Returns + ------- + numpy.ndarray + Array. + + """ return np.array(tensor.detach().cpu()).tolist() def toGPU(graph_data): + """ + Moves "graph_data" from CPU to GPU. + + Parameters + ---------- + graph_data : GraphDataset + Dataset to be moved to GPU. + + Returns + ------- + x : torch.Tensor + Matrix of node feature vectors. + y : torch.Tensor + Ground truth. + edge_idx : torch.Tensor + Tensor containing edges in graph. + + """ x = graph_data.x.to("cuda:0", dtype=torch.float32) y = graph_data.y.to("cuda:0", dtype=torch.float32) edge_index = graph_data.edge_index.to("cuda:0") @@ -172,8 +353,20 @@ def truncate(hat_y, y): def get_predictions(hat_y, threshold=0.5): - return (sigmoid(np.array(hat_y)) > threshold).tolist() + """ + Generate binary predictions based on the input probabilities. + + Parameters + ---------- + hat_y : torch.Tensor + Predicted probabilities generated by "self.model". + threshold : float, optional + The threshold value for binary classification. The default is 0.5. + Returns + ------- + list[int] + Binary predictions based on the given threshold. -def sigmoid(x): - return 1.0/(1.0 + np.exp(-x)) + """ + return (ml_utils.sigmoid(np.array(hat_y)) > threshold).tolist() diff --git a/src/deep_neurographs/machine_learning/inference.py b/src/deep_neurographs/machine_learning/inference.py index 40fc886..1be9089 100644 --- a/src/deep_neurographs/machine_learning/inference.py +++ b/src/deep_neurographs/machine_learning/inference.py @@ -28,6 +28,7 @@ CHUNK_SHAPE = (256, 256, 256) +# -- Whole Brain Inference -- def run( neurograph, model_type, @@ -171,6 +172,7 @@ def predict( return accepts, graph +# -- Whole Brain Seed-Based Inference -- def build_from_soma( neurograph, labels_path, chunk_origin, chunk_shape=CHUNK_SHAPE, n_hops=1 ): @@ -250,11 +252,14 @@ def ingest_subgraph(neurograph_1, neurograph_2, node_subset): return neurograph_2 +# -- Inference -- def run_model(dataset, model, model_type): - data = dataset["dataset"] - if "Net" in model_type: + if "Graph" in model_type: + return run_graph_model(dataset, model) + elif "Net" in model_type: model.eval() hat_y = [] + data = dataset["dataset"] for batch in DataLoader(data, batch_size=32, shuffle=False): # Run model with torch.no_grad(): @@ -269,6 +274,17 @@ def run_model(dataset, model, model_type): return np.array(hat_y) +def run_graph_model(graph_data, model): + # Run model + x, edge_index = toGPU(graph_data.data) + hat_y = model(x, edge_index) + + # Reformat pred + idx = graph_data.n_proposals + hat_y = ml_utils.toCPU(hat_y[0:idx, 0]) + return ml_utils.sigmoid(hat_y) + + # Utils def report_progress(current, total, chunk_size, cnt, t0, t1): eta = get_eta(current, total, chunk_size, t1) @@ -290,3 +306,8 @@ def get_runtime(current, total, chunk_size, t0, t1): total_runtime = time() - t0 + eta t, unit = utils.time_writer(total_runtime) return f"{round(t, 4)} {unit}" + +def toGPU(graph_data): + x = graph_data.x.to("cuda:0", dtype=torch.float32) + edge_index = graph_data.edge_index.to("cuda:0") + return x, edge_index diff --git a/src/deep_neurographs/machine_learning/ml_utils.py b/src/deep_neurographs/machine_learning/ml_utils.py index 6cb09e5..89539bd 100644 --- a/src/deep_neurographs/machine_learning/ml_utils.py +++ b/src/deep_neurographs/machine_learning/ml_utils.py @@ -175,3 +175,24 @@ def get_lengths(neurograph): for edge in neurograph.proposals.keys(): lengths.append(neurograph.proposal_length(edge)) return lengths + + +def toCPU(tensor): + return np.array(tensor.detach().cpu()) + + +def sigmoid(x): + """ + Sigmoid function. + + Parameters + ---------- + x : numpy.ndarray + Input to sigmoid. + + Return + ------ + Sigmoid applied to "x". + + """ + return 1.0/(1.0 + np.exp(-x))