Skip to content

Commit

Permalink
bug: gnn missing edge types (#257)
Browse files Browse the repository at this point in the history
Co-authored-by: anna-grim <anna.grim@alleninstitute.org>
  • Loading branch information
anna-grim and anna-grim authored Oct 1, 2024
1 parent 774b03f commit 84a8a73
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 68 deletions.
4 changes: 2 additions & 2 deletions src/deep_neurographs/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
@author: Anna Grim
@email: anna.grim@alleninstitute.org
Routines for running inference with a machine model that classifies edge proposals.
Routines for running inference with a machine model that classifies edge
proposals.
"""

from datetime import datetime
from time import time
from torch.nn.functional import sigmoid
from torch.utils.data import DataLoader
from tqdm import tqdm

import networkx as nx
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def generate_gnn_features(
"nodes": run_on_nodes(
neurograph, computation_graph, img, downsample_factor
),
"edges": run_on_edges(neurograph, computation_graph),
"edge": run_on_edges(neurograph, computation_graph),
"proposals": run_on_proposals(
neurograph, img, proposals, radius, downsample_factor
),
Expand Down Expand Up @@ -459,13 +459,13 @@ def check_degenerate(voxels):


def n_node_features():
return {'branch': 2, 'proposal': 34}
return {"branch": 2, "proposal": 34}


def n_edge_features():
n_edge_features_dict = {
('proposal', 'edge', 'proposal'): 3,
('branch', 'edge', 'branch'): 3,
('branch', 'edge', 'proposal'): 3
("proposal", "edge", "proposal"): 3,
("branch", "edge", "branch"): 3,
("branch", "edge", "proposal"): 3
}
return n_edge_features_dict
52 changes: 31 additions & 21 deletions src/deep_neurographs/machine_learning/heterograph_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def init(neurograph, features, computation_graph):
"""
# Extract features
x_branches, _, idxs_branches = feature_generation.get_matrix(
neurograph, features["edges"]
neurograph, features["edge"]
)
x_proposals, y_proposals, idxs_proposals = feature_generation.get_matrix(
neurograph, features["proposals"]
Expand Down Expand Up @@ -137,7 +137,7 @@ def __init__(

# Edges
self.init_edges()
self.check_missing_edge_type()
self.check_missing_edge_types()
self.init_edge_attrs(x_nodes)
self.n_edge_attrs = n_edge_features(x_nodes)

Expand All @@ -162,9 +162,7 @@ def init_edges(self):
# Store edges
self.data["proposal", "edge", "proposal"].edge_index = proposal_edges
self.data["branch", "edge", "branch"].edge_index = branch_edges
self.data[
"branch", "edge", "proposal"
].edge_index = branch_proposal_edges
self.data["branch", "edge", "proposal"].edge_index = branch_proposal_edges

def init_edge_attrs(self, x_nodes):
"""
Expand Down Expand Up @@ -193,22 +191,34 @@ def init_edge_attrs(self, x_nodes):
x_nodes, edge_type, self.idxs_branches, self.idxs_proposals
)

def check_missing_edge_type(self):
edge_type = ("branch", "edge", "branch")
if len(self.data[edge_type].edge_index) == 0:
# Add dummy features
dtype = self.data["branch"].x.dtype
zeros = torch.zeros(2, self.n_branch_features(), dtype=dtype)
self.data["branch"].x = torch.cat(
(self.data["branch"].x, zeros), dim=0
)

# Update edge_index
n = self.data["branch"]["x"].size(0)
edge_index = [[n - 1, n - 2], [n - 2, n - 1]]
self.data[edge_type].edge_index = gnn_util.toTensor(edge_index)
self.idxs_branches["idx_to_edge"][n - 1] = frozenset({-1, -2})
self.idxs_branches["idx_to_edge"][n - 2] = frozenset({-2, -3})
def check_missing_edge_types(self):
for node_type in ["branch", "proposal"]:
edge_type = (node_type, "edge", node_type)
if len(self.data[edge_type].edge_index) == 0:
# Add dummy features
dtype = self.data[node_type].x.dtype
if node_type == "branch":
d = self.n_branch_features()
else:
d = self.n_proposal_features()

zeros = torch.zeros(2, d, dtype=dtype)
self.data[node_type].x = torch.cat(
(self.data[node_type].x, zeros), dim=0
)

# Update edge_index
n = self.data[node_type]["x"].size(0)
e_1 = frozenset({-1, -2})
e_2 = frozenset({-2, -3})
edges = [[n - 1, n - 2], [n - 2, n - 1]]
self.data[edge_type].edge_index = gnn_util.toTensor(edges)
if node_type == "branch":
self.idxs_branches["idx_to_edge"][n - 1] = e_1
self.idxs_branches["idx_to_edge"][n - 2] = e_2
else:
self.idxs_proposals["idx_to_edge"][n - 1] = e_1
self.idxs_proposals["idx_to_edge"][n - 2] = e_2

# -- Getters --
def n_branch_features(self):
Expand Down
66 changes: 33 additions & 33 deletions src/deep_neurographs/machine_learning/heterograph_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class HeteroGNN(torch.nn.Module):
def __init__(
self,
device=None,
scale_hidden_dim=2,
scale_hidden=2,
dropout=DROPOUT,
heads_1=HEADS_1,
heads_2=HEADS_2,
Expand All @@ -46,75 +46,75 @@ def __init__(
# Feature vector sizes
node_dict = ml.heterograph_feature_generation.n_node_features()
edge_dict = ml.heterograph_feature_generation.n_edge_features()
hidden_dim = scale_hidden_dim * np.max(list(node_dict.values()))
hidden = scale_hidden * np.max(list(node_dict.values()))

# Linear layers
output_dim = heads_1 * heads_2 * hidden_dim
self.input_nodes = nn.ModuleDict(
{key: nn.Linear(d, hidden_dim, device=device) for key, d in node_dict.items()}
)
self.input_edges = {
key: nn.Linear(d, hidden_dim, device=device) for key, d in edge_dict.items()
}
output_dim = heads_1 * heads_2 * hidden
self.input_nodes = nn.ModuleDict()
self.input_edges = dict()
for key, d in node_dict.items():
self.input_nodes[key] = nn.Linear(d, hidden, device=device)
for key, d in edge_dict.items():
self.input_edges[key] = nn.Linear(d, hidden, device=device)
self.output = Linear(output_dim, 1, device=device)

# Convolutional layers
self.conv1 = HeteroConv(
{
("proposal", "edge", "proposal"): GATConv(
-1,
hidden_dim,
hidden,
dropout=dropout,
edge_dim=hidden_dim,
edge_dim=hidden,
heads=heads_1,
),
("branch", "edge", "branch"): GATConv(
-1,
hidden_dim,
hidden,
dropout=dropout,
edge_dim=hidden_dim,
edge_dim=hidden,
heads=heads_1,
),
("branch", "edge", "proposal"): GATConv(
(hidden_dim, hidden_dim),
hidden_dim,
(hidden, hidden),
hidden,
add_self_loops=False,
edge_dim=hidden_dim,
edge_dim=hidden,
heads=heads_1,
),
},
aggr="sum",
)
edge_dim = hidden_dim
hidden_dim = heads_1 * hidden_dim
edge_dim = hidden
hidden = heads_1 * hidden

self.conv2 = HeteroConv(
{
("proposal", "edge", "proposal"): GATConv(
-1,
hidden_dim,
hidden,
dropout=dropout,
edge_dim=edge_dim,
heads=heads_2,
),
("branch", "edge", "branch"): GATConv(
-1,
hidden_dim,
hidden,
dropout=dropout,
edge_dim=edge_dim,
heads=heads_2,
),
("branch", "edge", "proposal"): GATConv(
(hidden_dim, hidden_dim),
hidden_dim,
(hidden, hidden),
hidden,
add_self_loops=False,
edge_dim=edge_dim,
heads=heads_2,
),
},
aggr="sum",
)
hidden_dim = heads_2 * hidden_dim
hidden = heads_2 * hidden

# Nonlinear activation
self.dropout = Dropout(dropout)
Expand Down Expand Up @@ -193,7 +193,7 @@ class HEATGNN(torch.nn.Module):

def __init__(
self,
hidden_dim,
hidden,
metadata,
node_dict,
edge_dict,
Expand All @@ -208,17 +208,17 @@ def __init__(
super().__init__()
# Linear layers
self.input_nodes = nn.ModuleDict(
{key: nn.Linear(d, hidden_dim) for key, d in node_dict.items()}
{key: nn.Linear(d, hidden) for key, d in node_dict.items()}
)
self.input_edges = {
key: nn.Linear(d, hidden_dim) for key, d in edge_dict.items()
key: nn.Linear(d, hidden) for key, d in edge_dict.items()
}
self.output = Linear(heads_1 * heads_2 * hidden_dim)
self.output = Linear(heads_1 * heads_2 * hidden)

# Convolutional layers
self.conv1 = HEATConv(
hidden_dim,
hidden_dim,
hidden,
hidden,
heads=heads_1,
dropout=dropout,
metadata=metadata,
Expand All @@ -234,16 +234,16 @@ def __init__(
edge_attr_emb_dim (int) – The embedding size of edge features.
heads (int, optional) – Number of multi-head-attentions. (default: 1)
"""
hidden_dim = heads_1 * hidden_dim
hidden = heads_1 * hidden

self.conv2 = HEATConv(
hidden_dim,
hidden_dim,
hidden,
hidden,
heads=heads_2,
dropout=dropout,
metadata=metadata,
)
hidden_dim = heads_2 * hidden_dim
hidden = heads_2 * hidden

# Nonlinear activation
self.dropout = Dropout(dropout)
Expand Down
2 changes: 1 addition & 1 deletion src/deep_neurographs/utils/gnn_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def get_inputs(data, device=None):
if "cuda" in device and torch.cuda.is_available():
x = toGPU(x, device)
edge_index = toGPU(edge_index, device)
edge_attr = toGPU(edge_attr, device)
edge_attr = toGPU(edge_attr, device)
return x, edge_index, edge_attr


Expand Down
2 changes: 1 addition & 1 deletion src/deep_neurographs/utils/graph_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def get_subcomponent_irreducibles(graph, swc_dict, smooth_bool):
else:
xyz_i = swc_dict["xyz"][swc_dict["idx"][i]]
xyz_j = swc_dict["xyz"][swc_dict["idx"][j]]
cur_length += geometry.dist(xyz_i, xyz_j)
cur_length += geometry.dist(xyz_i, xyz_j)

# Visit j
attrs = upd_edge_attrs(swc_dict, attrs, j)
Expand Down
5 changes: 0 additions & 5 deletions src/deep_neurographs/utils/ml_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,8 @@

from deep_neurographs.machine_learning import (
datasets,
feature_generation,
heterograph_datasets,
)
from deep_neurographs.machine_learning.models import (
FeedForwardNet,
MultiModalNet,
)

SUPPORTED_MODELS = ["RandomForest", "GraphNeuralNet"]

Expand Down

0 comments on commit 84a8a73

Please sign in to comment.