From 5ee38adc398c8112bed97dcdedcaf5a38df9c156 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Fri, 5 Apr 2024 17:17:49 +0000 Subject: [PATCH 1/3] bug: groundtruth generation --- .../machine_learning/groundtruth_generation.py | 9 +++++---- src/deep_neurographs/neurograph.py | 4 ++-- src/deep_neurographs/reconstruction.py | 9 +++++---- src/deep_neurographs/visualization.py | 8 ++++---- 4 files changed, 16 insertions(+), 14 deletions(-) diff --git a/src/deep_neurographs/machine_learning/groundtruth_generation.py b/src/deep_neurographs/machine_learning/groundtruth_generation.py index 7b3de7b..92da7db 100644 --- a/src/deep_neurographs/machine_learning/groundtruth_generation.py +++ b/src/deep_neurographs/machine_learning/groundtruth_generation.py @@ -27,12 +27,13 @@ def init_targets(target_neurograph, pred_neurograph): valid_proposals = get_valid_proposals(target_neurograph, pred_neurograph) # Add best simple edges - dists = [pred_neurograph.proposal_length(p) for p in valid_proposals] - groundtruth_graph = pred_neurograph.copy_graph() + dists = [pred_neurograph.proposal_length(e) for e in valid_proposals] + graph = pred_neurograph.copy_graph() for idx in np.argsort(dists): edge = valid_proposals[idx] - if not gutils.creates_cycle(groundtruth_graph, tuple(edge)): - groundtruth_graph.add_edges_from([edge]) + created_cycle, _ = gutils.creates_cycle(graph, tuple(edge)) + if not created_cycle: + graph.add_edges_from([edge]) target_edges.add(edge) return target_edges diff --git a/src/deep_neurographs/neurograph.py b/src/deep_neurographs/neurograph.py index eb972f6..6e543e2 100644 --- a/src/deep_neurographs/neurograph.py +++ b/src/deep_neurographs/neurograph.py @@ -306,7 +306,7 @@ def generate_proposals( # Get connection (i, j) = self.xyz_to_edge[xyz] node, xyz = self.__get_connection(leaf, xyz, (i, j), radius) - if not complex_proposals and self.degree[node] >= 2: + if not complex_proposals and self.degree[node] > 1: continue # Check whether connection exists @@ -329,7 +329,7 @@ def generate_proposals( existing_connections[pair_id] = frozenset({leaf, node}) # Finish - #self.filter_nodes() + self.filter_nodes() self.init_kdtree(node_type="leaf") self.init_kdtree(node_type="proposal") if optimize: diff --git a/src/deep_neurographs/reconstruction.py b/src/deep_neurographs/reconstruction.py index 94a340f..a38b369 100644 --- a/src/deep_neurographs/reconstruction.py +++ b/src/deep_neurographs/reconstruction.py @@ -23,7 +23,6 @@ def get_accepted_propoals_blocks( neurographs, - graph, preds, blocks, block_to_idxs, @@ -44,6 +43,7 @@ def get_accepted_propoals_blocks( # Refine accepts wrt structure if structure_aware: + graph = neurographs[block_id].copy() accepts[block_id] = get_structure_aware_accepts( neurographs[block_id], graph, @@ -104,6 +104,7 @@ def threshold_preds(preds, idx_to_edge, threshold, valid_idxs=[]): predicted probability. """ + print(preds) thresholded_preds = dict() for i, pred_i in enumerate(preds): contained_bool = True if len(valid_idxs) == 0 else i in valid_idxs @@ -117,7 +118,7 @@ def get_structure_aware_accepts( ): # Add best preds best_preds, best_probs = get_best_preds(neurograph, preds, high_threshold) - accepts = check_cycles_sequential(neurograph, best_preds, best_probs) + accepts = check_cycles_sequential(graph, best_preds, best_probs) if len(best_preds) == len(preds.keys()): return accepts @@ -130,7 +131,7 @@ def get_structure_aware_accepts( good_preds.append(edge) good_probs.append(prob) - more_accepts = check_cycles_sequential(neurograph, good_preds, good_probs) + more_accepts = check_cycles_sequential(graph, good_preds, good_probs) accepts.extend(more_accepts) return accepts @@ -203,7 +204,7 @@ def check_cycles_sequential(graph, edges, probs): subgraph = get_subgraphs(graph, edges[i]) created_cycle, _ = gutils.creates_cycle(subgraph, tuple(edges[i])) if not created_cycle: - graph.add_edges_from(tuple(edges[i])) + graph.add_edges_from([tuple(edges[i])]) accepts.append(edges[i]) return accepts diff --git a/src/deep_neurographs/visualization.py b/src/deep_neurographs/visualization.py index 36e18c0..a674fe8 100644 --- a/src/deep_neurographs/visualization.py +++ b/src/deep_neurographs/visualization.py @@ -140,10 +140,10 @@ def plot_edges(graph, edges, color=None, line_width=3.5): def plot(data, title): fig = go.Figure(data=data) fig.update_layout( - title=title, - template="plotly_white", - plot_bgcolor="rgba(0, 0, 0, 0)", - scene=dict(aspectmode="manual", aspectratio=dict(x=1, y=1, z=1)), + # title=title, + # #template="plotly_white", + # #plot_bgcolor="rgba(0, 0, 0, 0)", + # #scene=dict(aspectmode="manual", aspectratio=dict(x=1, y=1, z=1)), width=1200, height=700, ) From d9c5c8dff9870c0075b768f06dfb780db1baa649 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Sun, 7 Apr 2024 21:21:16 +0000 Subject: [PATCH 2/3] documentation --- src/deep_neurographs/delete_merges_gt.py | 86 ++++++++++++++++++- src/deep_neurographs/densegraph.py | 43 ++++++++++ src/deep_neurographs/geometry.py | 2 +- src/deep_neurographs/intake.py | 1 - .../machine_learning/feature_generation.py | 3 +- .../machine_learning/inference.py | 10 +-- src/deep_neurographs/reconstruction.py | 7 +- src/deep_neurographs/swc_utils.py | 9 +- src/deep_neurographs/utils.py | 74 +++++++++------- src/deep_neurographs/visualization.py | 8 +- 10 files changed, 183 insertions(+), 60 deletions(-) diff --git a/src/deep_neurographs/delete_merges_gt.py b/src/deep_neurographs/delete_merges_gt.py index 033da6e..967c7d2 100644 --- a/src/deep_neurographs/delete_merges_gt.py +++ b/src/deep_neurographs/delete_merges_gt.py @@ -111,13 +111,14 @@ def detect_merges_neuron( radius : int Each node within "radius" is deleted. output_dir : str, optional - ... + Directory that merge sites are saved in swc files. The default is + None. save : bool, optional Indication of whether to save merge sites. The default is False. Returns ------- - set + delete_nodes : set Nodes that are part of a merge mistake. """ @@ -171,6 +172,34 @@ def detect_intersections(target_densegraph, graph, component): def detect_merges(target_densegraph, graph, hits, radius, output_dir, save): + """ + Detects merge mistakes in "graph" (i.e. whether "graph" is closely aligned + with two distinct connected components in "target_densegraph". + + Parameters + ---------- + target_densegraph : DenseGraph + Graph built from ground truth swc files. + graph : networkx.Graph + Graph build from a predicted swc file. + hits : dict + Dictionary that stores intersections between "target_densegraph" and + "graph", where the keys are swc ids from "target_densegraph" and + values are nodes from "graph". + radius : int + Each node within "radius" is deleted. + output_dir : str, optional + Directory that merge sites are saved in swc files. The default is + None. + save : bool, optional + Indication of whether to save merge sites. + + Returns + ------- + merge_sites : set + Nodes that are part of a merge site. + + """ merge_sites = set() if len(hits.keys()) > 0: visited = set() @@ -184,7 +213,6 @@ def detect_merges(target_densegraph, graph, hits, radius, output_dir, save): # Check for merge site min_dist, sites = locate_site(graph, hits[id_1], hits[id_2]) visited.add(pair) - print(graph.nodes[sites[0]]["xyz"], min_dist) if min_dist < MERGE_DIST_THRESHOLD: merge_nbhd = get_merged_nodes(graph, sites, radius) merge_sites = merge_sites.union(merge_nbhd) @@ -231,6 +259,24 @@ def locate_site(graph, merged_1, merged_2): def get_merged_nodes(graph, sites, radius): + """ + Gets nodes that are falsely merged. + + Parameters + ---------- + graph : networkx.Graph + Graph that contains a merge at "sites". + sites : list + Nodes in "graph" that are part of a merge mistake. + radius : int + Radius about node to be searched. + + Returns + ------- + merged_nodes : set + Nodes that are falsely merged. + + """ i, j = tuple(sites) merged_nodes = set(nx.shortest_path(graph, source=i, target=j)) merged_nodes = merged_nodes.union(get_nbhd(graph, i, radius)) @@ -239,10 +285,44 @@ def get_merged_nodes(graph, sites, radius): def get_nbhd(graph, i, radius): + """ + Gets all nodes within a path length of "radius" from node "i". + + Parameters + ---------- + graph : networkx.Graph + Graph to searched. + i : node + Node that is root of neighborhood to be returned. + radius : int + Radius about node to be searched. + + Returns + ------- + set + Nodes within a path length of "radius" from node "i". + + """ return set(nx.dfs_tree(graph, source=i, depth_limit=radius)) def get_point(graph, sites): + """ + Gets midpoint of merge site defined by the pair contained in "sites". + + Parameters + ---------- + graph : networkx.Graph + Graph that contains a merge at "sites". + sites : list + Nodes in "graph" that are part of a merge mistake. + + Returns + ------- + numpy.ndarray + Midpoint between pair of xyz coordinates in "sites". + + """ xyz_0 = graph.nodes[sites[0]]["xyz"] xyz_1 = graph.nodes[sites[1]]["xyz"] return geometry.get_midpoint(xyz_0, xyz_1) diff --git a/src/deep_neurographs/densegraph.py b/src/deep_neurographs/densegraph.py index 3542f2a..6260799 100644 --- a/src/deep_neurographs/densegraph.py +++ b/src/deep_neurographs/densegraph.py @@ -113,10 +113,37 @@ def init_kdtree(self): self.kdtree = KDTree(list(self.xyz_to_swc.keys())) def get_projection(self, xyz): + """ + Projects "xyz" onto "self by finding the closest point. + + Parameters + ---------- + xyz : numpy.ndarray + xyz coordinate to be queried. + + Returns + ------- + numpy.ndarray + Projection of "xyz". + + """ _, idx = self.kdtree.query(xyz, k=1) return tuple(self.kdtree.data[idx]) def save(self, output_dir): + """ + Saves "self" to an swc file. + + Parameters + ---------- + output_dir : str + Path to directory that swc files are written to. + + Returns + ------- + None + + """ for swc_id, graph in self.graphs.items(): cnt = 0 for component in nx.connected_components(graph): @@ -128,6 +155,22 @@ def save(self, output_dir): swc_utils.write(path, entry_list) def make_entries(self, graph, component): + """ + Makes swc entries corresponding to nodes in "component". + + Parameters + ---------- + graph : networkx.Graph + Graph that "component" is a connected component of. + component : set + Connected component of "graph". + + Returns + ------- + entry_list + List of swc entries generated from nodes in "component". + + """ node_to_idx = dict() entry_list = [] for i, j in nx.dfs_edges(graph.subgraph(component)): diff --git a/src/deep_neurographs/geometry.py b/src/deep_neurographs/geometry.py index 2ba17b6..79e3dfe 100644 --- a/src/deep_neurographs/geometry.py +++ b/src/deep_neurographs/geometry.py @@ -178,7 +178,7 @@ def get_profile(img, xyz_arr, process_id=None, window=[5, 5, 5]): def fill_path(img, path, val=-1): for xyz in path: x, y, z = tuple(np.floor(xyz).astype(int)) - img[x - 1 : x + 2, y - 1 : y + 2, z - 1 : z + 2] = val + img[x - 1: x + 2, y - 1: y + 2, z - 1: z + 2] = val return img diff --git a/src/deep_neurographs/intake.py b/src/deep_neurographs/intake.py index f9860f3..15171f6 100644 --- a/src/deep_neurographs/intake.py +++ b/src/deep_neurographs/intake.py @@ -8,7 +8,6 @@ """ -import os from concurrent.futures import ProcessPoolExecutor, as_completed from time import time diff --git a/src/deep_neurographs/machine_learning/feature_generation.py b/src/deep_neurographs/machine_learning/feature_generation.py index c2b3d2a..d78637f 100644 --- a/src/deep_neurographs/machine_learning/feature_generation.py +++ b/src/deep_neurographs/machine_learning/feature_generation.py @@ -21,7 +21,6 @@ import numpy as np import tensorstore as ts -from time import time from deep_neurographs import geometry, utils @@ -68,7 +67,7 @@ def run(neurograph, model_type, img_path, labels_path=None, proposals=None): """ # Initializations - img_driver = driver = "n5" if ".n5" in img_path else "zarr" + img_driver = "n5" if ".n5" in img_path else "zarr" img = utils.open_tensorstore(img_path, img_driver) if labels_path: labels_driver = "neuroglancer_precomputed" diff --git a/src/deep_neurographs/machine_learning/inference.py b/src/deep_neurographs/machine_learning/inference.py index eab13de..10077fb 100644 --- a/src/deep_neurographs/machine_learning/inference.py +++ b/src/deep_neurographs/machine_learning/inference.py @@ -9,21 +9,19 @@ """ from copy import deepcopy -from random import sample +from time import time import fastremap import networkx as nx import numpy as np import torch -from time import time from torch.nn.functional import sigmoid from torch.utils.data import DataLoader from deep_neurographs import graph_utils as gutils from deep_neurographs import reconstruction as build from deep_neurographs import utils -from deep_neurographs.machine_learning import feature_generation -from deep_neurographs.machine_learning import ml_utils +from deep_neurographs.machine_learning import feature_generation, ml_utils from deep_neurographs.neurograph import NeuroGraph BATCH_SIZE_PROPOSALS = 1000 @@ -114,7 +112,6 @@ def run_without_seeds( chunk_size = max(int(n_batches * 0.02), 1) for i, batch in enumerate(batches): # Prediction - t2 = time() proposals_i = [proposals[j] for j in batch] accepts_i = predict( neurograph, @@ -128,7 +125,6 @@ def run_without_seeds( ) # Merge proposals - t2 = time() neurograph = build.fuse_branches(neurograph, accepts_i) accepts.extend(accepts_i) @@ -153,7 +149,6 @@ def predict( confidence_threshold=0.7, ): # Generate features - t3 = time() features = feature_generation.run( neurograph, model_type, @@ -164,7 +159,6 @@ def predict( dataset = ml_utils.init_dataset(neurograph, features, model_type) # Run model - t3 = time() proposal_probs = run_model(dataset, model, model_type) accepts = build.get_accepted_proposals( neurograph, diff --git a/src/deep_neurographs/reconstruction.py b/src/deep_neurographs/reconstruction.py index a38b369..4fdcd5e 100644 --- a/src/deep_neurographs/reconstruction.py +++ b/src/deep_neurographs/reconstruction.py @@ -8,11 +8,10 @@ """ -import networkx as nx import os from concurrent.futures import ProcessPoolExecutor, as_completed -from random import sample +import networkx as nx import numpy as np from deep_neurographs import graph_utils as gutils @@ -64,7 +63,7 @@ def get_accepted_proposals( high_threshold=0.9, low_threshold=0.6, structure_aware=True, -): +): # Get positive edge predictions preds = threshold_preds(preds, idx_to_edge, low_threshold) if structure_aware: @@ -132,7 +131,7 @@ def get_structure_aware_accepts( good_probs.append(prob) more_accepts = check_cycles_sequential(graph, good_preds, good_probs) - accepts.extend(more_accepts) + accepts.extend(more_accepts) return accepts diff --git a/src/deep_neurographs/swc_utils.py b/src/deep_neurographs/swc_utils.py index 962cf8a..86bf91d 100644 --- a/src/deep_neurographs/swc_utils.py +++ b/src/deep_neurographs/swc_utils.py @@ -341,13 +341,6 @@ def save_edge(path, xyz_1, xyz_2, color=None, radius=6): f.write(make_simple_entry(2, 1, xyz_2, radius=radius)) -def set_radius(graph, i): - try: - return graph[i]["radius"] - except: - return 2 - - def make_entry(graph, i, parent, node_to_idx): """ Makes an entry to be written in an swc file. @@ -368,7 +361,7 @@ def make_entry(graph, i, parent, node_to_idx): ... """ - r = set_radius(graph, i) + r = graph[i]["radius"] x, y, z = tuple(graph.nodes[i]["xyz"]) node_to_idx[i] = len(node_to_idx) + 1 entry = f"{node_to_idx[i]} 2 {x} {y} {z} {r} {node_to_idx[parent]}" diff --git a/src/deep_neurographs/utils.py b/src/deep_neurographs/utils.py index f8a3b13..e9462ca 100644 --- a/src/deep_neurographs/utils.py +++ b/src/deep_neurographs/utils.py @@ -12,7 +12,7 @@ import math import os import shutil -from concurrent.futures import ThreadPoolExecutor, as_completed +from concurrent.futures import ThreadPoolExecutor from copy import deepcopy from io import BytesIO from random import sample @@ -392,7 +392,7 @@ def read_tensorstore_bbox(img, bbox): start = bbox["min"] end = bbox["max"] return ( - img[start[0] : end[0], start[1] : end[1], start[2] : end[2]] + img[start[0]: end[0], start[1]: end[1], start[2]: end[2]] .read() .result() ) @@ -401,7 +401,7 @@ def read_tensorstore_bbox(img, bbox): def get_chunk(arr, xyz, shape, from_center=True): start, end = get_start_end(xyz, shape, from_center=from_center) return deepcopy( - arr[start[0] : end[0], start[1] : end[1], start[2] : end[2]] + arr[start[0]: end[0], start[1]: end[1], start[2]: end[2]] ) @@ -580,6 +580,47 @@ def sample_singleton(my_container): return sample(my_container, 1)[0] +# --- runtime --- +def init_timers(): + """ + Initializes two timers. + + Parameters + ---------- + None + + Returns + ------- + time.time + Timer. + time.time + Timer. + + """ + return time(), time() + + +def progress_bar(current, total, bar_length=50, eta=None, runtime=None): + progress = int(current / total * bar_length) + n_completed = f"Completed: {current}/{total}" + bar = f"[{'=' * progress}{' ' * (bar_length - progress)}]" + eta = f"Time Remaining: {eta}" if eta else "" + runtime = f"Estimated Total Runtime: {runtime}" if runtime else "" + print(f"\r{bar} {n_completed} | {eta} | {runtime} ", end="", flush=True) + + +def time_writer(t, unit="seconds"): + assert unit in ["seconds", "minutes", "hours"] + upd_unit = {"seconds": "minutes", "minutes": "hours"} + if t < 60 or unit == "hours": + return t, unit + else: + t /= 60 + unit = upd_unit[unit] + t, unit = time_writer(t, unit=unit) + return t, unit + + # --- miscellaneous --- def get_img_bbox(origin, shape): """ @@ -638,31 +679,6 @@ def get_memory_usage(): return psutil.virtual_memory().used / 1e9 -def init_timers(): - return time(), time() - - -def progress_bar(current, total, bar_length=50, eta=None, runtime=None): - progress = int(current / total * bar_length) - n_completed = f"Completed: {current}/{total}" - bar = f"[{'=' * progress}{' ' * (bar_length - progress)}]" - eta = f"Time Remaining: {eta}" if eta else "" - runtime = f"Estimated Total Runtime: {runtime}" if runtime else "" - print(f"\r{bar} {n_completed} | {eta} | {runtime} ", end="", flush=True) - - -def time_writer(t, unit="seconds"): - assert unit in ["seconds", "minutes", "hours"] - upd_unit = {"seconds": "minutes", "minutes": "hours"} - if t < 60 or unit == "hours": - return t, unit - else: - t /= 60 - unit = upd_unit[unit] - t, unit = time_writer(t, unit=unit) - return t, unit - - def get_batches(iterable, batch_size): for start in range(0, len(iterable), batch_size): - yield iterable[start : min(start + batch_size, len(iterable))] + yield iterable[start: min(start + batch_size, len(iterable))] diff --git a/src/deep_neurographs/visualization.py b/src/deep_neurographs/visualization.py index a674fe8..2cc2394 100644 --- a/src/deep_neurographs/visualization.py +++ b/src/deep_neurographs/visualization.py @@ -140,10 +140,10 @@ def plot_edges(graph, edges, color=None, line_width=3.5): def plot(data, title): fig = go.Figure(data=data) fig.update_layout( - # title=title, - # #template="plotly_white", - # #plot_bgcolor="rgba(0, 0, 0, 0)", - # #scene=dict(aspectmode="manual", aspectratio=dict(x=1, y=1, z=1)), + # title=title, + # #template="plotly_white", + # #plot_bgcolor="rgba(0, 0, 0, 0)", + # #scene=dict(aspectmode="manual", aspectratio=dict(x=1, y=1, z=1)), width=1200, height=700, ) From b4109407f993a0f66d29405fce942d9a9f378cc4 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Sun, 14 Apr 2024 04:00:26 +0000 Subject: [PATCH 3/3] upds: reconstruction bugs, documentation --- src/deep_neurographs/geometry.py | 72 ++++++++++++++----- .../machine_learning/graph_datasets.py | 2 + .../machine_learning/graph_models.py | 34 +++++++-- .../machine_learning/graph_trainer.py | 28 +++++--- .../machine_learning/inference.py | 3 +- .../machine_learning/ml_utils.py | 2 +- .../machine_learning/trainer.py | 4 +- src/deep_neurographs/reconstruction.py | 25 ++++--- src/deep_neurographs/utils.py | 9 ++- 9 files changed, 128 insertions(+), 51 deletions(-) diff --git a/src/deep_neurographs/geometry.py b/src/deep_neurographs/geometry.py index 79e3dfe..69cb788 100644 --- a/src/deep_neurographs/geometry.py +++ b/src/deep_neurographs/geometry.py @@ -20,21 +20,38 @@ def get_directional(neurograph, i, origin, window_size): branches = neurograph.get_branches(i) branches = translate_branches(branches, origin) if len(branches) == 1: - return compute_tangent(get_sub_branch(branches[0], window_size)) + return compute_tangent(get_subarray(branches[0], window_size)) elif len(branches) == 2: - branch_1 = get_sub_branch(branches[0], window_size) - branch_2 = get_sub_branch(branches[1], window_size) + branch_1 = get_subarray(branches[0], window_size) + branch_2 = get_subarray(branches[1], window_size) branch = np.concatenate((branch_1, branch_2)) return compute_tangent(branch) else: return np.array([0, 0, 0]) -def get_sub_branch(branch, window_size): - if branch.shape[0] < window_size: - return branch +def get_subarray(arr, window_size): + """ + Extracts a sub-array of a specified window size from a given input array. + + Parameters + ---------- + branch : numpy.ndarray + Array from which the sub-branch will be extracted. + window_size : int + Size of the window to extract from "arr". + + Returns + ------- + numpy.ndarray + A sub-array of the specified window size. If the input array is + smaller than the window size, the entire branch array is returned. + + """ + if arr.shape[0] < window_size: + return arr else: - return branch[0:window_size, :] + return arr[0:window_size, :] def compute_svd(xyz): @@ -65,6 +82,22 @@ def compute_svd(xyz): def compute_tangent(xyz): + """ + Computes the tangent vector at a given point or along a curve defined by + an array of points. + + Parameters + ---------- + xyz : numpy.ndarray + Array containing either two xyz coordinates or an arbitrary number of + defining a curve. + + Returns + ------- + numpy.ndarray + Tangent vector at the specified point or along the curve. + + """ if xyz.shape[0] == 2: tangent = (xyz[1] - xyz[0]) / dist(xyz[1], xyz[0]) else: @@ -74,6 +107,21 @@ def compute_tangent(xyz): def compute_normal(xyz): + """ + Computes the normal vector of a plane defined by an array of xyz + coordinates using Singular Value Decomposition (SVD). + + Parameters + ---------- + xyz : numpy.ndarray + An array of xyz coordinates that normal vector is to be computed of. + + Returns + ------- + numpy.ndarray + The normal vector of the array "xyz". + + """ U, S, VT = compute_svd(xyz) normal = VT[-1] return normal / np.linalg.norm(normal) @@ -150,16 +198,6 @@ def fit_spline(xyz, s=None): return spline_x, spline_y, spline_z -def sample_path(path, n_points): - if len(path) > 5: - t = np.linspace(0, 1, n_points) - spline_x, spline_y, spline_z = fit_spline(path) - path = np.column_stack((spline_x(t), spline_y(t), spline_z(t))) - else: - path = make_line(path[0], path[-1], 10) - return path.astype(int) - - # Image feature extraction def get_profile(img, xyz_arr, process_id=None, window=[5, 5, 5]): profile = [] diff --git a/src/deep_neurographs/machine_learning/graph_datasets.py b/src/deep_neurographs/machine_learning/graph_datasets.py index 2cbc4a3..4aaf4c0 100644 --- a/src/deep_neurographs/machine_learning/graph_datasets.py +++ b/src/deep_neurographs/machine_learning/graph_datasets.py @@ -78,6 +78,7 @@ class GraphDataset: Custom dataset for homogenous graphs. """ + def __init__( self, neurograph, @@ -106,6 +107,7 @@ class HeteroGraphDataset: Custom dataset for heterogenous graphs. """ + def __init__( self, neurograph, diff --git a/src/deep_neurographs/machine_learning/graph_models.py b/src/deep_neurographs/machine_learning/graph_models.py index 8194e57..ca7eefe 100644 --- a/src/deep_neurographs/machine_learning/graph_models.py +++ b/src/deep_neurographs/machine_learning/graph_models.py @@ -11,14 +11,14 @@ import torch import torch.nn.functional as F from torch.nn import ELU, Linear -from torch_geometric.nn import GCNConv +from torch_geometric.nn import GATConv, GCNConv class GCN(torch.nn.Module): def __init__(self, input_channels): super().__init__() - self.conv1 = GCNConv(input_channels, input_channels // 2) - self.conv2 = GCNConv(input_channels // 2, input_channels // 2) + self.conv1 = GCNConv(input_channels, input_channels) + self.conv2 = GCNConv(input_channels, input_channels // 2) self.conv3 = GCNConv(input_channels // 2, 1) self.ELU = ELU() @@ -38,11 +38,35 @@ def forward(self, x, edge_index): return x +class GAT(torch.nn.Module): + def __init__(self, input_channels): + super().__init__() + self.conv1 = GATConv(input_channels, input_channels) + self.conv2 = GATConv(input_channels, input_channels // 2) + self.conv3 = GATConv(input_channels // 2, 1) + self.ELU = ELU() + + def forward(self, x, edge_index): + # Layer 1 + x = self.conv1(x, edge_index) + # x = self.ELU(x) + # x = F.dropout(x, p=0.25) + + # Layer 2 + # x = self.conv2(x, edge_index) + # x = self.ELU(x) + # x = F.dropout(x, p=0.25) + + # Layer 3 + # x = self.conv3(x, edge_index) + return x + + class MLP(torch.nn.Module): def __init__(self, input_channels): super().__init__() - self.linear1 = Linear(input_channels, input_channels // 2) - self.linear2 = Linear(input_channels // 2, input_channels // 2) + self.linear1 = Linear(input_channels, input_channels) + self.linear2 = Linear(input_channels, input_channels // 2) self.linear3 = Linear(input_channels // 2, 1) self.ELU = ELU() diff --git a/src/deep_neurographs/machine_learning/graph_trainer.py b/src/deep_neurographs/machine_learning/graph_trainer.py index 0eba638..9e556b5 100644 --- a/src/deep_neurographs/machine_learning/graph_trainer.py +++ b/src/deep_neurographs/machine_learning/graph_trainer.py @@ -8,15 +8,20 @@ """ +from copy import deepcopy from random import sample, shuffle + import numpy as np import torch -from copy import deepcopy -from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score -from torch.nn.functional import sigmoid +from sklearn.metrics import ( + accuracy_score, + f1_score, + precision_score, + recall_score, +) from torch.utils.tensorboard import SummaryWriter -from deep_neurographs.machine_learning import ml_utils +from deep_neurographs.machine_learning import ml_utils LR = 1e-3 N_EPOCHS = 1000 @@ -29,6 +34,7 @@ class GraphTrainer: Custom class that trains graph neural networks. """ + def __init__( self, model, @@ -98,7 +104,7 @@ def run_on_graphs(self, graph_datasets): y_i, hat_y_i = self.train(graph_datasets[graph_id].data, epoch) y.extend(toCPU(y_i)) hat_y.extend(toCPU(hat_y_i)) - train_score = self.compute_metrics(y, hat_y, "train", epoch) + self.compute_metrics(y, hat_y, "train", epoch) # Test if epoch % 10 == 0: @@ -236,11 +242,11 @@ def compute_metrics(self, y, hat_y, prefix, epoch): f1 = f1_score(y, hat_y) # Log - self.writer.add_scalar(prefix + '_accuracy:', accuracy, epoch) - self.writer.add_scalar(prefix + '_accuracy_df:', accuracy_dif, epoch) - self.writer.add_scalar(prefix + '_precision:', precision, epoch) - self.writer.add_scalar(prefix + '_recall:', recall, epoch) - self.writer.add_scalar(prefix + '_f1:', f1, epoch) + self.writer.add_scalar(prefix + "_accuracy:", accuracy, epoch) + self.writer.add_scalar(prefix + "_accuracy_df:", accuracy_dif, epoch) + self.writer.add_scalar(prefix + "_precision:", precision, epoch) + self.writer.add_scalar(prefix + "_recall:", recall, epoch) + self.writer.add_scalar(prefix + "_f1:", f1, epoch) return f1 @@ -349,7 +355,7 @@ def truncate(hat_y, y): Truncated "hat_y". """ - return hat_y[0: y.size(0), 0] + return hat_y[: y.size(0), 0] def get_predictions(hat_y, threshold=0.5): diff --git a/src/deep_neurographs/machine_learning/inference.py b/src/deep_neurographs/machine_learning/inference.py index 7a639a4..e747c95 100644 --- a/src/deep_neurographs/machine_learning/inference.py +++ b/src/deep_neurographs/machine_learning/inference.py @@ -9,7 +9,6 @@ """ from copy import deepcopy -from time import time import fastremap import networkx as nx @@ -268,7 +267,7 @@ def run_model(dataset, model, model_type): # Postprocess hat_y_i = np.array(hat_y_i) - hat_y.extend(hat_y_i.tolist()) + hat_y.extend(hat_y_i[:, 0].tolist()) else: data = dataset["dataset"] hat_y = model.predict_proba(data["inputs"])[:, 1] diff --git a/src/deep_neurographs/machine_learning/ml_utils.py b/src/deep_neurographs/machine_learning/ml_utils.py index 89539bd..7c2f1bc 100644 --- a/src/deep_neurographs/machine_learning/ml_utils.py +++ b/src/deep_neurographs/machine_learning/ml_utils.py @@ -195,4 +195,4 @@ def sigmoid(x): Sigmoid applied to "x". """ - return 1.0/(1.0 + np.exp(-x)) + return 1.0 / (1.0 + np.exp(-x)) diff --git a/src/deep_neurographs/machine_learning/trainer.py b/src/deep_neurographs/machine_learning/trainer.py index d1c179a..8e02e27 100644 --- a/src/deep_neurographs/machine_learning/trainer.py +++ b/src/deep_neurographs/machine_learning/trainer.py @@ -87,9 +87,9 @@ def fit_deep_model( # Fit model pylightning_trainer = pl.Trainer( - accelerator="gpu", + # accelerator="gpu", callbacks=[ckpt_callback], - devices=1, + # devices=1, enable_model_summary=False, enable_progress_bar=False, logger=logger, diff --git a/src/deep_neurographs/reconstruction.py b/src/deep_neurographs/reconstruction.py index 1bd668b..2e4eaa5 100644 --- a/src/deep_neurographs/reconstruction.py +++ b/src/deep_neurographs/reconstruction.py @@ -30,23 +30,26 @@ def get_accepted_propoals_blocks( ): accepts = dict() for block_id in blocks: + # Threshold prediction + preds_upd = threshold_preds( + preds, + idx_to_edge, + low_threshold, + valid_idxs=block_to_idxs[block_id], + ) + # Get accepts if structure_aware: graph = neurographs[block_id].copy() - accepts[block_id] = get_structure_aware_accepts( + accepts[block_id], _ = get_structure_aware_accepts( neurographs[block_id], graph, - preds, + preds_upd, high_threshold=high_threshold, low_threshold=low_threshold, ) else: - preds = threshold_preds( - preds, - idx_to_edge, - low_threshold, - valid_idxs=block_to_idxs[block_id], - ) + accepts[block_id] = preds.keys() return accepts @@ -114,7 +117,7 @@ def get_structure_aware_accepts( best_preds, best_probs = get_best_preds(neurograph, preds, high_threshold) accepts, graph = check_cycles_sequential(graph, best_preds, best_probs) if len(best_preds) == len(preds.keys()): - return accepts + return accepts, graph # Add remaining preds best_preds = set(best_preds) @@ -156,7 +159,7 @@ def get_subgraphs(graph, edge): subgraph = nx.union(subgraph_1, subgraph_2) return subgraph except: - return False + return False def check_cycles_parallelized(graph, edge_list): @@ -201,6 +204,7 @@ def check_cycles_parallelized(graph, edge_list): def check_cycles_sequential(graph, edges, probs): accepts = [] for i in np.argsort(probs): + print(i, edges) subgraph = get_subgraphs(graph, edges[i]) if subgraph: created_cycle, _ = gutils.creates_cycle(subgraph, tuple(edges[i])) @@ -241,6 +245,7 @@ def save_prediction(neurograph, accepted_proposals, output_dir): save_corrections(neurograph, accepted_proposals, corrections_dir) save_connections(neurograph, accepted_proposals, connections_path) + def save_corrections(neurograph, accepted_proposals, output_dir): for cnt, (i, j) in enumerate(accepted_proposals): # Info diff --git a/src/deep_neurographs/utils.py b/src/deep_neurographs/utils.py index 5b26bbe..2440bef 100644 --- a/src/deep_neurographs/utils.py +++ b/src/deep_neurographs/utils.py @@ -22,6 +22,7 @@ import numpy as np import psutil import tensorstore as ts +import torch import zarr from skimage.color import label2rgb @@ -620,10 +621,11 @@ def time_writer(t, unit="seconds"): t, unit = time_writer(t, unit=unit) return t, unit + def report_progress(current, total, chunk_size, cnt, t0, t1): eta = get_eta(current, total, chunk_size, t1) runtime = get_runtime(current, total, chunk_size, t0, t1) - utils.progress_bar(current, total, eta=eta, runtime=runtime) + progress_bar(current, total, eta=eta, runtime=runtime) return cnt + 1, time() @@ -631,16 +633,17 @@ def get_eta(current, total, chunk_size, t0, return_str=True): chunk_runtime = time() - t0 remaining = total - current eta = remaining * (chunk_runtime / chunk_size) - t, unit = utils.time_writer(eta) + t, unit = time_writer(eta) return f"{round(t, 4)} {unit}" if return_str else eta def get_runtime(current, total, chunk_size, t0, t1): eta = get_eta(current, total, chunk_size, t1, return_str=False) total_runtime = time() - t0 + eta - t, unit = utils.time_writer(total_runtime) + t, unit = time_writer(total_runtime) return f"{round(t, 4)} {unit}" + def toGPU(graph_data): x = graph_data.x.to("cuda:0", dtype=torch.float32) edge_index = graph_data.edge_index.to("cuda:0")