Skip to content

Commit

Permalink
resolving merge conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
anna-grim committed Feb 21, 2024
2 parents 0f0de84 + 18f5d8c commit 92bed20
Show file tree
Hide file tree
Showing 10 changed files with 282 additions and 265 deletions.
3 changes: 2 additions & 1 deletion src/deep_neurographs/densegraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,11 @@ def init_graphs(self, swc_paths):
self.graphs = dict()
for path in swc_paths:
# Construct Graph
swc_id, swc_dict = swc_utils.parse_local_swc(path)
swc_dict = swc_utils.parse_local_swc(path)
graph, xyz_to_node = swc_utils.to_graph(swc_dict, set_attrs=True)

# Store
swc_id = swc_dict["swc_id"]
if type(swc_dict["xyz"]) == np.ndarray:
swc_dict["xyz"] = utils.numpy_to_hashable(swc_dict["xyz"])
xyz_to_id = dict(zip_broadcast(swc_dict["xyz"], swc_id))
Expand Down
109 changes: 38 additions & 71 deletions src/deep_neurographs/edit_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def prune_spurious_paths(graph, min_branch_length=16):
graph : networkx.graph
Graph to be pruned.
min_branch_length : int, optional
Upper bound on short branch length to be pruned. The default is 10.
Upper bound on short branch length to be pruned. The default is 16.
Returns
-------
Expand Down Expand Up @@ -49,86 +49,53 @@ def prune_spurious_paths(graph, min_branch_length=16):
return graph


def detect_short_connectors(graph, min_connector_length):
def prune_short_connectors(graph, connector_dist=8):
""" "
Detects shorts paths between branches (i.e. paths that connect branches).
Prunes shorts paths (i.e. connectors) between junctions nodes and the nbhd about the
junctions.
Parameters
----------
graph : netowrkx.graph
Graph to be inspected.
min_connector_length : int
Upper bound on short paths that connect branches.
connector_dist : int
Upper bound on the distance that defines a connector path to be pruned.
Returns
-------
remove_edges : list[tuple]
List of edges to be removed.
remove_nodes : list[int]
List of nodes to be removed.
graph : list[tuple]
Graph with connectors pruned
pruned_centroids : list[np.ndarray]
List of xyz coordinates of centroids of connectors
"""
leaf_nodes = [i for i in graph.nodes if graph.degree[i] == 1]
dfs_edges = list(nx.dfs_edges(graph, leaf_nodes[0]))
remove_nodes = []
remove_edges = []
flag_junction = False
path_length = 0
for (i, j) in dfs_edges:
# Check for junction
if graph.degree[i] > 2:
flag_junction = True
path_length = 1
cur_branch = [(i, j)]
elif flag_junction:
path_length += 1
cur_branch.append((i, j))

# Check whether to reset
if graph.degree[j] == 1:
flag_junction = False
cur_branch = list()
elif graph.degree[j] > 2 and flag_junction:
if path_length < min_connector_length:
remove_edges.extend(cur_branch)
remove_nodes.extend(graph.neighbors(cur_branch[0][0]))
remove_nodes.extend(graph.neighbors(j))
cur_branch = list()
return remove_edges, remove_nodes


def prune_short_connectors(list_of_graphs, min_connector_length=10):
"""
Prunes short connecting paths on graph in "list_of_graphs".
Parameters
----------
list_of_graphs : list[networkx.graph]
List of graphs such that short connecting paths will be pruned on
each graph.
min_connector_length : int, optional
Upper bound on short paths that connect branches. The default is 10.
Returns
-------
upd : list[networkx.graph]
List of graphs with short connecting paths pruned.
"""
upd = []
for graph in list_of_graphs:
pruned_graph = prune_spurious_paths(graph)
if pruned_graph.number_of_nodes() > 3:
remove_edges, remove_nodes = detect_short_connectors(
pruned_graph, min_connector_length
)
graph.remove_edges_from(remove_edges)
graph.remove_nodes_from(remove_nodes)
for g in nx.connected_components(graph):
subgraph = graph.subgraph(g).copy()
if subgraph.number_of_nodes() > 10:
upd.append(subgraph)
return upd
junctions = [j for j in graph.nodes if graph.degree[j] > 2]
pruned_centroids = []
pruned_nodes = set()
while len(junctions):
# Search nbhd
j = junctions.pop()
junction_nbs = []
for _, i in nx.dfs_edges(graph, source=j, depth_limit=connector_dist):
if graph.degree[i] > 2 and i != j:
junction_nbs.append(i)

