diff --git a/src/deep_neurographs/fragment_filtering.py b/src/deep_neurographs/fragment_filtering.py index d77c492..cb35e91 100644 --- a/src/deep_neurographs/fragment_filtering.py +++ b/src/deep_neurographs/fragment_filtering.py @@ -8,83 +8,98 @@ other from a FragmentsGraph. """ + from collections import defaultdict +import networkx as nx import numpy as np -from networkx import connected_components from tqdm import tqdm from deep_neurographs import geometry -from deep_neurographs.utils import util -COLOR = "1.0 0.0 0.0" QUERY_DIST = 15 # --- Curvy Removal --- -def remove_curvy(graph, max_length, ratio=0.5): +def remove_curvy(fragments_graph, max_length, ratio=0.5): + """ + Removes connected components with 2 nodes from "fragments_graph" that are + "curvy" fragments, based on a specified ratio of endpoint distance to edge + length and a maximum length threshold. + + Parameters + ---------- + fragments_graph : FragmentsGraph + Graph generated from fragments of a predicted segmentation. + max_length : float + The maximum allowable length (in microns) for an edge to be considered + for removal. + ratio : float, optional + Threshold ratio of endpoint distance to edge length. Components with a + ratio below this value are considered "curvy" and are removed. The + default is 0.5. + + Returns + ------- + int + Number of fragments removed from the graph. + + """ deleted_ids = set() - components = [c for c in connected_components(graph) if len(c) == 2] - for nodes in tqdm(components, desc="Filter Curvy Fragment"): - if len(nodes) == 2: - i, j = tuple(nodes) - length = graph.edges[i, j]["length"] - endpoint_dist = graph.dist(i, j) - if endpoint_dist / length < ratio and length < max_length: - deleted_ids.add(graph.edges[i, j]["swc_id"]) - delete_fragment(graph, i, j) + components = get_line_components(fragments_graph) + for nodes in tqdm(components, desc="Filter Curvy Fragments"): + i, j = tuple(nodes) + length = fragments_graph.edges[i, j]["length"] + endpoint_dist = fragments_graph.dist(i, j) + if endpoint_dist / length < ratio and length < max_length: + deleted_ids.add(fragments_graph.edges[i, j]["swc_id"]) + delete_fragment(fragments_graph, i, j) return len(deleted_ids) # --- Doubles Removal --- -def remove_doubles(graph, max_length, node_spacing, output_dir=None): +def remove_doubles(fragments_graph, max_length, node_spacing): """ - Removes connected components from "neurgraph" that are likely to be a - double. + Removes connected components from "fragments_graph" that are likely to be + a double -- caused by ghosting in the image. Parameters ---------- - graph : FragmentsGraph + fragments_graph : FragmentsGraph Graph to be searched for doubles. max_length : int Maximum size of connected components to be searched. node_spacing : int - Expected distance in microns between nodes in "graph". - output_dir : str or None, optional - Directory that doubles will be written to. The default is None. + Expected distance (in microns) between nodes in "fragments_graph". Returns ------- - graph - Graph with doubles removed. + int + Number of fragments removed from graph. """ # Initializations - components = [c for c in connected_components(graph) if len(c) == 2] + components = get_line_components(fragments_graph) deleted_ids = set() - kdtree = graph.get_kdtree() - if output_dir: - util.mkdir(output_dir, delete=True) + kdtree = fragments_graph.get_kdtree() # Main - desc = "Filter Doubled Fragment" + desc = "Filter Doubled Fragments" for idx in tqdm(np.argsort([len(c) for c in components]), desc=desc): i, j = tuple(components[idx]) - swc_id = graph.nodes[i]["swc_id"] + swc_id = fragments_graph.nodes[i]["swc_id"] if swc_id not in deleted_ids: - if graph.edges[i, j]["length"] < max_length: + if fragments_graph.edges[i, j]["length"] < max_length: # Check doubles criteria - n_points = len(graph.edges[i, j]["xyz"]) - hits = compute_projections(graph, kdtree, (i, j)) + n_points = len(fragments_graph.edges[i, j]["xyz"]) + hits = compute_projections(fragments_graph, kdtree, (i, j)) if check_doubles_criteria(hits, n_points): - if output_dir: - graph.to_swc(output_dir, [i, j], color=COLOR) - delete_fragment(graph, i, j) + delete_fragment(fragments_graph, i, j) deleted_ids.add(swc_id) return len(deleted_ids) -def compute_projections(graph, kdtree, edge): +def compute_projections(fragments_graph, kdtree, edge): """ Given a fragment defined by "edge", this routine iterates of every xyz in the fragment and projects it onto the closest fragment. For each detected @@ -93,11 +108,11 @@ def compute_projections(graph, kdtree, edge): Parameters ---------- - graph : graph + fragments_graph : graph Graph that contains "edge". kdtree : KDTree KD-Tree that contains all xyz coordinates of every fragment in - "graph". + "fragments_graph". edge : tuple Pair of leaf nodes that define a fragment. @@ -109,13 +124,13 @@ def compute_projections(graph, kdtree, edge): """ hits = defaultdict(list) - query_id = graph.edges[edge]["swc_id"] - for i, xyz in enumerate(graph.edges[edge]["xyz"]): + query_id = fragments_graph.edges[edge]["swc_id"] + for i, xyz in enumerate(fragments_graph.edges[edge]["xyz"]): # Compute projections best_id = None best_dist = np.inf for hit_xyz in geometry.query_ball(kdtree, xyz, QUERY_DIST): - hit_id = graph.xyz_to_swc(hit_xyz) + hit_id = fragments_graph.xyz_to_swc(hit_xyz) if hit_id is not None and hit_id != query_id: if geometry.dist(hit_xyz, xyz) < best_dist: best_dist = geometry.dist(hit_xyz, xyz) @@ -157,15 +172,15 @@ def check_doubles_criteria(hits, n_points): return False -def delete_fragment(graph, i, j): +def delete_fragment(fragments_graph, i, j): """ - Deletes nodes "i" and "j" from "graph", where these nodes form a connected - component. + Deletes nodes "i" and "j" from "fragments_graph", where these nodes form a + connected component. Parameters ---------- - graph : FragmentsGraph - Graph that contains nodes to be deleted. + fragments_graph : FragmentsGraph + Graph that contains nodes to be removed. i : int Node to be removed. j : int @@ -173,28 +188,28 @@ def delete_fragment(graph, i, j): Returns ------- - graph + fragments_graph Graph with nodes removed. """ - graph = remove_xyz_entries(graph, i, j) - graph.swc_ids.remove(graph.nodes[i]["swc_id"]) - graph.remove_nodes_from([i, j]) + fragments_graph = remove_xyz_entries(fragments_graph, i, j) + fragments_graph.swc_ids.remove(fragments_graph.nodes[i]["swc_id"]) + fragments_graph.remove_nodes_from([i, j]) -def remove_xyz_entries(graph, i, j): +def remove_xyz_entries(fragments_graph, i, j): """ - Removes dictionary entries from "graph.xyz_to_edge" corresponding to - the edge {i, j}. + Removes dictionary entries from "fragments_graph.xyz_to_edge" + corresponding to the edge {i, j}. Parameters ---------- - graph : graph + fragments_graph : graph Graph to be updated. i : int - Node in "graph". + Node in graph. j : int - Node in "graph". + Node in graph. Returns ------- @@ -202,9 +217,9 @@ def remove_xyz_entries(graph, i, j): Updated graph. """ - for xyz in graph.edges[i, j]["xyz"]: - del graph.xyz_to_edge[tuple(xyz)] - return graph + for xyz in fragments_graph.edges[i, j]["xyz"]: + del fragments_graph.xyz_to_edge[tuple(xyz)] + return fragments_graph def upd_hits(hits, key, value): @@ -215,8 +230,8 @@ def upd_hits(hits, key, value): Parameters ---------- hits : dict - Stores swd_ids of fragments that are within a certain distance a query - fragment along with the corresponding distances. + Stores swd_ids of fragments within a certain distance a query fragment + along with the corresponding distances. key : str swc id of some fragment. value : float @@ -229,9 +244,30 @@ def upd_hits(hits, key, value): Updated version of hits. """ - if key in hits.keys(): + if key in hits: if value < hits[key]: hits[key] = value else: hits[key] = value return hits + + +# --- utils --- +def get_line_components(graph): + """ + Identifies and returns all line components in the given graph. A line + component is defined as a connected component with exactly two nodes. + + Parameters + ---------- + graph : networkx.Graph + Input graph in which line components are to be identified. + + Returns + ------- + List[set] + List of sets, where each set contains two nodes representing a + connected component with exactly two nodes. + + """ + return [c for c in nx.connected_components(graph) if len(c) == 2] diff --git a/src/deep_neurographs/neurograph.py b/src/deep_neurographs/fragments_graph.py similarity index 100% rename from src/deep_neurographs/neurograph.py rename to src/deep_neurographs/fragments_graph.py diff --git a/src/deep_neurographs/generate_proposals.py b/src/deep_neurographs/generate_proposals.py index 1f23c0d..eb531ff 100644 --- a/src/deep_neurographs/generate_proposals.py +++ b/src/deep_neurographs/generate_proposals.py @@ -21,7 +21,7 @@ def run( - graph, + fragments_graph, radius, complex_bool=False, long_range_bool=True, @@ -29,11 +29,11 @@ def run( trim_endpoints_bool=True, ): """ - Generates proposals emanating from "leaf". + Generates proposals for fragments graph. Parameters ---------- - graph : FragmentsGraph + fragments_graph : FragmentsGraph Graph that proposals will be generated for. radius : float Maximum Euclidean distance between endpoints of proposal. @@ -58,29 +58,29 @@ def run( """ # Initializations connections = dict() - kdtree = init_kdtree(graph, complex_bool) + kdtree = init_kdtree(fragments_graph, complex_bool) radius *= RADIUS_SCALING_FACTOR if trim_endpoints_bool else 1.0 if progress_bar: - iterable = tqdm(graph.get_leafs(), desc="Proposals") + iterable = tqdm(fragments_graph.get_leafs(), desc="Proposals") else: - iterable = graph.get_leafs() + iterable = fragments_graph.get_leafs() # Main for leaf in iterable: # Generate potential proposals candidates = get_candidates( - graph, + fragments_graph, leaf, kdtree, radius, - graph.proposals_per_leaf, + fragments_graph.proposals_per_leaf, complex_bool, ) # Generate long range proposals (if applicable) if len(candidates) == 0 and long_range_bool: candidates = get_candidates( - graph, + fragments_graph, leaf, kdtree, radius * RADIUS_SCALING_FACTOR, @@ -90,33 +90,34 @@ def run( # Determine which potential proposals to keep for i in candidates: - leaf_swc_id = graph.nodes[leaf]["swc_id"] - pair_id = frozenset((leaf_swc_id, graph.nodes[i]["swc_id"])) + leaf_swc_id = fragments_graph.nodes[leaf]["swc_id"] + node_swc_id = fragments_graph.nodes[i]["swc_id"] + pair_id = frozenset((leaf_swc_id, node_swc_id)) if pair_id in connections.keys(): cur_proposal = connections[pair_id] - cur_dist = graph.proposal_length(cur_proposal) - if graph.dist(leaf, i) < cur_dist: - graph.remove_proposal(cur_proposal) + cur_dist = fragments_graph.proposal_length(cur_proposal) + if fragments_graph.dist(leaf, i) < cur_dist: + fragments_graph.remove_proposal(cur_proposal) del connections[pair_id] else: continue # Add proposal - graph.add_proposal(leaf, i) + fragments_graph.add_proposal(leaf, i) connections[pair_id] = frozenset({leaf, i}) # Trim endpoints (if applicable) n_trimmed = 0 if trim_endpoints_bool: radius /= RADIUS_SCALING_FACTOR - long_range, in_range = separate_proposals(graph, radius) - graph, n_trimmed_1 = run_trimming(graph, long_range, radius) - graph, n_trimmed_2 = run_trimming(graph, in_range, radius) - n_trimmed = n_trimmed_1 + n_trimmed_2 + long_range, in_range = partition_proposals(fragments_graph, radius) + cnt_1 = run_trimming(fragments_graph, long_range, radius) + cnt_2 = run_trimming(fragments_graph, in_range, radius) + n_trimmed = cnt_1 + cnt_2 return n_trimmed -def init_kdtree(graph, complex_bool): +def init_kdtree(fragments_graph, complex_bool): """ Initializes a KD-Tree used to generate proposals. @@ -130,13 +131,14 @@ def init_kdtree(graph, complex_bool): Returns ------- scipy.spatial.cKDTree - kdtree. + kdtree built from all xyz coordinates across edges in graph if + complex_bool is True; otherwise, only built from leaf nodes. """ if complex_bool: - return graph.get_kdtree() + return fragments_graph.get_kdtree() else: - return graph.get_kdtree(node_type="leaf") + return fragments_graph.get_kdtree(node_type="leaf") def get_candidates( @@ -157,16 +159,16 @@ def get_candidates( return list() if max_proposals < 0 else candidates -def search_kdtree(graph, leaf, kdtree, radius, max_proposals): +def search_kdtree(fragments_graph, leaf, kdtree, radius, max_proposals): """ - Generates proposals for node "leaf" in "graph" by finding candidate - xyz points on distinct connected components nearby. + Generates proposals emanating from node "leaf" by finding candidate xyz + points on distinct connected components nearby. Parameters ---------- - graph : FragmentsGraph - Graph built from swc files. - kdtree : ... + fragments_graph : FragmentsGraph + Graph that proposals will be generated for. + kdtree : scipy.spatial.cKDTree ... leaf : int Leaf node that proposals are to be generated from. @@ -183,10 +185,10 @@ def search_kdtree(graph, leaf, kdtree, radius, max_proposals): """ # Generate candidates candidates = dict() - leaf_xyz = graph.nodes[leaf]["xyz"] + leaf_xyz = fragments_graph.nodes[leaf]["xyz"] for xyz in geometry.query_ball(kdtree, leaf_xyz, radius): - swc_id = graph.xyz_to_swc(xyz) - if swc_id != graph.nodes[leaf]["swc_id"]: + swc_id = fragments_graph.xyz_to_swc(xyz) + if swc_id != fragments_graph.nodes[leaf]["swc_id"]: d = geometry.dist(leaf_xyz, xyz) if swc_id not in candidates.keys(): candidates[swc_id] = {"dist": d, "xyz": tuple(xyz)} @@ -231,13 +233,13 @@ def get_best(candidates, max_proposals): return list_candidates_xyz(candidates) -def get_connecting_node(graph, leaf, xyz, radius, complex_bool): +def get_connecting_node(fragments_graph, leaf, xyz, radius, complex_bool): """ - Gets the node that proposal with leaf will connect to. + Gets node that proposal emanating from "leaf" will connect to. Parameters ---------- - graph : FragmentsGraph + fragments_graph : FragmentsGraph Graph containing "leaf". leaf : int Leaf node. @@ -247,28 +249,28 @@ def get_connecting_node(graph, leaf, xyz, radius, complex_bool): Returns ------- int - Node id. + Node id that proposal will connect to. """ - edge = graph.xyz_to_edge[xyz] - node = get_closer_endpoint(graph, edge, xyz) - if graph.dist(leaf, node) < radius: + edge = fragments_graph.xyz_to_edge[xyz] + node = get_closer_endpoint(fragments_graph, edge, xyz) + if fragments_graph.dist(leaf, node) < radius: return node elif complex_bool: - attrs = graph.get_edge_data(*edge) + attrs = fragments_graph.get_edge_data(*edge) idx = np.where(np.all(attrs["xyz"] == xyz, axis=1))[0][0] if type(idx) is int: - return graph.split_edge(edge, attrs, idx) + return fragments_graph.split_edge(edge, attrs, idx) return None -def get_closer_endpoint(graph, edge, xyz): +def get_closer_endpoint(fragments_graph, edge, xyz): """ - Gets the node from "edge" that is closer to "xyz". + Gets node from "edge" that is closer to "xyz". Parameters ---------- - graph : FragmentsGraph + fragments_graph : FragmentsGraph Graph containing "edge". edge : tuple Edge to be checked. @@ -277,51 +279,66 @@ def get_closer_endpoint(graph, edge, xyz): Returns ------- - tuple - Node id and its distance from "xyz". + int + Node closer to "xyz". """ i, j = tuple(edge) - d_i = geometry.dist(graph.nodes[i]["xyz"], xyz) - d_j = geometry.dist(graph.nodes[j]["xyz"], xyz) + d_i = geometry.dist(fragments_graph.nodes[i]["xyz"], xyz) + d_j = geometry.dist(fragments_graph.nodes[j]["xyz"], xyz) return i if d_i < d_j else j -def separate_proposals(graph, radius): +def partition_proposals(fragments_graph, radius): + """ + Partitions proposals in "fragments_graph" into long-range and in-range + categories based on a specified length threshold. + + Parameters + ---------- + fragments_graph : FragmentsGraph + Graph with proposals to be partitioned. + radius : float + Length threshold used to partition proposals. Proposals with length + greater than "radius" are said to be long-range; otherwise, in-range. + + Returns + ------- + list, list + Lists of long-range and in-range proposals. + + """ long_range_proposals = list() - proposals = list() - for proposal in graph.proposals: - i, j = tuple(proposal) - if graph.dist(i, j) > radius: - long_range_proposals.append(proposal) + in_range_proposals = list() + for p in fragments_graph.proposals: + if fragments_graph.proposal_length(p) > radius: + long_range_proposals.append(p) else: - proposals.append(proposal) - return long_range_proposals, proposals + in_range_proposals.append(p) + return long_range_proposals, in_range_proposals # --- Trim Endpoints --- -def run_trimming(graph, proposals, radius): - n_endpoints_trimmed = 0 +def run_trimming(fragments_graph, proposals, radius): + n_trimmed = 0 long_radius = radius * RADIUS_SCALING_FACTOR - for proposal in deepcopy(proposals): - i, j = tuple(proposal) - is_simple = graph.is_simple(proposal) - is_single = graph.is_single_proposal(proposal) + for p in deepcopy(proposals): + is_simple = fragments_graph.is_simple(p) + is_single = fragments_graph.is_single_proposal(p) trim_bool = False if is_simple and is_single: - graph, trim_bool = trim_endpoints( - graph, i, j, long_radius - ) - elif graph.dist(i, j) > radius: - graph.remove_proposal(proposal) - n_endpoints_trimmed += 1 if trim_bool else 0 - return graph, n_endpoints_trimmed + trim_bool = trim_endpoints(fragments_graph, p, long_radius) + elif fragments_graph.proposal_length(p) > radius: + fragments_graph.remove_proposal(p) + n_trimmed += 1 if trim_bool else 0 + return n_trimmed -def trim_endpoints(graph, i, j, radius): +def trim_endpoints(fragments_graph, proposal, radius): # Initializations - branch_i = graph.branch(i) - branch_j = graph.branch(j) + i, j = tuple(proposal) + branch_i = fragments_graph.branch(i) + branch_j = fragments_graph.branch(j) # Check both orderings idx_i, idx_j = trim_endpoints_ordered(branch_i, branch_j) @@ -334,14 +351,14 @@ def trim_endpoints(graph, i, j, radius): # Update branches (if applicable) if min(d1, d2) > radius: - graph.remove_proposal(frozenset((i, j))) - return graph, False + fragments_graph.remove_proposal(frozenset((i, j))) + return False elif min(d1, d2) + 2 < geometry.dist(branch_i[0], branch_j[0]): if compute_dot(branch_i, branch_j, idx_i, idx_j) < DOT_THRESHOLD: - graph = trim_to_idx(graph, i, idx_i) - graph = trim_to_idx(graph, j, idx_j) - return graph, True - return graph, False + fragments_graph = trim_to_idx(fragments_graph, i, idx_i) + fragments_graph = trim_to_idx(fragments_graph, j, idx_j) + return True + return False def trim_endpoints_ordered(branch_1, branch_2): @@ -376,13 +393,13 @@ def trim_endpoint(branch_1, branch_2): return 0 if best_idx is None else best_idx -def trim_to_idx(graph, i, idx): +def trim_to_idx(fragments_graph, i, idx): """ Trims the branch emanating from "i". Parameters ---------- - graph : FragmentsGraph + fragments_graph : FragmentsGraph Graph containing node "i" i : int Leaf node. @@ -395,21 +412,21 @@ def trim_to_idx(graph, i, idx): """ # Update node - branch_xyz = graph.branch(i, key="xyz") - branch_radii = graph.branch(i, key="radius") - graph.nodes[i]["xyz"] = branch_xyz[idx] - graph.nodes[i]["radius"] = branch_radii[idx] + branch_xyz = fragments_graph.branch(i, key="xyz") + branch_radii = fragments_graph.branch(i, key="radius") + fragments_graph.nodes[i]["xyz"] = branch_xyz[idx] + fragments_graph.nodes[i]["radius"] = branch_radii[idx] # Update edge - j = graph.leaf_neighbor(i) - graph.edges[i, j]["xyz"] = branch_xyz[idx::] - graph.edges[i, j]["radius"] = branch_radii[idx::] + j = fragments_graph.leaf_neighbor(i) + fragments_graph.edges[i, j]["xyz"] = branch_xyz[idx::] + fragments_graph.edges[i, j]["radius"] = branch_radii[idx::] for k in range(idx): try: - del graph.xyz_to_edge[tuple(branch_xyz[k])] + del fragments_graph.xyz_to_edge[tuple(branch_xyz[k])] except KeyError: pass - return graph + return fragments_graph # --- utils --- diff --git a/src/deep_neurographs/inference.py b/src/deep_neurographs/inference.py index b34e49e..2d0537d 100644 --- a/src/deep_neurographs/inference.py +++ b/src/deep_neurographs/inference.py @@ -66,7 +66,7 @@ def __init__( model_path, output_dir, config, - device=None, + device="cpu", is_multimodal=False, label_path=None, log_runtimes=True, diff --git a/src/deep_neurographs/utils/graph_util.py b/src/deep_neurographs/utils/graph_util.py index f8ef9c3..da4e68d 100644 --- a/src/deep_neurographs/utils/graph_util.py +++ b/src/deep_neurographs/utils/graph_util.py @@ -118,7 +118,7 @@ def run( FragmentsGraph generated from swc files. """ - from deep_neurographs.neurograph import FragmentsGraph + from deep_neurographs.fragments_graph import FragmentsGraph # Load fragments and extract irreducibles self.img_bbox = img_util.init_bbox(img_patch_origin, img_patch_shape)