Skip to content

Commit

Permalink
Refactor fragment thresholding (#260)
Browse files Browse the repository at this point in the history
* refactor: removed trim option, exact path length threshold, mark merges

* upds

---------

Co-authored-by: anna-grim <anna.grim@alleninstitute.org>
  • Loading branch information
anna-grim and anna-grim authored Oct 7, 2024
1 parent 3257fb3 commit 513d022
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 34 deletions.
4 changes: 2 additions & 2 deletions src/deep_neurographs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class GraphConfig:
min_size: float = 30.0
node_spacing: int = 1
proposals_per_leaf: int = 2
prune_depth: float = 16.0
prune_depth: float = 25.0
remove_doubles_bool: bool = False
search_radius: float = 20.0
smooth_bool: bool = True
Expand Down Expand Up @@ -93,7 +93,7 @@ class MLConfig:
batch_size: int = 2000
downsample_factor: int = 1
high_threshold: float = 0.9
lr: float = 1e-4
lr: float = 1e-3
threshold: float = 0.6
model_type: str = "GraphNeuralNet"
n_epochs: int = 1000
Expand Down
12 changes: 7 additions & 5 deletions src/deep_neurographs/generate_proposals.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from deep_neurographs import geometry

DOT_THRESHOLD = -0.4
DOT_THRESHOLD = -0.3
RADIUS_SCALING_FACTOR = 1.5
TRIM_SEARCH_DIST = 15

Expand Down Expand Up @@ -61,9 +61,9 @@ def run(
kdtree = init_kdtree(neurograph, complex_bool)
radius *= RADIUS_SCALING_FACTOR if trim_endpoints_bool else 1.0
if progress_bar:
iterable = tqdm(neurograph.leafs, desc="Proposals")
iterable = tqdm(neurograph.get_leafs(), desc="Proposals")
else:
iterable = neurograph.leafs
iterable = neurograph.get_leafs()

# Main
for leaf in iterable:
Expand Down Expand Up @@ -151,7 +151,7 @@ def get_candidates(
if max_proposals < 0 and len(candidates) == 1:
return candidates if neurograph.is_leaf(candidates[0]) else []
else:
return [] if max_proposals < 0 else candidates
return list() if max_proposals < 0 else candidates


def search_kdtree(neurograph, leaf, kdtree, radius, max_proposals):
Expand Down Expand Up @@ -312,6 +312,7 @@ def run_trimming(neurograph, proposals, radius):
elif neurograph.dist(i, j) > radius:
neurograph.remove_proposal(proposal)
n_endpoints_trimmed += 1 if trim_bool else 0
print("# Endpoints Trimmed:", n_endpoints_trimmed)
return neurograph


Expand Down Expand Up @@ -458,9 +459,10 @@ def compute_dot(branch_1, branch_2, idx_1, idx_2):
b2 = branch_2 - geometry.midpoint(branch_1[idx_1], branch_2[idx_2])

# Main
dot_5 = np.dot(tangent(b1, idx_1, 5), tangent(b2, idx_2, 5))
dot_10 = np.dot(tangent(b1, idx_1, 10), tangent(b2, idx_2, 10))
dot_20 = np.dot(tangent(b1, idx_1, 20), tangent(b2, idx_2, 20))
return min(dot_10, dot_20)
return min(dot_5, min(dot_10, dot_20))


def tangent(branch, idx, depth):
Expand Down
2 changes: 1 addition & 1 deletion src/deep_neurographs/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def smooth_branch(xyz, s=None):
return xyz.astype(np.float32)


def fit_spline(xyz, k=3, s=None):
def fit_spline(xyz, k=2, s=None):
"""
Fits a cubic spline to an array containing xyz coordinates.
Expand Down
2 changes: 0 additions & 2 deletions src/deep_neurographs/graph_artifact_removal.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,6 @@ def delete(neurograph, nodes, swc_id):
i, j = tuple(nodes)
neurograph = remove_xyz_entries(neurograph, i, j)
neurograph.remove_nodes_from([i, j])
neurograph.leafs.remove(i)
neurograph.leafs.remove(j)
neurograph.swc_ids.remove(swc_id)
return neurograph

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def edge_skeletal(neurograph, computation_graph):
edge_skeletal_features[frozenset(edge)] = np.array(
[
np.mean(neurograph.edges[edge]["radius"]),
neurograph.edges[edge]["length"] / 1000,
min(neurograph.edges[edge]["length"], 500) / 500,
],
)
return edge_skeletal_features
Expand Down Expand Up @@ -227,7 +227,7 @@ def proposal_skeletal(neurograph, proposals, radius):
for proposal in proposals:
proposal_skeletal_features[proposal] = np.concatenate(
(
neurograph.proposal_length(proposal),
neurograph.proposal_length(proposal) / radius,
neurograph.n_nearby_leafs(proposal, radius),
neurograph.proposal_radii(proposal),
neurograph.proposal_directionals(proposal, 8),
Expand Down
49 changes: 32 additions & 17 deletions src/deep_neurographs/neurograph.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ def __init__(self, img_bbox=None, node_spacing=1):
self.xyz_to_edge = dict()

# Nodes and Edges
self.leafs = set()
self.junctions = set()
self.proposals = set()
self.target_edges = set()
Expand Down Expand Up @@ -106,6 +105,22 @@ def set_proxy_soma_ids(self, k):
for i in gutil.largest_components(self, k):
self.soma_ids[self.nodes[i]["swc_id"]] = i

def get_leafs(self):
"""
Gets all leaf nodes in graph.
Parameters
----------
None
Returns
-------
list[int]
Leaf nodes in graph.
"""
return [i for i in self.nodes if self.is_leaf(i)]

# --- Edit Graph --
def add_component(self, irreducibles):
"""
Expand Down Expand Up @@ -170,10 +185,6 @@ def __add_nodes(self, irreducibles, node_type, node_ids):
xyz=irreducibles[node_type][i]["xyz"],
)
self.node_cnt += 1
if node_type == "leafs":
self.leafs.add(cur_id)
else:
self.junctions.add(cur_id)
node_ids[i] = cur_id
return node_ids

Expand Down Expand Up @@ -521,7 +532,7 @@ def get_kdtree(self, node_type=None):
"""
# Get xyz coordinates
if node_type == "leaf":
xyz_list = [self.nodes[i]["xyz"] for i in self.leafs]
xyz_list = [self.nodes[i]["xyz"] for i in self.get_leafs()]
elif node_type == "proposal":
xyz_list = list(self.xyz_to_proposal.keys())
else:
Expand Down Expand Up @@ -649,11 +660,12 @@ def proposal_directionals(self, proposal, window):

def merge_proposal(self, proposal):
i, j = tuple(proposal)
somas_check = not (self.is_soma(i) and self.is_soma(j))
degrees_check = not (self.degree[i] == 2 and self.degree[j] == 2)
if somas_check and degrees_check:
somas_check = not (self.is_soma(i) and self.is_soma(j))
if somas_check and self.check_proposal_degrees(i, j):
# Dense attributes
attrs = dict()
self.nodes[i]["radius"] = 7.0
self.nodes[j]["radius"] = 7.0
for k in ["xyz", "radius"]:
combine = np.vstack if k == "xyz" else np.array
self.nodes[i][k][-1] = 8.0
Expand All @@ -666,12 +678,12 @@ def merge_proposal(self, proposal):
e_j = (j, self.leaf_neighbor(j))
len_ij = self.edges[e_i]["length"] + self.edges[e_j]["length"]
attrs["length"] = len_ij
elif self.degree[i] == 2:
e_j = (j, self.leaf_neighbor(j))
attrs["length"] = self.edges[e_j]["length"]
else:
elif self.degree[i] == 1:
e_i = (i, self.leaf_neighbor(i))
attrs["length"] = self.edges[e_i]["length"]
else:
e_j = (j, self.leaf_neighbor(j))
attrs["length"] = self.edges[e_j]["length"]

swc_id_i = self.nodes[i]["swc_id"]
swc_id_j = self.nodes[j]["swc_id"]
Expand All @@ -681,12 +693,15 @@ def merge_proposal(self, proposal):
self.merged_ids.add((swc_id_i, swc_id_j))
self.upd_ids(swc_id, j if swc_id == swc_id_i else i)
self.__add_edge((i, j), attrs, swc_id)
if i in self.leafs:
self.leafs.remove(i)
if j in self.leafs:
self.leafs.remove(j)
self.proposals.remove(proposal)
else:
print("Skip! -- Failed Degree Check")

def check_proposal_degrees(self, i, j):
one_leaf = self.degree[i] == 1 or self.degree[j] == 1
branching = self.degree[i] > 2 or self.degree[j] > 2
return one_leaf and not branching

def upd_ids(self, swc_id, r):
"""
Updates the swc_id of all nodes connected to "r".
Expand Down
10 changes: 5 additions & 5 deletions src/deep_neurographs/utils/graph_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
MIN_SIZE = 30
NODE_SPACING = 1
SMOOTH_BOOL = True
PRUNE_DEPTH = 16
PRUNE_DEPTH = 25


class GraphLoader:
Expand Down Expand Up @@ -165,7 +165,7 @@ def get_irreducibles(
min_size,
img_bbox=None,
progress_bar=True,
prune_depth=16.0,
prune_depth=PRUNE_DEPTH,
smooth_bool=True,
):
"""
Expand Down Expand Up @@ -220,7 +220,7 @@ def get_component_irreducibles(
swc_dict,
min_size,
img_bbox=None,
prune_depth=16.0,
prune_depth=PRUNE_DEPTH,
smooth_bool=True,
):
"""
Expand Down Expand Up @@ -451,7 +451,7 @@ def prune_branch(graph, leaf, prune_depth):
"""
branch = [leaf]
node_dists = list()
for (i, j) in nx.dfs_edges(graph, source=leaf, depth_limit=prune_depth):
for (i, j) in nx.dfs_edges(graph, source=leaf, depth_limit=2*prune_depth):
# Visit edge
node_dists.append(compute_dist(graph, i, j))
if graph.degree(j) > 2:
Expand All @@ -462,7 +462,7 @@ def prune_branch(graph, leaf, prune_depth):
# Check whether to stop
if np.sum(node_dists) > prune_depth:
break
return list()
return branch[0:min(4, len(branch))]


def smooth_branch(swc_dict, attrs, edges, nbs, root, j):
Expand Down

0 comments on commit 513d022

Please sign in to comment.