# Store nodes to be pruned
print("# junction nbs:", len(junction_nbs))
for nb in junction_nbs:
connector = list(nx.shortest_path(graph, source=j, target=nb))
nbhd = set(nx.dfs_tree(graph, source=nb, depth_limit=5))
centroid = connector[len(connector) // 2]
pruned_nodes.update(nbhd.union(set(connector)))
pruned_centroids.append(graph.nodes[centroid]["xyz"])

if len(junction_nbs) > 0:
nbhd = set(nx.dfs_tree(graph, source=j, depth_limit=8))
pruned_nodes.update(nbhd)
break

graph.remove_nodes_from(list(pruned_nodes))
return graph, pruned_centroids


def break_crossovers(list_of_graphs, depth=10):
Expand Down Expand Up @@ -184,7 +151,7 @@ def detect_crossovers(graph, depth):
"""
cnt = 0
prune_nodes = []
junctions = [j for j in graph.nodes() if graph.degree(j) > 2]
junctions = [j for j in graph.nodes if graph.degree(j) > 2]
for j in junctions:
# Explore node
upd = False
Expand Down
91 changes: 17 additions & 74 deletions src/deep_neurographs/feature_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,28 +69,30 @@ def generate_features(
vector and the numerical vector.
"""
# Initialize proposals
if proposals is None:
proposals = neurograph.get_proposals()

# Generate features
features = {
"skel": generate_skel_features(neurograph, proposals=proposals)
"skel": generate_skel_features(neurograph, proposals)
}
if model_type in ["ConvNet", "MultiModalNet"]:
features["img_chunks"], features["img_profile"] = generate_img_chunks(
neurograph,
proposals,
img_path,
labels_path,
model_type=model_type,
proposals=proposals,
)
if model_type in ["AdaBoost", "RandomForest", "FeedForwardNet"]:
features["img_profile"] = generate_img_profiles(
neurograph, img_path, proposals=proposals
neurograph, proposals, img_path
)
return features


# -- Edge feature extraction --
def generate_img_chunks(
neurograph, img_path, labels_path, model_type=None, proposals=None
):
def generate_img_chunks(neurograph, proposals, img_path, labels_path):
"""
Generates an image chunk for each edge proposal such that the centroid of
the image chunk is the midpoint of the edge proposal. Image chunks contain
Expand All @@ -115,65 +117,6 @@ def generate_img_chunks(
Dictonary such that each pair is the edge id and image chunk.
"""
if neurograph.bbox:
return generate_img_chunks_via_superchunk(
neurograph, img_path, labels_path, proposals=proposals
)
else:
return generate_img_chunks_via_multithreads(
neurograph, img_path, labels_path, proposals=proposals
)


def generate_img_chunks_via_superchunk(
neurograph, img_path, labels_path, proposals=None
):
"""
Generates an image chunk for each edge proposal such that the centroid of
the image chunk is the midpoint of the edge proposal. Image chunks contain
two channels: raw image and predicted segmentation.
Parameters
----------
neurograph : NeuroGraph
NeuroGraph generated from a directory of swcs generated from a
predicted segmentation.
img_path : str
Path to raw image.
labels_path : str
Path to predicted segmentation.
proposals : list[frozenset], optional
List of edge proposals for which features will be generated. The
default is None.
Returns
-------
features : dict
Dictonary such that each pair is the edge id and image chunk.
"""
chunk_features = dict()
profile_features = dict()
img, labels = utils.get_superchunks(
img_path,
labels_path,
neurograph.origin,
neurograph.shape,
from_center=False,
)
for edge in neurograph.proposals:
xyz_0, xyz_1 = neurograph.proposal_xyz(edge)
coord_0 = utils.to_img(xyz_0) - neurograph.origin
coord_1 = utils.to_img(xyz_1) - neurograph.origin
chunks, profile = get_img_chunks(img, labels, coord_0, coord_1)
chunk_features[edge] = chunks
profile_features[edge] = profile
return chunk_features, profile_features


def generate_img_chunks_via_multithreads(
neurograph, img_path, labels_path, proposals=None
):
driver = "n5" if ".n5" in img_path else "zarr"
img = utils.open_tensorstore(img_path, driver)
labels = utils.open_tensorstore(labels_path, "neuroglancer_precomputed")
Expand Down Expand Up @@ -227,18 +170,18 @@ def get_img_chunks(img, labels, coord_0, coord_1, thread_id=None):
return chunk, profile


def generate_img_profiles(neurograph, path, proposals=None):
if neurograph.bbox:
def generate_img_profiles(neurograph, proposals, path):
if False: #neurograph.bbox:
return generate_img_profiles_via_superchunk(
neurograph, path, proposals=proposals
neurograph, proposals, path
)
else:
return generate_img_profiles_via_multithreads(
neurograph, path, proposals=proposals
neurograph, proposals, path
)


def generate_img_profiles_via_multithreads(neurograph, path, proposals=None):
def generate_img_profiles_via_multithreads(neurograph, proposals, path):
profile_features = dict()
driver = "n5" if ".n5" in path else "zarr"
img = utils.open_tensorstore(path, driver)
Expand All @@ -259,7 +202,7 @@ def generate_img_profiles_via_multithreads(neurograph, path, proposals=None):
return profile_features


def generate_img_profiles_via_superchunk(neurograph, path, proposals=None):
def generate_img_profiles_via_superchunk(neurograph, proposals, path):
"""
Generates an image intensity profile along each edge proposal by reading
a single superchunk from cloud that contains all proposals.
Expand Down Expand Up @@ -297,9 +240,9 @@ def generate_img_profiles_via_superchunk(neurograph, path, proposals=None):
return features


def generate_skel_features(neurograph, proposals=None):
def generate_skel_features(neurograph, proposals):
features = dict()
for edge in neurograph.proposals:
for edge in proposals:
i, j = tuple(edge)
features[edge] = np.concatenate(
(
Expand Down
6 changes: 3 additions & 3 deletions src/deep_neurographs/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def get_directional(
elif branch.shape[0] <= d:
xyz = deepcopy(branch)
else:
xyz = deepcopy(branch[d: window + d, :])
xyz = deepcopy(branch[d : window + d, :])
directionals.append(compute_tangent(xyz))

# Determine best
Expand Down Expand Up @@ -114,8 +114,8 @@ def smooth_branch(xyz, s=None):
xyz : numpy.ndarray
Array of xyz coordinates to be smoothed.
s : float
A parameter that controls the smoothness of the spline, where
"s" \in [0, N]. Note that the larger "s", the smoother the spline.
A parameter that controls the smoothness of the spline, where
"s" in [0, N]. Note that the larger "s", the smoother the spline.
Returns
-------
Expand Down
Loading

0 comments on commit 92bed20

Please sign in to comment.