Skip to content

Refactor gnn training #247

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Sep 26, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/deep_neurographs/config.py
Original file line number Diff line number Diff line change
@@ -97,8 +97,12 @@ class MLConfig:
batch_size: int = 2000
downsample_factor: int = 1
high_threshold: float = 0.9
lr: float = 1e-4
threshold: float = 0.6
model_type: str = "GraphNeuralNet"
n_epochs: int = 1000
validation_split: float = 0.15
weight_decay: float = 1e-3


class Config:
4 changes: 2 additions & 2 deletions src/deep_neurographs/machine_learning/datasets.py
Original file line number Diff line number Diff line change
@@ -295,8 +295,8 @@ def reformat(arr):

def init_idxs(idxs):
"""
Adds dictionary item called "edge_to_index" which maps an edge in a
neurograph to an that represents the edge's position in the feature
Adds dictionary item called "edge_to_index" which maps a branch/proposal
in a neurograph to an idx that represents it's position in the feature
matrix.

Parameters
6 changes: 2 additions & 4 deletions src/deep_neurographs/machine_learning/feature_generation.py
Original file line number Diff line number Diff line change
@@ -515,15 +515,13 @@ def stack_chunks(neurograph, features, shift=0):


# -- util --
def count_features(model_type):
def count_features():
"""
Counts number of features based on the "model_type".

Parameters
----------
model_type : str
Indication of model to be trained. Options include: AdaBoost,
RandomForest, FeedForwardNet, MultiModalNet.
None

