From 2a2a8a8ac29969c431f594bf188d46ccf5f0b28d Mon Sep 17 00:00:00 2001 From: Anna Grim <108307071+anna-grim@users.noreply.github.com> Date: Mon, 7 Oct 2024 14:57:47 -0700 Subject: [PATCH 1/6] refactor: extract fixed image patch in features (#261) Co-authored-by: anna-grim --- src/deep_neurographs/generate_proposals.py | 9 +++++---- src/deep_neurographs/groundtruth_generation.py | 5 ++--- src/deep_neurographs/inference.py | 11 +++++++++-- .../machine_learning/feature_generation.py | 4 ++-- .../feature_generation_graphs.py | 2 +- src/deep_neurographs/neurograph.py | 6 ++---- src/deep_neurographs/utils/graph_util.py | 2 +- src/deep_neurographs/utils/img_util.py | 18 +++++++++++++----- 8 files changed, 35 insertions(+), 22 deletions(-) diff --git a/src/deep_neurographs/generate_proposals.py b/src/deep_neurographs/generate_proposals.py index f845bb9..dac1465 100644 --- a/src/deep_neurographs/generate_proposals.py +++ b/src/deep_neurographs/generate_proposals.py @@ -109,8 +109,8 @@ def run( if trim_endpoints_bool: radius /= RADIUS_SCALING_FACTOR long_range, in_range = separate_proposals(neurograph, radius) - neurograph = run_trimming(neurograph, long_range, radius) - neurograph = run_trimming(neurograph, in_range, radius) + neurograph = run_trimming(neurograph, long_range, radius, progress_bar) + neurograph = run_trimming(neurograph, in_range, radius, progress_bar) def init_kdtree(neurograph, complex_bool): @@ -297,7 +297,7 @@ def separate_proposals(neurograph, radius): # --- Trim Endpoints --- -def run_trimming(neurograph, proposals, radius): +def run_trimming(neurograph, proposals, radius, progress_bar): n_endpoints_trimmed = 0 long_radius = radius * RADIUS_SCALING_FACTOR for proposal in deepcopy(proposals): @@ -312,7 +312,8 @@ def run_trimming(neurograph, proposals, radius): elif neurograph.dist(i, j) > radius: neurograph.remove_proposal(proposal) n_endpoints_trimmed += 1 if trim_bool else 0 - print("# Endpoints Trimmed:", n_endpoints_trimmed) + if progress_bar: + print("# Endpoints Trimmed:", n_endpoints_trimmed) return neurograph diff --git a/src/deep_neurographs/groundtruth_generation.py b/src/deep_neurographs/groundtruth_generation.py index dbba151..a469ad5 100644 --- a/src/deep_neurographs/groundtruth_generation.py +++ b/src/deep_neurographs/groundtruth_generation.py @@ -16,7 +16,6 @@ from deep_neurographs import geometry from deep_neurographs.geometry import dist as get_dist -from deep_neurographs.utils import graph_util as gutil from deep_neurographs.utils import util ALIGNED_THRESHOLD = 4 @@ -132,8 +131,8 @@ def is_component_aligned(target_graph, pred_graph, nodes, kdtree): Returns ------- bool - Indication of whether connected component "nodes" is aligned to a connected - component in "target_graph". + Indication of whether connected component "nodes" is aligned to a + connected component in "target_graph". """ # Compute distances diff --git a/src/deep_neurographs/inference.py b/src/deep_neurographs/inference.py index 9b2dcfb..e9b31bf 100644 --- a/src/deep_neurographs/inference.py +++ b/src/deep_neurographs/inference.py @@ -152,7 +152,9 @@ def run(self, fragments_pointer): t, unit = util.time_writer(time() - t0) print(f"Total Runtime: {round(t, 4)} {unit}\n") - def run_schedule(self, fragments_pointer, search_radius_schedule): + def run_schedule( + self, fragments_pointer, search_radius_schedule, save_all_rounds=False + ): t0 = time() self.report_experiment() self.build_graph(fragments_pointer) @@ -161,7 +163,12 @@ def run_schedule(self, fragments_pointer, search_radius_schedule): round_id += 1 self.generate_proposals(search_radius) self.run_inference() + if save_all_rounds: + self.save_results(round_id=round_id) + + if not save_all_rounds: self.save_results(round_id=round_id) + t, unit = util.time_writer(time() - t0) print(f"Total Runtime: {round(t, 4)} {unit}\n") @@ -263,7 +270,7 @@ def run_inference(self): ) self.accepted_proposals.extend(accepts) print("# Accepted:", util.reformat_number(len(accepts))) - print("% Accepted:", len(accepts) / n_proposals) + print("% Accepted:", round(len(accepts) / n_proposals, 4)) t, unit = util.time_writer(time() - t0) print(f"Module Runtime: {round(t, 4)} {unit}\n") diff --git a/src/deep_neurographs/machine_learning/feature_generation.py b/src/deep_neurographs/machine_learning/feature_generation.py index 689188a..13f2de2 100644 --- a/src/deep_neurographs/machine_learning/feature_generation.py +++ b/src/deep_neurographs/machine_learning/feature_generation.py @@ -28,7 +28,7 @@ ) from deep_neurographs.utils import img_util, util -CHUNK_SIZE = [64, 64, 64] +CHUNK_SIZE = [48, 48, 48] N_BRANCH_PTS = 50 N_PROFILE_PTS = 16 # 10 N_SKEL_FEATURES = 22 @@ -192,7 +192,7 @@ def get_profile_specs(xyz_1, xyz_2, downsample_factor): voxel_2 = img_util.to_voxels(xyz_2, downsample_factor=downsample_factor) # Store local coordinates - bbox = img_util.get_minimal_bbox(np.vstack([voxel_1, voxel_2]), buffer=1) + bbox = img_util.get_fixed_bbox(np.vstack([voxel_1, voxel_2]), CHUNK_SIZE) start = [voxel_1[i] - bbox["min"][i] for i in range(3)] end = [voxel_2[i] - bbox["min"][i] for i in range(3)] specs = { diff --git a/src/deep_neurographs/machine_learning/feature_generation_graphs.py b/src/deep_neurographs/machine_learning/feature_generation_graphs.py index 4fa7b6e..e750825 100644 --- a/src/deep_neurographs/machine_learning/feature_generation_graphs.py +++ b/src/deep_neurographs/machine_learning/feature_generation_graphs.py @@ -230,10 +230,10 @@ def proposal_skeletal(neurograph, proposals, radius): neurograph.proposal_length(proposal) / radius, neurograph.n_nearby_leafs(proposal, radius), neurograph.proposal_radii(proposal), - neurograph.proposal_directionals(proposal, 8), neurograph.proposal_directionals(proposal, 16), neurograph.proposal_directionals(proposal, 32), neurograph.proposal_directionals(proposal, 64), + neurograph.proposal_directionals(proposal, 128), ), axis=None, ) diff --git a/src/deep_neurographs/neurograph.py b/src/deep_neurographs/neurograph.py index 1e76f42..f28ade5 100644 --- a/src/deep_neurographs/neurograph.py +++ b/src/deep_neurographs/neurograph.py @@ -660,7 +660,7 @@ def proposal_directionals(self, proposal, window): def merge_proposal(self, proposal): i, j = tuple(proposal) - somas_check = not (self.is_soma(i) and self.is_soma(j)) + somas_check = not (self.is_soma(i) and self.is_soma(j)) if somas_check and self.check_proposal_degrees(i, j): # Dense attributes attrs = dict() @@ -668,8 +668,6 @@ def merge_proposal(self, proposal): self.nodes[j]["radius"] = 7.0 for k in ["xyz", "radius"]: combine = np.vstack if k == "xyz" else np.array - self.nodes[i][k][-1] = 8.0 - self.nodes[j][k][0] = 8.0 attrs[k] = combine([self.nodes[i][k], self.nodes[j][k]]) # Sparse attributes @@ -701,7 +699,7 @@ def check_proposal_degrees(self, i, j): one_leaf = self.degree[i] == 1 or self.degree[j] == 1 branching = self.degree[i] > 2 or self.degree[j] > 2 return one_leaf and not branching - + def upd_ids(self, swc_id, r): """ Updates the swc_id of all nodes connected to "r". diff --git a/src/deep_neurographs/utils/graph_util.py b/src/deep_neurographs/utils/graph_util.py index a8fbdcf..18101df 100644 --- a/src/deep_neurographs/utils/graph_util.py +++ b/src/deep_neurographs/utils/graph_util.py @@ -462,7 +462,7 @@ def prune_branch(graph, leaf, prune_depth): # Check whether to stop if np.sum(node_dists) > prune_depth: break - return branch[0:min(4, len(branch))] + return branch[0:min(5, len(branch))] def smooth_branch(swc_dict, attrs, edges, nbs, root, j): diff --git a/src/deep_neurographs/utils/img_util.py b/src/deep_neurographs/utils/img_util.py index da2bed2..41d596d 100644 --- a/src/deep_neurographs/utils/img_util.py +++ b/src/deep_neurographs/utils/img_util.py @@ -142,8 +142,7 @@ def read_tensorstore_with_bbox(img, bbox): try: shape = [bbox["max"][i] - bbox["min"][i] for i in range(3)] return read_tensorstore(img, bbox["min"], shape, from_center=False) - except Exception as e: - print(type(e), e) + except Exception: return np.zeros(shape) @@ -186,7 +185,7 @@ def read_intensities(img, voxels): Image intensities. """ - return [img[tuple(voxel)] for voxel in voxels] + return [img[voxel] for voxel in map(tuple, voxels)] def get_start_end(voxel, shape, from_center=True): @@ -405,7 +404,7 @@ def get_bbox(origin, shape): return None -def get_minimal_bbox(voxels, buffer=0): +def get_minimal_bbox(voxels): """ Gets the min and max coordinates of a bounding box that contains "voxels". @@ -425,7 +424,16 @@ def get_minimal_bbox(voxels, buffer=0): """ bbox = { "min": np.floor(np.min(voxels, axis=0) - 1).astype(int), - "max": np.ceil(np.max(voxels, axis=0) + buffer + 1).astype(int), + "max": np.ceil(np.max(voxels, axis=0) + 1).astype(int), + } + return bbox + + +def get_fixed_bbox(voxels, shape): + centroid = np.round(np.mean(voxels, axis=0)).astype(int) + bbox = { + "min": [centroid[i] - shape[i] // 2 for i in range(3)], + "max": [centroid[i] + shape[i] // 2 for i in range(3)], } return bbox From c7723405e40f0e3582c1cd23c2d49fe25dfa9ace Mon Sep 17 00:00:00 2001 From: Anna Grim <108307071+anna-grim@users.noreply.github.com> Date: Tue, 8 Oct 2024 13:10:10 -0700 Subject: [PATCH 2/6] refactor: feature generation (#262) Co-authored-by: anna-grim --- src/deep_neurographs/config.py | 1 + src/deep_neurographs/generate_proposals.py | 8 +- src/deep_neurographs/geometry.py | 353 +++----- src/deep_neurographs/inference.py | 52 +- .../machine_learning/archived/features.txt | 29 + .../machine_learning/feature_generation.py | 847 ++++++++++++------ .../feature_generation_graphs.py | 471 ---------- src/deep_neurographs/neurograph.py | 36 +- src/deep_neurographs/utils/graph_util.py | 6 +- src/deep_neurographs/utils/img_util.py | 69 +- src/deep_neurographs/utils/swc_util.py | 12 +- 11 files changed, 866 insertions(+), 1018 deletions(-) create mode 100644 src/deep_neurographs/machine_learning/archived/features.txt delete mode 100644 src/deep_neurographs/machine_learning/feature_generation_graphs.py diff --git a/src/deep_neurographs/config.py b/src/deep_neurographs/config.py index 1ae5a6f..f900247 100644 --- a/src/deep_neurographs/config.py +++ b/src/deep_neurographs/config.py @@ -97,6 +97,7 @@ class MLConfig: threshold: float = 0.6 model_type: str = "GraphNeuralNet" n_epochs: int = 1000 + use_img_embedding: bool = False validation_split: float = 0.15 weight_decay: float = 1e-3 diff --git a/src/deep_neurographs/generate_proposals.py b/src/deep_neurographs/generate_proposals.py index dac1465..9267b8a 100644 --- a/src/deep_neurographs/generate_proposals.py +++ b/src/deep_neurographs/generate_proposals.py @@ -319,8 +319,8 @@ def run_trimming(neurograph, proposals, radius, progress_bar): def trim_endpoints(neurograph, i, j, radius): # Initializations - branch_i = neurograph.get_branch(i) - branch_j = neurograph.get_branch(j) + branch_i = neurograph.branch(i) + branch_j = neurograph.branch(j) # Check both orderings idx_i, idx_j = trim_endpoints_ordered(branch_i, branch_j) @@ -394,8 +394,8 @@ def trim_to_idx(neurograph, i, idx): """ # Update node - branch_xyz = neurograph.get_branch(i, key="xyz") - branch_radii = neurograph.get_branch(i, key="radius") + branch_xyz = neurograph.branch(i, key="xyz") + branch_radii = neurograph.branch(i, key="radius") neurograph.nodes[i]["xyz"] = branch_xyz[idx] neurograph.nodes[i]["radius"] = branch_radii[idx] diff --git a/src/deep_neurographs/geometry.py b/src/deep_neurographs/geometry.py index c3e2a3a..b13bd16 100644 --- a/src/deep_neurographs/geometry.py +++ b/src/deep_neurographs/geometry.py @@ -7,16 +7,15 @@ """ import numpy as np -import tensorstore as ts from scipy.interpolate import UnivariateSpline from scipy.linalg import svd from scipy.spatial import distance -from deep_neurographs.utils import util +from deep_neurographs.utils import img_util -# Directional Vectors -def get_directional(branches, i, origin, depth): +# --- Directionals --- +def get_directional(branches, origin, depth): """ Computes the directional vector of a branch or bifurcation in a neurograph relative to a specified origin. @@ -25,8 +24,6 @@ def get_directional(branches, i, origin, depth): ---------- neurograph : Neurograph The neurograph object containing the branches. - i : int - The index of the branch or bifurcation in the neurograph. origin : numpy.ndarray The origin point xyz relative to which the directional vector is computed. @@ -41,41 +38,15 @@ def get_directional(branches, i, origin, depth): specified origin. """ - branches = shift_branches(branches, origin) + branches = [shift_path(b, origin) for b in branches] if len(branches) == 1: - return tangent(get_subarray(branches[0], depth)) + return tangent(truncate_path(branches[0], depth)) else: - branch_1 = get_subarray(branches[0], depth) - branch_2 = get_subarray(branches[1], depth) + branch_1 = truncate_path(branches[0], depth) + branch_2 = truncate_path(branches[1], depth) return tangent(np.concatenate((branch_1, branch_2))) -def get_subarray(arr, depth): - """ - Extracts a sub-array of a specified window size from a given input array. - - Parameters - ---------- - branch : numpy.ndarray - Array from which the sub-branch will be extracted. - depth : int - Size of the window in microns to extract from "arr". - - Returns - ------- - numpy.ndarray - A sub-array of the specified window size. If the input array is - smaller than the window size, the entire branch array is returned. - - """ - length = 0 - for i in range(1, arr.shape[0]): - length += dist(arr[i - 1], arr[i]) - if length > depth: - return arr[0:i, :] - return arr - - def compute_svd(xyz): """ Compute singular value decomposition (svd) of an NxD array where N is the @@ -130,47 +101,140 @@ def tangent(xyz_arr): return tangent_vec / np.linalg.norm(tangent_vec) -def normal(xyz): +def midpoint(xyz_1, xyz_2): """ - Computes the normal vector of a plane defined by an array of xyz - coordinates using Singular Value Decomposition (SVD). + Computes the midpoint between "xyz_1" and "xyz_2". Parameters ---------- - xyz : numpy.ndarray - An array of xyz coordinates that normal vector is to be computed of. + xyz_1 : numpy.ndarray + n-dimensional coordinate. + xyz_2 : numpy.ndarray + n-dimensional coordinate. Returns ------- numpy.ndarray - The normal vector of the array "xyz". + Midpoint of "xyz_1" and "xyz_2". """ - U, S, VT = compute_svd(xyz) - return VT[-1] / np.linalg.norm(VT[-1]) + return np.mean([xyz_1, xyz_2], axis=0) -def midpoint(xyz_1, xyz_2): +# --- Path utils --- +def sample_path(xyz_path, n_points): """ - Computes the midpoint between "xyz_1" and "xyz_2". + Uniformly samples points from a curve represented as an array. Parameters ---------- - xyz_1 : numpy.ndarray - n-dimensional coordinate. - xyz_2 : numpy.ndarray - n-dimensional coordinate. + xyz_arr : np.ndarray + xyz coordinates that form a continuous path. + n_points : int + Number of points to be sampled. Returns ------- numpy.ndarray - Midpoint of "xyz_1" and "xyz_2". + Resampled points along curve. """ - return np.mean([xyz_1, xyz_2], axis=0) + k = 1 if len(xyz_path) <= 3 else 3 + t = np.linspace(0, 1, n_points) + spline_x, spline_y, spline_z = fit_spline(xyz_path, k=k, s=0) + xyz_path = np.column_stack((spline_x(t), spline_y(t), spline_z(t))) + return xyz_path.astype(int) + + +def truncate_path(xyz_path, depth): + """ + Extracts a sub-path of a specified depth from a given input path. + + Parameters + ---------- + xyz_path : array-like + xyz coordinates that form a continuous path. + depth : int + Path length in microns to extract from input path. + + Returns + ------- + numpy.ndarray + Sub-path of a specified depth from a given input path. + + """ + length = 0 + for i in range(1, len(xyz_path)): + length += dist(xyz_path[i - 1], xyz_path[i]) + if length > depth: + return xyz_path[0:i] + return xyz_path + + +def shift_path(xyz_path, offset): + """ + Shifts "voxels" by subtracting the min coordinate in "bbox". + + Parameters + ---------- + voxels : numpy.ndarray + Voxel coordinates to be shifted. + offset : dict + Coordinates of a bounding box that contains "voxels". + + Returns + ------- + numpy.ndarray + Voxels shifted by min coordinate in "bbox". + + """ + return [xyz - offset for xyz in xyz_path] + + +def fill_path(img, path, val=-1): + """ + Fills a given path in a 3D image array with a specified value. + + Parameters + ---------- + img : numpy.ndarray + The 3D image array to fill the path in. + path : iterable + A list or iterable containing 3D coordinates (x, y, z) representing + the path. + val : int, optional + The value to fill the path with. Default is -1. + + Returns + ------- + numpy.ndarray + The modified image array with the path filled with the specified value. + + """ + for xyz in path: + x, y, z = tuple(np.floor(xyz).astype(int)) + img[x - 1: x + 2, y - 1: y + 2, z - 1: z + 2] = val + return img + + +def path_length(path): + """ + Computes the path length of "path". + + Parameters + ---------- + path : list + xyz coordinates that form a path. + + Returns + ------- + float + Path length of "path". + + """ + return np.sum([dist(path[i], path[i - 1]) for i in range(1, len(path))]) -# Smoothing def smooth_branch(xyz, s=None): """ Smooths a Nx3 array of points by fitting a cubic spline. The points are @@ -227,97 +291,51 @@ def fit_spline(xyz, k=2, s=None): return spline_x, spline_y, spline_z -def sample_curve(xyz_arr, n_pts): +# --- kd-tree utils --- +def query_ball(kdtree, xyz, radius): """ - Uniformly samples points from a curve represented as an array. + Queries a KD-tree for points within a given radius from a target point. Parameters ---------- - xyz_arr : numpy.ndarray - Array of xyz coordinates representing points along a curve. - n_pts : int - Number of points to be sampled. + kdtree : scipy.spatial.cKDTree + The KD-tree data structure containing the points to query. + xyz : numpy.ndarray + The target 3D coordinate (x, y, z) around which to search for points. + radius : float + The radius within which to search for points. Returns ------- numpy.ndarray - Resampled points along curve. - - """ - k = 1 if xyz_arr.shape[0] <= 3 else 3 - t = np.linspace(0, 1, n_pts) - spline_x, spline_y, spline_z = fit_spline(xyz_arr, k=k, s=0) - xyz = np.column_stack((spline_x(t), spline_y(t), spline_z(t))) - return xyz.astype(int) - - -# Image feature extraction -def get_profile(img, xyz_arr, process_id=None, window=[5, 5, 5]): - """ - Computes the maximum intensity profile along a list of 3D coordinates - in a given image. - - Parameters - ---------- - img : numpy.ndarray - The image volume or TensorStore object from which to extract intensity - profiles. - xyz_arr : numpy.ndarray - Array of 3D coordinates xyz representing points in the image volume. - process_id : int or None, optional - An optional identifier for the process. Default is None. - window : numpy.ndarray, optional - The size of the window around each coordinate for profile extraction. - Default is [5, 5, 5]. - - Returns - ------- - list, tuple - If "process_id" is provided, returns a tuple containing the process_id - and the intensity profile list. If "process_id" is not provided, - returns only the intensity profile list. + An array containing the points within the specified radius from the + target coordinate. """ - profile = [] - for xyz in xyz_arr: - if type(img) is ts.TensorStore: - profile.append(np.max(util.read_tensorstore(img, xyz, window))) - else: - profile.append(np.max(util.get_chunk(img, xyz, window))) - - if process_id: - return process_id, profile - else: - return profile + idxs = kdtree.query_ball_point(xyz, radius, return_sorted=True) + return kdtree.data[idxs] -def fill_path(img, path, val=-1): +def kdtree_query(kdtree, xyz): """ - Fills a given path in a 3D image array with a specified value. + Gets the xyz coordinates of the nearest neighbor of "xyz" from "kdtree". Parameters ---------- - img : numpy.ndarray - The 3D image array to fill the path in. - path : iterable - A list or iterable containing 3D coordinates (x, y, z) representing - the path. - val : int, optional - The value to fill the path with. Default is -1. + xyz : tuple + xyz coordinate to be queried. Returns ------- - numpy.ndarray - The modified image array with the path filled with the specified value. + tuple + xyz coordinate of the nearest neighbor of "xyz". """ - for xyz in path: - x, y, z = tuple(np.floor(xyz).astype(int)) - img[x - 1: x + 2, y - 1: y + 2, z - 1: z + 2] = val - return img + _, idx = kdtree.query(xyz) + return tuple(kdtree.data[idx]) -# Proposal optimization +# --- Proposal optimization --- def optimize_alignment(neurograph, img, edge, depth=15): """ Optimizes alignment of edge proposal between two branches by finding @@ -399,8 +417,8 @@ def optimize_complex_alignment(neurograph, img, edge, depth=15): """ i, j = tuple(edge) - branch = neurograph.get_branches(i if neurograph.is_leaf(i) else j)[0] - branches = neurograph.get_branches(j if neurograph.is_leaf(i) else i) + branch = neurograph.branches(i if neurograph.is_leaf(i) else j)[0] + branches = neurograph.branches(j if neurograph.is_leaf(i) else i) d1, e1, val_1 = align(neurograph, img, branch, branches[0], depth) d2, e2, val_2 = align(neurograph, img, branch, branches[1], depth) pair_1 = (branch[d1], branches[0][e1]) @@ -451,7 +469,7 @@ def align(neurograph, img, branch_1, branch_2, depth): for d2 in range(min(depth, len(branch_2) - 1)): xyz_2 = neurograph.to_voxels(branch_2[d2], shift=True) line = make_line(xyz_1, xyz_2, 10) - score = np.mean(get_profile(img, line, window=[3, 3, 3])) + score = np.mean(img_util.get_profile(img, line)) if score > best_score: best_score = score best_d1 = d1 @@ -459,7 +477,7 @@ def align(neurograph, img, branch_1, branch_2, depth): return best_d1, best_d2, best_score -# Miscellaneous +# --- Miscellaneous --- def dist(v_1, v_2, metric="l2"): """ Computes distance between "v_1" and "v_2". @@ -483,24 +501,6 @@ def dist(v_1, v_2, metric="l2"): return distance.euclidean(v_1, v_2) -def path_length(path): - """ - Computes the path length of "path". - - Parameters - ---------- - path : list - xyz coordinates that form a path. - - Returns - ------- - float - Path length of "path". - - """ - return np.sum([dist(path[i], path[i - 1]) for i in range(1, len(path))]) - - def make_line(xyz_1, xyz_2, n_steps): """ Generates a series of points representing a straight line between two 3D @@ -577,70 +577,3 @@ def nearest_neighbor(xyz_arr, xyz): min_dist = d idx = i return idx, min_dist - - -def shift_branches(branches, shift): - """ - Shifts the coordinates of branches in a list of arrays by a specified - shift vector. - - Parameters - ---------- - branches : list - A list containing arrays of 3D coordinates representing branches. - shift : numpy.ndarray - The shift vector (dx, dy, dz) by which to shift the coordinates. - - Returns - ------- - list - A list containing arrays of shifted 3D coordinates representing the - branches. - - """ - for i, branch in enumerate(branches): - branches[i] = branch - shift - return branches - - -def query_ball(kdtree, xyz, radius): - """ - Queries a KD-tree for points within a given radius from a target point. - - Parameters - ---------- - kdtree : scipy.spatial.cKDTree - The KD-tree data structure containing the points to query. - xyz : numpy.ndarray - The target 3D coordinate (x, y, z) around which to search for points. - radius : float - The radius within which to search for points. - - Returns - ------- - numpy.ndarray - An array containing the points within the specified radius from the - target coordinate. - - """ - idxs = kdtree.query_ball_point(xyz, radius, return_sorted=True) - return kdtree.data[idxs] - - -def kdtree_query(kdtree, xyz): - """ - Gets the xyz coordinates of the nearest neighbor of "xyz" from "kdtree". - - Parameters - ---------- - xyz : tuple - xyz coordinate to be queried. - - Returns - ------- - tuple - xyz coordinate of the nearest neighbor of "xyz". - - """ - _, idx = kdtree.query(xyz) - return tuple(kdtree.data[idx]) diff --git a/src/deep_neurographs/inference.py b/src/deep_neurographs/inference.py index e9b31bf..60d5842 100644 --- a/src/deep_neurographs/inference.py +++ b/src/deep_neurographs/inference.py @@ -99,7 +99,6 @@ def __init__( self.accepted_proposals = list() self.sample_id = sample_id self.segmentation_id = segmentation_id - self.img_path = img_path self.model_path = model_path # Extract config settings @@ -108,7 +107,7 @@ def __init__( # Inference engine self.inference_engine = InferenceEngine( - self.img_path, + img_path, self.model_path, self.ml_config.model_type, self.graph_config.search_radius, @@ -153,15 +152,15 @@ def run(self, fragments_pointer): print(f"Total Runtime: {round(t, 4)} {unit}\n") def run_schedule( - self, fragments_pointer, search_radius_schedule, save_all_rounds=False + self, fragments_pointer, radius_schedule, save_all_rounds=False ): t0 = time() self.report_experiment() self.build_graph(fragments_pointer) - for round_id, search_radius in enumerate(search_radius_schedule): - print(f"--- Round {round_id + 1}: Radius = {search_radius} ---") + for round_id, radius in enumerate(radius_schedule): + print(f"--- Round {round_id + 1}: Radius = {radius} ---") round_id += 1 - self.generate_proposals(search_radius) + self.generate_proposals(radius) self.run_inference() if save_all_rounds: self.save_results(round_id=round_id) @@ -213,7 +212,7 @@ def build_graph(self, fragments_pointer): print(f"Module Runtime: {round(t, 4)} {unit}\n") self.print_graph_overview() - def generate_proposals(self, search_radius=None): + def generate_proposals(self, radius=None): """ Generates proposals for the fragment graph based on the specified configuration. @@ -229,13 +228,13 @@ def generate_proposals(self, search_radius=None): """ # Initializations print("(2) Generate Proposals") - if search_radius is None: - search_radius = self.graph_config.search_radius + if radius is None: + radius = self.graph_config.radius # Main t0 = time() self.graph.generate_proposals( - search_radius, + radius, complex_bool=self.graph_config.complex_bool, long_range_bool=self.graph_config.long_range_bool, proposals_per_leaf=self.graph_config.proposals_per_leaf, @@ -392,7 +391,7 @@ def __init__( img_path, model_path, model_type, - search_radius, + radius, batch_size=BATCH_SIZE, confidence_threshold=CONFIDENCE_THRESHOLD, device=None, @@ -410,7 +409,7 @@ def __init__( Path to machine learning model parameters. model_type : str Type of machine learning model used to perform inference. - search_radius : float + radius : float Search radius used to generate proposals. batch_size : int, optional Number of proposals to generate features and classify per batch. @@ -429,16 +428,25 @@ def __init__( """ # Set class attributes self.batch_size = batch_size - self.downsample_factor = downsample_factor self.device = "cpu" if device is None else device self.is_gnn = True if "Graph" in model_type else False self.model_type = model_type - self.search_radius = search_radius + self.radius = radius self.threshold = confidence_threshold - # Load image and model + # Load image driver = "n5" if ".n5" in img_path else "zarr" - self.img = img_util.open_tensorstore(img_path, driver=driver) + img = img_util.open_tensorstore(img_path, driver=driver) + + # Features + feature_factory = feature_generation.Factory() + self.feature_generator = feature_factory.create( + model_type, + img, + downsample_factor + ) + + # Model self.model = ml_util.load_model(model_path) if self.is_gnn: self.model = self.model.to(self.device) @@ -532,17 +540,7 @@ def get_batch_dataset(self, neurograph, batch): ... """ - # Generate features - features = feature_generation.run( - neurograph, - self.img, - self.model_type, - batch, - self.search_radius, - downsample_factor=self.downsample_factor, - ) - - # Initialize dataset + features = self.feature_generator.run(neurograph, batch, self.radius) computation_graph = batch["graph"] if type(batch) is dict else None dataset = ml_util.init_dataset( neurograph, diff --git a/src/deep_neurographs/machine_learning/archived/features.txt b/src/deep_neurographs/machine_learning/archived/features.txt new file mode 100644 index 0000000..c9d95fd --- /dev/null +++ b/src/deep_neurographs/machine_learning/archived/features.txt @@ -0,0 +1,29 @@ +""" +Created on Sat May 9 11:00:00 2024 + +@author: Anna Grim +@email: anna.grim@alleninstitute.org + +Archived routines for feature generation. + +""" + +def compute_curvature(neurograph, edge): + kappa = curvature(neurograph.edges[edge]["xyz"]) + n_pts = len(kappa) + if n_pts <= N_BRANCH_PTS: + sampled_kappa = np.zeros((N_BRANCH_PTS)) + sampled_kappa[0:n_pts] = kappa + else: + idxs = np.linspace(0, n_pts - 1, N_BRANCH_PTS).astype(int) + sampled_kappa = kappa[idxs] + return np.array(sampled_kappa) + + +def curvature(xyz_list): + a = np.linalg.norm(xyz_list[1:-1] - xyz_list[:-2], axis=1) + b = np.linalg.norm(xyz_list[2:] - xyz_list[1:-1], axis=1) + c = np.linalg.norm(xyz_list[2:] - xyz_list[:-2], axis=1) + s = 0.5 * (a + b + c) + delta = np.sqrt(s * (s - a) * (s - b) * (s - c)) + return 4 * delta / (a * b * c) diff --git a/src/deep_neurographs/machine_learning/feature_generation.py b/src/deep_neurographs/machine_learning/feature_generation.py index 13f2de2..6f00b92 100644 --- a/src/deep_neurographs/machine_learning/feature_generation.py +++ b/src/deep_neurographs/machine_learning/feature_generation.py @@ -6,11 +6,10 @@ Generates features for training a model and performing inference. -Conventions: (1) "xyz" refers to a real world coordinate such as those from - an swc file. +Conventions: + (1) "xyz" refers to a real world coordinate such as those from an swc file - (2) "voxel" refers to an voxel coordinate in a whole exaspim - image. + (2) "voxel" refers to an voxel coordinate in a whole exaspim image. """ @@ -23,278 +22,582 @@ import tensorstore as ts from deep_neurographs import geometry -from deep_neurographs.machine_learning.feature_generation_graphs import ( - generate_gnn_features, -) from deep_neurographs.utils import img_util, util CHUNK_SIZE = [48, 48, 48] N_BRANCH_PTS = 50 -N_PROFILE_PTS = 16 # 10 +N_PROFILE_PTS = 16 N_SKEL_FEATURES = 22 -def run( - neurograph, - img, - model_type, - proposals_dict, - radius, - downsample_factor=1, - labels=None, -): +class Factory: """ - Generates feature vectors that are used by a machine learning model to - classify proposals. - - Parameters - ---------- - neurograph : NeuroGraph - Graph that "proposals" belong to. - img : tensorstore.Tensorstore - Image stored in a GCS bucket. - model_type : str - Type of machine learning model used to classify proposals. - proposals_dict : dict - Dictionary that contains the items (1) "proposals" which are the - proposals from "neurograph" that features will be generated and - (2) "graph" which is the computation graph used by the gnn. - radius : float - Search radius used to generate proposals. - downsample_factor : int, optional - Downsampling factor that accounts for which level in the image pyramid - the voxel coordinates must index into. The default is 0. - labels : tensorstore.TensorStore, optional - Segmentation mask stored in a GCS bucket. The default is None. - - Returns - ------- - dict - Feature vectors. + Class that generates feature generator instances based on the specified + model type. """ - # Init leaf kd-tree (if applicable) - if neurograph.leaf_kdtree is None: - neurograph.init_kdtree(node_type="leaf") - - # Feature generation by type of machine learning model - if model_type == "GraphNeuralNet": - return generate_gnn_features( - neurograph, img, proposals_dict, radius, downsample_factor - ) - else: - return generate_features( - neurograph, img, proposals_dict, radius, downsample_factor - ) + @staticmethod + def create(model_type, *args): + if model_type == "GraphNeuralNet": + return GraphFeatureGenerator(*args) + else: + return FeatureGenerator(*args) -def generate_features( - neurograph, img, proposals_dict, radius, downsample_factor -): +class FeatureGenerator: """ - Generates feature vectors that are used by a general machine learning model - (e.g. random forest or feed forward neural network). - - Parameters - ---------- - neurograph : NeuroGraph - Graph that "proposals" belong to. - img : tensorstore.Tensorstore - Image stored in a GCS bucket. - proposals_dict : dict - Dictionary containing the computation graph used by gnn and proposals - to be classified. - radius : float - Search radius used to generate proposals. - downsample_factor : int - Downsampling factor that accounts for which level in the image pyramid - the voxel coordinates must index into. - - Returns - ------- - dict - Feature vectors. + An abstract base class that generates features vectors to be classified + and/or learned by a machine learning model. """ - features = defaultdict(bool) - features["proposals"] = { - "skel": proposal_skeletal( - neurograph, proposals_dict["proposals"], radius - ), - "profiles": proposal_profiles( - neurograph, img, proposals_dict["proposals"], downsample_factor - ), - } - return features - -def proposal_profiles(neurograph, img, proposals, downsample_factor): + def __init__(self, img, downsample_factor): + """ + Instantiates a FeatureGenerator object that is used to generate + features in a machine learning pipeline. + + Parameters + ---------- + img : np.ndarray + Raw image from which features will be generated. + downsample_factor : int + Downsampling factor that represents which level in the image + pyramid the voxel coordinates must index into. + + Returns + ------- + None + + """ + self.img = img + self.downsample_factor = downsample_factor + + def run(self, neurograph, proposals_dict, radius): + """ + Generates feature vectors for each proposal. + + Parameters + ---------- + neurograph : NeuroGraph + Graph that "proposals" belong to. + proposals_dict : dict + Dictionary containing the computation graph used by gnn and + proposals to be classified. + radius : float + Search radius used to generate proposals. + + Returns + ------- + dict + Feature vectors. + + """ + # Initialiations + features = defaultdict(bool) + proposals = proposals_dict["proposals"] + if neurograph.leaf_kdtree is None: + neurograph.init_kdtree(node_type="leaf") + + # Main + features["proposals"] = { + "skel": self.proposal_skeletal(neurograph, proposals, radius), + "profiles": self.proposal_profiles(neurograph, proposals), + } + return features + + def proposal_skeletal(neurograph, proposals, radius): + """ + Generates features from skeleton (i.e. graph) which are graph or + geometry-based features. + + Parameters + ---------- + neurograph : NeuroGraph + Graph that "proposals" belong to. + proposals : list + Proposals for which features will be generated + radius : float + Search radius used to generate proposals. + + Returns + ------- + dict + Features generated from skeleton. + + """ + features = dict() + for proposal in proposals: + i, j = tuple(proposal) + features[proposal] = np.concatenate( + ( + neurograph.proposal_length(proposal), + neurograph.degree[i], + neurograph.degree[j], + len(neurograph.nodes[i]["proposals"]), + len(neurograph.nodes[j]["proposals"]), + neurograph.n_nearby_leafs(proposal, radius), + neurograph.proposal_radii(proposal), + neurograph.proposal_avg_radii(proposal), + neurograph.proposal_directionals(proposal, 16), + neurograph.proposal_directionals(proposal, 32), + neurograph.proposal_directionals(proposal, 64), + neurograph.proposal_directionals(proposal, 128), + ), + axis=None, + ) + return features + + def proposal_profiles(self, neurograph, proposals): + """ + Generates an image intensity profile along the proposal. + + Parameters + ---------- + neurograph : NeuroGraph + Graph that "proposals" belong to. + proposals : list[frozenset] + List of proposals for which features will be generated. + + Returns + ------- + dict + Dictonary such that each pair is the proposal id and image + intensity profile. + + """ + with ThreadPoolExecutor() as executor: + # Assign threads + threads = len(proposals) * [None] + for i, proposal in enumerate(proposals): + xyz_1, xyz_2 = neurograph.proposal_xyz(proposal) + xyz_path = geometry.make_line(xyz_1, xyz_2, N_PROFILE_PTS) + spec = self.get_profile_spec(xyz_path) + threads[i] = executor.submit( + img_util.get_profile, self.img, spec, proposal + ) + + # Store results + profiles = dict() + for thread in as_completed(threads): + profiles.update(thread.result()) + return profiles + + def get_profile_spec(self, xyz_path): + """ + Gets image bounding box and voxel coordinates needed to compute an + image intensity profile. + + Parameters + ---------- + xyz_path : numpy.ndarray + xyz coordinates that represent an image profile path. + + Returns + ------- + dict + Specifications needed to compute profile for a given proposal. + + """ + voxels = self.transform_path(xyz_path) + bbox = img_util.get_fixed_bbox(voxels, CHUNK_SIZE) + profile_path = geometry.shift_path(voxels, bbox["min"]) + return {"bbox": bbox, "profile_path": profile_path} + + def transform_path(self, xyz_path): + """ + Transforms "xyz_path" by converting the xyz coordinates to voxels and + resampling "N_PROFILE_PTS" from voxel coordinates. + + Parameters + ---------- + xyz_path : numpy.ndarray + xyz coordinates that represent an image profile path. + + Returns + ------- + numpy.ndarray + Voxel coordinates that represent an image profile path. + + """ + voxels = np.zeros((len(xyz_path), 3), dtype=int) + for i, xyz in enumerate(xyz_path): + voxels[i] = img_util.to_voxels(xyz, self.downsample_factor) + return voxels + + +class GraphFeatureGenerator(FeatureGenerator): """ - Generates an image intensity profile along each proposal by reading from - an image on the cloud. - - Parameters - ---------- - neurograph : NeuroGraph - Graph that "proposals" belong to. - img : tensorstore.TensorStore - Image stored in a GCS bucket. - proposals : list[frozenset] - List of proposals for which features will be generated. - downsample_factor : int - Downsampling factor that accounts for which level in the image pyramid - the voxel coordinates must index into. - - Returns - ------- - dict - Dictonary such that each pair is the proposal id and image intensity - profile. + Class that generates features vectors that are used by a graph neural + network to classify proposals. """ - with ThreadPoolExecutor() as executor: - threads = [] + def __init__( + self, img, downsample_factor, labels=None, use_img_embedding=False + ): + """ + Initializes object that generates features that are used by a gnn. + + Parameters + ---------- + img : tensorstore.Tensorstore + Image stored in a GCS bucket. + downsample_factor : int, optional + Downsampling factor that accounts for which level in the image + pyramid the voxel coordinates must index into. The default is 0. + labels : tensorstore.TensorStore, optional + Segmentation mask stored in a GCS bucket. The default is None. + use_img_embedding : bool, optional + ... + + Returns + ------- + None + + """ + # Initialize instance attributes + super().__init__(img, downsample_factor) + self.labels = labels + self.use_img_embedding = use_img_embedding + + def run(self, neurograph, proposals_dict, radius): + """ + Generates feature vectors for a graph given a set of proposals. + + Parameters + ---------- + neurograph : NeuroGraph + Graph that "proposals" belong to. + proposals_dict : dict + Dictionary that contains the items (1) "proposals" which are the + proposals from "neurograph" that features will be generated and + (2) "graph" which is the computation graph used by the gnn. + radius : float + Search radius used to generate proposals. + + Returns + ------- + dict + Dictionary that contains different types of feature vectors for + nodes, edges, and proposals. + + """ + # Initializations + computation_graph = proposals_dict["graph"] + proposals = proposals_dict["proposals"] + if neurograph.leaf_kdtree is None: + neurograph.init_kdtree(node_type="leaf") + + # Main + features = { + "nodes": self.run_on_nodes(neurograph, computation_graph), + "edge": self.run_on_edges(neurograph, computation_graph), + "proposals": self.run_on_proposals(neurograph, proposals, radius) + } + return features + + def run_on_nodes(self, neurograph, computation_graph): + """ + Generates feature vectors for every node in "computation_graph". + + Parameters + ---------- + neurograph : NeuroGraph + NeuroGraph generated from a predicted segmentation. + computation_graph : networkx.Graph + Graph used by gnn to classify proposals. + + Returns + ------- + dict + Dictionary whose keys are feature types (i.e. skeletal) and values + are a dictionary that maps a node id to the corresponding feature + vector. + + """ + return {"skel": self.node_skeletal(neurograph, computation_graph)} + + def run_on_edges(self, neurograph, computation_graph): + """ + Generates feature vectors for every edge in "computation_graph". + + Parameters + ---------- + neurograph : NeuroGraph + NeuroGraph generated from a predicted segmentation. + computation_graph : networkx.Graph + Graph used by gnn to classify proposals. + + Returns + ------- + dict + Dictionary whose keys are feature types (i.e. skeletal) and values + are a dictionary that maps an edge id to the corresponding feature + vector. + + """ + return {"skel": self.edge_skeletal(neurograph, computation_graph)} + + def run_on_proposals(self, neurograph, proposals, radius): + """ + Generates feature vectors for every proposal in "neurograph". + + Parameters + ---------- + neurograph : NeuroGraph + NeuroGraph generated from a predicted segmentation. + proposals : list[frozenset] + List of proposals for which features will be generated. + radius : float + Search radius used to generate proposals. + + Returns + ------- + dict + Dictionary whose keys are feature types (i.e. skeletal, profiles, + chunks) and values are a dictionary that maps a proposal id to a + feature vector. + + """ + # Skeleton features + features = { + "skel": self.proposal_skeletal(neurograph, proposals, radius) + } + + # Image features + if self.use_img_embedding: + chunks, profiles = self.proposal_chunks_profiles( + neurograph, proposals + ) + features.update({"chunks": chunks}) + features.update({"profiles": profiles}) + else: + profiles = self.proposal_profiles(neurograph, proposals) + features.update({"profiles": profiles}) + return features + + # -- Skeletal Features -- + def node_skeletal(self, neurograph, computation_graph): + """ + Generates skeleton-based features for nodes in "computation_graph". + + Parameters + ---------- + neurograph : NeuroGraph + NeuroGraph generated from a predicted segmentation. + computation_graph : networkx.Graph + Graph used by gnn to classify proposals. + + Returns + ------- + dict + Dictionary that maps a node id to a feature vector. + + """ + node_skeletal_features = dict() + for i in computation_graph.nodes: + node_skeletal_features[i] = np.concatenate( + ( + neurograph.degree[i], + neurograph.nodes[i]["radius"], + len(neurograph.nodes[i]["proposals"]), + ), + axis=None, + ) + return node_skeletal_features + + def edge_skeletal(self, neurograph, computation_graph): + """ + Generates skeleton-based features for edges in "computation_graph". + + Parameters + ---------- + neurograph : NeuroGraph + NeuroGraph generated from a predicted segmentation. + computation_graph : networkx.Graph + Graph used by gnn to classify proposals. + + Returns + ------- + dict + Dictionary that maps an edge id to a feature vector. + + """ + edge_skeletal_features = dict() + for edge in neurograph.edges: + edge_skeletal_features[frozenset(edge)] = np.array( + [ + np.mean(neurograph.edges[edge]["radius"]), + min(neurograph.edges[edge]["length"], 500) / 500, + ], + ) + return edge_skeletal_features + + def proposal_skeletal(self, neurograph, proposals, radius): + """ + Generates skeleton-based features for "proposals". + + Parameters + ---------- + neurograph : NeuroGraph + NeuroGraph generated from a predicted segmentation. + proposals : list[frozenset] + List of proposals for which features will be generated. + radius : float + Search radius used to generate proposals. + + Returns + ------- + dict + Dictionary that maps a node id to a feature vector. + + """ + proposal_skeletal_features = dict() for proposal in proposals: - xyz_1, xyz_2 = neurograph.proposal_xyz(proposal) - specs = get_profile_specs(xyz_1, xyz_2, downsample_factor) - threads.append(executor.submit(get_profile, img, specs, proposal)) - - profiles = dict() - for thread in as_completed(threads): - profiles.update(thread.result()) - return profiles - - -def get_profile_specs(xyz_1, xyz_2, downsample_factor): - """ - Gets image bounding box and voxel coordinates needed to compute an image - profile. - - Parameters - ---------- - xyz_1 : numpy.ndarray - xyz coordinate of starting point of profile. - xyz_2 : numpy.ndarray - xyz coordinate of ending point of profile. - downsample_factor : int - Downsampling factor that accounts for which level in the image pyramid - the voxel coordinates must index into. + proposal_skeletal_features[proposal] = np.concatenate( + ( + neurograph.proposal_length(proposal) / radius, + neurograph.n_nearby_leafs(proposal, radius), + neurograph.proposal_radii(proposal), + neurograph.proposal_directionals(proposal, 16), + neurograph.proposal_directionals(proposal, 32), + neurograph.proposal_directionals(proposal, 64), + neurograph.proposal_directionals(proposal, 128), + ), + axis=None, + ) + return proposal_skeletal_features + + # --- Image features --- + def node_profiles(self, neurograph, computation_graph): + """ + Generates image profiles for nodes in "computation_graph". + + Parameters + ---------- + neurograph : NeuroGraph + NeuroGraph generated from a predicted segmentation. + computation_graph : networkx.Graph + Graph used by gnn to classify proposals. + + Returns + ------- + dict + Dictionary that maps a node id to an image profile. + + """ + # Get profile specifications + specs = dict() + for i in computation_graph.nodes: + if neurograph.is_leaf(i): + profile_path = self.get_leaf_profile_path(neurograph, i) + else: + profile_path = self.get_branching_profile_path(neurograph, i) + specs[i] = self.get_img_specs(profile_path) + + # Generate profiles + with ThreadPoolExecutor() as executor: + threads = [] + for i, spec in specs.items(): + threads.append( + executor.submit(img_util.get_profile, self.img, spec, i) + ) - Returns - ------- - dict - Specifications needed to compute an image profile for a given - proposal. + node_profile_features = dict() + for thread in as_completed(threads): + node_profile_features.update(thread.result()) + return node_profile_features + + def proposal_chunks_profiles(self, neurograph, img, proposals): + """ + Generates an image intensity profile along each proposal. + + Parameters + ---------- + neurograph : NeuroGraph + Graph that "proposals" belong to. + img : tensorstore.TensorStore + Image stored in a GCS bucket. + proposals : list[frozenset] + List of proposals for which features will be generated. + downsample_factor : int + Downsampling factor that represents which level in the image + pyramid the voxel coordinates must index into. + + Returns + ------- + dict + Dictonary such that each pair is the proposal id and profile. + + """ + with ThreadPoolExecutor() as executor: + # Assign threads + threads = [] + for proposal in proposals: + xyz_1, xyz_2 = neurograph.proposal_xyz(proposal) + xyz_path = geometry.make_line(xyz_1, xyz_2, N_PROFILE_PTS) + specs = self.get_profile_spec(xyz_path) + threads.append( + executor.submit( + img_util.get_chunk_profile, img, specs, proposal) + ) - """ - # Compute voxel coordinates - voxel_1 = img_util.to_voxels(xyz_1, downsample_factor=downsample_factor) - voxel_2 = img_util.to_voxels(xyz_2, downsample_factor=downsample_factor) - - # Store local coordinates - bbox = img_util.get_fixed_bbox(np.vstack([voxel_1, voxel_2]), CHUNK_SIZE) - start = [voxel_1[i] - bbox["min"][i] for i in range(3)] - end = [voxel_2[i] - bbox["min"][i] for i in range(3)] - specs = { - "bbox": bbox, - "profile_path": geometry.make_line(start, end, N_PROFILE_PTS), - } - return specs + # Store results + profiles = dict() + chunks = dict() + for thread in as_completed(threads): + proposal, chunk, profile = thread.result() + chunks[proposal] = chunk + profiles[proposal] = profile + return chunks, profiles -def get_profile(img, specs, profile_id): +# --- Profile utils --- +def get_leaf_profile_path(neurograph, i): """ - Gets the image profile for a given proposal. + Gets path that profile will be computed over for the leaf node "i". Parameters ---------- - img : tensorstore.TensorStore - Image that profiles are generated from. - specs : dict - Dictionary that contains the image bounding box and coordinates of the - image profile path. - profile_id : frozenset - ... + neurograph : NeuroGraph + NeuroGraph generated from a predicted segmentation. + i : int + Leaf node in "neurograph". Returns ------- - dict - Dictionary that maps an id (e.g. node, edge, or proposal) to its image - profile. + list + Voxel coordinates that profile is generated from. """ - profile = img_util.read_profile(img, specs) - avg, std = util.get_avg_std(profile) - profile.extend([avg, std]) - return {profile_id: profile} + j = neurograph.leaf_neighbor(i) + xyz_path = neurograph.oriented_edge((i, j), i) + return geometry.truncate_path(xyz_path) -def proposal_skeletal(neurograph, proposals, radius): +def get_branching_profile_path(neurograph, i): """ - Generates features from skeleton (i.e. graph) which are graph or - geometry type features. + Gets path that profile will be computed over for the branching node "i". Parameters ---------- neurograph : NeuroGraph - Graph that "proposals" belong to. - proposals : list - Proposals for which features will be generated - radius : float - Search radius used to generate proposals. + NeuroGraph generated from a predicted segmentation. + i : int + branching node in "neurograph". Returns ------- - dict - Features generated from skeleton. + list + Voxel coordinates that profile is generated from. """ - features = dict() - for proposal in proposals: - i, j = tuple(proposal) - features[proposal] = np.concatenate( - ( - neurograph.proposal_length(proposal), - neurograph.degree[i], - neurograph.degree[j], - len(neurograph.nodes[i]["proposals"]), - len(neurograph.nodes[j]["proposals"]), - neurograph.n_nearby_leafs(proposal, radius), - neurograph.proposal_radii(proposal), - neurograph.proposal_avg_radii(proposal), - neurograph.proposal_directionals(proposal, 8), - neurograph.proposal_directionals(proposal, 16), - neurograph.proposal_directionals(proposal, 32), - neurograph.proposal_directionals(proposal, 64), - ), - axis=None, - ) - return features - - -# --- part 2: edge feature generation -- -def compute_curvature(neurograph, edge): - kappa = curvature(neurograph.edges[edge]["xyz"]) - n_pts = len(kappa) - if n_pts <= N_BRANCH_PTS: - sampled_kappa = np.zeros((N_BRANCH_PTS)) - sampled_kappa[0:n_pts] = kappa - else: - idxs = np.linspace(0, n_pts - 1, N_BRANCH_PTS).astype(int) - sampled_kappa = kappa[idxs] - return np.array(sampled_kappa) + j_1, j_2 = tuple(neurograph.neighbors(i)) + voxels_1 = geometry.truncate_path(neurograph.oriented_edge((i, j_1), i)) + voxles_2 = geometry.truncate_path(neurograph.oriented_edge((i, j_2), i)) + return np.vstack([np.flip(voxels_1, axis=0), voxles_2]) -def curvature(xyz_list): - a = np.linalg.norm(xyz_list[1:-1] - xyz_list[:-2], axis=1) - b = np.linalg.norm(xyz_list[2:] - xyz_list[1:-1], axis=1) - c = np.linalg.norm(xyz_list[2:] - xyz_list[:-2], axis=1) - s = 0.5 * (a + b + c) - delta = np.sqrt(s * (s - a) * (s - b) * (s - c)) - return 4 * delta / (a * b * c) - - -# -- Build Feature Matrix -- +# --- Build feature matrix --- def get_matrix(neurographs, features, sample_ids=None): if sample_ids: return stack_feature_matrices(neurographs, features, sample_ids) @@ -349,7 +652,21 @@ def get_feature_matrix(neurograph, features, shift=0): return X, y, idx_transforms -# -- util -- +def combine_features(features): + combined = dict() + for edge in features["skel"].keys(): + combined[edge] = None + for key in features.keys(): + if combined[edge] is None: + combined[edge] = deepcopy(features[key][edge]) + else: + combined[edge] = np.concatenate( + (combined[edge], features[key][edge]) + ) + return combined + + +# --- Utils --- def count_features(): """ Counts number of features based on the "model_type". @@ -366,63 +683,43 @@ def count_features(): return N_SKEL_FEATURES + N_PROFILE_PTS + 2 -def combine_features(features): - combined = dict() - for edge in features["skel"].keys(): - combined[edge] = None - for key in features.keys(): - if combined[edge] is None: - combined[edge] = deepcopy(features[key][edge]) - else: - combined[edge] = np.concatenate( - (combined[edge], features[key][edge]) - ) - return combined +def n_node_features(): + """ + Returns the number of features for different node types. + Parameters + ---------- + None -def generate_chunks(neurograph, proposals, img, labels): + Returns + ------- + dict + A dictionary containing the number of features for each node type + + """ + return {"branch": 2, "proposal": 34} + + +def n_edge_features(): """ - Generates an image chunk for each proposal such that the centroid of the - image chunk is the midpoint of the proposal. Image chunks contain two - channels: raw image and predicted segmentation. + Returns the number of features for different edge types. Parameters ---------- - neurograph : NeuroGraph - Graph that "proposals" belong to. - img : tensorstore.TensorStore - Image stored in a GCS bucket. - labels : tensorstore.TensorStore - Predicted segmentation mask stored in a GCS bucket. - proposals : list[frozenset], optional - List of proposals for which features will be generated. The - default is None. + None Returns ------- dict - Dictonary such that each pair is the proposal id and image chunk. + A dictionary containing the number of features for each edge type """ - with ThreadPoolExecutor() as executor: - # Assign Threads - threads = [None] * len(proposals) - for t, proposal in enumerate(proposals): - xyz_0, xyz_1 = neurograph.proposal_xyz(proposal) - voxel_1 = util.to_voxels(xyz_0) - voxel_2 = util.to_voxels(xyz_1) - threads[t] = executor.submit( - get_chunk, img, labels, voxel_1, voxel_2, proposal - ) - - # Save result - chunks = dict() - profiles = dict() - for thread in as_completed(threads): - proposal, chunk, profile = thread.result() - chunks[proposal] = chunk - profiles[proposal] = profile - return chunks, profiles + n_edge_features_dict = { + ("proposal", "edge", "proposal"): 3, + ("branch", "edge", "branch"): 3, + ("branch", "edge", "proposal"): 3 + } + return n_edge_features_dict def get_chunk(img, labels, voxel_1, voxel_2, thread_id=None): diff --git a/src/deep_neurographs/machine_learning/feature_generation_graphs.py b/src/deep_neurographs/machine_learning/feature_generation_graphs.py deleted file mode 100644 index e750825..0000000 --- a/src/deep_neurographs/machine_learning/feature_generation_graphs.py +++ /dev/null @@ -1,471 +0,0 @@ -""" -Created on Sat May 9 11:00:00 2024 - -@author: Anna Grim -@email: anna.grim@alleninstitute.org - -Generates features for training and performing inference with a heterogenous -graph neural network. - -""" -from concurrent.futures import ThreadPoolExecutor, as_completed - -import numpy as np - -from deep_neurographs import geometry -from deep_neurographs.machine_learning import feature_generation as feats -from deep_neurographs.utils import img_util - -N_PROFILE_PTS = 16 -NODE_PROFILE_DEPTH = 16 -WINDOW = [5, 5, 5] - - -def generate_gnn_features( - neurograph, img, proposals_dict, radius, downsample_factor -): - """ - Generates node and edge features for graph neural network. - - Parameters - ---------- - neurograph : NeuroGraph - NeuroGraph generated from a predicted segmentation. - img : str - Image stored on a GCS bucket. - proposals_dict : dict - Dictionary containing the computation graph used by gnn and proposals - to be classified. - radius : float - Search radius used to generate proposals. - downsample_factor : int - Downsampling factor that accounts for which level in the image pyramid - the voxel coordinates must index into. - - Returns - ------- - dict - Dictionary that contains different types of feature vectors for - nodes, edges, and proposals. - - """ - computation_graph = proposals_dict["graph"] - proposals = proposals_dict["proposals"] - features = { - "nodes": run_on_nodes( - neurograph, computation_graph, img, downsample_factor - ), - "edge": run_on_edges(neurograph, computation_graph), - "proposals": run_on_proposals( - neurograph, img, proposals, radius, downsample_factor - ), - } - return features - - -def run_on_nodes(neurograph, computation_graph, img, downsample_factor): - """ - Generates feature vectors for every node in "computation_graph". - - Parameters - ---------- - neurograph : NeuroGraph - NeuroGraph generated from a predicted segmentation. - computation_graph : networkx.Graph - Graph used by gnn to classify proposals. - img : str - Image stored in a GCS bucket. - downsample_factor : int - Downsampling factor that accounts for which level in the image pyramid - the voxel coordinates must index into. - - Returns - ------- - dict - Dictionary whose keys are feature types (i.e. skeletal) and values are - a dictionary that maps a node id to the corresponding feature vector. - - """ - return {"skel": node_skeletal(neurograph, computation_graph)} - - -def run_on_edges(neurograph, computation_graph): - """ - Generates feature vectors for every edge in "computation_graph". - - Parameters - ---------- - neurograph : NeuroGraph - NeuroGraph generated from a predicted segmentation. - computation_graph : networkx.Graph - Graph used by gnn to classify proposals. - - Returns - ------- - dict - Dictionary whose keys are feature types (i.e. skeletal) and values are - a dictionary that maps an edge id to the corresponding feature vector. - - """ - return {"skel": edge_skeletal(neurograph, computation_graph)} - - -def run_on_proposals(neurograph, img, proposals, radius, downsample_factor): - """ - Generates feature vectors for every proposal in "neurograph". - - Parameters - ---------- - neurograph : NeuroGraph - NeuroGraph generated from a predicted segmentation. - img : str - Image stored in a GCS bucket. - proposals : list[frozenset] - List of proposals for which features will be generated. - radius : float - Search radius used to generate proposals. - downsample_factor : int - Downsampling factor that accounts for which level in the image pyramid - the voxel coordinates must index into. - - Returns - ------- - dict - Dictionary whose keys are feature types (i.e. skeletal and profiles) - and values are a dictionary that maps a proposal id to the - corresponding feature vector. - - """ - proposal_features = { - "skel": proposal_skeletal(neurograph, proposals, radius), - "profiles": feats.proposal_profiles( - neurograph, img, proposals, downsample_factor - ), - } - return proposal_features - - -# -- Skeletal Features -- -def node_skeletal(neurograph, computation_graph): - """ - Generates skeleton-based features for nodes in "computation_graph". - - Parameters - ---------- - neurograph : NeuroGraph - NeuroGraph generated from a predicted segmentation. - computation_graph : networkx.Graph - Graph used by gnn to classify proposals. - - Returns - ------- - dict - Dictionary that maps a node id to the corresponding feature vector. - - """ - node_skeletal_features = dict() - for i in computation_graph.nodes: - node_skeletal_features[i] = np.concatenate( - ( - neurograph.degree[i], - neurograph.nodes[i]["radius"], - len(neurograph.nodes[i]["proposals"]), - ), - axis=None, - ) - return node_skeletal_features - - -def edge_skeletal(neurograph, computation_graph): - """ - Generates skeleton-based features for edges in "computation_graph". - - Parameters - ---------- - neurograph : NeuroGraph - NeuroGraph generated from a predicted segmentation. - computation_graph : networkx.Graph - Graph used by gnn to classify proposals. - - Returns - ------- - dict - Dictionary that maps an edge id to the corresponding feature vector. - - """ - edge_skeletal_features = dict() - for edge in neurograph.edges: - edge_skeletal_features[frozenset(edge)] = np.array( - [ - np.mean(neurograph.edges[edge]["radius"]), - min(neurograph.edges[edge]["length"], 500) / 500, - ], - ) - return edge_skeletal_features - - -def proposal_skeletal(neurograph, proposals, radius): - """ - Generates skeleton-based features for "proposals". - - Parameters - ---------- - neurograph : NeuroGraph - NeuroGraph generated from a predicted segmentation. - proposals : list[frozenset] - List of proposals for which features will be generated. - radius : float - Search radius used to generate proposals. - - Returns - ------- - dict - Dictionary that maps a node id to the corresponding feature vector. - - """ - proposal_skeletal_features = dict() - for proposal in proposals: - proposal_skeletal_features[proposal] = np.concatenate( - ( - neurograph.proposal_length(proposal) / radius, - neurograph.n_nearby_leafs(proposal, radius), - neurograph.proposal_radii(proposal), - neurograph.proposal_directionals(proposal, 16), - neurograph.proposal_directionals(proposal, 32), - neurograph.proposal_directionals(proposal, 64), - neurograph.proposal_directionals(proposal, 128), - ), - axis=None, - ) - return proposal_skeletal_features - - -# -- image features -- -def node_profiles(neurograph, computation_graph, img, downsample_factor): - """ - Generates image profiles for nodes in "computation_graph". - - Parameters - ---------- - neurograph : NeuroGraph - NeuroGraph generated from a predicted segmentation. - computation_graph : networkx.Graph - Graph used by gnn to classify proposals. - img : str - Image to be read from. - downsample_factor : int - Downsampling factor that accounts for which level in the image pyramid - the voxel coordinates must index into. - - Returns - ------- - dict - Dictionary that maps a node id to the corresponding image profile. - - """ - # Get specifications to compute profiles - specs = dict() - for i in computation_graph.nodes: - if neurograph.degree[i] == 1: - profile_path = get_leaf_profile_path(neurograph, i) - else: - profile_path = get_branching_profile_path(neurograph, i) - specs[i] = get_node_profile_specs(profile_path, downsample_factor) - - # Generate profiles - with ThreadPoolExecutor() as executor: - threads = [] - for i, specs_i in specs.items(): - threads.append(executor.submit(feats.get_profile, img, specs_i, i)) - - node_profile_features = dict() - for thread in as_completed(threads): - node_profile_features.update(thread.result()) - return node_profile_features - - -def get_leaf_profile_path(neurograph, i): - """ - Gets path that profile will be computed over for the leaf node "i". - - Parameters - ---------- - neurograph : NeuroGraph - NeuroGraph generated from a predicted segmentation. - i : int - Leaf node in "neurograph". - - Returns - ------- - list - Voxel coordinates that profile is generated from. - - """ - j = neurograph.leaf_neighbor(i) - return get_profile_path(neurograph.oriented_edge((i, j), i, key="xyz")) - - -def get_branching_profile_path(neurograph, i): - """ - Gets path that profile will be computed over for the branching node "i". - - Parameters - ---------- - neurograph : NeuroGraph - NeuroGraph generated from a predicted segmentation. - i : int - branching node in "neurograph". - - Returns - ------- - list - Voxel coordinates that profile is generated from. - - """ - nbs = list(neurograph.neighbors(i)) - voxels_1 = get_profile_path(neurograph.oriented_edge((i, nbs[0]), i)) - voxles_2 = get_profile_path(neurograph.oriented_edge((i, nbs[1]), i)) - return np.vstack([np.flip(voxels_1, axis=0), voxles_2]) - - -def get_profile_path(xyz_path): - """ - Gets a sub-path from "xyz_path" that has a path length of at most - "NODE_PROFILE_DEPTH" microns. - - Parameters - ---------- - xyz_path : numpy.ndarray - xyz coordinates that correspond to some edge in a neurograph from - which the profile path is extracted from. - - Returns - ------- - numpy.ndarray - xyz coordinates that an image profile will be generated from. - - """ - # Check for degeneracy - if xyz_path.shape[0] == 1: - xyz_path = np.vstack([xyz_path, xyz_path - 0.01]) - - # Truncate path - length = 0 - for i in range(1, xyz_path.shape[0]): - length += geometry.dist(xyz_path[i - 1], xyz_path[i]) - if length >= NODE_PROFILE_DEPTH: - break - return xyz_path[0:i, :] - - -def get_node_profile_specs(xyz_path, downsample_factor): - """ - Gets image bounding box and voxel coordinates needed to compute an image - profile. - - Parameters - ---------- - xyz_path : numpy.ndarray - xyz coordinates that represent an image profile path. - downsample_factor : int - Downsampling factor that accounts for which level in the image pyramid - the voxel coordinates must index into. - - Returns - ------- - dict - Specifications needed to compute image profile for a given proposal. - - """ - voxels = transform_path(xyz_path, downsample_factor) - bbox = img_util.get_minimal_bbox(voxels, buffer=1) - return {"bbox": bbox, "profile_path": shift_path(voxels, bbox)} - - -def transform_path(xyz_path, downsample_factor): - """ - Transforms "xyz_path" by converting the xyz coordinates to voxels and - resampling "N_PROFILE_PTS" from voxel coordinates. - - Parameters - ---------- - xyz_path : numpy.ndarray - xyz coordinates that represent an image profile path. - downsample_factor : int - Downsampling factor that accounts for which level in the image pyramid - the voxel coordinates must index into. - - Returns - ------- - numpy.ndarray - Voxel coordinates that represent an image profile path. - - """ - # Main - voxels = list() - for xyz in xyz_path: - voxels.append( - img_util.to_voxels(xyz, downsample_factor=downsample_factor) - ) - - # Finish - voxels = np.array(voxels) - if voxels.shape[0] < 5: - voxels = check_degenerate(voxels) - return geometry.sample_curve(voxels, N_PROFILE_PTS) - - -def shift_path(voxels, bbox): - """ - Shifts "voxels" by subtracting the min coordinate in "bbox". - - Parameters - ---------- - voxels : numpy.ndarray - Voxel coordinates to be shifted. - bbox : dict - Coordinates of a bounding box that contains "voxels". - - Returns - ------- - numpy.ndarray - Voxels shifted by min coordinate in "bbox". - - """ - return [voxel - bbox["min"] for voxel in voxels] - - -def check_degenerate(voxels): - """ - Checks whether "voxels" contains at least two unique points. If False, the - unique voxel coordinate is perturbed and added to "voxels". - - Parameters - ---------- - voxels : numpy.ndarray - Voxel coordinates to be checked. - - Returns - ------- - numpy.ndarray - Voxel coordinates that form a non-degenerate path. - - """ - if np.unique(voxels, axis=0).shape[0] == 1: - voxels = np.vstack( - [voxels, voxels[0, :] + np.array([1, 1, 1], dtype=int)] - ) - return voxels - - -def n_node_features(): - return {"branch": 2, "proposal": 34} - - -def n_edge_features(): - n_edge_features_dict = { - ("proposal", "edge", "proposal"): 3, - ("branch", "edge", "branch"): 3, - ("branch", "edge", "proposal"): 3 - } - return n_edge_features_dict diff --git a/src/deep_neurographs/neurograph.py b/src/deep_neurographs/neurograph.py index f28ade5..d30e241 100644 --- a/src/deep_neurographs/neurograph.py +++ b/src/deep_neurographs/neurograph.py @@ -236,8 +236,8 @@ def absorb_reducibles(self): # Concatenate attributes len_1 = self.edges[i, nbs[0]]["length"] len_2 = self.edges[i, nbs[1]]["length"] - xyz = self.get_branches(i, key="xyz") - radius = self.get_branches(i, key="radius") + xyz = self.branches(i, key="xyz") + radius = self.branches(i, key="radius") attrs = { "length": len_1 + len_2, "radius": concatenate([np.flip(radius[0]), radius[1]]), @@ -615,8 +615,8 @@ def proposal_radii(self, proposal): def proposal_avg_radii(self, proposal): i, j = tuple(proposal) - radii_i = self.get_branches(i, ignore_reducibles=True, key="radius") - radii_j = self.get_branches(j, ignore_reducibles=True, key="radius") + radii_i = self.branches(i, key="radius") + radii_j = self.branches(j, key="radius") return np.array([avg_radius(radii_i), avg_radius(radii_j)]) def proposal_xyz(self, proposal): @@ -637,15 +637,17 @@ def proposal_xyz(self, proposal): i, j = tuple(proposal) return np.array([self.nodes[i]["xyz"], self.nodes[j]["xyz"]]) - def proposal_directionals(self, proposal, window): - # Compute tangent vectors + def proposal_directionals(self, proposal, depth): + # Extract branches i, j = tuple(proposal) - direction = geometry.tangent(self.proposal_xyz(proposal)) + branches_i = [geometry.truncate_path(b, depth) for b in self.branches(i)] + branches_j = [geometry.truncate_path(b, depth) for b in self.branches(j)] origin = self.proposal_midpoint(proposal) - branches_i = self.get_branches(i, ignore_reducibles=True) - branches_j = self.get_branches(j, ignore_reducibles=True) - direction_i = geometry.get_directional(branches_i, i, origin, window) - direction_j = geometry.get_directional(branches_j, j, origin, window) + + # Compute tangent vectors + direction_i = geometry.get_directional(branches_i, origin, depth) + direction_j = geometry.get_directional(branches_j, origin, depth) + direction = geometry.tangent(self.proposal_xyz(proposal)) # Compute features inner_product_1 = abs(np.dot(direction, direction_i)) @@ -789,7 +791,7 @@ def dist(self, i, j): """ return get_dist(self.nodes[i]["xyz"], self.nodes[j]["xyz"]) - def get_branches(self, i, ignore_reducibles=False, key="xyz"): + def branches(self, i, ignore_reducibles=True, key="xyz"): branches = list() for j in self.neighbors(i): branch = self.oriented_edge((i, j), i, key=key) @@ -807,7 +809,7 @@ def get_branches(self, i, ignore_reducibles=False, key="xyz"): branches.append(branch) return branches - def get_branch(self, leaf, key="xyz"): + def branch(self, leaf, key="xyz"): """ Gets the xyz coordinates or radii contained in the edge emanating from "leaf". @@ -825,7 +827,7 @@ def get_branch(self, leaf, key="xyz"): """ assert self.is_leaf(leaf) - return self.get_branches(leaf, key=key)[0] + return self.branches(leaf, key=key)[0] def get_other_nb(self, i, j): """ @@ -1077,3 +1079,9 @@ def avg_radius(radii_list): end = max(min(16, len(radii) - 1), 1) avg += np.mean(radii[0:end]) / len(radii_list) return avg + + +def directional_origin(branch_1, branch_2): + origin_1 = np.mean(branch_1, axis=0) + origin_2 = np.mean(branch_2, axis=0) + return np.mean(np.vstack(origin_1, origin_2), axis=0) diff --git a/src/deep_neurographs/utils/graph_util.py b/src/deep_neurographs/utils/graph_util.py index 18101df..ebcb5cb 100644 --- a/src/deep_neurographs/utils/graph_util.py +++ b/src/deep_neurographs/utils/graph_util.py @@ -121,7 +121,7 @@ def run( """ # Load fragments and extract irreducibles - self.set_img_bbox(img_patch_origin, img_patch_shape) + self.init_img_bbox(img_patch_origin, img_patch_shape) swc_dicts = self.reader.load(fragments_pointer) irreducibles = get_irreducibles( swc_dicts, @@ -139,7 +139,7 @@ def run( neurograph.add_component(irreducible_set) return neurograph - def set_img_bbox(self, img_patch_origin, img_patch_shape): + def init_img_bbox(self, img_patch_origin, img_patch_shape): """ Sets the bounding box of an image patch as a class attriubte. @@ -156,7 +156,7 @@ def set_img_bbox(self, img_patch_origin, img_patch_shape): None """ - self.img_bbox = img_util.get_bbox(img_patch_origin, img_patch_shape) + self.img_bbox = img_util.init_bbox(img_patch_origin, img_patch_shape) # --- Graph structure extraction --- diff --git a/src/deep_neurographs/utils/img_util.py b/src/deep_neurographs/utils/img_util.py index 41d596d..833d73b 100644 --- a/src/deep_neurographs/utils/img_util.py +++ b/src/deep_neurographs/utils/img_util.py @@ -146,7 +146,7 @@ def read_tensorstore_with_bbox(img, bbox): return np.zeros(shape) -def read_profile(img, specs): +def read_profile(img, spec): """ Reads an intensity profile from an image (i.e. image profile). @@ -154,7 +154,7 @@ def read_profile(img, specs): ---------- img : tensorstore.TensorStore Image to be read. - specs : dict + spec : dict Dictionary that stores the bounding box of chunk to be read and the voxel coordinates of the profile path. @@ -164,8 +164,8 @@ def read_profile(img, specs): Image profile. """ - img_chunk = normalize(read_tensorstore_with_bbox(img, specs["bbox"])) - return read_intensities(img_chunk, specs["profile_path"]) + img_chunk = normalize(read_tensorstore_with_bbox(img, spec["bbox"])) + return read_intensities(img_chunk, spec["profile_path"]) def read_intensities(img, voxels): @@ -283,6 +283,57 @@ def get_labels_mip(img, axis=0): return (255 * mip).astype(np.uint8) +def get_chunk_profile(img, specs, profile_id): + """ + Gets the image profile for a given proposal. + + Parameters + ---------- + img : tensorstore.TensorStore + Image that profiles are generated from. + specs : dict + Dictionary that contains the image bounding box and coordinates of the + image profile path. + profile_id : frozenset + ... + + Returns + ------- + dict + Dictionary that maps an id (e.g. node, edge, or proposal) to its image + profile. + + """ + pass + + +def get_profile(img, spec, profile_id): + """ + Gets the image profile for a given proposal. + + Parameters + ---------- + img : tensorstore.TensorStore + Image that profiles are generated from. + spec : dict + Dictionary that contains the image bounding box and coordinates of the + image profile path. + profile_id : frozenset + Identifier of profile. + + Returns + ------- + dict + Dictionary that maps an id (e.g. node, edge, or proposal) to its image + profile. + + """ + profile = read_profile(img, spec) + avg, std = util.get_avg_std(profile) + profile.extend([avg, std]) + return {profile_id: profile} + + # --- coordinate conversions --- def img_to_patch(voxel, patch_centroid, patch_shape): """ @@ -331,7 +382,7 @@ def patch_to_img(voxel, patch_centroid, patch_dims): return np.round(voxel + patch_centroid - half_patch_dims).astype(int) -def to_world(voxel, anisotropy=ANISOTROPY, shift=[0, 0, 0]): +def to_world(voxel, shift=[0, 0, 0]): """ Converts coordinates from voxels to world. @@ -348,10 +399,10 @@ def to_world(voxel, anisotropy=ANISOTROPY, shift=[0, 0, 0]): Converted coordinates. """ - return tuple([voxel[i] * anisotropy[i] - shift[i] for i in range(3)]) + return tuple([voxel[i] * ANISOTROPY[i] - shift[i] for i in range(3)]) -def to_voxels(xyz, anisotropy=ANISOTROPY, downsample_factor=0): +def to_voxels(xyz, downsample_factor=0): """ Converts coordinates from world to voxel. @@ -373,12 +424,12 @@ def to_voxels(xyz, anisotropy=ANISOTROPY, downsample_factor=0): """ downsample_factor = 1.0 / 2 ** downsample_factor - voxel = downsample_factor * (xyz / np.array(anisotropy)) + voxel = downsample_factor * (xyz / np.array(ANISOTROPY)) return np.round(voxel).astype(int) # -- utils -- -def get_bbox(origin, shape): +def init_bbox(origin, shape): """ Gets the min and max coordinates of a bounding box based on "origin" and "shape". diff --git a/src/deep_neurographs/utils/swc_util.py b/src/deep_neurographs/utils/swc_util.py index 1621dee..8c54492 100644 --- a/src/deep_neurographs/utils/swc_util.py +++ b/src/deep_neurographs/utils/swc_util.py @@ -553,6 +553,7 @@ def set_radius(graph, i): ------- float Radius of node "i". + """ try: radius = graph[i]["radius"] @@ -567,15 +568,16 @@ def make_simple_entry(node, parent, xyz, radius=8): Parameters ---------- - graph : networkx.Graph - Graph that "i" and "parent" belong to. node : int Node that entry corresponds to. parent : int Parent of node "i". - anisotropy : list[float] - Image to real-world coordinates scaling factors for (x, y, z) that is - applied to swc files. + xyz : numpy.ndarray + ... + + Returns + ------- + ... """ x, y, z = tuple(xyz) From e3946bcdbce42cdfcb39b2de61df4947c1afc6da Mon Sep 17 00:00:00 2001 From: Anna Grim <108307071+anna-grim@users.noreply.github.com> Date: Tue, 8 Oct 2024 17:31:33 -0700 Subject: [PATCH 3/6] Feat multimodal gnn (#263) * refactor: feature generation * refactor: simplified feature generation --------- Co-authored-by: anna-grim --- src/deep_neurographs/inference.py | 34 +- .../machine_learning/feature_generation.py | 456 ++++++++---------- src/deep_neurographs/utils/img_util.py | 11 +- src/deep_neurographs/utils/ml_util.py | 4 +- 4 files changed, 231 insertions(+), 274 deletions(-) diff --git a/src/deep_neurographs/inference.py b/src/deep_neurographs/inference.py index 60d5842..8b124a0 100644 --- a/src/deep_neurographs/inference.py +++ b/src/deep_neurographs/inference.py @@ -20,12 +20,12 @@ from tqdm import tqdm from deep_neurographs.graph_artifact_removal import remove_doubles -from deep_neurographs.machine_learning import feature_generation from deep_neurographs.utils import gnn_util from deep_neurographs.utils import graph_util as gutil -from deep_neurographs.utils import img_util, ml_util, util +from deep_neurographs.utils import ml_util, util from deep_neurographs.utils.gnn_util import toCPU from deep_neurographs.utils.graph_util import GraphLoader +from deep_neurographs.machine_learning.feature_generation import FeatureGenerator BATCH_SIZE = 2000 CONFIDENCE_THRESHOLD = 0.7 @@ -65,6 +65,8 @@ def __init__( output_dir, config, device=None, + label_path=None, + use_img_embedding=False, ): """ Initializes an object that executes the full GraphTrace inference @@ -79,7 +81,7 @@ def __init__( Identifier for the predicted segmentation to be processed by the inference pipeline. img_path : str - Path to the raw image of whole brain stored on a GCS bucket. + Path to the raw image assumed to be stored in a GCS bucket. model_path : str Path to machine learning model parameters. output_dir : str @@ -89,6 +91,10 @@ def __init__( for the inference pipeline. device : str, optional ... + label_path : str, optional + Path to the segmentation assumed to be stored on a GCS bucket. + use_img_embedding : bool, optional + ... Returns ------- @@ -114,6 +120,7 @@ def __init__( confidence_threshold=self.ml_config.threshold, device=device, downsample_factor=self.ml_config.downsample_factor, + label_path=label_path ) # Set output directory @@ -396,6 +403,8 @@ def __init__( confidence_threshold=CONFIDENCE_THRESHOLD, device=None, downsample_factor=1, + label_path=None, + use_img_embedding=False ): """ Initializes an inference engine by loading images and setting class @@ -430,20 +439,15 @@ def __init__( self.batch_size = batch_size self.device = "cpu" if device is None else device self.is_gnn = True if "Graph" in model_type else False - self.model_type = model_type self.radius = radius self.threshold = confidence_threshold - # Load image - driver = "n5" if ".n5" in img_path else "zarr" - img = img_util.open_tensorstore(img_path, driver=driver) - # Features - feature_factory = feature_generation.Factory() - self.feature_generator = feature_factory.create( - model_type, - img, - downsample_factor + self.feature_generator = FeatureGenerator( + img_path, + downsample_factor, + label_path=label_path, + use_img_embedding=use_img_embedding ) # Model @@ -545,7 +549,7 @@ def get_batch_dataset(self, neurograph, batch): dataset = ml_util.init_dataset( neurograph, features, - self.model_type, + self.is_gnn, computation_graph=computation_graph, ) return dataset @@ -568,7 +572,7 @@ def predict(self, dataset): """ # Get predictions - if self.model_type == "GraphNeuralNet": + if self.is_gnn: with torch.no_grad(): # Get inputs n = len(dataset.data["proposal"]["y"]) diff --git a/src/deep_neurographs/machine_learning/feature_generation.py b/src/deep_neurographs/machine_learning/feature_generation.py index 6f00b92..fefa250 100644 --- a/src/deep_neurographs/machine_learning/feature_generation.py +++ b/src/deep_neurographs/machine_learning/feature_generation.py @@ -24,244 +24,86 @@ from deep_neurographs import geometry from deep_neurographs.utils import img_util, util -CHUNK_SIZE = [48, 48, 48] +CHUNK_SHAPE = [96, 96, 96] N_BRANCH_PTS = 50 N_PROFILE_PTS = 16 N_SKEL_FEATURES = 22 -class Factory: - """ - Class that generates feature generator instances based on the specified - model type. - - """ - @staticmethod - def create(model_type, *args): - if model_type == "GraphNeuralNet": - return GraphFeatureGenerator(*args) - else: - return FeatureGenerator(*args) - - class FeatureGenerator: """ - An abstract base class that generates features vectors to be classified - and/or learned by a machine learning model. + Class that generates features vectors that are used by a graph neural + network to classify proposals. """ - - def __init__(self, img, downsample_factor): + def __init__( + self, + img_path, + downsample_factor, + label_path=None, + use_img_embedding=False, + ): """ - Instantiates a FeatureGenerator object that is used to generate - features in a machine learning pipeline. + Initializes object that generates features for a graph. Parameters ---------- - img : np.ndarray - Raw image from which features will be generated. + img : tensorstore.Tensorstore + Raw image assumed to be stored in a GCS bucket. downsample_factor : int - Downsampling factor that represents which level in the image + Downsampling factor that accounts for which level in the image pyramid the voxel coordinates must index into. + labels : tensorstore.TensorStore, optional + Segmentation assumed to be stored in a GCS bucket. The default is + None. + use_img_embedding : bool, optional + ... Returns ------- None """ - self.img = img + # Initialize instance attributes self.downsample_factor = downsample_factor + self.use_img_embedding = use_img_embedding - def run(self, neurograph, proposals_dict, radius): - """ - Generates feature vectors for each proposal. - - Parameters - ---------- - neurograph : NeuroGraph - Graph that "proposals" belong to. - proposals_dict : dict - Dictionary containing the computation graph used by gnn and - proposals to be classified. - radius : float - Search radius used to generate proposals. - - Returns - ------- - dict - Feature vectors. - - """ - # Initialiations - features = defaultdict(bool) - proposals = proposals_dict["proposals"] - if neurograph.leaf_kdtree is None: - neurograph.init_kdtree(node_type="leaf") - - # Main - features["proposals"] = { - "skel": self.proposal_skeletal(neurograph, proposals, radius), - "profiles": self.proposal_profiles(neurograph, proposals), - } - return features - - def proposal_skeletal(neurograph, proposals, radius): - """ - Generates features from skeleton (i.e. graph) which are graph or - geometry-based features. - - Parameters - ---------- - neurograph : NeuroGraph - Graph that "proposals" belong to. - proposals : list - Proposals for which features will be generated - radius : float - Search radius used to generate proposals. - - Returns - ------- - dict - Features generated from skeleton. - - """ - features = dict() - for proposal in proposals: - i, j = tuple(proposal) - features[proposal] = np.concatenate( - ( - neurograph.proposal_length(proposal), - neurograph.degree[i], - neurograph.degree[j], - len(neurograph.nodes[i]["proposals"]), - len(neurograph.nodes[j]["proposals"]), - neurograph.n_nearby_leafs(proposal, radius), - neurograph.proposal_radii(proposal), - neurograph.proposal_avg_radii(proposal), - neurograph.proposal_directionals(proposal, 16), - neurograph.proposal_directionals(proposal, 32), - neurograph.proposal_directionals(proposal, 64), - neurograph.proposal_directionals(proposal, 128), - ), - axis=None, - ) - return features - - def proposal_profiles(self, neurograph, proposals): - """ - Generates an image intensity profile along the proposal. - - Parameters - ---------- - neurograph : NeuroGraph - Graph that "proposals" belong to. - proposals : list[frozenset] - List of proposals for which features will be generated. - - Returns - ------- - dict - Dictonary such that each pair is the proposal id and image - intensity profile. - - """ - with ThreadPoolExecutor() as executor: - # Assign threads - threads = len(proposals) * [None] - for i, proposal in enumerate(proposals): - xyz_1, xyz_2 = neurograph.proposal_xyz(proposal) - xyz_path = geometry.make_line(xyz_1, xyz_2, N_PROFILE_PTS) - spec = self.get_profile_spec(xyz_path) - threads[i] = executor.submit( - img_util.get_profile, self.img, spec, proposal - ) - - # Store results - profiles = dict() - for thread in as_completed(threads): - profiles.update(thread.result()) - return profiles - - def get_profile_spec(self, xyz_path): - """ - Gets image bounding box and voxel coordinates needed to compute an - image intensity profile. - - Parameters - ---------- - xyz_path : numpy.ndarray - xyz coordinates that represent an image profile path. - - Returns - ------- - dict - Specifications needed to compute profile for a given proposal. - - """ - voxels = self.transform_path(xyz_path) - bbox = img_util.get_fixed_bbox(voxels, CHUNK_SIZE) - profile_path = geometry.shift_path(voxels, bbox["min"]) - return {"bbox": bbox, "profile_path": profile_path} - - def transform_path(self, xyz_path): - """ - Transforms "xyz_path" by converting the xyz coordinates to voxels and - resampling "N_PROFILE_PTS" from voxel coordinates. - - Parameters - ---------- - xyz_path : numpy.ndarray - xyz coordinates that represent an image profile path. - - Returns - ------- - numpy.ndarray - Voxel coordinates that represent an image profile path. - - """ - voxels = np.zeros((len(xyz_path), 3), dtype=int) - for i, xyz in enumerate(xyz_path): - voxels[i] = img_util.to_voxels(xyz, self.downsample_factor) - return voxels + # Initialize image-based attributes + driver = "n5" if ".n5" in img_path else "zarr" + self.img = img_util.open_tensorstore(img_path, driver=driver) + if label_path: + self.labels = img_util.open_tensorstore(label_path) + else: + self.labels = None + # Set chunk shapes + self.img_chunk_shape = self.set_img_chunk_shape() + self.label_chunk_shape = CHUNK_SHAPE -class GraphFeatureGenerator(FeatureGenerator): - """ - Class that generates features vectors that are used by a graph neural - network to classify proposals. + # Validate embedding requirements + if self.use_img_embedding and not label_path: + raise("Must provide labels to generate image embeddings") - """ - def __init__( - self, img, downsample_factor, labels=None, use_img_embedding=False - ): + def set_img_chunk_shape(self): """ - Initializes object that generates features that are used by a gnn. + Sets the shape of chunks extracted from raw image. Parameters ---------- - img : tensorstore.Tensorstore - Image stored in a GCS bucket. - downsample_factor : int, optional - Downsampling factor that accounts for which level in the image - pyramid the voxel coordinates must index into. The default is 0. - labels : tensorstore.TensorStore, optional - Segmentation mask stored in a GCS bucket. The default is None. - use_img_embedding : bool, optional - ... + None Returns ------- - None + list + Shape of chunks extracted from raw image. """ - # Initialize instance attributes - super().__init__(img, downsample_factor) - self.labels = labels - self.use_img_embedding = use_img_embedding + return [s // 2 ** self.downsample_factor for s in CHUNK_SHAPE] def run(self, neurograph, proposals_dict, radius): """ - Generates feature vectors for a graph given a set of proposals. + Generates feature vectors for nodes, edges, and + proposals in a graph. Parameters ---------- @@ -365,11 +207,8 @@ def run_on_proposals(self, neurograph, proposals, radius): # Image features if self.use_img_embedding: - chunks, profiles = self.proposal_chunks_profiles( - neurograph, proposals - ) + chunks = self.proposal_chunks(neurograph, proposals) features.update({"chunks": chunks}) - features.update({"profiles": profiles}) else: profiles = self.proposal_profiles(neurograph, proposals) features.update({"profiles": profiles}) @@ -485,74 +324,198 @@ def node_profiles(self, neurograph, computation_graph): Dictionary that maps a node id to an image profile. """ - # Get profile specifications - specs = dict() - for i in computation_graph.nodes: - if neurograph.is_leaf(i): - profile_path = self.get_leaf_profile_path(neurograph, i) - else: - profile_path = self.get_branching_profile_path(neurograph, i) - specs[i] = self.get_img_specs(profile_path) - - # Generate profiles with ThreadPoolExecutor() as executor: - threads = [] - for i, spec in specs.items(): - threads.append( - executor.submit(img_util.get_profile, self.img, spec, i) + # Assign threads + threads = computation_graph.number_of_nodes() * [None] + for idx, i in enumerate(computation_graph.nodes): + # Get profile path + if neurograph.is_leaf(i): + xyz_path = self.get_leaf_path(neurograph, i) + else: + xyz_path = self.get_branching_path(neurograph, i) + + # Assign + threads[idx] = executor.submit( + img_util.get_profile, self.img, self.get_spec(xyz_path), i ) + # Store results node_profile_features = dict() for thread in as_completed(threads): node_profile_features.update(thread.result()) return node_profile_features - def proposal_chunks_profiles(self, neurograph, img, proposals): + def proposal_profiles(self, neurograph, proposals): """ - Generates an image intensity profile along each proposal. + Generates an image intensity profile along the proposal. Parameters ---------- neurograph : NeuroGraph Graph that "proposals" belong to. - img : tensorstore.TensorStore - Image stored in a GCS bucket. proposals : list[frozenset] List of proposals for which features will be generated. - downsample_factor : int - Downsampling factor that represents which level in the image - pyramid the voxel coordinates must index into. Returns ------- dict - Dictonary such that each pair is the proposal id and profile. + Dictonary such that each pair is the proposal id and image + intensity profile. """ with ThreadPoolExecutor() as executor: # Assign threads - threads = [] - for proposal in proposals: - xyz_1, xyz_2 = neurograph.proposal_xyz(proposal) + threads = len(proposals) * [None] + for i, p in enumerate(proposals): + xyz_1, xyz_2 = neurograph.proposal_xyz(p) xyz_path = geometry.make_line(xyz_1, xyz_2, N_PROFILE_PTS) - specs = self.get_profile_spec(xyz_path) - threads.append( - executor.submit( - img_util.get_chunk_profile, img, specs, proposal) - ) + threads[i] = executor.submit(self.get_profile, xyz_path, p) # Store results profiles = dict() + for thread in as_completed(threads): + profiles.update(thread.result()) + return profiles + + def get_profile(self, xyz_path, profile_id): + """ + Gets the image intensity profile given xyz coordinates that form a + path. + + Parameters + ---------- + xyz_path : numpy.ndarray + xyz coordinates of a profile path. + profile_id : hashable + Identifier of profile. + + Returns + ------- + dict + Dictionary that maps an id (e.g. node, edge, or proposal) to its + profile. + + """ + profile = img_util.read_profile(self.img, self.get_spec(xyz_path)) + profile.extend(list(util.get_avg_std(profile))) + return {profile_id: profile} + + def proposal_chunks(self, neurograph, proposals): + """ + Generates an image intensity profile along each proposal. + + Parameters + ---------- + neurograph : NeuroGraph + Graph that "proposals" belong to. + proposals : list[frozenset] + List of proposals for which features will be generated. + + Returns + ------- + dict + Dictonary such that each pair is the proposal id and profile. + + """ + with ThreadPoolExecutor() as executor: + # Assign threads + threads = neurograph.n_proposals() * [None] + for i, p in enumerate(proposals): + xyz_1, xyz_2 = neurograph.proposal_xyz(p) + xyz_path = np.vstack([xyz_1, xyz_2]) + threads[i] = executor.submit(self.get_chunk, xyz_path, p) + + # Store results chunks = dict() for thread in as_completed(threads): - proposal, chunk, profile = thread.result() - chunks[proposal] = chunk - profiles[proposal] = profile - return chunks, profiles + chunks.update(thread.result()) + return chunks + + def get_spec(self, xyz_path): + """ + Gets image bounding box and voxel coordinates needed to compute an + image intensity profile or extract image chunk for cnn embedding. + + Parameters + ---------- + xyz_path : numpy.ndarray + xyz coordinates of a profile path. + + Returns + ------- + dict + Specifications needed to compute a profile. + + """ + voxels = self.transform_path(xyz_path) + bbox = self.get_bbox(voxels) + profile_path = geometry.shift_path(voxels, bbox["min"]) + return {"bbox": bbox, "profile_path": profile_path} + + def transform_path(self, xyz_path): + """ + Converts "xyz_path" from world to voxel coordinates. + + Parameters + ---------- + xyz_path : numpy.ndarray + xyz coordinates of a profile path. + + Returns + ------- + numpy.ndarray + Voxel coordinates of given path. + + """ + voxels = np.zeros((len(xyz_path), 3), dtype=int) + for i, xyz in enumerate(xyz_path): + voxels[i] = img_util.to_voxels(xyz, self.downsample_factor) + return voxels + + def get_bbox(self, voxels): + center = np.round(np.mean(voxels, axis=0)).astype(int) + bbox = { + "min": [c - s // 2 for c, s in zip(center, self.img_chunk_shape)], + "max": [c + s // 2 for c, s in zip(center, self.img_chunk_shape)], + } + return bbox + + def get_chunk(self, xyz_path, proposal): + # Compute chunk centroids + center = np.round(np.mean(xyz_path, axis=0)).astype(int) + img_center = img_util.to_voxels(center, self.downsample_factor) + label_center = img_util.to_voxels(center) + + # Read chunks + img_chunk = img_util.read_tensorstore( + self.img, img_center, self.img_chunk_shape + ) + label_chunk = img_util.read_tensorstore( + self.labels, label_center, self.label_chunk_shape + ) + + # Process results + img_chunk = img_util.normalize(img_chunk) + label_chunk = self.relabel(label_chunk, xyz_path) + return np.stack([img_chunk, label_chunk], axis=0) + + def relabel(self, label_chunk, xyz_path): + # Initializations + voxels = [img_util.to_voxels(xyz) for xyz in xyz_path] + label_1 = label_chunk[voxels[0]] + label_2 = label_chunk[voxels[1]] + line = geometry.make_line(voxels[0], voxels[1], N_PROFILE_PTS) + assert label_1 > 0 and label_2 > 0, "At least one label in background" + + # Relabel + relabel_chunk = np.zeros(label_chunk.shape) + relabel_chunk[label_chunk == label_1] = 1 + relabel_chunk[label_chunk == label_2] = 2 + return geometry.fill_path(relabel_chunk, line, val=-1) # --- Profile utils --- -def get_leaf_profile_path(neurograph, i): +def get_leaf_path(neurograph, i): """ Gets path that profile will be computed over for the leaf node "i". @@ -574,7 +537,7 @@ def get_leaf_profile_path(neurograph, i): return geometry.truncate_path(xyz_path) -def get_branching_profile_path(neurograph, i): +def get_branching_path(neurograph, i): """ Gets path that profile will be computed over for the branching node "i". @@ -596,7 +559,6 @@ def get_branching_profile_path(neurograph, i): voxles_2 = geometry.truncate_path(neurograph.oriented_edge((i, j_2), i)) return np.vstack([np.flip(voxels_1, axis=0), voxles_2]) - # --- Build feature matrix --- def get_matrix(neurographs, features, sample_ids=None): if sample_ids: @@ -726,16 +688,16 @@ def get_chunk(img, labels, voxel_1, voxel_2, thread_id=None): # Extract chunks midpoint = geometry.get_midpoint(voxel_1, voxel_2).astype(int) if type(img) is ts.TensorStore: - chunk = util.read_tensorstore(img, midpoint, CHUNK_SIZE) - labels_chunk = util.read_tensorstore(labels, midpoint, CHUNK_SIZE) + chunk = util.read_tensorstore(img, midpoint, CHUNK_SHAPE) + labels_chunk = util.read_tensorstore(labels, midpoint, CHUNK_SHAPE) else: - chunk = img_util.read_chunk(img, midpoint, CHUNK_SIZE) - labels_chunk = img_util.read_chunk(labels, midpoint, CHUNK_SIZE) + chunk = img_util.read_chunk(img, midpoint, CHUNK_SHAPE) + labels_chunk = img_util.read_chunk(labels, midpoint, CHUNK_SHAPE) # Coordinate transform chunk = util.normalize(chunk) - patch_voxel_1 = util.voxels_to_patch(voxel_1, midpoint, CHUNK_SIZE) - patch_voxel_2 = util.voxels_to_patch(voxel_2, midpoint, CHUNK_SIZE) + patch_voxel_1 = util.voxels_to_patch(voxel_1, midpoint, CHUNK_SHAPE) + patch_voxel_2 = util.voxels_to_patch(voxel_2, midpoint, CHUNK_SHAPE) # Generate features path = geometry.make_line(patch_voxel_1, patch_voxel_2, N_PROFILE_PTS) diff --git a/src/deep_neurographs/utils/img_util.py b/src/deep_neurographs/utils/img_util.py index 833d73b..8d56fa6 100644 --- a/src/deep_neurographs/utils/img_util.py +++ b/src/deep_neurographs/utils/img_util.py @@ -122,7 +122,7 @@ def read_tensorstore(img, voxel, shape, from_center=True): return read(img, voxel, shape, from_center=from_center).read().result() -def read_tensorstore_with_bbox(img, bbox): +def read_tensorstore_with_bbox(img, bbox, normalize=True): """ Reads a chunk from a subarray that is determined by "bbox". @@ -480,15 +480,6 @@ def get_minimal_bbox(voxels): return bbox -def get_fixed_bbox(voxels, shape): - centroid = np.round(np.mean(voxels, axis=0)).astype(int) - bbox = { - "min": [centroid[i] - shape[i] // 2 for i in range(3)], - "max": [centroid[i] + shape[i] // 2 for i in range(3)], - } - return bbox - - def get_chunk_labels(path, xyz, shape, from_center=True): """ Gets the labels of segments contained in chunk centered at "xyz". diff --git a/src/deep_neurographs/utils/ml_util.py b/src/deep_neurographs/utils/ml_util.py index e8a74f8..431ca49 100644 --- a/src/deep_neurographs/utils/ml_util.py +++ b/src/deep_neurographs/utils/ml_util.py @@ -62,7 +62,7 @@ def save_model(path, model, model_type): # --- dataset utils --- def init_dataset( - neurograph, features, model_type, computation_graph=None, sample_ids=None + neurograph, features, is_gnn=True, computation_graph=None, sample_ids=None ): """ Initializes a dataset given features generated from some set of proposals @@ -89,7 +89,7 @@ def init_dataset( Dataset that stores features. """ - if model_type == "GraphNeuralNet": + if is_gnn: assert computation_graph is not None, "Must input computation graph!" dataset = heterograph_datasets.init( neurograph, features, computation_graph From 0d8c6ef8811ab985857124db29a8c4c1b73264b5 Mon Sep 17 00:00:00 2001 From: Anna Grim <108307071+anna-grim@users.noreply.github.com> Date: Wed, 9 Oct 2024 12:47:16 -0700 Subject: [PATCH 4/6] Feat multimodal gnn (#264) * refactor: feature generation * refactor: simplified feature generation * refactor: chunk extraction is functional * refactor: heterognn simplified --------- Co-authored-by: anna-grim --- src/deep_neurographs/inference.py | 7 +- .../machine_learning/feature_generation.py | 227 +++++++------- .../machine_learning/heterograph_models.py | 283 +++++------------- src/deep_neurographs/neurograph.py | 28 +- src/deep_neurographs/utils/img_util.py | 102 +------ 5 files changed, 204 insertions(+), 443 deletions(-) diff --git a/src/deep_neurographs/inference.py b/src/deep_neurographs/inference.py index 8b124a0..79ce5f8 100644 --- a/src/deep_neurographs/inference.py +++ b/src/deep_neurographs/inference.py @@ -20,12 +20,14 @@ from tqdm import tqdm from deep_neurographs.graph_artifact_removal import remove_doubles +from deep_neurographs.machine_learning.feature_generation import ( + FeatureGenerator, +) from deep_neurographs.utils import gnn_util from deep_neurographs.utils import graph_util as gutil from deep_neurographs.utils import ml_util, util from deep_neurographs.utils.gnn_util import toCPU from deep_neurographs.utils.graph_util import GraphLoader -from deep_neurographs.machine_learning.feature_generation import FeatureGenerator BATCH_SIZE = 2000 CONFIDENCE_THRESHOLD = 0.7 @@ -120,7 +122,8 @@ def __init__( confidence_threshold=self.ml_config.threshold, device=device, downsample_factor=self.ml_config.downsample_factor, - label_path=label_path + label_path=label_path, + use_img_embedding=use_img_embedding, ) # Set output directory diff --git a/src/deep_neurographs/machine_learning/feature_generation.py b/src/deep_neurographs/machine_learning/feature_generation.py index fefa250..c79be30 100644 --- a/src/deep_neurographs/machine_learning/feature_generation.py +++ b/src/deep_neurographs/machine_learning/feature_generation.py @@ -13,22 +13,16 @@ """ -from collections import defaultdict from concurrent.futures import ThreadPoolExecutor, as_completed from copy import deepcopy from random import sample import numpy as np -import tensorstore as ts +from scipy.ndimage import zoom from deep_neurographs import geometry from deep_neurographs.utils import img_util, util -CHUNK_SHAPE = [96, 96, 96] -N_BRANCH_PTS = 50 -N_PROFILE_PTS = 16 -N_SKEL_FEATURES = 22 - class FeatureGenerator: """ @@ -36,6 +30,10 @@ class FeatureGenerator: network to classify proposals. """ + # Class attributes + chunk_shape = [96, 96, 96] + n_profile_points = 16 + def __init__( self, img_path, @@ -48,14 +46,14 @@ def __init__( Parameters ---------- - img : tensorstore.Tensorstore - Raw image assumed to be stored in a GCS bucket. + img_path : str + Path to the raw image assumed to be stored in a GCS bucket. downsample_factor : int Downsampling factor that accounts for which level in the image pyramid the voxel coordinates must index into. - labels : tensorstore.TensorStore, optional - Segmentation assumed to be stored in a GCS bucket. The default is - None. + label_path : str, optional + Path to the segmentation assumed to be stored on a GCS bucket. The + default is None. use_img_embedding : bool, optional ... @@ -74,31 +72,40 @@ def __init__( if label_path: self.labels = img_util.open_tensorstore(label_path) else: - self.labels = None + self.labels = None # Set chunk shapes - self.img_chunk_shape = self.set_img_chunk_shape() - self.label_chunk_shape = CHUNK_SHAPE + self.img_chunk_shape = self.set_chunk_shape(downsample_factor) + self.label_chunk_shape = self.set_chunk_shape(0) # Validate embedding requirements if self.use_img_embedding and not label_path: raise("Must provide labels to generate image embeddings") - def set_img_chunk_shape(self): + @classmethod + def set_chunk_shape(cls, downsample_factor): """ - Sets the shape of chunks extracted from raw image. + Adjusts the chunk shape by downsampling each dimension by a specified + factor. Parameters ---------- - None + downsample_factor : int + The factor by which to downsample each dimension of the current + chunk shape. Returns ------- list - Shape of chunks extracted from raw image. + Adjusted chunk shape with each dimension reduced by the downsample + factor. """ - return [s // 2 ** self.downsample_factor for s in CHUNK_SHAPE] + return [s // 2 ** downsample_factor for s in cls.chunk_shape] + + @classmethod + def get_n_profile_points(cls): + return cls.n_profile_points def run(self, neurograph, proposals_dict, radius): """ @@ -365,11 +372,12 @@ def proposal_profiles(self, neurograph, proposals): """ with ThreadPoolExecutor() as executor: # Assign threads - threads = len(proposals) * [None] - for i, p in enumerate(proposals): + threads = list() + for p in proposals: + n_points = self.get_n_profile_points() xyz_1, xyz_2 = neurograph.proposal_xyz(p) - xyz_path = geometry.make_line(xyz_1, xyz_2, N_PROFILE_PTS) - threads[i] = executor.submit(self.get_profile, xyz_path, p) + xyz_path = geometry.make_line(xyz_1, xyz_2, n_points) + threads.append(executor.submit(self.get_profile, xyz_path, p)) # Store results profiles = dict() @@ -377,32 +385,9 @@ def proposal_profiles(self, neurograph, proposals): profiles.update(thread.result()) return profiles - def get_profile(self, xyz_path, profile_id): - """ - Gets the image intensity profile given xyz coordinates that form a - path. - - Parameters - ---------- - xyz_path : numpy.ndarray - xyz coordinates of a profile path. - profile_id : hashable - Identifier of profile. - - Returns - ------- - dict - Dictionary that maps an id (e.g. node, edge, or proposal) to its - profile. - - """ - profile = img_util.read_profile(self.img, self.get_spec(xyz_path)) - profile.extend(list(util.get_avg_std(profile))) - return {profile_id: profile} - def proposal_chunks(self, neurograph, proposals): """ - Generates an image intensity profile along each proposal. + Generates an image intensity profile along the proposal. Parameters ---------- @@ -414,16 +399,19 @@ def proposal_chunks(self, neurograph, proposals): Returns ------- dict - Dictonary such that each pair is the proposal id and profile. + Dictonary such that each pair is the proposal id and image + intensity profile. """ with ThreadPoolExecutor() as executor: # Assign threads - threads = neurograph.n_proposals() * [None] - for i, p in enumerate(proposals): - xyz_1, xyz_2 = neurograph.proposal_xyz(p) - xyz_path = np.vstack([xyz_1, xyz_2]) - threads[i] = executor.submit(self.get_chunk, xyz_path, p) + threads = list() + for p in proposals: + labels = neurograph.proposal_labels(p) + xyz_path = np.vstack(neurograph.proposal_xyz(p)) + threads.append( + executor.submit(self.get_chunk, labels, xyz_path, p) + ) # Store results chunks = dict() @@ -431,6 +419,29 @@ def proposal_chunks(self, neurograph, proposals): chunks.update(thread.result()) return chunks + def get_profile(self, xyz_path, profile_id): + """ + Gets the image intensity profile given xyz coordinates that form a + path. + + Parameters + ---------- + xyz_path : numpy.ndarray + xyz coordinates of a profile path. + profile_id : hashable + Identifier of profile. + + Returns + ------- + dict + Dictionary that maps an id (e.g. node, edge, or proposal) to its + profile. + + """ + profile = img_util.read_profile(self.img, self.get_spec(xyz_path)) + profile.extend(list(util.get_avg_std(profile))) + return {profile_id: profile} + def get_spec(self, xyz_path): """ Gets image bounding box and voxel coordinates needed to compute an @@ -472,45 +483,51 @@ def transform_path(self, xyz_path): voxels[i] = img_util.to_voxels(xyz, self.downsample_factor) return voxels - def get_bbox(self, voxels): + def get_bbox(self, voxels, is_img=True): center = np.round(np.mean(voxels, axis=0)).astype(int) + shape = self.img_chunk_shape if is_img else self.label_chunk_shape bbox = { - "min": [c - s // 2 for c, s in zip(center, self.img_chunk_shape)], - "max": [c + s // 2 for c, s in zip(center, self.img_chunk_shape)], + "min": [c - s // 2 for c, s in zip(center, shape)], + "max": [c + s // 2 for c, s in zip(center, shape)], } return bbox - def get_chunk(self, xyz_path, proposal): - # Compute chunk centroids - center = np.round(np.mean(xyz_path, axis=0)).astype(int) - img_center = img_util.to_voxels(center, self.downsample_factor) - label_center = img_util.to_voxels(center) + def get_chunk(self, labels, xyz_path, proposal): + # Read image chunk + center = np.mean(xyz_path, axis=0) + img_chunk = self.read_img_chunk(center) + + # Read label chunk + voxels = [img_util.to_voxels(xyz) for xyz in xyz_path] + label_chunk = self.read_label_chunk(voxels, labels) + return {proposal: np.stack([img_chunk, label_chunk], axis=0)} - # Read chunks + def read_img_chunk(self, xyz_centroid): + center = img_util.to_voxels(xyz_centroid, self.downsample_factor) img_chunk = img_util.read_tensorstore( - self.img, img_center, self.img_chunk_shape - ) - label_chunk = img_util.read_tensorstore( - self.labels, label_center, self.label_chunk_shape + self.img, center, self.img_chunk_shape ) + return img_util.normalize(img_chunk) - # Process results - img_chunk = img_util.normalize(img_chunk) - label_chunk = self.relabel(label_chunk, xyz_path) - return np.stack([img_chunk, label_chunk], axis=0) + def read_label_chunk(self, voxels, labels): + bbox = self.get_bbox(voxels, is_img=False) + label_chunk = img_util.read_tensorstore_with_bbox(self.labels, bbox) + voxels = geometry.shift_path(voxels, bbox["min"]) + return self.relabel(label_chunk, voxels, labels) - def relabel(self, label_chunk, xyz_path): + def relabel(self, label_chunk, voxels, labels): # Initializations - voxels = [img_util.to_voxels(xyz) for xyz in xyz_path] - label_1 = label_chunk[voxels[0]] - label_2 = label_chunk[voxels[1]] - line = geometry.make_line(voxels[0], voxels[1], N_PROFILE_PTS) - assert label_1 > 0 and label_2 > 0, "At least one label in background" + n_points = self.get_n_profile_points() + scaling_factor = 2 ** self.downsample_factor + label_chunk = zoom(label_chunk, 1.0 / scaling_factor, order=0) + for i, voxel in enumerate(voxels): + voxels[i] = [v // scaling_factor for v in voxel] - # Relabel + # Main relabel_chunk = np.zeros(label_chunk.shape) - relabel_chunk[label_chunk == label_1] = 1 - relabel_chunk[label_chunk == label_2] = 2 + relabel_chunk[label_chunk == labels[0]] = 1 + relabel_chunk[label_chunk == labels[1]] = 2 + line = geometry.make_line(voxels[0], voxels[-1], n_points) return geometry.fill_path(relabel_chunk, line, val=-1) @@ -559,6 +576,7 @@ def get_branching_path(neurograph, i): voxles_2 = geometry.truncate_path(neurograph.oriented_edge((i, j_2), i)) return np.vstack([np.flip(voxels_1, axis=0), voxles_2]) + # --- Build feature matrix --- def get_matrix(neurographs, features, sample_ids=None): if sample_ids: @@ -619,9 +637,9 @@ def combine_features(features): for edge in features["skel"].keys(): combined[edge] = None for key in features.keys(): - if combined[edge] is None: + if combined[edge] is None and key != "chunks": combined[edge] = deepcopy(features[key][edge]) - else: + elif key != "chunks": combined[edge] = np.concatenate( (combined[edge], features[key][edge]) ) @@ -629,22 +647,6 @@ def combine_features(features): # --- Utils --- -def count_features(): - """ - Counts number of features based on the "model_type". - - Parameters - ---------- - None - - Returns - ------- - int - Number of features. - """ - return N_SKEL_FEATURES + N_PROFILE_PTS + 2 - - def n_node_features(): """ Returns the number of features for different node types. @@ -682,32 +684,3 @@ def n_edge_features(): ("branch", "edge", "proposal"): 3 } return n_edge_features_dict - - -def get_chunk(img, labels, voxel_1, voxel_2, thread_id=None): - # Extract chunks - midpoint = geometry.get_midpoint(voxel_1, voxel_2).astype(int) - if type(img) is ts.TensorStore: - chunk = util.read_tensorstore(img, midpoint, CHUNK_SHAPE) - labels_chunk = util.read_tensorstore(labels, midpoint, CHUNK_SHAPE) - else: - chunk = img_util.read_chunk(img, midpoint, CHUNK_SHAPE) - labels_chunk = img_util.read_chunk(labels, midpoint, CHUNK_SHAPE) - - # Coordinate transform - chunk = util.normalize(chunk) - patch_voxel_1 = util.voxels_to_patch(voxel_1, midpoint, CHUNK_SHAPE) - patch_voxel_2 = util.voxels_to_patch(voxel_2, midpoint, CHUNK_SHAPE) - - # Generate features - path = geometry.make_line(patch_voxel_1, patch_voxel_2, N_PROFILE_PTS) - profile = geometry.get_profile(chunk, path) - labels_chunk[labels_chunk > 0] = 1 - labels_chunk = geometry.fill_path(labels_chunk, path, val=2) - chunk = np.stack([chunk, labels_chunk], axis=0) - - # Output - if thread_id: - return thread_id, chunk, profile - else: - return chunk, profile diff --git a/src/deep_neurographs/machine_learning/heterograph_models.py b/src/deep_neurographs/machine_learning/heterograph_models.py index 8a35c21..0b5139b 100644 --- a/src/deep_neurographs/machine_learning/heterograph_models.py +++ b/src/deep_neurographs/machine_learning/heterograph_models.py @@ -14,114 +14,108 @@ from torch import nn from torch.nn import Dropout, LeakyReLU from torch_geometric.nn import GATv2Conv as GATConv -from torch_geometric.nn import HEATConv, HeteroConv, Linear +from torch_geometric.nn import HeteroConv, Linear -from deep_neurographs import machine_learning as ml -CONV_TYPES = ["GATConv", "GCNConv"] -DROPOUT = 0.3 -HEADS_1 = 1 -HEADS_2 = 1 - - -class HeteroGNN(torch.nn.Module): +class HeteroGNN(torch.nn.Module): # change to HGAT """ - Heterogeneous graph neural network that utilizes edge features. + Heterogeneous graph attention network that classifies proposals. """ def __init__( self, + node_dict, + edge_dict, device=None, - scale_hidden=2, - dropout=DROPOUT, - heads_1=HEADS_1, - heads_2=HEADS_2, + dropout=0.3, + heads_1=1, + heads_2=1, + scale_hidden_dim=2, ): """ - Constructs a heterogeneous graph neural network. + Constructs a heterogeneous graph attention network. + + Parameters + ---------- + ... + + Returns + ------- + None """ super().__init__() + # Instance attributes + self.device = device + self.dropout = dropout + # Feature vector sizes - node_dict = ml.feature_generation_graphs.n_node_features() - edge_dict = ml.feature_generation_graphs.n_edge_features() - hidden = scale_hidden * np.max(list(node_dict.values())) + hidden_dim = scale_hidden_dim * np.max(list(node_dict.values())) + output_dim = heads_1 * heads_2 * hidden_dim # Linear layers - output_dim = heads_1 * heads_2 * hidden - self.input_nodes = nn.ModuleDict() - self.input_edges = dict() - for key, d in node_dict.items(): - self.input_nodes[key] = nn.Linear(d, hidden, device=device) - for key, d in edge_dict.items(): - self.input_edges[key] = nn.Linear(d, hidden, device=device) + self.input_nodes = self.init_linear_layer(hidden_dim, node_dict) + self.input_edges = self.init_linear_layer(hidden_dim, edge_dict) self.output = Linear(output_dim, 1).to(device) - # Convolutional layers - self.conv1 = HeteroConv( - { - ("proposal", "edge", "proposal"): GATConv( - -1, - hidden, - dropout=dropout, - edge_dim=hidden, - heads=heads_1, - ), - ("branch", "edge", "branch"): GATConv( - -1, - hidden, - dropout=dropout, - edge_dim=hidden, - heads=heads_1, - ), - ("branch", "edge", "proposal"): GATConv( - (hidden, hidden), - hidden, - add_self_loops=False, - edge_dim=hidden, - heads=heads_1, - ), - }, - aggr="sum", - ) - edge_dim = hidden - hidden = heads_1 * hidden + # Message passing layers + self.conv1 = self.init_gat_layer(hidden_dim, hidden_dim, heads_1) # change name + edge_dim = hidden_dim + hidden_dim = heads_1 * hidden_dim + + self.conv2 = self.init_gat_layer(hidden_dim, edge_dim, heads_2) # change name + + # Nonlinear activation + self.dropout = Dropout(dropout) # change name + self.leaky_relu = LeakyReLU() - self.conv2 = HeteroConv( + # Initialize weights + self.init_weights() + + # --- Initialize architecture --- + def init_linear_layer(self, hidden_dim, my_dict): + linear_layer = dict() + for key, dim in my_dict.items(): + linear_layer[key] = nn.Linear(dim, hidden_dim, device=self.device) + return linear_layer + + def init_gat_layer(self, hidden_dim, edge_dim, heads): + gat_layers = HeteroConv( { - ("proposal", "edge", "proposal"): GATConv( - -1, - hidden, - dropout=dropout, - edge_dim=edge_dim, - heads=heads_2, + ("proposal", "edge", "proposal"): self.init_gat_layer_same( + hidden_dim, edge_dim, heads ), - ("branch", "edge", "branch"): GATConv( - -1, - hidden, - dropout=dropout, - edge_dim=edge_dim, - heads=heads_2, + ("branch", "edge", "branch"): self.init_gat_layer_same( + hidden_dim, edge_dim, heads ), - ("branch", "edge", "proposal"): GATConv( - (hidden, hidden), - hidden, - add_self_loops=False, - edge_dim=edge_dim, - heads=heads_2, + ("branch", "edge", "proposal"): self.init_gat_layer_mixed( + hidden_dim, edge_dim, heads ), }, aggr="sum", ) - hidden = heads_2 * hidden - - # Nonlinear activation - self.dropout = Dropout(dropout) - self.leaky_relu = LeakyReLU() - - # Initialize weights - self.init_weights() + return gat_layers + + def init_gat_layer_same(self, hidden_dim, edge_dim, heads): + gat_layer = GATConv( + -1, + hidden_dim, + dropout=self.dropout, + edge_dim=edge_dim, + heads=heads, + ) + return gat_layer + + def init_gat_layer_mixed(self, hidden_dim, edge_dim, heads): + gat_layer = GATConv( + (hidden_dim, hidden_dim), + hidden_dim, + add_self_loops=False, + edge_dim=edge_dim, + heads=heads, + ) + return gat_layer def init_weights(self): """ @@ -185,128 +179,5 @@ def forward(self, x_dict, edge_index_dict, edge_attr_dict): return x_dict -class HEATGNN(torch.nn.Module): - """ - Heterogeneous graph neural network. - - """ - - def __init__( - self, - hidden, - metadata, - node_dict, - edge_dict, - dropout=DROPOUT, - heads_1=HEADS_1, - heads_2=HEADS_2, - ): - """ - Constructs a heterogeneous graph neural network. - - """ - super().__init__() - # Linear layers - self.input_nodes = nn.ModuleDict( - {key: nn.Linear(d, hidden) for key, d in node_dict.items()} - ) - self.input_edges = { - key: nn.Linear(d, hidden) for key, d in edge_dict.items() - } - self.output = Linear(heads_1 * heads_2 * hidden) - - # Convolutional layers - self.conv1 = HEATConv( - hidden, - hidden, - heads=heads_1, - dropout=dropout, - metadata=metadata, - ) - """ - x in_channels (int) – Size of each input sample, or -1 to - derive the size from the first input(s) to the forward method. - x out_channels (int) – Size of each output sample. - x num_node_types (int) – The number of node types. - x num_edge_types (int) – The number of edge types. - edge_type_emb_dim (int) – The embedding size of edge types. - edge_dim (int) – Edge feature dimensionality. - edge_attr_emb_dim (int) – The embedding size of edge features. - heads (int, optional) – Number of multi-head-attentions. (default: 1) - """ - hidden = heads_1 * hidden - - self.conv2 = HEATConv( - hidden, - hidden, - heads=heads_2, - dropout=dropout, - metadata=metadata, - ) - hidden = heads_2 * hidden - - # Nonlinear activation - self.dropout = Dropout(dropout) - self.leaky_relu = LeakyReLU() - - # Initialize weights - self.init_weights() - - def init_weights(self): - """ - Initializes linear and convolutional layers. - - Parameters - ---------- - None - - Returns - ------- - None - - """ - layers = [self.input_nodes, self.conv1, self.conv2, self.output] - for layer in layers: - for param in layer.parameters(): - if len(param.shape) > 1: - init.kaiming_normal_(param) - else: - init.zeros_(param) - - def activation(self, x_dict): - """ - Applies nonlinear activation - - Parameters - ---------- - x_dict : dict - Dictionary that maps node/edge types to feature matrices. - - Returns - ------- - dict - Feature matrices with activation applied. - - """ - x_dict = {key: self.leaky_relu(x) for key, x in x_dict.items()} - x_dict = {key: self.dropout(x) for key, x in x_dict.items()} - return x_dict - - def forward(self, x_dict, edge_index_dict, edge_attr_dict, metadata): - # Input - Nodes - x_dict = {key: f(x_dict[key]) for key, f in self.input_nodes.items()} - x_dict = self.activation(x_dict) - - # Input - Edges - edge_attr_dict = { - key: f(edge_attr_dict[key]) for key, f in self.input_edges.items() - } - edge_attr_dict = self.activation(edge_attr_dict) - - # Convolutional layers - x_dict = self.conv1(x_dict, edge_index_dict, metadata) - x_dict = self.conv2(x_dict, edge_index_dict, metadata) - - # Output - x_dict = self.output(x_dict["proposal"]) - return x_dict +class MultiModalHGAT(HeteroGNN): + pass diff --git a/src/deep_neurographs/neurograph.py b/src/deep_neurographs/neurograph.py index d30e241..db5e66d 100644 --- a/src/deep_neurographs/neurograph.py +++ b/src/deep_neurographs/neurograph.py @@ -6,6 +6,8 @@ Implementation of subclass of Networkx.Graph called "FragmentsGraph". +NOTE: SAVE LABEL UPDATES --- THERE IS A BUG IN FEATURE GENERATION + """ import os import zipfile @@ -637,6 +639,24 @@ def proposal_xyz(self, proposal): i, j = tuple(proposal) return np.array([self.nodes[i]["xyz"], self.nodes[j]["xyz"]]) + def proposal_labels(self, proposal): + """ + Gets the xyz coordinates of the nodes that comprise "proposal". + + Parameters + ---------- + proposal : frozenset + Pair of nodes that form a proposal. + + Returns + ------- + numpy.ndarray + xyz coordinates of nodes that comprise "proposal". + + """ + i, j = tuple(proposal) + return [int(self.nodes[i]["swc_id"]), int(self.nodes[j]["swc_id"])] + def proposal_directionals(self, proposal, depth): # Extract branches i, j = tuple(proposal) @@ -915,14 +935,6 @@ def leaf_neighbor(self, i): assert self.is_leaf(i) return list(self.neighbors(i))[0] - def to_patch_coords(self, edge, midpoint, chunk_size): - patch_coords = list() - for xyz in self.edges[edge]["xyz"]: - coord = self.to_voxels(xyz) - local_coord = util.voxels_to_patch(coord, midpoint, chunk_size) - patch_coords.append(local_coord) - return patch_coords - def xyz_to_swc(self, xyz, return_node=False): if tuple(xyz) in self.xyz_to_edge.keys(): edge = self.xyz_to_edge[tuple(xyz)] diff --git a/src/deep_neurographs/utils/img_util.py b/src/deep_neurographs/utils/img_util.py index 8d56fa6..cd0f2be 100644 --- a/src/deep_neurographs/utils/img_util.py +++ b/src/deep_neurographs/utils/img_util.py @@ -10,7 +10,6 @@ from copy import deepcopy -import fastremap import numpy as np import tensorstore as ts from skimage.color import label2rgb @@ -143,6 +142,7 @@ def read_tensorstore_with_bbox(img, bbox, normalize=True): shape = [bbox["max"][i] - bbox["min"][i] for i in range(3)] return read_tensorstore(img, bbox["min"], shape, from_center=False) except Exception: + print(f"Unable to read from image with bbox {bbox}") return np.zeros(shape) @@ -214,7 +214,7 @@ def get_start_end(voxel, shape, from_center=True): end = [voxel[i] + shape[i] // 2 for i in range(3)] else: start = voxel - end = [voxel[i] + shape[i] + 1 for i in range(3)] + end = [voxel[i] + shape[i] for i in range(3)] return start, end @@ -283,30 +283,6 @@ def get_labels_mip(img, axis=0): return (255 * mip).astype(np.uint8) -def get_chunk_profile(img, specs, profile_id): - """ - Gets the image profile for a given proposal. - - Parameters - ---------- - img : tensorstore.TensorStore - Image that profiles are generated from. - specs : dict - Dictionary that contains the image bounding box and coordinates of the - image profile path. - profile_id : frozenset - ... - - Returns - ------- - dict - Dictionary that maps an id (e.g. node, edge, or proposal) to its image - profile. - - """ - pass - - def get_profile(img, spec, profile_id): """ Gets the image profile for a given proposal. @@ -335,53 +311,6 @@ def get_profile(img, spec, profile_id): # --- coordinate conversions --- -def img_to_patch(voxel, patch_centroid, patch_shape): - """ - Converts coordinates from global to local image coordinates. - - Parameters - ---------- - voxel : numpy.ndarray - Voxel coordinate to be converted. - patch_centroid : numpy.ndarray - Centroid of image patch. - patch_shape : numpy.ndarray - Shape of image patch. - - Returns - ------- - tuple - Converted coordinates. - - """ - half_patch_shape = [patch_shape[i] // 2 for i in range(3)] - patch_voxel = voxel - patch_centroid + half_patch_shape - return tuple(patch_voxel.astype(int)) - - -def patch_to_img(voxel, patch_centroid, patch_dims): - """ - Converts coordinates from local to global image coordinates. - - Parameters - ---------- - coord : numpy.ndarray - Coordinates to be converted. - patch_centroid : numpy.ndarray - Centroid of image patch. - patch_shape : numpy.ndarray - Shape of image patch. - - Returns - ------- - tuple - Converted coordinates. - - """ - half_patch_dims = [patch_dims[i] // 2 for i in range(3)] - return np.round(voxel + patch_centroid - half_patch_dims).astype(int) - - def to_world(voxel, shift=[0, 0, 0]): """ Converts coordinates from voxels to world. @@ -480,33 +409,6 @@ def get_minimal_bbox(voxels): return bbox -def get_chunk_labels(path, xyz, shape, from_center=True): - """ - Gets the labels of segments contained in chunk centered at "xyz". - - Parameters - ---------- - path : str - Path to segmentation stored in a GCS bucket. - xyz : numpy.ndarray - Center point of chunk to be read. - shape : tuple - Shape of chunk to be read. - from_center : bool, optional - Indication of whether "xyz" is the center point or upper, left, front - corner of chunk to be read. The default is True. - - Returns - ------- - set - Labels of segments contained in chunk read from GCS bucket. - - """ - img = open_tensorstore(path) - img = read_tensorstore(img, xyz, shape, from_center=from_center) - return set(fastremap.unique(img).astype(int)) - - def find_img_path(bucket_name, img_root, dataset_name): """ Find the path of a specific dataset in a GCS bucket. From 13a1a7f2e96bb79065c0ff3480c23d2960c06018 Mon Sep 17 00:00:00 2001 From: Anna Grim <108307071+anna-grim@users.noreply.github.com> Date: Wed, 9 Oct 2024 18:45:00 -0700 Subject: [PATCH 5/6] Feat multimodal gnn (#265) * refactor: feature generation * refactor: simplified feature generation * refactor: chunk extraction is functional * refactor: heterognn simplified * refactor with issue --------- Co-authored-by: anna-grim --- src/deep_neurographs/config.py | 2 +- .../groundtruth_generation.py | 2 +- src/deep_neurographs/inference.py | 4 +- .../machine_learning/datasets.py | 13 +- .../machine_learning/feature_generation.py | 184 +++++++----------- .../machine_learning/heterograph_datasets.py | 66 ++++--- .../machine_learning/heterograph_models.py | 91 ++++----- src/deep_neurographs/neurograph.py | 26 +-- src/deep_neurographs/train.py | 33 ++-- src/deep_neurographs/utils/graph_util.py | 42 ++-- src/deep_neurographs/utils/img_util.py | 31 +-- src/deep_neurographs/utils/ml_util.py | 11 +- 12 files changed, 223 insertions(+), 282 deletions(-) diff --git a/src/deep_neurographs/config.py b/src/deep_neurographs/config.py index f900247..f043183 100644 --- a/src/deep_neurographs/config.py +++ b/src/deep_neurographs/config.py @@ -93,7 +93,7 @@ class MLConfig: batch_size: int = 2000 downsample_factor: int = 1 high_threshold: float = 0.9 - lr: float = 1e-3 + lr: float = 1e-4 threshold: float = 0.6 model_type: str = "GraphNeuralNet" n_epochs: int = 1000 diff --git a/src/deep_neurographs/groundtruth_generation.py b/src/deep_neurographs/groundtruth_generation.py index a469ad5..5907a43 100644 --- a/src/deep_neurographs/groundtruth_generation.py +++ b/src/deep_neurographs/groundtruth_generation.py @@ -204,7 +204,7 @@ def is_valid(target_graph, pred_graph, kdtree, target_id, edge): def proj_branch(target_graph, pred_graph, kdtree, target_id, i): # Compute projections hits = defaultdict(list) - for branch in pred_graph.get_branches(i): + for branch in pred_graph.branches(i): for xyz in branch: hat_xyz = geometry.kdtree_query(kdtree, xyz) swc_id = target_graph.xyz_to_swc(hat_xyz) diff --git a/src/deep_neurographs/inference.py b/src/deep_neurographs/inference.py index 79ce5f8..41b0fd0 100644 --- a/src/deep_neurographs/inference.py +++ b/src/deep_neurographs/inference.py @@ -547,7 +547,9 @@ def get_batch_dataset(self, neurograph, batch): ... """ + t0 = time() features = self.feature_generator.run(neurograph, batch, self.radius) + print("Feature Generation:", time() - t0) computation_graph = batch["graph"] if type(batch) is dict else None dataset = ml_util.init_dataset( neurograph, @@ -590,7 +592,7 @@ def predict(self, dataset): preds = np.array(self.model.predict_proba(dataset.data.x)[:, 1]) # Reformat prediction - idxs = dataset.idxs_proposals["idx_to_edge"] + idxs = dataset.idxs_proposals["idx_to_id"] return {idxs[i]: p for i, p in enumerate(preds)} diff --git a/src/deep_neurographs/machine_learning/datasets.py b/src/deep_neurographs/machine_learning/datasets.py index 5af5349..978dfda 100644 --- a/src/deep_neurographs/machine_learning/datasets.py +++ b/src/deep_neurographs/machine_learning/datasets.py @@ -82,7 +82,7 @@ def __init__(self, proposals, x_proposals, y_proposals, idxs_proposals): """ # Conversion idxs self.block_to_idxs = idxs_proposals["block_to_idxs"] - self.idxs_proposals = init_idxs(idxs_proposals) + self.idxs_proposals = init_idx_mapping(idxs_proposals) self.proposals = proposals # Features @@ -293,7 +293,7 @@ def reformat(arr): return np.expand_dims(arr, axis=1).astype(np.float32) -def init_idxs(idxs): +def init_idx_mapping(idx_to_id): """ Adds dictionary item called "edge_to_index" which maps a branch/proposal in a neurograph to an idx that represents it's position in the feature @@ -310,7 +310,8 @@ def init_idxs(idxs): Updated dictionary. """ - idxs["edge_to_idx"] = dict() - for idx, edge in idxs["idx_to_edge"].items(): - idxs["edge_to_idx"][edge] = idx - return idxs + idx_mapping = { + "idx_to_id": idx_to_id, + "id_to_idx": {v: k for k, v in idx_to_id.items()} + } + return idx_mapping diff --git a/src/deep_neurographs/machine_learning/feature_generation.py b/src/deep_neurographs/machine_learning/feature_generation.py index c79be30..8d1d4f9 100644 --- a/src/deep_neurographs/machine_learning/feature_generation.py +++ b/src/deep_neurographs/machine_learning/feature_generation.py @@ -31,7 +31,7 @@ class FeatureGenerator: """ # Class attributes - chunk_shape = [96, 96, 96] + patch_shape = [96, 96, 96] n_profile_points = 16 def __init__( @@ -75,15 +75,15 @@ def __init__( self.labels = None # Set chunk shapes - self.img_chunk_shape = self.set_chunk_shape(downsample_factor) - self.label_chunk_shape = self.set_chunk_shape(0) + self.img_patch_shape = self.set_patch_shape(downsample_factor) + self.label_patch_shape = self.set_patch_shape(0) # Validate embedding requirements if self.use_img_embedding and not label_path: raise("Must provide labels to generate image embeddings") @classmethod - def set_chunk_shape(cls, downsample_factor): + def set_patch_shape(cls, downsample_factor): """ Adjusts the chunk shape by downsampling each dimension by a specified factor. @@ -101,7 +101,7 @@ def set_chunk_shape(cls, downsample_factor): factor. """ - return [s // 2 ** downsample_factor for s in cls.chunk_shape] + return [s // 2 ** downsample_factor for s in cls.patch_shape] @classmethod def get_n_profile_points(cls): @@ -139,9 +139,13 @@ def run(self, neurograph, proposals_dict, radius): # Main features = { "nodes": self.run_on_nodes(neurograph, computation_graph), - "edge": self.run_on_edges(neurograph, computation_graph), + "branches": self.run_on_branches(neurograph, computation_graph), "proposals": self.run_on_proposals(neurograph, proposals, radius) } + + # Generate image patches (if applicable) + if self.use_img_embedding: + features["patches"] = self.proposal_patches(neurograph, proposals) return features def run_on_nodes(self, neurograph, computation_graph): @@ -158,14 +162,12 @@ def run_on_nodes(self, neurograph, computation_graph): Returns ------- dict - Dictionary whose keys are feature types (i.e. skeletal) and values - are a dictionary that maps a node id to the corresponding feature - vector. + Dictionary that maps a node id to a feature vector. """ - return {"skel": self.node_skeletal(neurograph, computation_graph)} + return self.node_skeletal(neurograph, computation_graph) - def run_on_edges(self, neurograph, computation_graph): + def run_on_branches(self, neurograph, computation_graph): """ Generates feature vectors for every edge in "computation_graph". @@ -179,12 +181,10 @@ def run_on_edges(self, neurograph, computation_graph): Returns ------- dict - Dictionary whose keys are feature types (i.e. skeletal) and values - are a dictionary that maps an edge id to the corresponding feature - vector. + Dictionary that maps an edge id to a feature vector. """ - return {"skel": self.edge_skeletal(neurograph, computation_graph)} + return self.branch_skeletal(neurograph, computation_graph) def run_on_proposals(self, neurograph, proposals, radius): """ @@ -202,23 +202,14 @@ def run_on_proposals(self, neurograph, proposals, radius): Returns ------- dict - Dictionary whose keys are feature types (i.e. skeletal, profiles, - chunks) and values are a dictionary that maps a proposal id to a - feature vector. + Dictionary that maps a proposal id to a feature vector. """ - # Skeleton features - features = { - "skel": self.proposal_skeletal(neurograph, proposals, radius) - } - - # Image features - if self.use_img_embedding: - chunks = self.proposal_chunks(neurograph, proposals) - features.update({"chunks": chunks}) - else: + features = self.proposal_skeletal(neurograph, proposals, radius) + if not self.use_img_embedding: profiles = self.proposal_profiles(neurograph, proposals) - features.update({"profiles": profiles}) + for p in proposals: + features[p] = np.concatenate((features[p], profiles[p])) return features # -- Skeletal Features -- @@ -251,7 +242,7 @@ def node_skeletal(self, neurograph, computation_graph): ) return node_skeletal_features - def edge_skeletal(self, neurograph, computation_graph): + def branch_skeletal(self, neurograph, computation_graph): """ Generates skeleton-based features for edges in "computation_graph". @@ -268,15 +259,15 @@ def edge_skeletal(self, neurograph, computation_graph): Dictionary that maps an edge id to a feature vector. """ - edge_skeletal_features = dict() + branch_skeletal_features = dict() for edge in neurograph.edges: - edge_skeletal_features[frozenset(edge)] = np.array( + branch_skeletal_features[frozenset(edge)] = np.array( [ np.mean(neurograph.edges[edge]["radius"]), min(neurograph.edges[edge]["length"], 500) / 500, ], ) - return edge_skeletal_features + return branch_skeletal_features def proposal_skeletal(self, neurograph, proposals, radius): """ @@ -385,7 +376,7 @@ def proposal_profiles(self, neurograph, proposals): profiles.update(thread.result()) return profiles - def proposal_chunks(self, neurograph, proposals): + def proposal_patches(self, neurograph, proposals): """ Generates an image intensity profile along the proposal. @@ -410,7 +401,7 @@ def proposal_chunks(self, neurograph, proposals): labels = neurograph.proposal_labels(p) xyz_path = np.vstack(neurograph.proposal_xyz(p)) threads.append( - executor.submit(self.get_chunk, labels, xyz_path, p) + executor.submit(self.get_patch, labels, xyz_path, p) ) # Store results @@ -485,50 +476,50 @@ def transform_path(self, xyz_path): def get_bbox(self, voxels, is_img=True): center = np.round(np.mean(voxels, axis=0)).astype(int) - shape = self.img_chunk_shape if is_img else self.label_chunk_shape + shape = self.img_patch_shape if is_img else self.label_patch_shape bbox = { "min": [c - s // 2 for c, s in zip(center, shape)], "max": [c + s // 2 for c, s in zip(center, shape)], } return bbox - def get_chunk(self, labels, xyz_path, proposal): - # Read image chunk + def get_patch(self, labels, xyz_path, proposal): + # Initializations center = np.mean(xyz_path, axis=0) - img_chunk = self.read_img_chunk(center) - - # Read label chunk voxels = [img_util.to_voxels(xyz) for xyz in xyz_path] - label_chunk = self.read_label_chunk(voxels, labels) - return {proposal: np.stack([img_chunk, label_chunk], axis=0)} - def read_img_chunk(self, xyz_centroid): + # Read patches + img_patch = self.read_img_patch(center) + label_patch = self.read_label_patch(voxels, labels) + return {proposal: np.stack([img_patch, label_patch], axis=0)} + + def read_img_patch(self, xyz_centroid): center = img_util.to_voxels(xyz_centroid, self.downsample_factor) - img_chunk = img_util.read_tensorstore( - self.img, center, self.img_chunk_shape + img_patch = img_util.read_tensorstore( + self.img, center, self.img_patch_shape ) - return img_util.normalize(img_chunk) + return img_util.normalize(img_patch) - def read_label_chunk(self, voxels, labels): + def read_label_patch(self, voxels, labels): bbox = self.get_bbox(voxels, is_img=False) - label_chunk = img_util.read_tensorstore_with_bbox(self.labels, bbox) + label_patch = img_util.read_tensorstore_with_bbox(self.labels, bbox) voxels = geometry.shift_path(voxels, bbox["min"]) - return self.relabel(label_chunk, voxels, labels) + return self.relabel(label_patch, voxels, labels) - def relabel(self, label_chunk, voxels, labels): + def relabel(self, label_patch, voxels, labels): # Initializations n_points = self.get_n_profile_points() scaling_factor = 2 ** self.downsample_factor - label_chunk = zoom(label_chunk, 1.0 / scaling_factor, order=0) + label_patch = zoom(label_patch, 1.0 / scaling_factor, order=0) for i, voxel in enumerate(voxels): voxels[i] = [v // scaling_factor for v in voxel] # Main - relabel_chunk = np.zeros(label_chunk.shape) - relabel_chunk[label_chunk == labels[0]] = 1 - relabel_chunk[label_chunk == labels[1]] = 2 + relabel_patch = np.zeros(label_patch.shape) + relabel_patch[label_patch == labels[0]] = 1 + relabel_patch[label_patch == labels[1]] = 2 line = geometry.make_line(voxels[0], voxels[-1], n_points) - return geometry.fill_path(relabel_chunk, line, val=-1) + return geometry.fill_path(relabel_patch, line, val=-1) # --- Profile utils --- @@ -578,76 +569,37 @@ def get_branching_path(neurograph, i): # --- Build feature matrix --- -def get_matrix(neurographs, features, sample_ids=None): - if sample_ids: - return stack_feature_matrices(neurographs, features, sample_ids) - else: - return get_feature_matrix(neurographs, features) +def get_matrix(features, gt_accepts=set()): + # Initialize matrices + key = sample(list(features.keys()), 1)[0] + X = np.zeros((len(features.keys()), len(features[key]))) + y = np.zeros((len(features.keys()))) + # Populate + idx_to_id = dict() + for i, id_i in enumerate(features): + idx_to_id[i] = id_i + X[i, :] = features[id_i] + y[i] = 1 if id_i in gt_accepts else 0 + return X, y, idx_to_id -def stack_feature_matrices(neurographs, features, blocks): - # Initialize - X = None - y = None - idx_transforms = {"block_to_idxs": dict(), "idx_to_edge": dict()} - # Feature extraction +def stack_matrices(neurographs, features, blocks): + idx_to_id = dict() + X, y = None, None for block_id in blocks: - # Extract feature matrix - idx_shift = 0 if X is None else X.shape[0] - X_i, y_i, idx_transforms_i = get_feature_matrix( - neurographs[block_id], features[block_id], shift=idx_shift - ) - - # Concatenate + X_i, y_i, _ = get_matrix(features[block_id]) if X is None: X = deepcopy(X_i) y = deepcopy(y_i) else: X = np.concatenate((X, X_i), axis=0) y = np.concatenate((y, y_i), axis=0) - - # Update dicts - idx_transforms["block_to_idxs"][block_id] = idx_transforms_i[ - "block_to_idxs" - ] - idx_transforms["idx_to_edge"].update(idx_transforms_i["idx_to_edge"]) - return X, y, idx_transforms - - -def get_feature_matrix(neurograph, features, shift=0): - # Initialize - features = combine_features(features) - key = sample(list(features.keys()), 1)[0] - X = np.zeros((len(features.keys()), len(features[key]))) - y = np.zeros((len(features.keys()))) - idx_transforms = {"block_to_idxs": set(), "idx_to_edge": dict()} - - # Build - for i, edge in enumerate(features.keys()): - X[i, :] = features[edge] - y[i] = 1 if edge in neurograph.target_edges else 0 - idx_transforms["block_to_idxs"].add(i + shift) - idx_transforms["idx_to_edge"][i + shift] = edge - return X, y, idx_transforms - - -def combine_features(features): - combined = dict() - for edge in features["skel"].keys(): - combined[edge] = None - for key in features.keys(): - if combined[edge] is None and key != "chunks": - combined[edge] = deepcopy(features[key][edge]) - elif key != "chunks": - combined[edge] = np.concatenate( - (combined[edge], features[key][edge]) - ) - return combined + return X, y # --- Utils --- -def n_node_features(): +def get_node_dict(use_img_embedding=False): """ Returns the number of features for different node types. @@ -664,7 +616,7 @@ def n_node_features(): return {"branch": 2, "proposal": 34} -def n_edge_features(): +def get_edge_dict(): """ Returns the number of features for different edge types. @@ -678,9 +630,9 @@ def n_edge_features(): A dictionary containing the number of features for each edge type """ - n_edge_features_dict = { + edge_dict = { ("proposal", "edge", "proposal"): 3, ("branch", "edge", "branch"): 3, ("branch", "edge", "proposal"): 3 } - return n_edge_features_dict + return edge_dict diff --git a/src/deep_neurographs/machine_learning/heterograph_datasets.py b/src/deep_neurographs/machine_learning/heterograph_datasets.py index b41afae..7c2beea 100644 --- a/src/deep_neurographs/machine_learning/heterograph_datasets.py +++ b/src/deep_neurographs/machine_learning/heterograph_datasets.py @@ -17,7 +17,8 @@ import torch from torch_geometric.data import HeteroData as HeteroGraphData -from deep_neurographs.machine_learning import datasets, feature_generation +from deep_neurographs.machine_learning import datasets +from deep_neurographs.machine_learning.feature_generation import get_matrix from deep_neurographs.utils import gnn_util DTYPE = torch.float32 @@ -44,17 +45,21 @@ def init(neurograph, features, computation_graph): Custom dataset. """ + # Check for groundtruth + if neurograph.gt_accepts is not None: + gt_accepts = neurograph.gt_accepts + else: + gt_accepts = set() + # Extract features - x_branches, _, idxs_branches = feature_generation.get_matrix( - neurograph, features["edge"] - ) - x_proposals, y_proposals, idxs_proposals = feature_generation.get_matrix( - neurograph, features["proposals"] + x_branches, _, idxs_branches = get_matrix(features["branches"]) + x_proposals, y_proposals, idxs_proposals = get_matrix( + features["proposals"], gt_accepts ) - x_nodes = feature_generation.combine_features(features["nodes"]) + x_nodes = features["nodes"] # Initialize dataset - proposals = list(features["proposals"]["skel"].keys()) + proposals = list(features["proposals"].keys()) heterograph_dataset = HeteroGraphDataset( computation_graph, proposals, @@ -63,7 +68,7 @@ def init(neurograph, features, computation_graph): x_proposals, y_proposals, idxs_branches, - idxs_proposals, + idxs_proposals ) return heterograph_dataset @@ -103,12 +108,9 @@ def __init__( Feature matrix generated from "proposals" in "computation_graph". y_proposals : numpy.ndarray Ground truth of proposals. - idxs_branches : dict - Dictionary that maps edges in "computation_graph" to an index that - represents the edge's position in "x_branches". - idxs_proposals : dict - Dictionary that maps "proposals" to an index that represents the - edge's position in "x_proposals". + idx_to_id : dict + Dictionary that maps an edge id in "computation_graph" to its + index in either x_branches or x_proposals. Returns ------- @@ -116,8 +118,8 @@ def __init__( """ # Conversion idxs - self.idxs_branches = datasets.init_idxs(idxs_branches) - self.idxs_proposals = datasets.init_idxs(idxs_proposals) + self.idxs_branches = datasets.init_idx_mapping(idxs_branches) + self.idxs_proposals = datasets.init_idx_mapping(idxs_proposals) self.computation_graph = computation_graph self.proposals = proposals @@ -214,11 +216,11 @@ def check_missing_edge_types(self): edges = [[n - 1, n - 2], [n - 2, n - 1]] self.data[edge_type].edge_index = gnn_util.toTensor(edges) if node_type == "branch": - self.idxs_branches["idx_to_edge"][n - 1] = e_1 - self.idxs_branches["idx_to_edge"][n - 2] = e_2 + self.idxs_branches["idx_to_id"][n - 1] = e_1 + self.idxs_branches["idx_to_id"][n - 2] = e_2 else: - self.idxs_proposals["idx_to_edge"][n - 1] = e_1 - self.idxs_proposals["idx_to_edge"][n - 2] = e_2 + self.idxs_proposals["idx_to_id"][n - 1] = e_1 + self.idxs_proposals["idx_to_id"][n - 2] = e_2 # -- Getters -- def n_branch_features(self): @@ -289,8 +291,8 @@ def proposal_to_proposal(self): edge_index = [] line_graph = gnn_util.init_line_graph(self.proposals) for e1, e2 in line_graph.edges: - v1 = self.idxs_proposals["edge_to_idx"][frozenset(e1)] - v2 = self.idxs_proposals["edge_to_idx"][frozenset(e2)] + v1 = self.idxs_proposals["id_to_idx"][frozenset(e1)] + v2 = self.idxs_proposals["id_to_idx"][frozenset(e2)] edge_index.extend([[v1, v2], [v2, v1]]) return gnn_util.toTensor(edge_index) @@ -315,8 +317,8 @@ def branch_to_branch(self): e1_edge_bool = frozenset(e1) not in self.proposals e2_edge_bool = frozenset(e2) not in self.proposals if e1_edge_bool and e2_edge_bool: - v1 = self.idxs_branches["edge_to_idx"][frozenset(e1)] - v2 = self.idxs_branches["edge_to_idx"][frozenset(e2)] + v1 = self.idxs_branches["id_to_idx"][frozenset(e1)] + v2 = self.idxs_branches["id_to_idx"][frozenset(e2)] edge_index.extend([[v1, v2], [v2, v1]]) return gnn_util.toTensor(edge_index) @@ -339,14 +341,14 @@ def branch_to_proposal(self): edge_index = [] for p in self.proposals: i, j = tuple(p) - v1 = self.idxs_proposals["edge_to_idx"][frozenset(p)] + v1 = self.idxs_proposals["id_to_idx"][frozenset(p)] for k in self.computation_graph.neighbors(i): if frozenset((i, k)) not in self.proposals: - v2 = self.idxs_branches["edge_to_idx"][frozenset((i, k))] + v2 = self.idxs_branches["id_to_idx"][frozenset((i, k))] edge_index.extend([[v2, v1]]) for k in self.computation_graph.neighbors(j): if frozenset((j, k)) not in self.proposals: - v2 = self.idxs_branches["edge_to_idx"][frozenset((j, k))] + v2 = self.idxs_branches["id_to_idx"][frozenset((j, k))] edge_index.extend([[v2, v1]]) return gnn_util.toTensor(edge_index) @@ -419,8 +421,8 @@ def node_intersection(idx_map, e1, e2): Common node between "e1" and "e2". """ - hat_e1 = idx_map["idx_to_edge"][int(e1)] - hat_e2 = idx_map["idx_to_edge"][int(e2)] + hat_e1 = idx_map["idx_to_id"][int(e1)] + hat_e2 = idx_map["idx_to_id"][int(e2)] node = list(hat_e1.intersection(hat_e2)) assert len(node) == 1, "Node intersection is not unique!" return node[0] @@ -444,8 +446,8 @@ def hetero_node_intersection(idx_map_1, idx_map_2, e1, e2): Common node between "e1" and "e2". """ - hat_e1 = idx_map_1["idx_to_edge"][int(e1)] - hat_e2 = idx_map_2["idx_to_edge"][int(e2)] + hat_e1 = idx_map_1["idx_to_id"][int(e1)] + hat_e2 = idx_map_2["idx_to_id"][int(e2)] node = list(hat_e1.intersection(hat_e2)) assert len(node) == 1, "Node intersection is empty or not unique!" return node[0] diff --git a/src/deep_neurographs/machine_learning/heterograph_models.py b/src/deep_neurographs/machine_learning/heterograph_models.py index 0b5139b..aeab7d5 100644 --- a/src/deep_neurographs/machine_learning/heterograph_models.py +++ b/src/deep_neurographs/machine_learning/heterograph_models.py @@ -22,6 +22,12 @@ class HeteroGNN(torch.nn.Module): # change to HGAT Heterogeneous graph attention network that classifies proposals. """ + # Class attributes + relation_types = [ + ("proposal", "edge", "proposal"), + ("branch", "edge", "proposal"), + ("branch", "edge", "branch"), + ] def __init__( self, @@ -73,6 +79,11 @@ def __init__( # Initialize weights self.init_weights() + # --- Class methods --- + @classmethod + def get_relation_types(cls): + return cls.relation_types + # --- Initialize architecture --- def init_linear_layer(self, hidden_dim, my_dict): linear_layer = dict() @@ -81,23 +92,14 @@ def init_linear_layer(self, hidden_dim, my_dict): return linear_layer def init_gat_layer(self, hidden_dim, edge_dim, heads): - gat_layers = HeteroConv( - { - ("proposal", "edge", "proposal"): self.init_gat_layer_same( - hidden_dim, edge_dim, heads - ), - ("branch", "edge", "branch"): self.init_gat_layer_same( - hidden_dim, edge_dim, heads - ), - ("branch", "edge", "proposal"): self.init_gat_layer_mixed( - hidden_dim, edge_dim, heads - ), - }, - aggr="sum", - ) - return gat_layers - - def init_gat_layer_same(self, hidden_dim, edge_dim, heads): + gat_dict = dict() + for r in self.get_relation_types(): + is_same = True if r[0] == r[2] else False + init_gat = self.init_gat_same if is_same else self.init_gat_mixed + gat_dict[r] = init_gat(hidden_dim, edge_dim, heads) + return HeteroConv(gat_dict, aggr="sum") + + def init_gat_same(self, hidden_dim, edge_dim, heads): gat_layer = GATConv( -1, hidden_dim, @@ -107,7 +109,7 @@ def init_gat_layer_same(self, hidden_dim, edge_dim, heads): ) return gat_layer - def init_gat_layer_mixed(self, hidden_dim, edge_dim, heads): + def init_gat_mixed(self, hidden_dim, edge_dim, heads): gat_layer = GATConv( (hidden_dim, hidden_dim), hidden_dim, @@ -130,32 +132,14 @@ def init_weights(self): None """ - for layer in [self.input_nodes, self.output]: - for param in layer.parameters(): - if len(param.shape) > 1: - init.kaiming_normal_(param) - else: - init.zeros_(param) - - def activation(self, x_dict): - """ - Applies nonlinear activation - - Parameters - ---------- - x_dict : dict - Dictionary that maps node/edge types to feature matrices. - - Returns - ------- - dict - Feature matrices with activation applied. - - """ - x_dict = {key: self.leaky_relu(x) for key, x in x_dict.items()} - x_dict = {key: self.dropout(x) for key, x in x_dict.items()} - return x_dict - + # Output layer + for params in self.output.parameters(): + if len(params.shape) > 1: + init.kaiming_normal_(params) + else: + init.zeros_(params) + + # --- Generate prediction --- def forward(self, x_dict, edge_index_dict, edge_attr_dict): # Input - Nodes x_dict = {key: f(x_dict[key]) for key, f in self.input_nodes.items()} @@ -178,6 +162,25 @@ def forward(self, x_dict, edge_index_dict, edge_attr_dict): x_dict = self.output(x_dict["proposal"]) return x_dict + def activation(self, x_dict): + """ + Applies nonlinear activation + + Parameters + ---------- + x_dict : dict + Dictionary that maps node/edge types to feature matrices. + + Returns + ------- + dict + Feature matrices with activation applied. + + """ + x_dict = {key: self.leaky_relu(x) for key, x in x_dict.items()} + x_dict = {key: self.dropout(x) for key, x in x_dict.items()} + return x_dict + class MultiModalHGAT(HeteroGNN): pass diff --git a/src/deep_neurographs/neurograph.py b/src/deep_neurographs/neurograph.py index db5e66d..dc4bfbe 100644 --- a/src/deep_neurographs/neurograph.py +++ b/src/deep_neurographs/neurograph.py @@ -54,19 +54,15 @@ def __init__(self, img_bbox=None, node_spacing=1): super(NeuroGraph, self).__init__() # General class attributes self.leaf_kdtree = None - + self.node_cnt = 0 self.node_spacing = node_spacing + self.proposals = set() + self.merged_ids = set() self.soma_ids = dict() self.swc_ids = set() self.xyz_to_edge = dict() - # Nodes and Edges - self.junctions = set() - self.proposals = set() - self.target_edges = set() - self.node_cnt = 0 - # Bounding box (if applicable) self.bbox = img_bbox if self.bbox: @@ -133,7 +129,7 @@ def add_component(self, irreducibles): irreducibles : dict Dictionary containing the irreducibles of some connected component being added to "self". This dictionary must contain the keys: - 'leafs', 'junctions', 'edges', and 'swc_id'. + 'leaf', 'branching', 'edge', and 'swc_id'. Returns ------- @@ -144,11 +140,11 @@ def add_component(self, irreducibles): if swc_id not in self.swc_ids: # Nodes self.swc_ids.add(swc_id) - ids = self.__add_nodes(irreducibles, "leafs", dict()) - ids = self.__add_nodes(irreducibles, "junctions", ids) + ids = self.__add_nodes(irreducibles, "leaf", dict()) + ids = self.__add_nodes(irreducibles, "branching", ids) # Edges - for (i, j), attrs in irreducibles["edges"].items(): + for (i, j), attrs in irreducibles["edge"].items(): edge = (ids[i], ids[j]) idxs = util.spaced_idxs(attrs["radius"], self.node_spacing) for key in ["radius", "xyz"]: @@ -166,7 +162,7 @@ def __add_nodes(self, irreducibles, node_type, node_ids): being added to "self". node_type : str Type of node being added to "self". This value must be either - 'leafs' or 'junctions'. + 'leaf' or 'branching'. node_ids : dict Dictionary containing conversion from a node id in "irreducibles" to the corresponding node id in "self". @@ -349,8 +345,12 @@ def generate_proposals( progress_bar=progress_bar, trim_endpoints_bool=trim_endpoints_bool, ) + + # Establish groundtruth if groundtruth_graph: - self.target_edges = init_targets(self, groundtruth_graph) + self.gt_accepts = init_targets(self, groundtruth_graph) + else: + self.gt_accepts = set() def reset_proposals(self): """ diff --git a/src/deep_neurographs/train.py b/src/deep_neurographs/train.py index 99ed883..2b01470 100644 --- a/src/deep_neurographs/train.py +++ b/src/deep_neurographs/train.py @@ -25,15 +25,15 @@ from torch.optim.lr_scheduler import StepLR from torch.utils.tensorboard import SummaryWriter -from deep_neurographs.machine_learning import feature_generation +from deep_neurographs.machine_learning.feature_generation import FeatureGenerator from deep_neurographs.utils import gnn_util, img_util, ml_util, util from deep_neurographs.utils.gnn_util import toCPU from deep_neurographs.utils.graph_util import GraphLoader LR = 1e-3 -N_EPOCHS = 200 -SCHEDULER_GAMMA = 0.5 -SCHEDULER_STEP_SIZE = 1000 +N_EPOCHS = 500 +SCHEDULER_GAMMA = 0.7 +SCHEDULER_STEP_SIZE = 100 WEIGHT_DECAY = 1e-3 @@ -50,6 +50,7 @@ def __init__( model_type, criterion=None, output_dir=None, + use_img_embedding=False, validation_ids=None, save_model_bool=True, ): @@ -58,17 +59,18 @@ def __init__( raise ValueError("Must provide output_dir to save model.") # Set class attributes + self.feature_generators = dict() self.idx_to_ids = list() self.model = model self.model_type = model_type self.output_dir = output_dir self.save_model_bool = save_model_bool + self.use_img_embedding = use_img_embedding self.validation_ids = validation_ids # Set data structures for training examples self.gt_graphs = list() self.pred_graphs = list() - self.imgs = dict() self.train_dataset_list = list() self.validation_dataset_list = list() @@ -142,9 +144,16 @@ def load_example( } ) - def load_img(self, path, sample_id): - if sample_id not in self.imgs: - self.imgs[sample_id] = img_util.open_tensorstore(path, "zarr") + def load_img( + self, sample_id, img_path, downsample_factor, label_path=None + ): + if sample_id not in self.feature_generators: + self.feature_generators[sample_id] = FeatureGenerator( + img_path, + downsample_factor, + label_path=label_path, + use_img_embedding=self.use_img_embedding, + ) # --- main pipeline --- def run(self): @@ -186,8 +195,8 @@ def generate_proposals(self): sample_id = self.idx_to_ids[i]["sample_id"] example_id = self.idx_to_ids[i]["example_id"] n_proposals = self.pred_graphs[i].n_proposals() - n_targets = len(self.pred_graphs[i].target_edges) - p_accepts = round(n_targets / n_proposals, 4) + n_accepts = len(self.pred_graphs[i].gt_accepts) + p_accepts = round(n_accepts / n_proposals, 4) print(f"{sample_id} {example_id} {n_proposals} {p_accepts}") def generate_features(self): @@ -200,10 +209,8 @@ def generate_features(self): # Generate features sample_id = self.idx_to_ids[i]["sample_id"] - features = feature_generation.run( + features = self.feature_generators[sample_id].run( self.pred_graphs[i], - self.imgs[sample_id], - self.model_type, proposals_dict, self.graph_config.search_radius, ) diff --git a/src/deep_neurographs/utils/graph_util.py b/src/deep_neurographs/utils/graph_util.py index ebcb5cb..30efe08 100644 --- a/src/deep_neurographs/utils/graph_util.py +++ b/src/deep_neurographs/utils/graph_util.py @@ -13,10 +13,10 @@ Leaf: a node with degree 1. -Junction: a node with degree > 2. +Branching: a node with degree > 2. Irreducibles: the irreducibles of a graph consists of 1) leaf nodes, -2) junction nodes, and 3) edges connecting (1) and (2). +2) branching nodes, and 3) edges connecting (1) and (2). Branch: a sequence of nodes between two irreducible nodes. @@ -183,7 +183,7 @@ def get_irreducibles( ------- list[dict] List of irreducibles stored in a dictionary where key-values are type - of irreducible (i.e. leaf, junction, or edge) and the corresponding + of irreducible (i.e. leaf, branching, or edge) and the corresponding set of all irreducibles from the graph of that type. """ @@ -249,7 +249,7 @@ def get_component_irreducibles( ------- list List of irreducibles stored in a dictionary where key-values are type - of irreducible (i.e. leaf, junction, or edge) and corresponding set of + of irreducible (i.e. leaf, branching, or edge) and corresponding set of all irreducibles from the graph of that type. """ @@ -304,7 +304,7 @@ def clip_branches(graph, img_bbox): def prune_branches(graph, prune_depth): """ Prunes all short branches from "graph". A short branch is a path between a - leaf and junction node where the path length is less than "prune_depth". + leaf and branching node where the path length is less than "prune_depth". Parameters ---------- @@ -350,7 +350,7 @@ def get_subcomponent_irreducibles(graph, swc_dict, smooth_bool, min_size): """ # Extract nodes - leafs, junctions = get_irreducible_nodes(graph) + leafs, branchings = get_irreducible_nodes(graph) assert len(leafs) > 0, "No leaf nodes!" # Extract edges @@ -372,7 +372,7 @@ def get_subcomponent_irreducibles(graph, swc_dict, smooth_bool, min_size): # Visit j attrs = upd_edge_attrs(swc_dict, attrs, j) - if j in leafs or j in junctions: + if j in leafs or j in branchings: # Check whether to smooth attrs["length"] = cur_length attrs = to_numpy(attrs) @@ -392,9 +392,9 @@ def get_subcomponent_irreducibles(graph, swc_dict, smooth_bool, min_size): # Output if total_length > min_size: irreducibles = { - "leafs": set_node_attrs(swc_dict, leafs), - "junctions": set_node_attrs(swc_dict, junctions), - "edges": edges, + "leaf": set_node_attrs(swc_dict, leafs), + "branching": set_node_attrs(swc_dict, branchings), + "edge": edges, "swc_id": swc_dict["swc_id"], } return irreducibles @@ -402,7 +402,7 @@ def get_subcomponent_irreducibles(graph, swc_dict, smooth_bool, min_size): def get_irreducible_nodes(graph): """ - Gets irreducible nodes (i.e. leafs and junctions) of a graph. + Gets irreducible nodes (i.e. leafs and branchings) of a graph. Parameters ---------- @@ -416,13 +416,13 @@ def get_irreducible_nodes(graph): """ leafs = set() - junctions = set() + branchings = set() for i in graph.nodes: if graph.degree[i] == 1: leafs.add(i) elif graph.degree[i] > 2: - junctions.add(i) - return leafs, junctions + branchings.add(i) + return leafs, branchings # --- Refine graph --- @@ -686,7 +686,7 @@ def set_node_attrs(swc_dict, nodes): return attrs -def upd_node_attrs(swc_dict, leafs, junctions, i): +def upd_node_attrs(swc_dict, leafs, branchings, i): """ Updates node attributes by extracting values from "swc_dict". @@ -694,13 +694,13 @@ def upd_node_attrs(swc_dict, leafs, junctions, i): ---------- swc_dict : dict Contents of an swc file that contains the smoothed xyz coordinates of - corresponding to "leafs" and "junctions". Note xyz coordinates are + corresponding to "leafs" and "branchings". Note xyz coordinates are smoothed during edge extraction. leafs : dict Dictionary where keys are leaf node ids and values are attribute dictionaries. - junctions : dict - Dictionary where keys are junction node ids and values are attribute + branchings : dict + Dictionary where keys are branching node ids and values are attribute dictionaries. i : int Node to be updated. @@ -710,7 +710,7 @@ def upd_node_attrs(swc_dict, leafs, junctions, i): dict Updated dictionary if "i" was contained in "leafs.keys()". dict - Updated dictionary if "i" was contained in "junctions.keys()". + Updated dictionary if "i" was contained in "branchings.keys()". """ j = swc_dict["idx"][i] @@ -718,8 +718,8 @@ def upd_node_attrs(swc_dict, leafs, junctions, i): if i in leafs: leafs[i] = upd_attrs else: - junctions[i] = upd_attrs - return leafs, junctions + branchings[i] = upd_attrs + return leafs, branchings # -- miscellaneous -- diff --git a/src/deep_neurographs/utils/img_util.py b/src/deep_neurographs/utils/img_util.py index cd0f2be..9a6c918 100644 --- a/src/deep_neurographs/utils/img_util.py +++ b/src/deep_neurographs/utils/img_util.py @@ -13,6 +13,7 @@ import numpy as np import tensorstore as ts from skimage.color import label2rgb +from tifffile import imwrite from deep_neurographs.utils import util @@ -164,28 +165,8 @@ def read_profile(img, spec): Image profile. """ - img_chunk = normalize(read_tensorstore_with_bbox(img, spec["bbox"])) - return read_intensities(img_chunk, spec["profile_path"]) - - -def read_intensities(img, voxels): - """ - Reads the image intensities of voxels. - - Parameters - ---------- - img : tensorstore.TensorStore - Image to be read. - voxels : list - Voxels to be read. - - Returns - ------- - list - Image intensities. - - """ - return [img[voxel] for voxel in map(tuple, voxels)] + img_patch = normalize(read_tensorstore_with_bbox(img, spec["bbox"])) + return [img_patch[voxel] for voxel in map(tuple, spec["profile_path"])] def get_start_end(voxel, shape, from_center=True): @@ -384,7 +365,7 @@ def init_bbox(origin, shape): return None -def get_minimal_bbox(voxels): +def get_minimal_bbox(voxels, buffer=0): """ Gets the min and max coordinates of a bounding box that contains "voxels". @@ -403,8 +384,8 @@ def get_minimal_bbox(voxels): """ bbox = { - "min": np.floor(np.min(voxels, axis=0) - 1).astype(int), - "max": np.ceil(np.max(voxels, axis=0) + 1).astype(int), + "min": np.floor(np.min(voxels, axis=0) - buffer).astype(int), + "max": np.ceil(np.max(voxels, axis=0) + buffer).astype(int), } return bbox diff --git a/src/deep_neurographs/utils/ml_util.py b/src/deep_neurographs/utils/ml_util.py index 431ca49..6af655a 100644 --- a/src/deep_neurographs/utils/ml_util.py +++ b/src/deep_neurographs/utils/ml_util.py @@ -61,9 +61,7 @@ def save_model(path, model, model_type): # --- dataset utils --- -def init_dataset( - neurograph, features, is_gnn=True, computation_graph=None, sample_ids=None -): +def init_dataset(neurograph, features, is_gnn=True, computation_graph=None): """ Initializes a dataset given features generated from some set of proposals and neurograph. @@ -79,9 +77,6 @@ def init_dataset( computation_graph : networkx.Graph, optional Computation graph used by gnn if the "model_type" is either "GraphNeuralNet" or "HeteroGraphNeuralNet". The default is None. - sample_ids : list[str], optional - List of ids of samples if features were generated from distinct - predictions. The default is None. Returns ------- @@ -95,9 +90,7 @@ def init_dataset( neurograph, features, computation_graph ) else: - dataset = datasets.init( - neurograph, features, sample_ids=sample_ids - ) + dataset = datasets.init(neurograph, features) return dataset From 8389babec237c6df364831d6f4889cdb74f9daf0 Mon Sep 17 00:00:00 2001 From: Anna Grim <108307071+anna-grim@users.noreply.github.com> Date: Fri, 11 Oct 2024 12:51:45 -0700 Subject: [PATCH 6/6] fixed performance bug (#266) Co-authored-by: anna-grim --- src/deep_neurographs/generate_proposals.py | 2 +- .../groundtruth_generation.py | 2 +- .../machine_learning/heterograph_datasets.py | 2 +- .../machine_learning/heterograph_models.py | 65 ++++++++++--------- src/deep_neurographs/train.py | 2 +- src/deep_neurographs/utils/gnn_util.py | 16 ++--- src/deep_neurographs/utils/graph_util.py | 23 +------ src/deep_neurographs/utils/img_util.py | 2 +- src/deep_neurographs/utils/ml_util.py | 2 +- src/deep_neurographs/utils/util.py | 2 +- 10 files changed, 49 insertions(+), 69 deletions(-) diff --git a/src/deep_neurographs/generate_proposals.py b/src/deep_neurographs/generate_proposals.py index 9267b8a..6c3881f 100644 --- a/src/deep_neurographs/generate_proposals.py +++ b/src/deep_neurographs/generate_proposals.py @@ -484,4 +484,4 @@ def tangent(branch, idx, depth): """ end = min(idx + depth, len(branch)) - return geometry.tangent(branch[idx:end]) + return geometry.tangent(branch[idx:end]) \ No newline at end of file diff --git a/src/deep_neurographs/groundtruth_generation.py b/src/deep_neurographs/groundtruth_generation.py index 5907a43..13818a2 100644 --- a/src/deep_neurographs/groundtruth_generation.py +++ b/src/deep_neurographs/groundtruth_generation.py @@ -300,4 +300,4 @@ def orient_branch(branch_i, branch_j): def upd_dict(node_to_target_id, nodes, target_id): for node in nodes: node_to_target_id[node] = target_id - return node_to_target_id + return node_to_target_id \ No newline at end of file diff --git a/src/deep_neurographs/machine_learning/heterograph_datasets.py b/src/deep_neurographs/machine_learning/heterograph_datasets.py index 7c2beea..a45d913 100644 --- a/src/deep_neurographs/machine_learning/heterograph_datasets.py +++ b/src/deep_neurographs/machine_learning/heterograph_datasets.py @@ -469,4 +469,4 @@ def n_edge_features(x): """ key = sample(list(x.keys()), 1)[0] - return x[key].shape[0] + return x[key].shape[0] \ No newline at end of file diff --git a/src/deep_neurographs/machine_learning/heterograph_models.py b/src/deep_neurographs/machine_learning/heterograph_models.py index aeab7d5..0eeda2f 100644 --- a/src/deep_neurographs/machine_learning/heterograph_models.py +++ b/src/deep_neurographs/machine_learning/heterograph_models.py @@ -57,12 +57,16 @@ def __init__( self.dropout = dropout # Feature vector sizes - hidden_dim = scale_hidden_dim * np.max(list(node_dict.values())) + hidden_dim = scale_hidden_dim* np.max(list(node_dict.values())) output_dim = heads_1 * heads_2 * hidden_dim # Linear layers - self.input_nodes = self.init_linear_layer(hidden_dim, node_dict) - self.input_edges = self.init_linear_layer(hidden_dim, edge_dict) + self.input_nodes = nn.ModuleDict() + self.input_edges = dict() + for key, d in node_dict.items(): + self.input_nodes[key] = nn.Linear(d, hidden_dim, device=device) + for key, d in edge_dict.items(): + self.input_edges[key] = nn.Linear(d, hidden_dim, device=device) self.output = Linear(output_dim, 1).to(device) # Message passing layers @@ -84,7 +88,7 @@ def __init__( def get_relation_types(cls): return cls.relation_types - # --- Initialize architecture --- + # --- Architecture --- def init_linear_layer(self, hidden_dim, my_dict): linear_layer = dict() for key, dim in my_dict.items(): @@ -132,14 +136,32 @@ def init_weights(self): None """ - # Output layer - for params in self.output.parameters(): - if len(params.shape) > 1: - init.kaiming_normal_(params) - else: - init.zeros_(params) - - # --- Generate prediction --- + for layer in [self.output, self.input_nodes]: + for param in layer.parameters(): + if len(param.shape) > 1: + init.kaiming_normal_(param) + else: + init.zeros_(param) + + def activation(self, x_dict): + """ + Applies nonlinear activation + + Parameters + ---------- + x_dict : dict + Dictionary that maps node/edge types to feature matrices. + + Returns + ------- + dict + Feature matrices with activation applied. + + """ + x_dict = {key: self.leaky_relu(x) for key, x in x_dict.items()} + x_dict = {key: self.dropout(x) for key, x in x_dict.items()} + return x_dict + def forward(self, x_dict, edge_index_dict, edge_attr_dict): # Input - Nodes x_dict = {key: f(x_dict[key]) for key, f in self.input_nodes.items()} @@ -162,25 +184,6 @@ def forward(self, x_dict, edge_index_dict, edge_attr_dict): x_dict = self.output(x_dict["proposal"]) return x_dict - def activation(self, x_dict): - """ - Applies nonlinear activation - - Parameters - ---------- - x_dict : dict - Dictionary that maps node/edge types to feature matrices. - - Returns - ------- - dict - Feature matrices with activation applied. - - """ - x_dict = {key: self.leaky_relu(x) for key, x in x_dict.items()} - x_dict = {key: self.dropout(x) for key, x in x_dict.items()} - return x_dict - class MultiModalHGAT(HeteroGNN): pass diff --git a/src/deep_neurographs/train.py b/src/deep_neurographs/train.py index 2b01470..24a223b 100644 --- a/src/deep_neurographs/train.py +++ b/src/deep_neurographs/train.py @@ -470,4 +470,4 @@ def get_predictions(hat_y, threshold=0.5): Binary predictions based on the given threshold. """ - return (ml_util.sigmoid(np.array(hat_y)) > threshold).tolist() + return (ml_util.sigmoid(np.array(hat_y)) > threshold).tolist() \ No newline at end of file diff --git a/src/deep_neurographs/utils/gnn_util.py b/src/deep_neurographs/utils/gnn_util.py index 9d89a84..9cb54a4 100644 --- a/src/deep_neurographs/utils/gnn_util.py +++ b/src/deep_neurographs/utils/gnn_util.py @@ -22,20 +22,16 @@ # --- Tensor Operations --- def get_inputs(data, device=None): - # Extract inputs x = data.x_dict edge_index = data.edge_index_dict edge_attr = data.edge_attr_dict + if device and torch.cuda.is_available(): + return toGPU(x), toGPU(edge_index), toGPU(edge_attr) + else: + return x, edge_index, edge_attr - # Move to gpu (if applicable) - if "cuda" in device and torch.cuda.is_available(): - x = toGPU(x, device) - edge_index = toGPU(edge_index, device) - edge_attr = toGPU(edge_attr, device) - return x, edge_index, edge_attr - -def toGPU(tensor_dict, device): +def toGPU(tensor_dict): """ Moves dictionary of tensors from CPU to GPU. @@ -301,4 +297,4 @@ def init_line_graph(edges): """ graph = nx.Graph() graph.add_edges_from(edges) - return nx.line_graph(graph) + return nx.line_graph(graph) \ No newline at end of file diff --git a/src/deep_neurographs/utils/graph_util.py b/src/deep_neurographs/utils/graph_util.py index 30efe08..dd8a3dd 100644 --- a/src/deep_neurographs/utils/graph_util.py +++ b/src/deep_neurographs/utils/graph_util.py @@ -121,7 +121,7 @@ def run( """ # Load fragments and extract irreducibles - self.init_img_bbox(img_patch_origin, img_patch_shape) + self.img_bbox = img_util.init_bbox(img_patch_origin, img_patch_shape) swc_dicts = self.reader.load(fragments_pointer) irreducibles = get_irreducibles( swc_dicts, @@ -139,25 +139,6 @@ def run( neurograph.add_component(irreducible_set) return neurograph - def init_img_bbox(self, img_patch_origin, img_patch_shape): - """ - Sets the bounding box of an image patch as a class attriubte. - - Parameters - ---------- - img_patch_origin : tuple[int] - Origin of bounding box which is assumed to be top, front, left - corner. - img_patch_shape : tuple[int] - Shape of bounding box. - - Returns - ------- - None - - """ - self.img_bbox = img_util.init_bbox(img_patch_origin, img_patch_shape) - # --- Graph structure extraction --- def get_irreducibles( @@ -877,4 +858,4 @@ def largest_components(neurograph, k): node_ids.pop(-1) break i += 1 - return node_ids + return node_ids \ No newline at end of file diff --git a/src/deep_neurographs/utils/img_util.py b/src/deep_neurographs/utils/img_util.py index 9a6c918..ebabe8c 100644 --- a/src/deep_neurographs/utils/img_util.py +++ b/src/deep_neurographs/utils/img_util.py @@ -412,4 +412,4 @@ def find_img_path(bucket_name, img_root, dataset_name): for subdir in util.list_gcs_subdirectories(bucket_name, img_root): if dataset_name in subdir: return subdir + "whole-brain/fused.zarr/" - raise f"Dataset not found in {bucket_name} - {img_root}" + raise f"Dataset not found in {bucket_name} - {img_root}" \ No newline at end of file diff --git a/src/deep_neurographs/utils/ml_util.py b/src/deep_neurographs/utils/ml_util.py index 6af655a..bb2e236 100644 --- a/src/deep_neurographs/utils/ml_util.py +++ b/src/deep_neurographs/utils/ml_util.py @@ -140,4 +140,4 @@ def get_kfolds(filenames, k): folds.append(samples_i) if n_samples > len(samples): break - return folds + return folds \ No newline at end of file diff --git a/src/deep_neurographs/utils/util.py b/src/deep_neurographs/utils/util.py index 023b6a6..a82cb28 100644 --- a/src/deep_neurographs/utils/util.py +++ b/src/deep_neurographs/utils/util.py @@ -663,4 +663,4 @@ def spaced_idxs(container, k): idxs = np.arange(0, len(container) + k, k)[:-1] if len(container) % 2 == 0: idxs = np.append(idxs, len(container) - 1) - return idxs + return idxs \ No newline at end of file