Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor gnn training #247

Merged
merged 11 commits into from
Sep 26, 2024
4 changes: 4 additions & 0 deletions src/deep_neurographs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,12 @@ class MLConfig:
batch_size: int = 2000
downsample_factor: int = 1
high_threshold: float = 0.9
lr: float = 1e-4
threshold: float = 0.6
model_type: str = "GraphNeuralNet"
n_epochs: int = 1000
validation_split: float = 0.15
weight_decay: float = 1e-3


class Config:
Expand Down
4 changes: 2 additions & 2 deletions src/deep_neurographs/machine_learning/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,8 @@ def reformat(arr):

def init_idxs(idxs):
"""
Adds dictionary item called "edge_to_index" which maps an edge in a
neurograph to an that represents the edge's position in the feature
Adds dictionary item called "edge_to_index" which maps a branch/proposal
in a neurograph to an idx that represents it's position in the feature
matrix.

Parameters
Expand Down
6 changes: 2 additions & 4 deletions src/deep_neurographs/machine_learning/feature_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,15 +515,13 @@ def stack_chunks(neurograph, features, shift=0):


# -- util --
def count_features(model_type):
def count_features():
"""
Counts number of features based on the "model_type".

Parameters
----------
model_type : str
Indication of model to be trained. Options include: AdaBoost,
RandomForest, FeedForwardNet, MultiModalNet.
None

