From 5ee38adc398c8112bed97dcdedcaf5a38df9c156 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Fri, 5 Apr 2024 17:17:49 +0000 Subject: [PATCH 1/2] bug: groundtruth generation --- .../machine_learning/groundtruth_generation.py | 9 +++++---- src/deep_neurographs/neurograph.py | 4 ++-- src/deep_neurographs/reconstruction.py | 9 +++++---- src/deep_neurographs/visualization.py | 8 ++++---- 4 files changed, 16 insertions(+), 14 deletions(-) diff --git a/src/deep_neurographs/machine_learning/groundtruth_generation.py b/src/deep_neurographs/machine_learning/groundtruth_generation.py index 7b3de7b..92da7db 100644 --- a/src/deep_neurographs/machine_learning/groundtruth_generation.py +++ b/src/deep_neurographs/machine_learning/groundtruth_generation.py @@ -27,12 +27,13 @@ def init_targets(target_neurograph, pred_neurograph): valid_proposals = get_valid_proposals(target_neurograph, pred_neurograph) # Add best simple edges - dists = [pred_neurograph.proposal_length(p) for p in valid_proposals] - groundtruth_graph = pred_neurograph.copy_graph() + dists = [pred_neurograph.proposal_length(e) for e in valid_proposals] + graph = pred_neurograph.copy_graph() for idx in np.argsort(dists): edge = valid_proposals[idx] - if not gutils.creates_cycle(groundtruth_graph, tuple(edge)): - groundtruth_graph.add_edges_from([edge]) + created_cycle, _ = gutils.creates_cycle(graph, tuple(edge)) + if not created_cycle: + graph.add_edges_from([edge]) target_edges.add(edge) return target_edges diff --git a/src/deep_neurographs/neurograph.py b/src/deep_neurographs/neurograph.py index eb972f6..6e543e2 100644 --- a/src/deep_neurographs/neurograph.py +++ b/src/deep_neurographs/neurograph.py @@ -306,7 +306,7 @@ def generate_proposals( # Get connection (i, j) = self.xyz_to_edge[xyz] node, xyz = self.__get_connection(leaf, xyz, (i, j), radius) - if not complex_proposals and self.degree[node] >= 2: + if not complex_proposals and self.degree[node] > 1: continue # Check whether connection exists @@ -329,7 +329,7 @@ def generate_proposals( existing_connections[pair_id] = frozenset({leaf, node}) # Finish - #self.filter_nodes() + self.filter_nodes() self.init_kdtree(node_type="leaf") self.init_kdtree(node_type="proposal") if optimize: diff --git a/src/deep_neurographs/reconstruction.py b/src/deep_neurographs/reconstruction.py index 94a340f..a38b369 100644 --- a/src/deep_neurographs/reconstruction.py +++ b/src/deep_neurographs/reconstruction.py @@ -23,7 +23,6 @@ def get_accepted_propoals_blocks( neurographs, - graph, preds, blocks, block_to_idxs, @@ -44,6 +43,7 @@ def get_accepted_propoals_blocks( # Refine accepts wrt structure if structure_aware: + graph = neurographs[block_id].copy() accepts[block_id] = get_structure_aware_accepts( neurographs[block_id], graph, @@ -104,6 +104,7 @@ def threshold_preds(preds, idx_to_edge, threshold, valid_idxs=[]): predicted probability. """ + print(preds) thresholded_preds = dict() for i, pred_i in enumerate(preds): contained_bool = True if len(valid_idxs) == 0 else i in valid_idxs @@ -117,7 +118,7 @@ def get_structure_aware_accepts( ): # Add best preds best_preds, best_probs = get_best_preds(neurograph, preds, high_threshold) - accepts = check_cycles_sequential(neurograph, best_preds, best_probs) + accepts = check_cycles_sequential(graph, best_preds, best_probs) if len(best_preds) == len(preds.keys()): return accepts @@ -130,7 +131,7 @@ def get_structure_aware_accepts( good_preds.append(edge) good_probs.append(prob) - more_accepts = check_cycles_sequential(neurograph, good_preds, good_probs) + more_accepts = check_cycles_sequential(graph, good_preds, good_probs) accepts.extend(more_accepts) return accepts @@ -203,7 +204,7 @@ def check_cycles_sequential(graph, edges, probs): subgraph = get_subgraphs(graph, edges[i]) created_cycle, _ = gutils.creates_cycle(subgraph, tuple(edges[i])) if not created_cycle: - graph.add_edges_from(tuple(edges[i])) + graph.add_edges_from([tuple(edges[i])]) accepts.append(edges[i]) return accepts diff --git a/src/deep_neurographs/visualization.py b/src/deep_neurographs/visualization.py index 36e18c0..a674fe8 100644 --- a/src/deep_neurographs/visualization.py +++ b/src/deep_neurographs/visualization.py @@ -140,10 +140,10 @@ def plot_edges(graph, edges, color=None, line_width=3.5): def plot(data, title): fig = go.Figure(data=data) fig.update_layout( - title=title, - template="plotly_white", - plot_bgcolor="rgba(0, 0, 0, 0)", - scene=dict(aspectmode="manual", aspectratio=dict(x=1, y=1, z=1)), + # title=title, + # #template="plotly_white", + # #plot_bgcolor="rgba(0, 0, 0, 0)", + # #scene=dict(aspectmode="manual", aspectratio=dict(x=1, y=1, z=1)), width=1200, height=700, ) From d9c5c8dff9870c0075b768f06dfb780db1baa649 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Sun, 7 Apr 2024 21:21:16 +0000 Subject: [PATCH 2/2] documentation --- src/deep_neurographs/delete_merges_gt.py | 86 ++++++++++++++++++- src/deep_neurographs/densegraph.py | 43 ++++++++++ src/deep_neurographs/geometry.py | 2 +- src/deep_neurographs/intake.py | 1 - .../machine_learning/feature_generation.py | 3 +- .../machine_learning/inference.py | 10 +-- src/deep_neurographs/reconstruction.py | 7 +- src/deep_neurographs/swc_utils.py | 9 +- src/deep_neurographs/utils.py | 74 +++++++++------- src/deep_neurographs/visualization.py | 8 +- 10 files changed, 183 insertions(+), 60 deletions(-) diff --git a/src/deep_neurographs/delete_merges_gt.py b/src/deep_neurographs/delete_merges_gt.py index 033da6e..967c7d2 100644 --- a/src/deep_neurographs/delete_merges_gt.py +++ b/src/deep_neurographs/delete_merges_gt.py @@ -111,13 +111,14 @@ def detect_merges_neuron( radius : int Each node within "radius" is deleted. output_dir : str, optional - ... + Directory that merge sites are saved in swc files. The default is + None. save : bool, optional Indication of whether to save merge sites. The default is False. Returns ------- - set + delete_nodes : set Nodes that are part of a merge mistake. """ @@ -171,6 +172,34 @@ def detect_intersections(target_densegraph, graph, component): def detect_merges(target_densegraph, graph, hits, radius, output_dir, save): + """ + Detects merge mistakes in "graph" (i.e. whether "graph" is closely aligned + with two distinct connected components in "target_densegraph". + + Parameters + ---------- + target_densegraph : DenseGraph + Graph built from ground truth swc files. + graph : networkx.Graph + Graph build from a predicted swc file. + hits : dict + Dictionary that stores intersections between "target_densegraph" and + "graph", where the keys are swc ids from "target_densegraph" and + values are nodes from "graph". + radius : int + Each node within "radius" is deleted. + output_dir : str, optional + Directory that merge sites are saved in swc files. The default is + None. + save : bool, optional + Indication of whether to save merge sites. + + Returns + ------- + merge_sites : set + Nodes that are part of a merge site. + + """ merge_sites = set() if len(hits.keys()) > 0: visited = set() @@ -184,7 +213,6 @@ def detect_merges(target_densegraph, graph, hits, radius, output_dir, save): # Check for merge site min_dist, sites = locate_site(graph, hits[id_1], hits[id_2]) visited.add(pair) - print(graph.nodes[sites[0]]["xyz"], min_dist) if min_dist < MERGE_DIST_THRESHOLD: merge_nbhd = get_merged_nodes(graph, sites, radius) merge_sites = merge_sites.union(merge_nbhd) @@ -231,6 +259,24 @@ def locate_site(graph, merged_1, merged_2): def get_merged_nodes(graph, sites, radius): + """ + Gets nodes that are falsely merged. + + Parameters + ---------- + graph : networkx.Graph + Graph that contains a merge at "sites". + sites : list + Nodes in "graph" that are part of a merge mistake. + radius : int + Radius about node to be searched. + + Returns + ------- + merged_nodes : set + Nodes that are falsely merged. + + """ i, j = tuple(sites) merged_nodes = set(nx.shortest_path(graph, source=i, target=j)) merged_nodes = merged_nodes.union(get_nbhd(graph, i, radius)) @@ -239,10 +285,44 @@ def get_merged_nodes(graph, sites, radius): def get_nbhd(graph, i, radius): + """ + Gets all nodes within a path length of "radius" from node "i". + + Parameters + ---------- + graph : networkx.Graph + Graph to searched. + i : node + Node that is root of neighborhood to be returned. + radius : int + Radius about node to be searched. + + Returns + ------- + set + Nodes within a path length of "radius" from node "i". + + """ return set(nx.dfs_tree(graph, source=i, depth_limit=radius)) def get_point(graph, sites): + """ + Gets midpoint of merge site defined by the pair contained in "sites". + + Parameters + ---------- + graph : networkx.Graph + Graph that contains a merge at "sites". + sites : list + Nodes in "graph" that are part of a merge mistake. + + Returns + ------- + numpy.ndarray + Midpoint between pair of xyz coordinates in "sites". + + """ xyz_0 = graph.nodes[sites[0]]["xyz"] xyz_1 = graph.nodes[sites[1]]["xyz"] return geometry.get_midpoint(xyz_0, xyz_1) diff --git a/src/deep_neurographs/densegraph.py b/src/deep_neurographs/densegraph.py index 3542f2a..6260799 100644 --- a/src/deep_neurographs/densegraph.py +++ b/src/deep_neurographs/densegraph.py @@ -113,10 +113,37 @@ def init_kdtree(self): self.kdtree = KDTree(list(self.xyz_to_swc.keys())) def get_projection(self, xyz): + """ + Projects "xyz" onto "self by finding the closest point. + + Parameters + ---------- + xyz : numpy.ndarray + xyz coordinate to be queried. + + Returns + ------- + numpy.ndarray + Projection of "xyz". + + """ _, idx = self.kdtree.query(xyz, k=1) return tuple(self.kdtree.data[idx]) def save(self, output_dir): + """ + Saves "self" to an swc file. + + Parameters + ---------- + output_dir : str + Path to directory that swc files are written to. + + Returns + ------- + None + + """ for swc_id, graph in self.graphs.items(): cnt = 0 for component in nx.connected_components(graph): @@ -128,6 +155,22 @@ def save(self, output_dir): swc_utils.write(path, entry_list) def make_entries(self, graph, component): + """ + Makes swc entries corresponding to nodes in "component". + + Parameters + ---------- + graph : networkx.Graph + Graph that "component" is a connected component of. + component : set + Connected component of "graph". + + Returns + ------- + entry_list + List of swc entries generated from nodes in "component". + + """ node_to_idx = dict() entry_list = [] for i, j in nx.dfs_edges(graph.subgraph(component)): diff --git a/src/deep_neurographs/geometry.py b/src/deep_neurographs/geometry.py index 2ba17b6..79e3dfe 100644 --- a/src/deep_neurographs/geometry.py +++ b/src/deep_neurographs/geometry.py @@ -178,7 +178,7 @@ def get_profile(img, xyz_arr, process_id=None, window=[5, 5, 5]): def fill_path(img, path, val=-1): for xyz in path: x, y, z = tuple(np.floor(xyz).astype(int)) - img[x - 1 : x + 2, y - 1 : y + 2, z - 1 : z + 2] = val + img[x - 1: x + 2, y - 1: y + 2, z - 1: z + 2] = val return img diff --git a/src/deep_neurographs/intake.py b/src/deep_neurographs/intake.py index f9860f3..15171f6 100644 --- a/src/deep_neurographs/intake.py +++ b/src/deep_neurographs/intake.py @@ -8,7 +8,6 @@ """ -import os from concurrent.futures import ProcessPoolExecutor, as_completed from time import time diff --git a/src/deep_neurographs/machine_learning/feature_generation.py b/src/deep_neurographs/machine_learning/feature_generation.py index c2b3d2a..d78637f 100644 --- a/src/deep_neurographs/machine_learning/feature_generation.py +++ b/src/deep_neurographs/machine_learning/feature_generation.py @@ -21,7 +21,6 @@ import numpy as np import tensorstore as ts -from time import time from deep_neurographs import geometry, utils @@ -68,7 +67,7 @@ def run(neurograph, model_type, img_path, labels_path=None, proposals=None): """ # Initializations - img_driver = driver = "n5" if ".n5" in img_path else "zarr" + img_driver = "n5" if ".n5" in img_path else "zarr" img = utils.open_tensorstore(img_path, img_driver) if labels_path: labels_driver = "neuroglancer_precomputed" diff --git a/src/deep_neurographs/machine_learning/inference.py b/src/deep_neurographs/machine_learning/inference.py index eab13de..10077fb 100644 --- a/src/deep_neurographs/machine_learning/inference.py +++ b/src/deep_neurographs/machine_learning/inference.py @@ -9,21 +9,19 @@ """ from copy import deepcopy -from random import sample +from time import time import fastremap import networkx as nx import numpy as np import torch -from time import time from torch.nn.functional import sigmoid from torch.utils.data import DataLoader from deep_neurographs import graph_utils as gutils from deep_neurographs import reconstruction as build from deep_neurographs import utils -from deep_neurographs.machine_learning import feature_generation -from deep_neurographs.machine_learning import ml_utils +from deep_neurographs.machine_learning import feature_generation, ml_utils from deep_neurographs.neurograph import NeuroGraph BATCH_SIZE_PROPOSALS = 1000 @@ -114,7 +112,6 @@ def run_without_seeds( chunk_size = max(int(n_batches * 0.02), 1) for i, batch in enumerate(batches): # Prediction - t2 = time() proposals_i = [proposals[j] for j in batch] accepts_i = predict( neurograph, @@ -128,7 +125,6 @@ def run_without_seeds( ) # Merge proposals - t2 = time() neurograph = build.fuse_branches(neurograph, accepts_i) accepts.extend(accepts_i) @@ -153,7 +149,6 @@ def predict( confidence_threshold=0.7, ): # Generate features - t3 = time() features = feature_generation.run( neurograph, model_type, @@ -164,7 +159,6 @@ def predict( dataset = ml_utils.init_dataset(neurograph, features, model_type) # Run model - t3 = time() proposal_probs = run_model(dataset, model, model_type) accepts = build.get_accepted_proposals( neurograph, diff --git a/src/deep_neurographs/reconstruction.py b/src/deep_neurographs/reconstruction.py index a38b369..4fdcd5e 100644 --- a/src/deep_neurographs/reconstruction.py +++ b/src/deep_neurographs/reconstruction.py @@ -8,11 +8,10 @@ """ -import networkx as nx import os from concurrent.futures import ProcessPoolExecutor, as_completed -from random import sample +import networkx as nx import numpy as np from deep_neurographs import graph_utils as gutils @@ -64,7 +63,7 @@ def get_accepted_proposals( high_threshold=0.9, low_threshold=0.6, structure_aware=True, -): +): # Get positive edge predictions preds = threshold_preds(preds, idx_to_edge, low_threshold) if structure_aware: @@ -132,7 +131,7 @@ def get_structure_aware_accepts( good_probs.append(prob) more_accepts = check_cycles_sequential(graph, good_preds, good_probs) - accepts.extend(more_accepts) + accepts.extend(more_accepts) return accepts diff --git a/src/deep_neurographs/swc_utils.py b/src/deep_neurographs/swc_utils.py index 962cf8a..86bf91d 100644 --- a/src/deep_neurographs/swc_utils.py +++ b/src/deep_neurographs/swc_utils.py @@ -341,13 +341,6 @@ def save_edge(path, xyz_1, xyz_2, color=None, radius=6): f.write(make_simple_entry(2, 1, xyz_2, radius=radius)) -def set_radius(graph, i): - try: - return graph[i]["radius"] - except: - return 2 - - def make_entry(graph, i, parent, node_to_idx): """ Makes an entry to be written in an swc file. @@ -368,7 +361,7 @@ def make_entry(graph, i, parent, node_to_idx): ... """ - r = set_radius(graph, i) + r = graph[i]["radius"] x, y, z = tuple(graph.nodes[i]["xyz"]) node_to_idx[i] = len(node_to_idx) + 1 entry = f"{node_to_idx[i]} 2 {x} {y} {z} {r} {node_to_idx[parent]}" diff --git a/src/deep_neurographs/utils.py b/src/deep_neurographs/utils.py index f8a3b13..e9462ca 100644 --- a/src/deep_neurographs/utils.py +++ b/src/deep_neurographs/utils.py @@ -12,7 +12,7 @@ import math import os import shutil -from concurrent.futures import ThreadPoolExecutor, as_completed +from concurrent.futures import ThreadPoolExecutor from copy import deepcopy from io import BytesIO from random import sample @@ -392,7 +392,7 @@ def read_tensorstore_bbox(img, bbox): start = bbox["min"] end = bbox["max"] return ( - img[start[0] : end[0], start[1] : end[1], start[2] : end[2]] + img[start[0]: end[0], start[1]: end[1], start[2]: end[2]] .read() .result() ) @@ -401,7 +401,7 @@ def read_tensorstore_bbox(img, bbox): def get_chunk(arr, xyz, shape, from_center=True): start, end = get_start_end(xyz, shape, from_center=from_center) return deepcopy( - arr[start[0] : end[0], start[1] : end[1], start[2] : end[2]] + arr[start[0]: end[0], start[1]: end[1], start[2]: end[2]] ) @@ -580,6 +580,47 @@ def sample_singleton(my_container): return sample(my_container, 1)[0] +# --- runtime --- +def init_timers(): + """ + Initializes two timers. + + Parameters + ---------- + None + + Returns + ------- + time.time + Timer. + time.time + Timer. + + """ + return time(), time() + + +def progress_bar(current, total, bar_length=50, eta=None, runtime=None): + progress = int(current / total * bar_length) + n_completed = f"Completed: {current}/{total}" + bar = f"[{'=' * progress}{' ' * (bar_length - progress)}]" + eta = f"Time Remaining: {eta}" if eta else "" + runtime = f"Estimated Total Runtime: {runtime}" if runtime else "" + print(f"\r{bar} {n_completed} | {eta} | {runtime} ", end="", flush=True) + + +def time_writer(t, unit="seconds"): + assert unit in ["seconds", "minutes", "hours"] + upd_unit = {"seconds": "minutes", "minutes": "hours"} + if t < 60 or unit == "hours": + return t, unit + else: + t /= 60 + unit = upd_unit[unit] + t, unit = time_writer(t, unit=unit) + return t, unit + + # --- miscellaneous --- def get_img_bbox(origin, shape): """ @@ -638,31 +679,6 @@ def get_memory_usage(): return psutil.virtual_memory().used / 1e9 -def init_timers(): - return time(), time() - - -def progress_bar(current, total, bar_length=50, eta=None, runtime=None): - progress = int(current / total * bar_length) - n_completed = f"Completed: {current}/{total}" - bar = f"[{'=' * progress}{' ' * (bar_length - progress)}]" - eta = f"Time Remaining: {eta}" if eta else "" - runtime = f"Estimated Total Runtime: {runtime}" if runtime else "" - print(f"\r{bar} {n_completed} | {eta} | {runtime} ", end="", flush=True) - - -def time_writer(t, unit="seconds"): - assert unit in ["seconds", "minutes", "hours"] - upd_unit = {"seconds": "minutes", "minutes": "hours"} - if t < 60 or unit == "hours": - return t, unit - else: - t /= 60 - unit = upd_unit[unit] - t, unit = time_writer(t, unit=unit) - return t, unit - - def get_batches(iterable, batch_size): for start in range(0, len(iterable), batch_size): - yield iterable[start : min(start + batch_size, len(iterable))] + yield iterable[start: min(start + batch_size, len(iterable))] diff --git a/src/deep_neurographs/visualization.py b/src/deep_neurographs/visualization.py index a674fe8..2cc2394 100644 --- a/src/deep_neurographs/visualization.py +++ b/src/deep_neurographs/visualization.py @@ -140,10 +140,10 @@ def plot_edges(graph, edges, color=None, line_width=3.5): def plot(data, title): fig = go.Figure(data=data) fig.update_layout( - # title=title, - # #template="plotly_white", - # #plot_bgcolor="rgba(0, 0, 0, 0)", - # #scene=dict(aspectmode="manual", aspectratio=dict(x=1, y=1, z=1)), + # title=title, + # #template="plotly_white", + # #plot_bgcolor="rgba(0, 0, 0, 0)", + # #scene=dict(aspectmode="manual", aspectratio=dict(x=1, y=1, z=1)), width=1200, height=700, )