Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug: gnn missing edge types #257

Merged
merged 1 commit into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/deep_neurographs/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
@author: Anna Grim
@email: anna.grim@alleninstitute.org

Routines for running inference with a machine model that classifies edge proposals.
Routines for running inference with a machine model that classifies edge
proposals.

"""

from datetime import datetime
from time import time
from torch.nn.functional import sigmoid
from torch.utils.data import DataLoader
from tqdm import tqdm

import networkx as nx
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def generate_gnn_features(
"nodes": run_on_nodes(
neurograph, computation_graph, img, downsample_factor
),
"edges": run_on_edges(neurograph, computation_graph),
"edge": run_on_edges(neurograph, computation_graph),
"proposals": run_on_proposals(
neurograph, img, proposals, radius, downsample_factor
),
Expand Down Expand Up @@ -459,13 +459,13 @@ def check_degenerate(voxels):


def n_node_features():
return {'branch': 2, 'proposal': 34}
return {"branch": 2, "proposal": 34}


def n_edge_features():
n_edge_features_dict = {
('proposal', 'edge', 'proposal'): 3,
('branch', 'edge', 'branch'): 3,
('branch', 'edge', 'proposal'): 3
("proposal", "edge", "proposal"): 3,
("branch", "edge", "branch"): 3,
("branch", "edge", "proposal"): 3
}
return n_edge_features_dict
52 changes: 31 additions & 21 deletions src/deep_neurographs/machine_learning/heterograph_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def init(neurograph, features, computation_graph):
"""
# Extract features
x_branches, _, idxs_branches = feature_generation.get_matrix(
neurograph, features["edges"]
neurograph, features["edge"]
)
x_proposals, y_proposals, idxs_proposals = feature_generation.get_matrix(
neurograph, features["proposals"]
Expand Down Expand Up @@ -137,7 +137,7 @@ def __init__(

# Edges
self.init_edges()
self.check_missing_edge_type()
self.check_missing_edge_types()
self.init_edge_attrs(x_nodes)
self.n_edge_attrs = n_edge_features(x_nodes)

Expand All @@ -162,9 +162,7 @@ def init_edges(self):
# Store edges
self.data["proposal", "edge", "proposal"].edge_index = proposal_edges
self.data["branch", "edge", "branch"].edge_index = branch_edges
self.data[
"branch", "edge", "proposal"
].edge_index = branch_proposal_edges
self.data["branch", "edge", "proposal"].edge_index = branch_proposal_edges

def init_edge_attrs(self, x_nodes):
"""
Expand Down Expand Up @@ -193,22 +191,34 @@ def init_edge_attrs(self, x_nodes):
x_nodes, edge_type, self.idxs_branches, self.idxs_proposals
)

def check_missing_edge_type(self):
edge_type = ("branch", "edge", "branch")
if len(self.data[edge_type].edge_index) == 0:
# Add dummy features
dtype = self.data["branch"].x.dtype
zeros = torch.zeros(2, self.n_branch_features(), dtype=dtype)
self.data["branch"].x = torch.cat(
(self.data["branch"].x, zeros), dim=0
)

# Update edge_index
n = self.data["branch"]["x"].size(0)
edge_index = [[n - 1, n - 2], [n - 2, n - 1]]
self.data[edge_type].edge_index = gnn_util.toTensor(edge_index)
self.idxs_branches["idx_to_edge"][n - 1] = frozenset({-1, -2})
self.idxs_branches["idx_to_edge"][n - 2] = frozenset({-2, -3})
def check_missing_edge_types(self):
for node_type in ["branch", "proposal"]:
edge_type = (node_type, "edge", node_type)
if len(self.data[edge_type].edge_index) == 0:
# Add dummy features
dtype = self.data[node_type].x.dtype
if node_type == "branch":
d = self.n_branch_features()
else:
d = self.n_proposal_features()

zeros = torch.zeros(2, d, dtype=dtype)
self.data[node_type].x = torch.cat(
(self.data[node_type].x, zeros), dim=0
)

# Update edge_index
n = self.data[node_type]["x"].size(0)
e_1 = frozenset({-1, -2})
e_2 = frozenset({-2, -3})
edges = [[n - 1, n - 2], [n - 2, n - 1]]
self.data[edge_type].edge_index = gnn_util.toTensor(edges)
if node_type == "branch":
self.idxs_branches["idx_to_edge"][n - 1] = e_1
self.idxs_branches["idx_to_edge"][n - 2] = e_2
else:
self.idxs_proposals["idx_to_edge"][n - 1] = e_1
self.idxs_proposals["idx_to_edge"][n - 2] = e_2

# -- Getters --
def n_branch_features(self):
Expand Down
66 changes: 33 additions & 33 deletions src/deep_neurographs/machine_learning/heterograph_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class HeteroGNN(torch.nn.Module):
def __init__(
self,
device=None,
scale_hidden_dim=2,
scale_hidden=2,
dropout=DROPOUT,
heads_1=HEADS_1,
heads_2=HEADS_2,
Expand All @@ -46,75 +46,75 @@ def __init__(
# Feature vector sizes
node_dict = ml.heterograph_feature_generation.n_node_features()
edge_dict = ml.heterograph_feature_generation.n_edge_features()
hidden_dim = scale_hidden_dim * np.max(list(node_dict.values()))
hidden = scale_hidden * np.max(list(node_dict.values()))

# Linear layers
output_dim = heads_1 * heads_2 * hidden_dim
self.input_nodes = nn.ModuleDict(
{key: nn.Linear(d, hidden_dim, device=device) for key, d in node_dict.items()}
)
self.input_edges = {
key: nn.Linear(d, hidden_dim, device=device) for key, d in edge_dict.items()
}
output_dim = heads_1 * heads_2 * hidden
self.input_nodes = nn.ModuleDict()
self.input_edges = dict()
for key, d in node_dict.items():
self.input_nodes[key] = nn.Linear(d, hidden, device=device)
for key, d in edge_dict.items():
self.input_edges[key] = nn.Linear(d, hidden, device=device)
self.output = Linear(output_dim, 1, device=device)

# Convolutional layers
self.conv1 = HeteroConv(
{
("proposal", "edge", "proposal"): GATConv(
-1,
hidden_dim,
hidden,
dropout=dropout,
edge_dim=hidden_dim,
edge_dim=hidden,
heads=heads_1,
),
("branch", "edge", "branch"): GATConv(
-1,
hidden_dim,
hidden,
dropout=dropout,
edge_dim=hidden_dim,
edge_dim=hidden,
heads=heads_1,
),
("branch", "edge", "proposal"): GATConv(
(hidden_dim, hidden_dim),
hidden_dim,
(hidden, hidden),
hidden,
add_self_loops=False,
edge_dim=hidden_dim,
edge_dim=hidden,
heads=heads_1,
),
},
aggr="sum",
)
edge_dim = hidden_dim
hidden_dim = heads_1 * hidden_dim
edge_dim = hidden
hidden = heads_1 * hidden

self.conv2 = HeteroConv(
{
("proposal", "edge", "proposal"): GATConv(
-1,
hidden_dim,
hidden,
dropout=dropout,
edge_dim=edge_dim,
heads=heads_2,
),
("branch", "edge", "branch"): GATConv(
-1,
hidden_dim,
hidden,
dropout=dropout,
edge_dim=edge_dim,
heads=heads_2,
),
("branch", "edge", "proposal"): GATConv(
(hidden_dim, hidden_dim),
hidden_dim,
(hidden, hidden),
hidden,
add_self_loops=False,
edge_dim=edge_dim,
heads=heads_2,
),
},
aggr="sum",
)
hidden_dim = heads_2 * hidden_dim
hidden = heads_2 * hidden

# Nonlinear activation
self.dropout = Dropout(dropout)
Expand Down Expand Up @@ -193,7 +193,7 @@ class HEATGNN(torch.nn.Module):

def __init__(
self,
hidden_dim,
hidden,
metadata,
node_dict,
edge_dict,
Expand All @@ -208,17 +208,17 @@ def __init__(
super().__init__()
# Linear layers
self.input_nodes = nn.ModuleDict(
{key: nn.Linear(d, hidden_dim) for key, d in node_dict.items()}
{key: nn.Linear(d, hidden) for key, d in node_dict.items()}
)
self.input_edges = {
key: nn.Linear(d, hidden_dim) for key, d in edge_dict.items()
key: nn.Linear(d, hidden) for key, d in edge_dict.items()
}
self.output = Linear(heads_1 * heads_2 * hidden_dim)
self.output = Linear(heads_1 * heads_2 * hidden)

# Convolutional layers
self.conv1 = HEATConv(
hidden_dim,
hidden_dim,
hidden,
hidden,
heads=heads_1,
dropout=dropout,
metadata=metadata,
Expand All @@ -234,16 +234,16 @@ def __init__(
edge_attr_emb_dim (int) – The embedding size of edge features.
heads (int, optional) – Number of multi-head-attentions. (default: 1)
"""
hidden_dim = heads_1 * hidden_dim
hidden = heads_1 * hidden

self.conv2 = HEATConv(
hidden_dim,
hidden_dim,
hidden,
hidden,
heads=heads_2,
dropout=dropout,
metadata=metadata,
)
hidden_dim = heads_2 * hidden_dim
hidden = heads_2 * hidden

# Nonlinear activation
self.dropout = Dropout(dropout)
Expand Down
2 changes: 1 addition & 1 deletion src/deep_neurographs/utils/gnn_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def get_inputs(data, device=None):
if "cuda" in device and torch.cuda.is_available():
x = toGPU(x, device)
edge_index = toGPU(edge_index, device)
edge_attr = toGPU(edge_attr, device)
edge_attr = toGPU(edge_attr, device)
return x, edge_index, edge_attr


Expand Down
2 changes: 1 addition & 1 deletion src/deep_neurographs/utils/graph_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def get_subcomponent_irreducibles(graph, swc_dict, smooth_bool):
else:
xyz_i = swc_dict["xyz"][swc_dict["idx"][i]]
xyz_j = swc_dict["xyz"][swc_dict["idx"][j]]
cur_length += geometry.dist(xyz_i, xyz_j)
cur_length += geometry.dist(xyz_i, xyz_j)

# Visit j
attrs = upd_edge_attrs(swc_dict, attrs, j)
Expand Down
5 changes: 0 additions & 5 deletions src/deep_neurographs/utils/ml_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,8 @@

from deep_neurographs.machine_learning import (
datasets,
feature_generation,
heterograph_datasets,
)
from deep_neurographs.machine_learning.models import (
FeedForwardNet,
MultiModalNet,
)

SUPPORTED_MODELS = ["RandomForest", "GraphNeuralNet"]

Expand Down
Loading