Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug: gnn dropout #123

Merged
merged 1 commit into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 165 additions & 4 deletions src/deep_neurographs/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,32 @@

# Directional Vectors
def get_directional(neurograph, i, origin, window_size):
"""
Computes the directional vector of a branch or bifurcation in a neurograph
relative to a specified origin.

Parameters
----------
neurograph : Neurograph
The neurograph object containing the branches.
i : int
The index of the branch or bifurcation in the neurograph.
origin : numpy.ndarray
The origin point xyz relative to which the directional vector is
computed.
window_size : numpy.ndarry
The size of the window around the branch or bifurcation to consider
for computing the directional vector.

Returns
-------
numpy.ndarray
The directional vector of the branch or bifurcation relative to the
specified origin.

"""
branches = neurograph.get_branches(i)
branches = translate_branches(branches, origin)
branches = shift_branches(branches, origin)
if len(branches) == 1:
return compute_tangent(get_subarray(branches[0], window_size))
elif len(branches) == 2:
Expand Down Expand Up @@ -200,6 +224,31 @@ def fit_spline(xyz, s=None):

# Image feature extraction
def get_profile(img, xyz_arr, process_id=None, window=[5, 5, 5]):
"""
Computes the maximum intensity profile along a list of 3D coordinates
in a given image.

Parameters
----------
img : numpy.ndarray
The image volume or TensorStore object from which to extract intensity
profiles.
xyz_arr : numpy.ndarray
Array of 3D coordinates xyz representing points in the image volume.
process_id : int or None, optional
An optional identifier for the process. Default is None.
window : numpy.ndarray, optional
The size of the window around each coordinate for profile extraction.
Default is [5, 5, 5].

Returns
-------
list, tuple
If "process_id" is provided, returns a tuple containing the process_id
and the intensity profile list. If "process_id" is not provided,
returns only the intensity profile list.

"""
profile = []
for xyz in xyz_arr:
if type(img) == ts.TensorStore:
Expand All @@ -214,6 +263,25 @@ def get_profile(img, xyz_arr, process_id=None, window=[5, 5, 5]):


def fill_path(img, path, val=-1):
"""
Fills a given path in a 3D image array with a specified value.

Parameters
----------
img : numpy.ndarray
The 3D image array to fill the path in.
path : iterable
A list or iterable containing 3D coordinates (x, y, z) representing
the path.
val : int, optional
The value to fill the path with. Default is -1.

Returns
-------
numpy.ndarray
The modified image array with the path filled with the specified value.

"""
for xyz in path:
x, y, z = tuple(np.floor(xyz).astype(int))
img[x - 1: x + 2, y - 1: y + 2, z - 1: z + 2] = val
Expand Down Expand Up @@ -415,17 +483,73 @@ def check_dists(xyz_1, xyz_2, xyz_3, radius):


def make_line(xyz_1, xyz_2, n_steps):
"""
Generates a series of points representing a straight line between two 3D
coordinates.

Parameters
----------
xyz_1 : tuple or array-like
The starting 3D coordinate (x, y, z) of the line.
xyz_2 : tuple or array-like
The ending 3D coordinate (x, y, z) of the line.
n_steps : int
The number of steps to interpolate between the two coordinates.

Returns
-------
numpy.ndarray
An array of shape (n_steps, 3) containing the interpolated 3D
coordinates representing the straight line between xyz_1 and xyz_2.

"""
xyz_1 = np.array(xyz_1)
xyz_2 = np.array(xyz_2)
t_steps = np.linspace(0, 1, n_steps)
return np.array([(1 - t) * xyz_1 + t * xyz_2 for t in t_steps], dtype=int)


def normalize(vec, norm="l2"):
return vec / abs(dist(np.zeros((3)), vec, metric=norm))
def normalize(vector, norm="l2"):
"""
Normalizes a vector to have unit length with respect to a specified norm.

Parameters
----------
vector : numpy.ndarray
The input vector to be normalized.
norm : str, optional
The norm to use for normalization. Default is "l2".

Returns
-------
numpy.ndarray
The normalized vector with unit length with respect to the specified
norm.

"""
return vector / abs(dist(np.zeros((3)), vector, metric=norm))


def nearest_neighbor(xyz_arr, xyz):
"""
Finds the nearest neighbor in a list of 3D coordinates to a given target
coordinate.

Parameters
----------
xyz_arr : numpy.ndarray
Array of 3D coordinates to search for the nearest neighbor.
xyz : numpy.ndarray
The target 3D coordinate xyz to find the nearest neighbor to.

Returns
-------
tuple[int, float]
A tuple containing the index of the nearest neighbor in "xyz_arr" and
the distance between the target coordinate `xyz` and its nearest
neighbor.

"""
min_dist = np.inf
idx = None
for i, xyz_i in enumerate(xyz_arr):
Expand All @@ -436,12 +560,49 @@ def nearest_neighbor(xyz_arr, xyz):
return idx, min_dist