Returns
-------
Expand Down
24 changes: 23 additions & 1 deletion src/deep_neurographs/machine_learning/heterograph_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,10 @@ def __init__(

# Edges
self.init_edges()
self.check_missing_edge_type()
self.init_edge_attrs(x_nodes)
self.n_edge_attrs = n_edge_features(x_nodes)


def init_edges(self):
"""
Expand Down Expand Up @@ -192,6 +194,23 @@ def init_edge_attrs(self, x_nodes):
x_nodes, edge_type, self.idxs_branches, self.idxs_proposals
)

def check_missing_edge_type(self):
edge_type = ("branch", "edge", "branch")
if len(self.data[edge_type].edge_index) == 0:
# Add dummy features
dtype = self.data["branch"].x.dtype
zeros = torch.zeros(2, self.n_branch_features(), dtype=dtype)
self.data["branch"].x = torch.cat(
(self.data["branch"].x, zeros), dim=0
)

# Update edge_index
n = self.data["branch"]["x"].size(0)
edge_index = [[n - 1, n - 2], [n - 2, n - 1]]
self.data[edge_type].edge_index = gnn_util.to_tensor(edge_index)
self.idxs_branches["idx_to_edge"][n - 1] = frozenset({-1, -2})
self.idxs_branches["idx_to_edge"][n - 2] = frozenset({-2, -3})

# -- Getters --
def n_branch_features(self):
"""
Expand Down Expand Up @@ -342,7 +361,10 @@ def set_edge_attrs(self, x_nodes, edge_type, idx_map):
for i in range(self.data[edge_type].edge_index.size(1)):
e1, e2 = self.data[edge_type].edge_index[:, i]
v = node_intersection(idx_map, e1, e2)
attrs.append(x_nodes[v])
if v < 0:
attrs.append(torch.zeros(self.n_branch_features() + 1))
else:
attrs.append(x_nodes[v])
arrs = torch.tensor(np.array(attrs), dtype=DTYPE)
self.data[edge_type].edge_attr = arrs

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
from deep_neurographs.machine_learning import feature_generation as feats
from deep_neurographs.utils import img_util

WINDOW = [5, 5, 5]

N_PROFILE_PTS = 16
NODE_PROFILE_DEPTH = 16

WINDOW = [5, 5, 5]

def generate_hgnn_features(
neurograph, img, proposals_dict, radius, downsample_factor
Expand Down Expand Up @@ -458,3 +458,16 @@ def check_degenerate(voxels):
[voxels, voxels[0, :] + np.array([1, 1, 1], dtype=int)]
)
return voxels


def n_node_features():
return {'branch': 2, 'proposal': 34}


def n_edge_features():
n_edge_features_dict = {
('proposal', 'edge', 'proposal'): 3,
('branch', 'edge', 'branch'): 3,
('branch', 'edge', 'proposal'): 3
}
return n_edge_features_dict
20 changes: 13 additions & 7 deletions src/deep_neurographs/machine_learning/heterograph_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@

"""

import numpy as np
import torch
import torch.nn.init as init
from torch import nn
from torch.nn import Dropout, LeakyReLU
from torch_geometric.nn import GATv2Conv as GATConv
from torch_geometric.nn import HEATConv, HeteroConv, Linear

from deep_neurographs.machine_learning import heterograph_feature_generation

CONV_TYPES = ["GATConv", "GCNConv"]
DROPOUT = 0.3
HEADS_1 = 1
Expand All @@ -29,9 +32,7 @@ class HeteroGNN(torch.nn.Module):

def __init__(
self,
node_dict,
edge_dict,
hidden_dim,
scale_hidden_dim=2,
dropout=DROPOUT,
heads_1=HEADS_1,
heads_2=HEADS_2,
Expand All @@ -41,6 +42,11 @@ def __init__(

"""
super().__init__()
# Feature vector sizes
node_dict = heterograph_feature_generation.n_node_features()
edge_dict = heterograph_feature_generation.n_edge_features()
hidden_dim = scale_hidden_dim * np.max(list(node_dict.values()))

# Linear layers
output_dim = heads_1 * heads_2 * hidden_dim
self.input_nodes = nn.ModuleDict(
Expand Down Expand Up @@ -161,9 +167,8 @@ def forward(self, x_dict, edge_index_dict, edge_attr_dict):
x_dict = self.activation(x_dict)

# Input - Edges
edge_attr_dict = {
key: f(edge_attr_dict[key]) for key, f in self.input_edges.items()
}
for key, f in self.input_edges.items():
edge_attr_dict[key] = f(edge_attr_dict[key])
edge_attr_dict = self.activation(edge_attr_dict)

# Convolutional layers
Expand Down Expand Up @@ -218,7 +223,8 @@ def __init__(
metadata=metadata,
)
"""
x in_channels (int) – Size of each input sample, or -1 to derive the size from the first input(s) to the forward method.
x in_channels (int) – Size of each input sample, or -1 to
derive the size from the first input(s) to the forward method.
x out_channels (int) – Size of each output sample.
x num_node_types (int) – The number of node types.
x num_edge_types (int) – The number of edge types.
Expand Down
6 changes: 3 additions & 3 deletions src/deep_neurographs/train_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ class Trainer:
def __init__(
self,
config,
model,
model_type,
criterion=None,
output_dir=None,
validation_ids=None,
validation_split=0.15,
save_model_bool=True,
):
# Check for parameter errors
Expand All @@ -43,11 +43,11 @@ def __init__(

# Set class attributes
self.idx_to_ids = list()
self.model = model
self.model_type = model_type
self.output_dir = output_dir
self.save_model_bool = save_model_bool
self.validation_ids = validation_ids
self.validation_split = validation_split

# Set data structures for training examples
self.gt_graphs = list()
Expand Down Expand Up @@ -80,7 +80,7 @@ def n_validation_samples(self):

def set_validation_idxs(self):
if self.validation_ids is None:
k = int(self.validation_split * self.n_examples())
k = int(self.ml_config.validation_split * self.n_examples())
self.validation_idxs = sample(np.arange(self.n_examples), k)
else:
self.validation_idxs = list()
Expand Down
24 changes: 12 additions & 12 deletions src/deep_neurographs/utils/gnn_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,18 @@
from deep_neurographs.utils import util


def get_inputs(data, model_type):
if "Hetero" in model_type:
x = data.x_dict
edge_index = data.edge_index_dict
edge_attr_dict = data.edge_attr_dict
return x, edge_index, edge_attr_dict
else:
x = data.x
edge_index = data.edge_index
return x, edge_index


def toCPU(tensor):
"""
Moves tensor from GPU to CPU.
Expand All @@ -35,18 +47,6 @@ def toCPU(tensor):
return tensor.detach().cpu().tolist()


def get_inputs(data, model_type):
if "Hetero" in model_type:
x = data.x_dict
edge_index = data.edge_index_dict
edge_attr_dict = data.edge_attr_dict
return x, edge_index, edge_attr_dict
else:
x = data.x
edge_index = data.edge_index
return x, edge_index


def to_tensor(my_list):
"""
Converts a list to a tensor with contiguous memory.
Expand Down
Loading