From ae887d76a08b8afea3b64ae3aa2eb2ad04a8b6b1 Mon Sep 17 00:00:00 2001 From: Anna Grim <108307071+anna-grim@users.noreply.github.com> Date: Wed, 7 Feb 2024 22:15:01 -0800 Subject: [PATCH] refactor: remove immutable edge notion, upd coord conversions (#51) Co-authored-by: anna-grim --- src/deep_neurographs/feature_extraction.py | 42 ++-- src/deep_neurographs/geometry.py | 33 ++- src/deep_neurographs/graph_utils.py | 36 +++- src/deep_neurographs/intake.py | 17 +- src/deep_neurographs/neurograph.py | 199 ++++++++++--------- src/deep_neurographs/structural_inference.py | 47 +++-- src/deep_neurographs/swc_utils.py | 36 ++-- src/deep_neurographs/utils.py | 65 +++--- src/deep_neurographs/visualization.py | 6 +- 9 files changed, 265 insertions(+), 216 deletions(-) diff --git a/src/deep_neurographs/feature_extraction.py b/src/deep_neurographs/feature_extraction.py index 118f660..e0b3e7e 100644 --- a/src/deep_neurographs/feature_extraction.py +++ b/src/deep_neurographs/feature_extraction.py @@ -97,7 +97,8 @@ def generate_img_chunks( def generate_img_chunks_via_multithreads( neurograph, img_path, labels_path, proposals=None ): - img = utils.open_tensorstore(img_path, "zarr") + driver = "n5" if ".n5" in img_path else "zarr" + img = utils.open_tensorstore(img_path, driver) labels = utils.open_tensorstore(labels_path, "neuroglancer_precomputed") with ThreadPoolExecutor() as executor: # Assign Threads @@ -155,9 +156,9 @@ def generate_img_chunks_via_superchunk( # Compute image coordinates i, j = tuple(edge) xyz = gutils.get_edge_attr(neurograph, edge, "xyz") - xyz_i = utils.world_to_img(neurograph, xyz[0]) - xyz_j = utils.world_to_img(neurograph, xyz[1]) - chunk, profile = get_img_chunk_features(img, labels, xyz_i, xyz_j) + coord_i = utils.to_img(xyz[0]) + coord_j = utils.to_img(xyz[1]) + chunk, profile = get_img_chunk_features(img, labels, coord_i, coord_j) chunk_features[edge] = chunk profile_features[edge] = profile return chunk_features, profile_features @@ -203,19 +204,18 @@ def generate_img_profiles(neurograph, path, proposals=None): ) -def generate_img_profiles_via_multithreads( - neurograph, img_path, proposals=None -): +def generate_img_profiles_via_multithreads(neurograph, path, proposals=None): profile_features = dict() - img = utils.open_tensorstore(img_path, "zarr") + driver = "n5" if ".n5" in path else "zarr" + img = utils.open_tensorstore(path, driver) with ThreadPoolExecutor() as executor: # Assign threads threads = [None] * len(proposals) for i, edge in enumerate(proposals): xyz_i, xyz_j = gutils.get_edge_attr(neurograph, edge, "xyz") - xyz_i = utils.world_to_img(neurograph, xyz_i) - xyz_j = utils.world_to_img(neurograph, xyz_j) - line = geometry.make_line(xyz_i, xyz_j, N_PROFILE_POINTS) + coord_i = utils.to_img(xyz_i) + coord_j = utils.to_img(xyz_j) + line = geometry.make_line(coord_i, coord_j, N_PROFILE_POINTS) threads[i] = executor.submit(geometry.get_profile, img, line, edge) # Store result @@ -249,16 +249,16 @@ def generate_img_profiles_via_superchunk(neurograph, path, proposals=None): """ features = dict() - origin = utils.apply_anisotropy(neurograph.origin, return_int=True) + driver = "n5" if ".n5" in path else "zarr" img = utils.get_superchunk( - path, "zarr", origin, neurograph.shape, from_center=False + path, driver, neurograph.origin, neurograph.shape, from_center=False ) img = utils.normalize_img(img) for edge in neurograph.mutable_edges: xyz_i, xyz_j = neurograph.get_edge_attr(edge, "xyz") - xyz_i = utils.world_to_img(neurograph, xyz_i) - xyz_j = utils.world_to_img(neurograph, xyz_j) - path = geometry.make_line(xyz_i, xyz_j, N_PROFILE_POINTS) + coord_i = utils.to_img(xyz_i) - neurograph.origin + coord_j = utils.to_img(xyz_j) - neurograph.origin + path = geometry.make_line(coord_i, coord_j, N_PROFILE_POINTS) features[edge] = geometry.get_profile(img, path, window=WINDOW) return features @@ -270,8 +270,8 @@ def generate_skel_features(neurograph, proposals=None): features[edge] = np.concatenate( ( neurograph.compute_length(edge), - neurograph.immutable_degree(i), - neurograph.immutable_degree(j), + neurograph.get_degree_temp(i), + neurograph.get_degree_temp(j), get_radii(neurograph, edge), get_avg_radii(neurograph, edge), get_avg_branch_lens(neurograph, edge), @@ -413,7 +413,7 @@ def get_feature_vectors(neurograph, features, shift=0): features = combine_features(features) features.keys() key = sample(list(features.keys()), 1)[0] - n_edges = neurograph.num_mutables() + n_edges = neurograph.n_mutables() n_features = len(features[key]) # Build @@ -429,7 +429,7 @@ def get_feature_vectors(neurograph, features, shift=0): def get_multimodal_features(neurograph, features, shift=0): idx_to_edge = dict() - n_edges = neurograph.num_mutables() + n_edges = neurograph.n_mutables() X = np.zeros(((n_edges, 2) + tuple(CHUNK_SIZE))) x = np.zeros((n_edges, N_SKEL_FEATURES + N_PROFILE_POINTS)) y = np.zeros((n_edges)) @@ -445,7 +445,7 @@ def get_multimodal_features(neurograph, features, shift=0): def get_img_chunks(neurograph, features, shift=0): idx_to_edge = dict() - n_edges = neurograph.num_mutables() + n_edges = neurograph.n_mutables() X = np.zeros(((n_edges, 2) + tuple(CHUNK_SIZE))) y = np.zeros((n_edges)) for i, edge in enumerate(features["img_chunks"].keys()): diff --git a/src/deep_neurographs/geometry.py b/src/deep_neurographs/geometry.py index 72ae741..1e504dd 100644 --- a/src/deep_neurographs/geometry.py +++ b/src/deep_neurographs/geometry.py @@ -22,7 +22,7 @@ def get_directional( elif branch.shape[0] <= d: xyz = deepcopy(branch) else: - xyz = deepcopy(branch[d : window + d, :]) + xyz = deepcopy(branch[d: window + d, :]) directionals.append(compute_tangent(xyz)) # Determine best @@ -38,6 +38,28 @@ def get_directional( def compute_svd(xyz): + """ + Compute singular value decomposition (svd) of an NxD array where N is the + number of points and D is the dimension of the space. + + Parameters + ---------- + xyz : numpy.ndarray + Array containing data points. + + Returns + ------- + numpy.ndarry + Unitary matrix having left singular vectors as columns. Of shape + (N, N) or (N, min(N, D)), depending on full_matrices. + numpy.ndarray + Singular values, sorted in non-increasing order. Of shape (K,), with + K = min(N, D). + numpy.ndarray + Unitary matrix having right singular vectors as rows. Of shape (D, D) + or (K, D) depending on full_matrices. + + """ xyz = xyz - np.mean(xyz, axis=0) return svd(xyz) @@ -242,9 +264,9 @@ def align(neurograph, img, branch_1, branch_2, depth): best_d2 = None best_score = 0 for d1 in range(min(depth, len(branch_1) - 1)): - xyz_1 = neurograph.to_img(branch_1[d1]) + xyz_1 = neurograph.to_img(branch_1[d1], shift=True) for d2 in range(min(depth, len(branch_2) - 1)): - xyz_2 = neurograph.to_img(branch_2[d2]) + xyz_2 = neurograph.to_img(branch_2[d2], shift=True) line = make_line(xyz_1, xyz_2, 10) score = np.mean(get_profile(img, line, window=[3, 3, 3])) if score > best_score: @@ -481,6 +503,5 @@ def make_line(xyz_1, xyz_2, n_steps): return np.array([(1 - t) * xyz_1 + t * xyz_2 for t in t_steps], dtype=int) -def normalize(x, norm="l2"): - zero_vec = np.zeros((3)) - return x / abs(dist(zero_vec, x, metric=norm)) +def normalize(vec, norm="l2"): + return vec / abs(dist(np.zeros((3)), vec, metric=norm)) diff --git a/src/deep_neurographs/graph_utils.py b/src/deep_neurographs/graph_utils.py index c0877ec..e093f2f 100644 --- a/src/deep_neurographs/graph_utils.py +++ b/src/deep_neurographs/graph_utils.py @@ -316,7 +316,7 @@ def init_edge_attrs(swc_dict, i): Edge attribute dictionary. """ - j = swc_dict["idx"][i] + j = swc_dict["idx"][i] return {"radius": [swc_dict["radius"][j]], "xyz": [swc_dict["xyz"][j]]} @@ -339,7 +339,7 @@ def upd_edge_attrs(swc_dict, attrs, i): Edge attribute dictionary. """ - j = swc_dict["idx"][i] + j = swc_dict["idx"][i] attrs["radius"].append(swc_dict["radius"][j]) attrs["xyz"].append(swc_dict["xyz"][j]) return attrs @@ -406,7 +406,7 @@ def set_node_attrs(swc_dict, nodes): """ attrs = dict() for i in nodes: - j = swc_dict["idx"][i] + j = swc_dict["idx"][i] attrs[i] = {"radius": swc_dict["radius"][j], "xyz": swc_dict["xyz"][j]} return attrs @@ -438,10 +438,38 @@ def upd_node_attrs(swc_dict, leafs, junctions, i): Updated dictionary if "i" was contained in "junctions.keys()". """ - j = swc_dict["idx"][i] + j = swc_dict["idx"][i] upd_attrs = {"radius": swc_dict["radius"][j], "xyz": swc_dict["xyz"][j]} if i in leafs: leafs[i] = upd_attrs else: junctions[i] = upd_attrs return leafs, junctions + + +# -- miscellaneous -- +def creates_cycle(graph, edge): + """ + Checks whether adding "edge" to "graph" creates a cycle. + + Paramaters + ---------- + graph : networkx.Graph + Graph to be checked for cycles. + edge : tuple + Edge to be added to "graph" + + Returns + ------- + bool + Indication of whether adding "edge" to graph creates a cycle. + + """ + graph.add_edges_from([edge]) + try: + nx.find_cycle(graph) + graph.remove_edges_from([edge]) + return True + except: + graph.remove_edges_from([edge]) + return False diff --git a/src/deep_neurographs/intake.py b/src/deep_neurographs/intake.py index 792b479..a4286e4 100644 --- a/src/deep_neurographs/intake.py +++ b/src/deep_neurographs/intake.py @@ -48,14 +48,14 @@ def build_neurograph_from_local( ): # Process swc files assert swc_dir or swc_paths, "Provide swc_dir or swc_paths!" - bbox = utils.get_bbox(img_patch_origin, img_patch_shape) + img_bbox = utils.get_img_bbox(img_patch_origin, img_patch_shape) paths = get_paths(swc_dir) if swc_dir else swc_paths - swc_dicts = process_local_paths(paths, min_size, bbox=bbox) + swc_dicts = process_local_paths(paths, min_size, img_bbox=img_bbox) # Build neurograph neurograph = build_neurograph( swc_dicts, - bbox=bbox, + img_bbox=img_bbox, img_path=img_path, swc_paths=paths, progress_bar=progress_bar, @@ -207,8 +207,6 @@ def download_gcs_zips(bucket_name, cloud_path, min_size): cnt, t1 = report_progress( i, len(zip_paths), chunk_size, cnt, t0, t1 ) - if len(swc_dicts) > 2000: - stop return swc_dicts @@ -223,7 +221,7 @@ def count_files_in_zips(bucket, zip_paths): # -- Build neurograph --- def build_neurograph( swc_dicts, - bbox=None, + img_bbox=None, img_path=None, swc_paths=None, progress_bar=True, @@ -249,13 +247,16 @@ def build_neurograph( print("(2) Combine irreducibles...") print("# nodes:", utils.reformat_number(n_nodes)) print("# edges:", utils.reformat_number(n_edges)) - neurograph = NeuroGraph(bbox=bbox, img_path=img_path, swc_paths=swc_paths) + + neurograph = NeuroGraph( + img_bbox=img_bbox, img_path=img_path, swc_paths=swc_paths + ) t0, t1 = utils.init_timers() chunk_size = max(int(n_components * 0.05), 1) cnt, i = 1, 0 while len(irreducibles): key, irreducible_set = irreducibles.popitem() - neurograph.add_immutables(irreducible_set, key) + neurograph.add_component(irreducible_set, key) if i > cnt * chunk_size and progress_bar: cnt, t1 = report_progress(i, n_components, chunk_size, cnt, t0, t1) i += 1 diff --git a/src/deep_neurographs/neurograph.py b/src/deep_neurographs/neurograph.py index 8774c14..2d54722 100644 --- a/src/deep_neurographs/neurograph.py +++ b/src/deep_neurographs/neurograph.py @@ -28,13 +28,12 @@ class NeuroGraph(nx.Graph): """ A class of graphs whose nodes correspond to irreducible nodes from the - predicted swc files. This type of graph has two sets of edges referred - to as "mutable" and "immutable". + predicted swc files. """ def __init__( - self, bbox=None, swc_paths=None, img_path=None, label_mask=None + self, img_bbox=None, swc_paths=None, img_path=None, label_mask=None ): super(NeuroGraph, self).__init__() # Initialize paths @@ -45,7 +44,6 @@ def __init__( # Initialize node and edge sets self.leafs = set() self.junctions = set() - self.immutable_edges = set() self.mutable_edges = set() self.target_edges = set() @@ -56,33 +54,38 @@ def __init__( self.kdtree = None # Initialize bounding box (if exists) - self.bbox = bbox + self.bbox = img_bbox if self.bbox: - self.origin = bbox["min"].astype(int) - self.shape = (bbox["max"] - bbox["min"]).astype(int) + self.origin = img_bbox["min"].astype(int) + self.shape = (img_bbox["max"] - img_bbox["min"]).astype(int) else: self.origin = np.array([0, 0, 0], dtype=int) self.shape = None - def init_immutable_graph(self, add_attrs=False): - immutable_graph = nx.Graph() - immutable_graph.add_nodes_from(self.nodes(data=add_attrs)) + def copy_graph(self, add_attrs=False): + graph = nx.Graph() + graph.add_nodes_from(self.nodes(data=add_attrs)) if add_attrs: - for edge in self.immutable_edges: + for edge in self.get_edges_temp(): i, j = tuple(edge) - immutable_graph.add_edge(i, j, **self.get_edge_data(i, j)) + graph.add_edge(i, j, **self.get_edge_data(i, j)) else: - immutable_graph.add_edges_from(self.immutable_edges) - return immutable_graph + graph.add_edges_from(self.get_edges_temp()) + return graph - def init_predicted_graph(self): - self.predicted_graph = self.init_immutable_graph() + def get_edges_temp(self): + edges = [] + for edge in self.edges: + edge = frozenset(edge) + if edge not in self.mutable_edges: + edges.append(edge) + return edges def init_densegraph(self): self.densegraph = DenseGraph(self.swc_paths) # --- Add nodes or edges --- - def add_immutables(self, irreducibles, swc_id): + def add_component(self, irreducibles, swc_id): # Nodes node_ids = dict() cur_id = len(self.nodes) + 1 @@ -96,7 +99,6 @@ def add_immutables(self, irreducibles, swc_id): # Add edges for edge, values in irreducibles["edges"].items(): i, j = edge - self.immutable_edges.add(frozenset((node_ids[i], node_ids[j]))) self.add_edge( node_ids[i], node_ids[j], @@ -166,10 +168,10 @@ def generate_proposals( node = j else: idxs = np.where(np.all(attrs["xyz"] == xyz, axis=1))[0] - node = self.add_immutable_node((i, j), attrs, idxs[0]) + node = self.split_edge((i, j), attrs, idxs[0]) # Add edge - #self.add_edge(leaf, node, xyz=np.array([xyz_leaf, xyz])) + self.add_edge(leaf, node, xyz=np.array([xyz_leaf, xyz])) self.mutable_edges.add(frozenset((leaf, node))) # Check whether to optimization proposals @@ -234,23 +236,43 @@ def _get_best_edges(self, dists, xyz, n_proposals_per_leaf): else: return list(xyz.values()) - def add_immutable_node(self, edge, attrs, idx): + def split_edge(self, edge, attrs, idx): + """ + Splits "edge" into two distinct edges by making the subnode at "idx" a + new node in "self". + + Parameters + ---------- + edge : tuple + Edge to be split. + attrs : dict + Attributes of "edge". + idx : int + Index of subnode that will become a new node in "self". + + Returns + ------- + new_node : int + Node ID of node that was created. + + """ # Remove old edge (i, j) = edge self.remove_edge(i, j) - self.immutable_edges.remove(frozenset(edge)) # Add new node and split edge - node_id = len(self.nodes) + 1 + new_node = len(self.nodes) + 1 self.add_node( - node_id, + new_node, xyz=tuple(attrs["xyz"][idx]), radius=attrs["radius"][idx], swc_id=attrs["swc_id"], ) - self.__add_edge((i, node_id), attrs, np.arange(0, idx + 1)) - self.__add_edge((node_id, j), attrs, np.arange(idx, len(attrs["xyz"]))) - return node_id + self.__add_edge((i, new_node), attrs, np.arange(0, idx + 1)) + self.__add_edge( + (new_node, j), attrs, np.arange(idx, len(attrs["xyz"])) + ) + return new_node def __add_edge(self, edge, attrs, idxs): self.add_edge( @@ -262,7 +284,6 @@ def __add_edge(self, edge, attrs, idxs): ) for xyz in attrs["xyz"][idxs]: self.xyz_to_edge[tuple(xyz)] = edge - self.immutable_edges.add(frozenset(edge)) def init_kdtree(self): """ @@ -302,9 +323,9 @@ def _query_kdtree(self, query, d): # --- Optimize Proposals --- def run_optimization(self): - origin = utils.apply_anisotropy(self.origin, return_int=True) + driver = "n5" if ".n5" in self.img_path else "zarr" img = utils.get_superchunk( - self.img_path, "zarr", origin, self.shape, from_center=False + self.img_path, driver, self.origin, self.shape, from_center=False ) for edge in self.mutable_edges: xyz_1, xyz_2 = geometry.optimize_alignment(self, img, edge) @@ -312,7 +333,7 @@ def run_optimization(self): def get_branch(self, xyz_or_node, key="xyz"): if type(xyz_or_node) == int: - nb = self.get_immutable_nbs(xyz_or_node)[0] + nb = self.get_nbs_temp(xyz_or_node)[0] return self.orient_edge((xyz_or_node, nb), xyz_or_node, key=key) else: edge = self.xyz_to_edge[tuple(xyz_or_node)] @@ -324,7 +345,7 @@ def get_branch(self, xyz_or_node, key="xyz"): def get_branches(self, i, key="xyz"): branches = [] - for j in self.get_immutable_nbs(i): + for j in self.get_nbs_temp(i): branches.append(self.orient_edge((i, j), i, key=key)) return branches @@ -339,10 +360,10 @@ def init_targets(self, target_neurograph): # Initializations msg = "Provide swc_dir/swc_paths to initialize target edges!" assert target_neurograph.swc_paths, msg + pred_graph = self.copy_graph() target_neurograph.init_densegraph() target_neurograph.init_kdtree() self.target_edges = set() - self.init_predicted_graph() # Add best simple edges remaining_proposals = [] @@ -352,9 +373,15 @@ def init_targets(self, target_neurograph): edge = proposals[idx] if self.is_simple(edge): add_bool = self.is_target( - target_neurograph, edge, dist=7, ratio=0.7, exclude=10 + target_neurograph, + pred_graph, + edge, + dist=7, + ratio=0.7, + exclude=10, ) if add_bool: + pred_graph.add_edges_from([edge]) self.target_edges.add(edge) continue remaining_proposals.append(edge) @@ -364,17 +391,17 @@ def init_targets(self, target_neurograph): for idx in np.argsort(dists): edge = remaining_proposals[idx] add_bool = self.is_target( - target_neurograph, edge, dist=8, ratio=0.5, exclude=10 + target_neurograph, + pred_graph, + edge, + dist=8, + ratio=0.5, + exclude=10, ) if add_bool: + pred_graph.add_edges_from([edge]) self.target_edges.add(edge) - # Print results - # target_ratio = len(self.target_edges) / len(self.mutable_edges) - # print("# target edges:", len(self.target_edges)) - # print("% target edges in mutable:", target_ratio) - # print("") - def filter_infeasible(self, target_neurograph): proposals = list() for edge in self.mutable_edges: @@ -404,11 +431,10 @@ def is_feasible(self, xyz_1, xyz_2): return False def is_target( - self, target_graph, edge, dist=5, ratio=0.5, strict=True, exclude=10 + self, target_graph, pred_graph, edge, dist=5, ratio=0.5, exclude=10 ): # Check for cycle - i, j = tuple(edge) - if self.creates_cycle((i, j)): + if gutils.creates_cycle(pred_graph, tuple(edge)): return False # Check projection distance @@ -427,7 +453,7 @@ def is_target( # --- Generate reconstructions post-inference def get_reconstruction(self, proposals, upd_self=False): - reconstruction = self.init_immutable_graph(add_attrs=True) + reconstruction = self.copy_graph(add_attrs=True) for edge in proposals: i, j = tuple(edge) r_i = self.nodes[i]["radius"] @@ -438,7 +464,7 @@ def get_reconstruction(self, proposals, upd_self=False): return reconstruction # --- Utils --- - def num_nodes(self): + def n_nodes(self): """ Computes number of nodes in the graph. @@ -454,7 +480,7 @@ def num_nodes(self): """ return self.number_of_nodes() - def num_edges(self): + def n_edges(self): """ Computes number of edges in the graph. @@ -470,23 +496,7 @@ def num_edges(self): """ return self.number_of_edges() - def num_immutables(self): - """ - Computes number of immutable edges in the graph. - - Parameters - ---------- - None - - Returns - ------- - int - Number of immutable edges in the graph. - - """ - return len(self.immutable_edges) - - def num_mutables(self): + def n_mutables(self): """ Computes number of mutable edges in the graph. @@ -502,28 +512,32 @@ def num_mutables(self): """ return len(self.mutable_edges) - def immutable_degree(self, i): - degree = 0 - for j in self.neighbors(i): - if frozenset((i, j)) in self.immutable_edges: - degree += 1 - return degree + def get_degree_temp(self, i): + return len(self.get_nbs_temp(i)) - def get_immutable_nbs(self, i): + def get_nbs_temp(self, i): nbs = [] for j in self.neighbors(i): - if frozenset((i, j)) in self.immutable_edges: + if frozenset((i, j)) not in self.mutable_edges: nbs.append(j) return nbs + def get_immutables_temp(self): + return [ + edge + for edge in self.edges + if frozenset(edge) not in self.mutable_edges + ] + def compute_length(self, edge, metric="l2"): xyz_1, xyz_2 = self.get_edge_attr(edge, "xyz") return get_dist(xyz_1, xyz_2, metric=metric) def path_length(self, metric="l2"): length = 0 - for edge in self.immutable_edges: - length += self.compute_length(edge, metric=metric) + for edge in self.edges: + if edge not in self.mutable_edges: + length += self.compute_length(edge, metric=metric) return length def get_projection(self, xyz): @@ -537,23 +551,21 @@ def is_nb(self, i, j): def is_contained(self, node_or_xyz, buffer=0): if self.bbox: - if type(node_or_xyz) == int: - node_or_xyz = deepcopy(self.nodes[node_or_xyz]["xyz"]) - return utils.is_contained(self.bbox, node_or_xyz, buffer=buffer) + img_coord = self.to_img(node_or_xyz) + return utils.is_contained(self.bbox, img_coord, buffer=buffer) else: return True - def is_leaf(self, i): - return True if self.immutable_degree(i) == 1 else False + def to_img(self, node_or_xyz, shift=False): + shift = self.origin if shift else np.zeros((3)) + if type(node_or_xyz) == int: + img_coord = utils.to_img(self.nodes[node_or_xyz]["xyz"]) + else: + img_coord = utils.to_img(node_or_xyz) + return img_coord - shift - def creates_cycle(self, edge): - self.predicted_graph.add_edges_from([edge]) - try: - nx.find_cycle(self.predicted_graph) - except: - return False - self.predicted_graph.remove_edges_from([edge]) - return True + def is_leaf(self, i): + return True if self.get_degree_temp(i) == 1 else False def get_edge_attr(self, edge, key): xyz_arr = gutils.get_edge_attr(self, edge, key) @@ -569,20 +581,11 @@ def is_simple(self, edge): i, j = tuple(edge) return True if self.is_leaf(i) and self.is_leaf(j) else False - def to_img(self, node_or_xyz): - if type(node_or_xyz) == int: - node_or_xyz = deepcopy(self.nodes[node_or_xyz]["xyz"]) - return utils.to_img(node_or_xyz, shift=self.origin) - - def to_world(self, node_or_xyz, shift=[0, 0, 0]): - if type(node_or_xyz) == int: - node_or_xyz = deepcopy(self.nodes[node_or_xyz]["xyz"]) - return utils.to_world(node_or_xyz, shift=-self.origin) - def to_patch_coords(self, edge, midpoint, chunk_size): patch_coords = [] for xyz in self.edges[edge]["xyz"]: - coord = utils.img_to_patch(self.to_img(xyz), midpoint, chunk_size) + img_coord = self.to_img(xyz) + coord = utils.img_to_patch(img_coord, midpoint, chunk_size) patch_coords.append(coord) return patch_coords diff --git a/src/deep_neurographs/structural_inference.py b/src/deep_neurographs/structural_inference.py index 3ad529e..be6e51d 100644 --- a/src/deep_neurographs/structural_inference.py +++ b/src/deep_neurographs/structural_inference.py @@ -9,10 +9,11 @@ """ import numpy as np +from deep_neurographs import graph_utils as gutils def get_reconstructions( - pred_neurographs, + neurographs, blocks, block_to_idxs, idx_to_edge, @@ -23,7 +24,7 @@ def get_reconstructions( ): edge_preds = dict() for block_id in blocks: - # Get positive edge predictions + # Get positive predictions edge_probs = get_edge_probs( idx_to_edge, y_pred, @@ -34,7 +35,7 @@ def get_reconstructions( # Refine predictions wrt structure if structure_aware: edge_preds[block_id] = get_structure_aware_prediction( - pred_neurographs[block_id], + neurographs[block_id], edge_probs, high_threshold=high_threshold, low_threshold=low_threshold, @@ -45,7 +46,7 @@ def get_reconstructions( def get_reconstruction( - pred_neurograph, + neurograph, y_pred, idx_to_edge, high_threshold=0.9, @@ -56,7 +57,7 @@ def get_reconstruction( edge_probs = get_edge_probs(idx_to_edge, y_pred, low_threshold) if structure_aware: return get_structure_aware_prediction( - pred_neurograph, + neurograph, edge_probs, high_threshold=high_threshold, low_threshold=low_threshold, @@ -75,29 +76,27 @@ def get_edge_probs(idx_to_edge, y_pred, threshold, valid_idxs=[]): def get_structure_aware_prediction( - pred_neurograph, edge_probs, high_threshold=0.8, low_threshold=0.6 + neurograph, probs, high_threshold=0.8, low_threshold=0.6 ): # Initializations - edge_preds = list(edge_probs.keys()) - pred_neurograph.init_predicted_graph() + proposals = list(probs.keys()) + pred_graph = neurograph.copy_graph() # Add best simple edges - remaining_edge_preds = [] - viable_edge_preds = [] - dists = [pred_neurograph.compute_length(edge) for edge in edge_preds] + positive_predictions = [] + remaining_proposals = [] + dists = [neurograph.compute_length(edge) for edge in proposals] for idx in np.argsort(dists): - edge = edge_preds[idx] - if ( - pred_neurograph.is_simple(edge) - and edge_probs[edge] > high_threshold - ): - if not pred_neurograph.creates_cycle(tuple(edge)): - viable_edge_preds.append(edge) + edge = proposals[idx] + if neurograph.is_simple(edge) and probs[edge] > high_threshold: + if not gutils.creates_cycle(pred_graph, tuple(edge)): + pred_graph.add_edges_from([edge]) + positive_predictions.append(edge) else: - remaining_edge_preds.append(edge) + remaining_proposals.append(edge) - # Add remaining valid edges - for edge in remaining_edge_preds: - if not pred_neurograph.creates_cycle(tuple(edge)): - viable_edge_preds.append(edge) - return viable_edge_preds + # Add remaining viable edges + for edge in remaining_proposals: + if not gutils.creates_cycle(pred_graph, tuple(edge)): + positive_predictions.append(edge) + return positive_predictions diff --git a/src/deep_neurographs/swc_utils.py b/src/deep_neurographs/swc_utils.py index 3cd6c51..807b2ed 100644 --- a/src/deep_neurographs/swc_utils.py +++ b/src/deep_neurographs/swc_utils.py @@ -22,10 +22,12 @@ # -- io utils -- -def process_local_paths(paths, min_size, bbox=None): +def process_local_paths(paths, min_size, img_bbox=None): swc_dicts = dict() for path in paths: - swc_id, swc_dict = parse_local_swc(path, bbox=bbox, min_size=min_size) + swc_id, swc_dict = parse_local_swc( + path, img_bbox=img_bbox, min_size=min_size + ) if len(swc_dict["id"]) > min_size: swc_dicts[swc_id] = swc_dict return swc_dicts @@ -48,11 +50,13 @@ def process_gsc_zip(bucket, zip_path, min_size=0): return swc_dicts -def parse_local_swc(path, bbox=None, min_size=0): +def parse_local_swc(path, img_bbox=None, min_size=0): contents = read_from_local(path) parse_bool = len(contents) > min_size - if parse_bool: - swc_dict = parse(contents, bbox=bbox) if bbox else fast_parse(contents) + if parse_bool and img_bbox: + swc_dict = parse(contents, img_bbox) + elif parse_bool: + swc_dict = fast_parse(contents) else: swc_dict = {"id": [-1]} return utils.get_swc_id(path), swc_dict @@ -65,7 +69,7 @@ def parse_gcs_zip(zip_file, path, min_size=0): return utils.get_swc_id(path), swc_dict -def parse(contents, bbox=None): +def parse(contents, img_bbox): """ Parses an swc file to extract the contents which is stored in a dict. Note that node_ids from swc are refactored to index from 0 to n-1 where n is @@ -88,8 +92,9 @@ def parse(contents, bbox=None): for line in contents: parts = line.split() xyz = read_xyz(parts[2:5], offset=offset) - if bbox: - if not utils.is_contained(bbox, xyz): + if img_bbox: + img_coord = utils.to_img(np.array(xyz)) + if not utils.is_contained(img_bbox, img_coord, buffer=8): break swc_dict["id"].append(int(parts[0])) swc_dict["radius"].append(float(parts[-2])) @@ -98,10 +103,12 @@ def parse(contents, bbox=None): if swc_dict["id"][-1] < min_id: min_id = swc_dict["id"][-1] - # Reindex from zero - for i in range(len(swc_dict["id"])): - swc_dict["id"][i] -= min_id - swc_dict["pid"][i] -= min_id + # Reindex from zero and reformat + if len(swc_dict["id"]) > 0: + swc_dict["id"] = np.array(swc_dict["id"], dtype=int) - min_id + swc_dict["pid"] = np.array(swc_dict["pid"], dtype=int) - min_id + swc_dict["radius"] = np.array(swc_dict["radius"]) + swc_dict["xyz"] = np.array(swc_dict["xyz"]) return swc_dict if len(swc_dict["id"]) > 1 else {"id": [-1]} @@ -148,7 +155,7 @@ def fast_parse(contents): def reindex(arr, idxs): return arr[idxs] - + def get_contents(swc_contents): offset = [0, 0, 0] @@ -340,7 +347,8 @@ def __add_attributes(swc_dict, graph): attrs = dict() for idx, node_id in enumerate(swc_dict["id"]): attrs[node_id] = { - "xyz": swc_dict["xyz"][idx], "radius": swc_dict["radius"][idx] + "xyz": swc_dict["xyz"][idx], + "radius": swc_dict["radius"][idx], } nx.set_node_attributes(graph, attrs) return graph diff --git a/src/deep_neurographs/utils.py b/src/deep_neurographs/utils.py index c9d1d8d..d57ac6d 100644 --- a/src/deep_neurographs/utils.py +++ b/src/deep_neurographs/utils.py @@ -283,7 +283,7 @@ def open_tensorstore(path, driver): """ assert driver in SUPPORTED_DRIVERS, "Error! Driver is not supported!" - ts_arr = ts.open( + arr = ts.open( { "driver": driver, "kvstore": { @@ -300,29 +300,29 @@ def open_tensorstore(path, driver): } ).result() if driver == "neuroglancer_precomputed": - return ts_arr[ts.d["channel"][0]] - elif driver == "zarr": - ts_arr = ts_arr[0, 0, :, :, :] - ts_arr = ts_arr[ts.d[0].transpose[2]] - ts_arr = ts_arr[ts.d[0].transpose[1]] - return ts_arr + return arr[ts.d["channel"][0]] + return arr +""" def read_img_chunk(img, xyz, shape): start, end = get_start_end(xyz, shape, from_center=from_center) return img[ start[2]: end[2], start[1]: end[1], start[0]: end[0] ].transpose(2, 1, 0) +""" def get_chunk(arr, xyz, shape, from_center=True): start, end = get_start_end(xyz, shape, from_center=from_center) - return deepcopy(arr[start[0]: end[0], start[1]: end[1], start[2]: end[2]]) + return deepcopy( + arr[start[0]: end[0], start[1]: end[1], start[2]: end[2]] + ) def read_tensorstore(arr, xyz, shape, from_center=True): chunk = get_chunk(arr, xyz, shape, from_center=from_center) - return np.swapaxes(chunk.read().result(), 0, 2) + return chunk.read().result() def get_start_end(xyz, shape, from_center=True): @@ -335,7 +335,7 @@ def get_start_end(xyz, shape, from_center=True): return start, end -def get_superchunks(img_path, label_path, xyz, shape, from_center=True): +def get_superchunks(img_path, labels_path, xyz, shape, from_center=True): with concurrent.futures.ThreadPoolExecutor() as executor: img_job = executor.submit( get_superchunk, @@ -345,26 +345,23 @@ def get_superchunks(img_path, label_path, xyz, shape, from_center=True): shape, from_center=from_center, ) - label_job = executor.submit( + labels_job = executor.submit( get_superchunk, - label_path, + labels_path, "neuroglancer_precomputed", xyz, shape, from_center=from_center, ) img = img_job.result().astype(np.int16) - label = label_job.result().astype(np.int64) - return img, label + labels = labels_job.result().astype(np.int64) + assert img.shape == labels.shape, "img.shape != labels.shape" + return img, labels def get_superchunk(path, driver, xyz, shape, from_center=True): - ts_arr = open_tensorstore(path, driver) - if from_center: - return read_tensorstore(ts_arr, xyz, shape) - else: - xyz = [xyz[i] + shape[i] // 2 for i in range(3)] - return read_tensorstore(ts_arr, xyz, shape) + arr = open_tensorstore(path, driver) + return read_tensorstore(arr, xyz, shape, from_center=from_center) def read_json(path): @@ -431,13 +428,8 @@ def write_txt(path, contents): f.write(contents) f.close() -# --- coordinate conversions --- -def world_to_img(neurograph, node_or_xyz): - if type(node_or_xyz) == int: - node_or_xyz = deepcopy(neurograph.nodes[node_or_xyz]["xyz"]) - return to_img(node_or_xyz, shift=neurograph.origin) - +# --- coordinate conversions --- def img_to_patch(xyz, patch_centroid, patch_dims): half_patch_dims = [patch_dims[i] // 2 for i in range(3)] return np.round(xyz - patch_centroid + half_patch_dims).astype(int) @@ -452,13 +444,13 @@ def to_world(xyz, shift=[0, 0, 0]): return tuple([xyz[i] * ANISOTROPY[i] - shift[i] for i in range(3)]) -def to_img(xyz, shift=np.array([0, 0, 0])): - return apply_anisotropy(xyz - shift, return_int=True) +def to_img(xyz): + return (xyz / ANISOTROPY).astype(int) def apply_anisotropy(xyz, return_int=False): if return_int: - return (xyz / ANISOTROPY).astype(int) + return else: return xyz / ANISOTROPY @@ -470,24 +462,21 @@ def get_avg_std(data, weights=None): return avg, math.sqrt(var) -def is_contained(bbox, xyz, buffer=0): - xyz = apply_anisotropy(xyz - bbox["min"]) - shape = bbox["max"] - bbox["min"] - if any(xyz < buffer) or any(xyz >= shape - buffer): - return False - else: - return True +def is_contained(bbox, xyz, buffer=5): + above = any(xyz > bbox["max"] - buffer) + below = any(xyz < bbox["min"] + buffer) + return False if above or below else True # --- miscellaneous --- -def get_bbox(origin, shape): +def get_img_bbox(origin, shape): """ Origin is assumed to be top, front, left corner. """ if origin and shape: origin = np.array(origin) - shape = np.array(shape) + shape = np.array(shape) # for i in [2, 1, 0]]) return {"min": origin, "max": origin + shape} else: return None diff --git a/src/deep_neurographs/visualization.py b/src/deep_neurographs/visualization.py index f28c3b9..ed3febe 100644 --- a/src/deep_neurographs/visualization.py +++ b/src/deep_neurographs/visualization.py @@ -45,7 +45,7 @@ def visualize_connected_components( def visualize_immutables(graph, title="Initial Segmentation"): - data = plot_edges(graph, graph.immutable_edges) + data = plot_edges(graph, graph.get_immutables_temp()) data.append(plot_nodes(graph)) plot(data, title) @@ -61,11 +61,11 @@ def visualize_targets(graph, target_graph=None, title="Target Edges"): def visualize_subset(graph, edges, line_width=5, target_graph=None, title=""): - data = plot_edges(graph, graph.immutable_edges, color="black") + data = plot_edges(graph, graph.get_immutables_temp(), color="black") data.extend(plot_edges(graph, edges, line_width=line_width)) data.append(plot_nodes(graph)) if target_graph: - edges = target_graph.immutable_edges + edges = target_graph.get_immutables_temp() data.extend(plot_edges(target_graph, edges, color="blue")) plot(data, title)