def translate_branches(branches, shift):
def shift_branches(branches, shift):
"""
Shifts the coordinates of branches in a list of arrays by a specified
shift vector.

Parameters
----------
branches : list
A list containing arrays of 3D coordinates representing branches.
shift : numpy.ndarray
The shift vector (dx, dy, dz) by which to shift the coordinates.

Returns
-------
list
A list containing arrays of shifted 3D coordinates representing the
branches.

"""
for i, branch in enumerate(branches):
branches[i] = branch - shift
return branches


def query_ball(kdtree, xyz, radius):
"""
Queries a KD-tree for points within a given radius from a target point.

Parameters
----------
kdtree : scipy.spatial.cKDTree
The KD-tree data structure containing the points to query.
xyz : numpy.ndarray
The target 3D coordinate (x, y, z) around which to search for points.
radius : float
The radius within which to search for points.

Returns
-------
numpy.ndarray
An array containing the points within the specified radius from the
target coordinate.

"""
idxs = kdtree.query_ball_point(xyz, radius, return_sorted=True)
return kdtree.data[idxs]
10 changes: 9 additions & 1 deletion src/deep_neurographs/machine_learning/feature_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,8 +294,10 @@ def generate_skel_features(neurograph, proposals):
neurograph.proposal_length(proposal),
neurograph.degree[i],
neurograph.degree[j],
n_nearby_leafs(neurograph, proposal),
get_radii(neurograph, proposal),
get_avg_radii(neurograph, proposal),
get_avg_branch_lens(neurograph, proposal),
get_directionals(neurograph, proposal, 8),
get_directionals(neurograph, proposal, 16),
get_directionals(neurograph, proposal, 32),
Expand Down Expand Up @@ -363,12 +365,18 @@ def avg_branch_radii(neurograph, edge):
return np.array([np.mean(neurograph.edges[edge]["radius"])])


def n_nearby_leafs(neurograph, proposal):
xyz = neurograph.proposal_midpoint(proposal)
leafs = neurograph.query_kdtree(xyz, 25, node_type="leaf")
return len(leafs)


# --- Edge Feature Generation --
def generate_branch_features(neurograph, edges):
features = dict()
for (i, j) in edges:
edge = frozenset((i, j))
features[edge] = np.zeros((31))
features[edge] = np.zeros((34))

temp = np.concatenate(
(
Expand Down
58 changes: 43 additions & 15 deletions src/deep_neurographs/machine_learning/graph_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,55 +10,83 @@

import torch
import torch.nn.functional as F
from torch.nn import ELU, Linear
from torch_geometric.nn import GATConv, GCNConv
from torch.nn import ELU, Dropout, Linear
import torch.nn.init as init
from torch_geometric.nn import GATv2Conv as GATConv
from torch_geometric.nn import GCNConv


class GCN(torch.nn.Module):
def __init__(self, input_channels):
super().__init__()
self.conv1 = GCNConv(input_channels, input_channels)
self.conv2 = GCNConv(input_channels, input_channels // 2)
self.conv3 = GCNConv(input_channels // 2, 1)
self.input = Linear(input_channels, input_channels)
self.conv1 = GCNConv(input_channels, 2 * input_channels)
self.conv2 = GCNConv(2 * input_channels, input_channels)
self.conv3 = GCNConv(input_channels, input_channels // 2)
self.dropout = Dropout(0.3)
self.ELU = ELU()
self.output = Linear(input_channels // 2, 1)

# Initialize weights
self.init_weights()

def init_weights(self):
layers = [self.conv1, self.conv2, self.conv3]
#, self.input, self.output]
for layer in layers:
for param in layer.parameters():
if len(param.shape) > 1:
# Initialize weights using Glorot uniform initialization
init.xavier_uniform_(param)
else:
# Initialize biases to zeros
init.zeros_(param)

def forward(self, x, edge_index):
# Input
x = self.input(x)

# Layer 1
x = self.conv1(x, edge_index)
x = self.ELU(x)
x = F.dropout(x, p=0.25)
x = self.dropout(x)

# Layer 2
x = self.conv2(x, edge_index)
x = self.ELU(x)
x = F.dropout(x, p=0.25)
x = self.dropout(x)

# Layer 3
x = self.conv3(x, edge_index)

# Output
x = self.output(x)

return x


class GAT(torch.nn.Module):
def __init__(self, input_channels):
super().__init__()
self.conv1 = GATConv(input_channels, input_channels)
self.conv2 = GATConv(input_channels, input_channels // 2)
self.conv1 = GATConv(input_channels, 2 * input_channels)
self.conv2 = GATConv(2 * input_channels, input_channels // 2)
self.conv3 = GATConv(input_channels // 2, 1)
self.dropout = Dropout(0.3)
self.ELU = ELU()

def forward(self, x, edge_index):
# Layer 1
x = self.conv1(x, edge_index)
# x = self.ELU(x)
# x = F.dropout(x, p=0.25)
x = self.ELU(x)
x = self.dropout(x)

# Layer 2
# x = self.conv2(x, edge_index)
# x = self.ELU(x)
# x = F.dropout(x, p=0.25)
x = self.conv2(x, edge_index)
x = self.ELU(x)
x = self.dropout(x)

# Layer 3
# x = self.conv3(x, edge_index)
x = self.conv3(x, edge_index)
return x


Expand Down
Loading
Loading