Returns
-------
24 changes: 23 additions & 1 deletion src/deep_neurographs/machine_learning/heterograph_datasets.py
Original file line number Diff line number Diff line change
@@ -137,8 +137,10 @@ def __init__(

# Edges
self.init_edges()
self.check_missing_edge_type()
self.init_edge_attrs(x_nodes)
self.n_edge_attrs = n_edge_features(x_nodes)


def init_edges(self):
"""
@@ -192,6 +194,23 @@ def init_edge_attrs(self, x_nodes):
x_nodes, edge_type, self.idxs_branches, self.idxs_proposals
)

def check_missing_edge_type(self):
edge_type = ("branch", "edge", "branch")
if len(self.data[edge_type].edge_index) == 0:
# Add dummy features
dtype = self.data["branch"].x.dtype
zeros = torch.zeros(2, self.n_branch_features(), dtype=dtype)
self.data["branch"].x = torch.cat(
(self.data["branch"].x, zeros), dim=0
)

# Update edge_index
n = self.data["branch"]["x"].size(0)
edge_index = [[n - 1, n - 2], [n - 2, n - 1]]
self.data[edge_type].edge_index = gnn_util.to_tensor(edge_index)
self.idxs_branches["idx_to_edge"][n - 1] = frozenset({-1, -2})
self.idxs_branches["idx_to_edge"][n - 2] = frozenset({-2, -3})

# -- Getters --
def n_branch_features(self):
"""
@@ -342,7 +361,10 @@ def set_edge_attrs(self, x_nodes, edge_type, idx_map):
for i in range(self.data[edge_type].edge_index.size(1)):
e1, e2 = self.data[edge_type].edge_index[:, i]
v = node_intersection(idx_map, e1, e2)
attrs.append(x_nodes[v])
if v < 0:
attrs.append(torch.zeros(self.n_branch_features() + 1))
else:
attrs.append(x_nodes[v])
arrs = torch.tensor(np.array(attrs), dtype=DTYPE)
self.data[edge_type].edge_attr = arrs

Original file line number Diff line number Diff line change
@@ -16,10 +16,10 @@
from deep_neurographs.machine_learning import feature_generation as feats
from deep_neurographs.utils import img_util

WINDOW = [5, 5, 5]

N_PROFILE_PTS = 16
NODE_PROFILE_DEPTH = 16

WINDOW = [5, 5, 5]

def generate_hgnn_features(
neurograph, img, proposals_dict, radius, downsample_factor
@@ -458,3 +458,16 @@ def check_degenerate(voxels):
[voxels, voxels[0, :] + np.array([1, 1, 1], dtype=int)]
)
return voxels


def n_node_features():
return {'branch': 2, 'proposal': 34}


def n_edge_features():
n_edge_features_dict = {
('proposal', 'edge', 'proposal'): 3,
('branch', 'edge', 'branch'): 3,
('branch', 'edge', 'proposal'): 3
}
return n_edge_features_dict
20 changes: 13 additions & 7 deletions src/deep_neurographs/machine_learning/heterograph_models.py
Original file line number Diff line number Diff line change
@@ -8,13 +8,16 @@

"""

import numpy as np
import torch
import torch.nn.init as init
from torch import nn
from torch.nn import Dropout, LeakyReLU
from torch_geometric.nn import GATv2Conv as GATConv
from torch_geometric.nn import HEATConv, HeteroConv, Linear

from deep_neurographs.machine_learning import heterograph_feature_generation

CONV_TYPES = ["GATConv", "GCNConv"]
DROPOUT = 0.3
HEADS_1 = 1
@@ -29,9 +32,7 @@ class HeteroGNN(torch.nn.Module):

def __init__(
self,
node_dict,
edge_dict,
hidden_dim,
scale_hidden_dim=2,
dropout=DROPOUT,
heads_1=HEADS_1,
heads_2=HEADS_2,
@@ -41,6 +42,11 @@ def __init__(

"""
super().__init__()
# Feature vector sizes
node_dict = heterograph_feature_generation.n_node_features()
edge_dict = heterograph_feature_generation.n_edge_features()
hidden_dim = scale_hidden_dim * np.max(list(node_dict.values()))

# Linear layers
output_dim = heads_1 * heads_2 * hidden_dim
self.input_nodes = nn.ModuleDict(
@@ -161,9 +167,8 @@ def forward(self, x_dict, edge_index_dict, edge_attr_dict):
x_dict = self.activation(x_dict)

# Input - Edges
edge_attr_dict = {
key: f(edge_attr_dict[key]) for key, f in self.input_edges.items()
}
for key, f in self.input_edges.items():
edge_attr_dict[key] = f(edge_attr_dict[key])
edge_attr_dict = self.activation(edge_attr_dict)

# Convolutional layers
@@ -218,7 +223,8 @@ def __init__(
metadata=metadata,
)
"""
x in_channels (int) – Size of each input sample, or -1 to derive the size from the first input(s) to the forward method.
x in_channels (int) – Size of each input sample, or -1 to
derive the size from the first input(s) to the forward method.
x out_channels (int) – Size of each output sample.
x num_node_types (int) – The number of node types.
x num_edge_types (int) – The number of edge types.
6 changes: 3 additions & 3 deletions src/deep_neurographs/train_pipeline.py
Original file line number Diff line number Diff line change
@@ -30,11 +30,11 @@ class Trainer:
def __init__(
self,
config,
model,
model_type,
criterion=None,
output_dir=None,
validation_ids=None,
validation_split=0.15,
save_model_bool=True,
):
# Check for parameter errors
@@ -43,11 +43,11 @@ def __init__(

# Set class attributes
self.idx_to_ids = list()
self.model = model
self.model_type = model_type
self.output_dir = output_dir
self.save_model_bool = save_model_bool
self.validation_ids = validation_ids
self.validation_split = validation_split

# Set data structures for training examples
self.gt_graphs = list()
@@ -80,7 +80,7 @@ def n_validation_samples(self):

def set_validation_idxs(self):
if self.validation_ids is None:
k = int(self.validation_split * self.n_examples())
k = int(self.ml_config.validation_split * self.n_examples())
self.validation_idxs = sample(np.arange(self.n_examples), k)
else:
self.validation_idxs = list()
24 changes: 12 additions & 12 deletions src/deep_neurographs/utils/gnn_util.py
Original file line number Diff line number Diff line change
@@ -18,6 +18,18 @@
from deep_neurographs.utils import util


def get_inputs(data, model_type):
if "Hetero" in model_type:
x = data.x_dict
edge_index = data.edge_index_dict
edge_attr_dict = data.edge_attr_dict
return x, edge_index, edge_attr_dict
else:
x = data.x
edge_index = data.edge_index
return x, edge_index


def toCPU(tensor):
"""
Moves tensor from GPU to CPU.
@@ -35,18 +47,6 @@ def toCPU(tensor):
return tensor.detach().cpu().tolist()


def get_inputs(data, model_type):
if "Hetero" in model_type:
x = data.x_dict
edge_index = data.edge_index_dict
edge_attr_dict = data.edge_attr_dict
return x, edge_index, edge_attr_dict
else:
x = data.x
edge_index = data.edge_index
return x, edge_index


def to_tensor(my_list):
"""
Converts a list to a tensor with contiguous memory.
Loading