diff --git a/src/deep_neurographs/geometry.py b/src/deep_neurographs/geometry.py index 69cb788..9065917 100644 --- a/src/deep_neurographs/geometry.py +++ b/src/deep_neurographs/geometry.py @@ -17,8 +17,32 @@ # Directional Vectors def get_directional(neurograph, i, origin, window_size): + """ + Computes the directional vector of a branch or bifurcation in a neurograph + relative to a specified origin. + + Parameters + ---------- + neurograph : Neurograph + The neurograph object containing the branches. + i : int + The index of the branch or bifurcation in the neurograph. + origin : numpy.ndarray + The origin point xyz relative to which the directional vector is + computed. + window_size : numpy.ndarry + The size of the window around the branch or bifurcation to consider + for computing the directional vector. + + Returns + ------- + numpy.ndarray + The directional vector of the branch or bifurcation relative to the + specified origin. + + """ branches = neurograph.get_branches(i) - branches = translate_branches(branches, origin) + branches = shift_branches(branches, origin) if len(branches) == 1: return compute_tangent(get_subarray(branches[0], window_size)) elif len(branches) == 2: @@ -200,6 +224,31 @@ def fit_spline(xyz, s=None): # Image feature extraction def get_profile(img, xyz_arr, process_id=None, window=[5, 5, 5]): + """ + Computes the maximum intensity profile along a list of 3D coordinates + in a given image. + + Parameters + ---------- + img : numpy.ndarray + The image volume or TensorStore object from which to extract intensity + profiles. + xyz_arr : numpy.ndarray + Array of 3D coordinates xyz representing points in the image volume. + process_id : int or None, optional + An optional identifier for the process. Default is None. + window : numpy.ndarray, optional + The size of the window around each coordinate for profile extraction. + Default is [5, 5, 5]. + + Returns + ------- + list, tuple + If "process_id" is provided, returns a tuple containing the process_id + and the intensity profile list. If "process_id" is not provided, + returns only the intensity profile list. + + """ profile = [] for xyz in xyz_arr: if type(img) == ts.TensorStore: @@ -214,6 +263,25 @@ def get_profile(img, xyz_arr, process_id=None, window=[5, 5, 5]): def fill_path(img, path, val=-1): + """ + Fills a given path in a 3D image array with a specified value. + + Parameters + ---------- + img : numpy.ndarray + The 3D image array to fill the path in. + path : iterable + A list or iterable containing 3D coordinates (x, y, z) representing + the path. + val : int, optional + The value to fill the path with. Default is -1. + + Returns + ------- + numpy.ndarray + The modified image array with the path filled with the specified value. + + """ for xyz in path: x, y, z = tuple(np.floor(xyz).astype(int)) img[x - 1: x + 2, y - 1: y + 2, z - 1: z + 2] = val @@ -415,17 +483,73 @@ def check_dists(xyz_1, xyz_2, xyz_3, radius): def make_line(xyz_1, xyz_2, n_steps): + """ + Generates a series of points representing a straight line between two 3D + coordinates. + + Parameters + ---------- + xyz_1 : tuple or array-like + The starting 3D coordinate (x, y, z) of the line. + xyz_2 : tuple or array-like + The ending 3D coordinate (x, y, z) of the line. + n_steps : int + The number of steps to interpolate between the two coordinates. + + Returns + ------- + numpy.ndarray + An array of shape (n_steps, 3) containing the interpolated 3D + coordinates representing the straight line between xyz_1 and xyz_2. + + """ xyz_1 = np.array(xyz_1) xyz_2 = np.array(xyz_2) t_steps = np.linspace(0, 1, n_steps) return np.array([(1 - t) * xyz_1 + t * xyz_2 for t in t_steps], dtype=int) -def normalize(vec, norm="l2"): - return vec / abs(dist(np.zeros((3)), vec, metric=norm)) +def normalize(vector, norm="l2"): + """ + Normalizes a vector to have unit length with respect to a specified norm. + + Parameters + ---------- + vector : numpy.ndarray + The input vector to be normalized. + norm : str, optional + The norm to use for normalization. Default is "l2". + + Returns + ------- + numpy.ndarray + The normalized vector with unit length with respect to the specified + norm. + + """ + return vector / abs(dist(np.zeros((3)), vector, metric=norm)) def nearest_neighbor(xyz_arr, xyz): + """ + Finds the nearest neighbor in a list of 3D coordinates to a given target + coordinate. + + Parameters + ---------- + xyz_arr : numpy.ndarray + Array of 3D coordinates to search for the nearest neighbor. + xyz : numpy.ndarray + The target 3D coordinate xyz to find the nearest neighbor to. + + Returns + ------- + tuple[int, float] + A tuple containing the index of the nearest neighbor in "xyz_arr" and + the distance between the target coordinate `xyz` and its nearest + neighbor. + + """ min_dist = np.inf idx = None for i, xyz_i in enumerate(xyz_arr): @@ -436,12 +560,49 @@ def nearest_neighbor(xyz_arr, xyz): return idx, min_dist -def translate_branches(branches, shift): +def shift_branches(branches, shift): + """ + Shifts the coordinates of branches in a list of arrays by a specified + shift vector. + + Parameters + ---------- + branches : list + A list containing arrays of 3D coordinates representing branches. + shift : numpy.ndarray + The shift vector (dx, dy, dz) by which to shift the coordinates. + + Returns + ------- + list + A list containing arrays of shifted 3D coordinates representing the + branches. + + """ for i, branch in enumerate(branches): branches[i] = branch - shift return branches def query_ball(kdtree, xyz, radius): + """ + Queries a KD-tree for points within a given radius from a target point. + + Parameters + ---------- + kdtree : scipy.spatial.cKDTree + The KD-tree data structure containing the points to query. + xyz : numpy.ndarray + The target 3D coordinate (x, y, z) around which to search for points. + radius : float + The radius within which to search for points. + + Returns + ------- + numpy.ndarray + An array containing the points within the specified radius from the + target coordinate. + + """ idxs = kdtree.query_ball_point(xyz, radius, return_sorted=True) return kdtree.data[idxs] diff --git a/src/deep_neurographs/machine_learning/feature_generation.py b/src/deep_neurographs/machine_learning/feature_generation.py index 855da25..f158299 100644 --- a/src/deep_neurographs/machine_learning/feature_generation.py +++ b/src/deep_neurographs/machine_learning/feature_generation.py @@ -294,8 +294,10 @@ def generate_skel_features(neurograph, proposals): neurograph.proposal_length(proposal), neurograph.degree[i], neurograph.degree[j], + n_nearby_leafs(neurograph, proposal), get_radii(neurograph, proposal), get_avg_radii(neurograph, proposal), + get_avg_branch_lens(neurograph, proposal), get_directionals(neurograph, proposal, 8), get_directionals(neurograph, proposal, 16), get_directionals(neurograph, proposal, 32), @@ -363,12 +365,18 @@ def avg_branch_radii(neurograph, edge): return np.array([np.mean(neurograph.edges[edge]["radius"])]) +def n_nearby_leafs(neurograph, proposal): + xyz = neurograph.proposal_midpoint(proposal) + leafs = neurograph.query_kdtree(xyz, 25, node_type="leaf") + return len(leafs) + + # --- Edge Feature Generation -- def generate_branch_features(neurograph, edges): features = dict() for (i, j) in edges: edge = frozenset((i, j)) - features[edge] = np.zeros((31)) + features[edge] = np.zeros((34)) temp = np.concatenate( ( diff --git a/src/deep_neurographs/machine_learning/graph_models.py b/src/deep_neurographs/machine_learning/graph_models.py index ca7eefe..49b74d4 100644 --- a/src/deep_neurographs/machine_learning/graph_models.py +++ b/src/deep_neurographs/machine_learning/graph_models.py @@ -10,55 +10,83 @@ import torch import torch.nn.functional as F -from torch.nn import ELU, Linear -from torch_geometric.nn import GATConv, GCNConv +from torch.nn import ELU, Dropout, Linear +import torch.nn.init as init +from torch_geometric.nn import GATv2Conv as GATConv +from torch_geometric.nn import GCNConv class GCN(torch.nn.Module): def __init__(self, input_channels): super().__init__() - self.conv1 = GCNConv(input_channels, input_channels) - self.conv2 = GCNConv(input_channels, input_channels // 2) - self.conv3 = GCNConv(input_channels // 2, 1) + self.input = Linear(input_channels, input_channels) + self.conv1 = GCNConv(input_channels, 2 * input_channels) + self.conv2 = GCNConv(2 * input_channels, input_channels) + self.conv3 = GCNConv(input_channels, input_channels // 2) + self.dropout = Dropout(0.3) self.ELU = ELU() + self.output = Linear(input_channels // 2, 1) + + # Initialize weights + self.init_weights() + + def init_weights(self): + layers = [self.conv1, self.conv2, self.conv3] + #, self.input, self.output] + for layer in layers: + for param in layer.parameters(): + if len(param.shape) > 1: + # Initialize weights using Glorot uniform initialization + init.xavier_uniform_(param) + else: + # Initialize biases to zeros + init.zeros_(param) def forward(self, x, edge_index): + # Input + x = self.input(x) + # Layer 1 x = self.conv1(x, edge_index) x = self.ELU(x) - x = F.dropout(x, p=0.25) + x = self.dropout(x) # Layer 2 x = self.conv2(x, edge_index) x = self.ELU(x) - x = F.dropout(x, p=0.25) + x = self.dropout(x) # Layer 3 x = self.conv3(x, edge_index) + + # Output + x = self.output(x) + return x class GAT(torch.nn.Module): def __init__(self, input_channels): super().__init__() - self.conv1 = GATConv(input_channels, input_channels) - self.conv2 = GATConv(input_channels, input_channels // 2) + self.conv1 = GATConv(input_channels, 2 * input_channels) + self.conv2 = GATConv(2 * input_channels, input_channels // 2) self.conv3 = GATConv(input_channels // 2, 1) + self.dropout = Dropout(0.3) self.ELU = ELU() def forward(self, x, edge_index): # Layer 1 x = self.conv1(x, edge_index) - # x = self.ELU(x) - # x = F.dropout(x, p=0.25) + x = self.ELU(x) + x = self.dropout(x) # Layer 2 - # x = self.conv2(x, edge_index) - # x = self.ELU(x) - # x = F.dropout(x, p=0.25) + x = self.conv2(x, edge_index) + x = self.ELU(x) + x = self.dropout(x) # Layer 3 - # x = self.conv3(x, edge_index) + x = self.conv3(x, edge_index) return x diff --git a/src/deep_neurographs/machine_learning/graph_trainer.py b/src/deep_neurographs/machine_learning/graph_trainer.py index 9e556b5..dab4df0 100644 --- a/src/deep_neurographs/machine_learning/graph_trainer.py +++ b/src/deep_neurographs/machine_learning/graph_trainer.py @@ -19,6 +19,7 @@ precision_score, recall_score, ) +from torch.optim.lr_scheduler import StepLR from torch.utils.tensorboard import SummaryWriter from deep_neurographs.machine_learning import ml_utils @@ -26,7 +27,7 @@ LR = 1e-3 N_EPOCHS = 1000 TEST_PERCENT = 0.15 -WEIGHT_DECAY = 5e-4 +WEIGHT_DECAY = 5e-3 class GraphTrainer: @@ -93,6 +94,7 @@ def run_on_graphs(self, graph_datasets): # Initializations best_score = -np.inf best_ckpt = None + scheduler = StepLR(self.optimizer, step_size=500, gamma=0.5) # Main train_ids, test_ids = train_test_split(list(graph_datasets.keys())) @@ -105,6 +107,7 @@ def run_on_graphs(self, graph_datasets): y.extend(toCPU(y_i)) hat_y.extend(toCPU(hat_y_i)) self.compute_metrics(y, hat_y, "train", epoch) + scheduler.step() # Test if epoch % 10 == 0: @@ -120,7 +123,8 @@ def run_on_graphs(self, graph_datasets): if test_score > best_score: best_score = test_score best_ckpt = deepcopy(self.model.state_dict()) - return self.model.load_state_dict(best_ckpt) + self.model.load_state_dict(best_ckpt) + return self.model def run_on_graph(self): """ @@ -288,7 +292,7 @@ def train_test_split(graph_ids): """ n_test_examples = int(len(graph_ids) * TEST_PERCENT) - test_ids = sample(graph_ids, n_test_examples) + test_ids = ["block_007", "block_010"] # sample(graph_ids, n_test_examples) train_ids = list(set(graph_ids) - set(test_ids)) return train_ids, test_ids diff --git a/src/deep_neurographs/machine_learning/trainer.py b/src/deep_neurographs/machine_learning/trainer.py index 8e02e27..d1c179a 100644 --- a/src/deep_neurographs/machine_learning/trainer.py +++ b/src/deep_neurographs/machine_learning/trainer.py @@ -87,9 +87,9 @@ def fit_deep_model( # Fit model pylightning_trainer = pl.Trainer( - # accelerator="gpu", + accelerator="gpu", callbacks=[ckpt_callback], - # devices=1, + devices=1, enable_model_summary=False, enable_progress_bar=False, logger=logger, diff --git a/src/deep_neurographs/reconstruction.py b/src/deep_neurographs/reconstruction.py index 2e4eaa5..df31fd0 100644 --- a/src/deep_neurographs/reconstruction.py +++ b/src/deep_neurographs/reconstruction.py @@ -9,8 +9,6 @@ """ import os -from concurrent.futures import ProcessPoolExecutor, as_completed - import networkx as nx import numpy as np @@ -40,13 +38,12 @@ def get_accepted_propoals_blocks( # Get accepts if structure_aware: - graph = neurographs[block_id].copy() + graph = neurographs[block_id].copy_graph() accepts[block_id], _ = get_structure_aware_accepts( neurographs[block_id], graph, preds_upd, high_threshold=high_threshold, - low_threshold=low_threshold, ) else: @@ -67,11 +64,7 @@ def get_accepted_proposals( preds = threshold_preds(preds, idx_to_edge, low_threshold) if structure_aware: return get_structure_aware_accepts( - neurograph, - graph, - preds, - high_threshold=high_threshold, - low_threshold=low_threshold, + neurograph, graph, preds, high_threshold=high_threshold ) else: return preds.keys() @@ -110,9 +103,7 @@ def threshold_preds(preds, idx_to_edge, threshold, valid_idxs=[]): return thresholded_preds -def get_structure_aware_accepts( - neurograph, graph, preds, high_threshold=0.9, low_threshold=0.6 -): +def get_structure_aware_accepts(neurograph, graph, preds, high_threshold=0.9): # Add best preds best_preds, best_probs = get_best_preds(neurograph, preds, high_threshold) accepts, graph = check_cycles_sequential(graph, best_preds, best_probs) @@ -162,49 +153,9 @@ def get_subgraphs(graph, edge): return False -def check_cycles_parallelized(graph, edge_list): - """ - Checks whether each edge in "edge_list" creates a cycle in "graph" in a - with a parallelized algorithm. - - Parameters - ---------- - graph : networkx.Graph - Graph to be searched. - edge_list : list - List of edges to be checked. - - Returns - ------- - graph : networkx.Graph - Graph with each edge in "edge_list" added - fail : bool - Indication of whether a cycle was created due to parallelization. - - """ - # Assign processes - with ProcessPoolExecutor() as executor: - processes = [] - for edge in edge_list: - subgraph = get_subgraphs(graph, edge) - executor.submit(gutils.creates_cycle, subgraph, edge) - - # Store result - accepts = [] - for process in as_completed(processes): - created_cycle, edge = process.result() - if not created_cycle: - accepts.append(edge) - graph.add_edge_from([edge]) - - fail = True if gutils.cycle_exists(graph) else False - return accepts, fail - - def check_cycles_sequential(graph, edges, probs): accepts = [] for i in np.argsort(probs): - print(i, edges) subgraph = get_subgraphs(graph, edges[i]) if subgraph: created_cycle, _ = gutils.creates_cycle(subgraph, tuple(edges[i]))