Skip to content

Commit

Permalink
minor upds
Browse files Browse the repository at this point in the history
  • Loading branch information
anna-grim committed Oct 19, 2024
1 parent e88ccdd commit 7b79c26
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 53 deletions.
2 changes: 1 addition & 1 deletion src/deep_neurographs/groundtruth_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from deep_neurographs.geometry import dist as get_dist
from deep_neurographs.utils import util

ALIGNED_THRESHOLD = 4
ALIGNED_THRESHOLD = 4.5
MIN_INTERSECTION = 10


Expand Down
2 changes: 1 addition & 1 deletion src/deep_neurographs/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def build_graph(self, fragments_pointer):
# Save valid labels and current graph
swcs_path = os.path.join(self.output_dir, "processed-swcs.zip")
labels_path = os.path.join(self.output_dir, "valid_labels.txt")
n_saved = self.graph.to_zipped_swcs(swcs_path)
n_saved = self.graph.to_zipped_swcs(swcs_path, min_size=100)
self.graph.save_labels(labels_path)
self.report(f"# SWCs Saved: {n_saved}")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,6 @@ def forward(self, x_dict, edge_index_dict, edge_attr_dict):

# --- Utils ---
def reformat_edge_key(key):
print(key)
return tuple([rm_non_alphanumeric(s) for s in key.split(",")])


Expand Down
2 changes: 1 addition & 1 deletion src/deep_neurographs/neurograph.py
Original file line number Diff line number Diff line change
Expand Up @@ -954,7 +954,7 @@ def component_path_length(self, root):
return path_length

# --- write graph to swcs ---
def to_zipped_swcs(self, zip_path, color=None, min_size=100):
def to_zipped_swcs(self, zip_path, color=None, min_size=0):
with zipfile.ZipFile(zip_path, "w") as zip_writer:
cnt = 0
for nodes in nx.connected_components(self):
Expand Down
93 changes: 46 additions & 47 deletions src/deep_neurographs/utils/graph_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def schedule_processes(self, swc_dicts):
a connected component of the graph corresponding to "swc_dicts".
"""
with ProcessPoolExecutor() as executor:
with ProcessPoolExecutor(max_workers=1) as executor:
# Assign Processes
i = 0
processes = [None] * len(swc_dicts)
Expand Down Expand Up @@ -195,7 +195,7 @@ def get_irreducibles(self, swc_dict):
# Build dense graph
swc_dict["idx"] = dict(zip(swc_dict["id"], range(len(swc_dict["id"]))))
graph, _ = swc_util.to_graph(swc_dict, set_attrs=True)
self.clip_branches(graph)
self.clip_branches(graph, swc_dict["swc_id"])
self.prune_branches(graph)

# Extract irreducibles
Expand All @@ -210,7 +210,7 @@ def get_irreducibles(self, swc_dict):
irreducibles.append(result)
return irreducibles

def clip_branches(self, graph):
def clip_branches(self, graph, swc_id):
"""
Deletes all nodes from "graph" that are not contained in "img_bbox".
Expand All @@ -232,6 +232,49 @@ def clip_branches(self, graph):
delete_nodes.add(i)
graph.remove_nodes_from(delete_nodes)

def prune_branches(self, graph):
"""
Prunes all short branches from "graph". A short branch is a path
between a leaf and branching node where the path length is less than
"self.prune_depth".
Parameters
----------
graph : networkx.Graph
Graph to be pruned.
Returns
-------
networkx.Graph
Graph with short branches pruned.
"""
first_pass = True
deleted_nodes = list()
n_passes = 0
while len(deleted_nodes) > 0 or first_pass:
# Visit leafs
n_passes += 1
deleted_nodes = list()
for leaf in get_leafs(graph):
branch = [leaf]
length = 0
for (i, j) in nx.dfs_edges(graph, source=leaf):
# Visit edge
length += compute_dist(graph, i, j)
if graph.degree(j) == 2:
branch.append(j)
elif graph.degree(j) > 2:
deleted_nodes.extend(branch)
graph.remove_nodes_from(branch)
break

# Check whether to stop
if length > self.prune_depth or first_pass:
graph.remove_nodes_from(branch[0:min(3, len(branch))])
break
first_pass = False

def get_component_irreducibles(self, graph, swc_dict):
"""
Gets the irreducible components of "graph".
Expand Down Expand Up @@ -301,50 +344,6 @@ def get_component_irreducibles(self, graph, swc_dict):
else:
return False

def prune_branches(self, graph):
"""
Prunes all short branches from "graph". A short branch is a path
between a leaf and branching node where the path length is less than
"self.prune_depth".
Parameters
----------
graph : networkx.Graph
Graph to be pruned.
Returns
-------
networkx.Graph
Graph with short branches pruned.
"""
first_pass = True
deleted_nodes = list()
n_passes = 0
while len(deleted_nodes) > 0 or first_pass:
# Visit leafs
n_passes += 1
deleted_nodes = list()
for leaf in get_leafs(graph):
branch = [leaf]
length = 0
for (i, j) in nx.dfs_edges(graph, source=leaf):
# Visit edge
length += compute_dist(graph, i, j)
if graph.degree(j) == 2:
branch.append(j)
elif graph.degree(j) > 2:
deleted_nodes.extend(branch)
graph.remove_nodes_from(branch)
break

# Check whether to stop
if length > self.prune_depth or first_pass:
graph.remove_nodes_from(branch[0:min(3, len(branch))])
break

first_pass = False


# --- Utils ---
def get_irreducible_nodes(graph):
Expand Down
4 changes: 2 additions & 2 deletions src/deep_neurographs/utils/swc_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,8 +642,8 @@ def __add_attributes(swc_dict, graph):
"""
attrs = dict()
for idx, node_id in enumerate(swc_dict["id"]):
attrs[node_id] = {
for idx, node in enumerate(swc_dict["id"]):
attrs[node] = {
"xyz": swc_dict["xyz"][idx],
"radius": swc_dict["radius"][idx],
}
Expand Down

0 comments on commit 7b79c26

Please sign in to comment.