From e842b59c31fe08a5708096e42bface55d705bae9 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Wed, 25 Sep 2024 00:55:50 +0000 Subject: [PATCH 01/13] minor upds --- pyproject.toml | 1 + src/deep_neurographs/delete_merges_gt.py | 1 + src/deep_neurographs/densegraph.py | 12 +++++++----- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 004c70c..76b984c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ dependencies = [ 'torcheval', 'torchio', 'torch_geometric', + 'tqdm', 'zarr', ] diff --git a/src/deep_neurographs/delete_merges_gt.py b/src/deep_neurographs/delete_merges_gt.py index 936052c..118a34a 100644 --- a/src/deep_neurographs/delete_merges_gt.py +++ b/src/deep_neurographs/delete_merges_gt.py @@ -87,6 +87,7 @@ def delete_merges( # Finish if len(delete_nodes) > 0: graph.remove_nodes_from(delete_nodes) + print("Merge Detected:", swc_id) print("# Nodes Deleted:", len(delete_nodes)) print("") pred_densegraph.graphs[swc_id] = graph diff --git a/src/deep_neurographs/densegraph.py b/src/deep_neurographs/densegraph.py index 75f100f..273a7ef 100644 --- a/src/deep_neurographs/densegraph.py +++ b/src/deep_neurographs/densegraph.py @@ -15,7 +15,7 @@ from scipy.spatial import KDTree from deep_neurographs.utils import graph_util as gutil -from deep_neurographs.utils import swc_util, util +from deep_neurographs.utils import img_util, swc_util, util DELETION_RADIUS = 10 @@ -43,7 +43,7 @@ def __init__(self, swc_paths, img_patch_origin=None, img_patch_shape=None): None """ - self.bbox = util.get_img_bbox(img_patch_origin, img_patch_shape) + self.bbox = img_util.get_bbox(img_patch_origin, img_patch_shape) self.init_graphs(swc_paths) self.init_kdtree() @@ -65,7 +65,7 @@ def init_graphs(self, paths): """ self.graphs = dict() self.xyz_to_swc = dict() - swc_dicts, _ = swc_util.process_local_paths(paths) + swc_dicts = swc_util.Reader().load(paths) for i, swc_dict in enumerate(swc_dicts): # Build graph swc_id = swc_dict["swc_id"] @@ -179,11 +179,13 @@ def make_entries(self, graph, component): node_to_idx[i] = 1 x, y, z = tuple(graph.nodes[i]["xyz"]) r = graph.nodes[i]["radius"] - entry_list.append([1, 2, x, y, z, r, -1]) + entry_list.append(f"1 2 {x} {y} {z} {r} -1") # Create entry node_to_idx[j] = len(entry_list) + 1 x, y, z = tuple(graph.nodes[j]["xyz"]) r = graph.nodes[j]["radius"] - entry_list.append([node_to_idx[j], 2, x, y, z, r, node_to_idx[i]]) + entry_list.append( + f"{node_to_idx[j]} 2 {x} {y} {z} {r} {node_to_idx[i]}" + ) return entry_list From 1f36be9b5bc4a7e35f0699a5f949e8ddd8754b9e Mon Sep 17 00:00:00 2001 From: anna-grim Date: Wed, 25 Sep 2024 21:56:45 +0000 Subject: [PATCH 02/13] refactor: training pipeline --- src/deep_neurographs/densegraph.py | 2 +- src/deep_neurographs/intake.py | 138 ---------------- ...aphtrace_pipeline.py => run_graphtrace.py} | 4 +- src/deep_neurographs/train_graphtrace.py | 81 ++++++++++ src/deep_neurographs/utils/gnn_util.py | 1 - src/deep_neurographs/utils/graph_util.py | 148 +++++++++++++++++- src/deep_neurographs/utils/swc_util.py | 8 +- src/deep_neurographs/utils/util.py | 59 +++---- 8 files changed, 261 insertions(+), 180 deletions(-) delete mode 100644 src/deep_neurographs/intake.py rename src/deep_neurographs/{run_graphtrace_pipeline.py => run_graphtrace.py} (99%) diff --git a/src/deep_neurographs/densegraph.py b/src/deep_neurographs/densegraph.py index 273a7ef..c3f6537 100644 --- a/src/deep_neurographs/densegraph.py +++ b/src/deep_neurographs/densegraph.py @@ -15,7 +15,7 @@ from scipy.spatial import KDTree from deep_neurographs.utils import graph_util as gutil -from deep_neurographs.utils import img_util, swc_util, util +from deep_neurographs.utils import img_util, swc_util DELETION_RADIUS = 10 diff --git a/src/deep_neurographs/intake.py b/src/deep_neurographs/intake.py deleted file mode 100644 index ea4d5bc..0000000 --- a/src/deep_neurographs/intake.py +++ /dev/null @@ -1,138 +0,0 @@ -""" -Created on Sat July 15 9:00:00 2023 - -@author: Anna Grim -@email: anna.grim@alleninstitute.org - -Builds a neurograph for neuron reconstruction. - -""" - -from deep_neurographs.neurograph import NeuroGraph -from deep_neurographs.utils import graph_util as gutil -from deep_neurographs.utils import img_util, swc_util - -MIN_SIZE = 30 -NODE_SPACING = 1 -SMOOTH_BOOL = True -PRUNE_DEPTH = 16 -TRIM_DEPTH = 0 - - -class GraphBuilder: - """ - Class that is used to build an instance of FragmentsGraph. - - """ - - def __init__( - self, - anisotropy=[1.0, 1.0, 1.0], - min_size=MIN_SIZE, - node_spacing=NODE_SPACING, - prune_depth=PRUNE_DEPTH, - smooth_bool=SMOOTH_BOOL, - trim_depth=TRIM_DEPTH, - ): - """ - Builds a FragmentsGraph by reading swc files stored either on the - cloud or local machine, then extracting the irreducible components. - - Parameters - ---------- - anisotropy : list[float], optional - Scaling factors applied to xyz coordinates to account for - anisotropy of microscope. The default is [1.0, 1.0, 1.0]. - min_size : float, optional - Minimum path length of swc files which are stored as connected - components in the FragmentsGraph. The default is 30ums. - node_spacing : int, optional - Spacing (in microns) between nodes. The default is the global - variable "NODE_SPACING". - prune_depth : int, optional - Branches less than "prune_depth" microns are pruned if "prune" is - True. The default is the global variable "PRUNE_DEPTH". - smooth_bool : bool, optional - Indication of whether to smooth branches from swc files. The - default is the global variable "SMOOTH". - trim_depth : float, optional - Maximum path length (in microns) to trim from "branch". The default - is the global variable "TRIM_DEPTH". - - Returns - ------- - FragmentsGraph - FragmentsGraph generated from swc files. - - """ - self.anisotropy = anisotropy - self.min_size = min_size - self.node_spacing = node_spacing - self.prune_depth = prune_depth - self.smooth_bool = smooth_bool - self.trim_depth = trim_depth - - self.reader = swc_util.Reader(anisotropy, min_size) - - def run( - self, fragments_pointer, img_patch_origin=None, img_patch_shape=None - ): - """ - Builds a FragmentsGraph by reading swc files stored either on the - cloud or local machine, then extracting the irreducible components. - - Parameters - ---------- - fragments_pointer : dict, list, str - Pointer to swc files used to build an instance of FragmentsGraph, - see "swc_util.Reader" for further documentation. - img_patch_origin : list[int], optional - An xyz coordinate which is the upper, left, front corner of the - image patch that contains the swc files. The default is None. - img_patch_shape : list[int], optional - Shape of the image patch which contains the swc files. The default - is None. - - Returns - ------- - FragmentsGraph - FragmentsGraph generated from swc files. - - """ - # Load fragments and extract irreducibles - self.set_img_bbox(img_patch_origin, img_patch_shape) - swc_dicts = self.reader.load(fragments_pointer) - irreducibles = gutil.get_irreducibles( - swc_dicts, - self.min_size, - self.img_bbox, - self.prune_depth, - self.smooth_bool, - self.trim_depth, - ) - - # Build FragmentsGraph - neurograph = NeuroGraph(node_spacing=self.node_spacing) - while len(irreducibles): - irreducible_set = irreducibles.pop() - neurograph.add_component(irreducible_set) - return neurograph - - def set_img_bbox(self, img_patch_origin, img_patch_shape): - """ - Sets the bounding box of an image patch as a class attriubte. - - Parameters - ---------- - img_patch_origin : tuple[int] - Origin of bounding box which is assumed to be top, front, left - corner. - img_patch_shape : tuple[int] - Shape of bounding box. - - Returns - ------- - None - - """ - self.img_bbox = img_util.get_bbox(img_patch_origin, img_patch_shape) diff --git a/src/deep_neurographs/run_graphtrace_pipeline.py b/src/deep_neurographs/run_graphtrace.py similarity index 99% rename from src/deep_neurographs/run_graphtrace_pipeline.py rename to src/deep_neurographs/run_graphtrace.py index 3e5e53d..4b4db4b 100644 --- a/src/deep_neurographs/run_graphtrace_pipeline.py +++ b/src/deep_neurographs/run_graphtrace.py @@ -33,9 +33,9 @@ import networkx as nx from deep_neurographs.graph_artifact_removal import remove_doubles -from deep_neurographs.intake import GraphBuilder from deep_neurographs.machine_learning.inference import InferenceEngine from deep_neurographs.utils import util +from deep_neurographs.utils.graph_util import GraphLoader class GraphTracePipeline: @@ -163,7 +163,7 @@ def build_graph(self, fragments_pointer): t0 = time() # Initialize Graph - graph_builder = GraphBuilder( + graph_builder = GraphLoader( anisotropy=self.graph_config.anisotropy, min_size=self.graph_config.min_size, node_spacing=self.graph_config.node_spacing, diff --git a/src/deep_neurographs/train_graphtrace.py b/src/deep_neurographs/train_graphtrace.py index 128160f..663ce35 100644 --- a/src/deep_neurographs/train_graphtrace.py +++ b/src/deep_neurographs/train_graphtrace.py @@ -8,3 +8,84 @@ This script trains the GraphTrace inference pipeline. """ + +from deep_neurographs.utils import util +from deep_neurographs.utils.graph_util import GraphLoader + + +class Trainer: + """ + Class that is used to train a machine learning model that classifies + proposals. + + """ + def __init__( + self, config, model_type, output_dir=None, save_model_bool=True + ): + # Check for parameter errors + if save_model_bool and not output_dir: + raise ValueError("Must provide output_dir to save model.") + + # Set class attributes + self.idx_to_ids = list() + self.model_type = model_type + self.output_dir = output_dir + self.save_model_bool = save_model_bool + + # Set data structures for training examples + self.gt_graphs = list() + self.pred_graphs = list() + self.imgs = dict() + + # Extract config settings + self.graph_config = config.graph_config + self.ml_config = config.ml_config + self.graph_loader = GraphLoader( + min_size=self.graph_config.min_size, + progress_bar=False, + ) + + def n_examples(self): + return len(self.gt_graphs) + + def load_example( + self, + gt_pointer, + pred_pointer, + dataset_name, + example_id=None, + metadata_path=None, + ): + # Read metadata + if metadata_path: + origin, shape = util.read_metadata(metadata_path) + else: + origin, shape = None, None + + # Load graphs + self.gt_graphs.append(self.graph_loader.run(gt_pointer)) + self.pred_graphs.append( + self.graph_loader.run( + pred_pointer, + img_patch_origin=origin, + img_patch_shape=shape, + ) + ) + + # Set example ids + self.idx_to_ids.append( + {"dataset_name": dataset_name, "example_id": example_id} + ) + + def load_img(self, img_path, dataset_name): + pass + + def run(self): + pass + + def generate_features(self): + # check that every example has an image that was loaded! + pass + + def evaluate(self): + pass diff --git a/src/deep_neurographs/utils/gnn_util.py b/src/deep_neurographs/utils/gnn_util.py index 8b41092..2d8b5d9 100644 --- a/src/deep_neurographs/utils/gnn_util.py +++ b/src/deep_neurographs/utils/gnn_util.py @@ -108,7 +108,6 @@ def get_batch(graph, proposals, batch_size): """ batch = reset_batch() - cur_proposal_cnt = 0 visited = set() while len(proposals) > 0 and len(batch["proposals"]) < batch_size: root = tuple(util.sample_once(proposals)) diff --git a/src/deep_neurographs/utils/graph_util.py b/src/deep_neurographs/utils/graph_util.py index 0bc17fa..b01d9ce 100644 --- a/src/deep_neurographs/utils/graph_util.py +++ b/src/deep_neurographs/utils/graph_util.py @@ -5,7 +5,7 @@ @email: anna.grim@alleninstitute.org -Routines that extract the irreducible components of a graph. +Routines for loading fragments and building a neurograph. Terminology @@ -30,13 +30,147 @@ from tqdm import tqdm from deep_neurographs import geometry +from deep_neurographs.neurograph import NeuroGraph from deep_neurographs.utils import img_util, swc_util, util - +MIN_SIZE = 30 +NODE_SPACING = 1 +SMOOTH_BOOL = True +PRUNE_DEPTH = 16 +TRIM_DEPTH = 0 + + +class GraphLoader: + """ + Class that is used to build an instance of FragmentsGraph. + + """ + + def __init__( + self, + anisotropy=[1.0, 1.0, 1.0], + min_size=MIN_SIZE, + node_spacing=NODE_SPACING, + progress_bar=False, + prune_depth=PRUNE_DEPTH, + smooth_bool=SMOOTH_BOOL, + trim_depth=TRIM_DEPTH, + ): + """ + Builds a FragmentsGraph by reading swc files stored either on the + cloud or local machine, then extracting the irreducible components. + + Parameters + ---------- + anisotropy : list[float], optional + Scaling factors applied to xyz coordinates to account for + anisotropy of microscope. The default is [1.0, 1.0, 1.0]. + min_size : float, optional + Minimum path length of swc files which are stored as connected + components in the FragmentsGraph. The default is 30ums. + node_spacing : int, optional + Spacing (in microns) between nodes. The default is the global + variable "NODE_SPACING". + progress_bar : bool, optional + Indication of whether to print out a progress bar while building + graph. The default is True. + prune_depth : int, optional + Branches less than "prune_depth" microns are pruned if "prune" is + True. The default is the global variable "PRUNE_DEPTH". + smooth_bool : bool, optional + Indication of whether to smooth branches from swc files. The + default is the global variable "SMOOTH". + trim_depth : float, optional + Maximum path length (in microns) to trim from "branch". The default + is the global variable "TRIM_DEPTH". + + Returns + ------- + FragmentsGraph + FragmentsGraph generated from swc files. + + """ + self.anisotropy = anisotropy + self.min_size = min_size + self.node_spacing = node_spacing + self.progress_bar = progress_bar + self.prune_depth = prune_depth + self.smooth_bool = smooth_bool + self.trim_depth = trim_depth + + self.reader = swc_util.Reader(anisotropy, min_size) + + def run( + self, fragments_pointer, img_patch_origin=None, img_patch_shape=None + ): + """ + Builds a FragmentsGraph by reading swc files stored either on the + cloud or local machine, then extracting the irreducible components. + + Parameters + ---------- + fragments_pointer : dict, list, str + Pointer to swc files used to build an instance of FragmentsGraph, + see "swc_util.Reader" for further documentation. + img_patch_origin : list[int], optional + An xyz coordinate which is the upper, left, front corner of the + image patch that contains the swc files. The default is None. + img_patch_shape : list[int], optional + Shape of the image patch which contains the swc files. The default + is None. + + Returns + ------- + FragmentsGraph + FragmentsGraph generated from swc files. + + """ + # Load fragments and extract irreducibles + self.set_img_bbox(img_patch_origin, img_patch_shape) + swc_dicts = self.reader.load(fragments_pointer) + irreducibles = get_irreducibles( + swc_dicts, + self.min_size, + self.img_bbox, + self.progress_bar, + self.prune_depth, + self.smooth_bool, + self.trim_depth, + ) + + # Build FragmentsGraph + neurograph = NeuroGraph(node_spacing=self.node_spacing) + while len(irreducibles): + irreducible_set = irreducibles.pop() + neurograph.add_component(irreducible_set) + return neurograph + + def set_img_bbox(self, img_patch_origin, img_patch_shape): + """ + Sets the bounding box of an image patch as a class attriubte. + + Parameters + ---------- + img_patch_origin : tuple[int] + Origin of bounding box which is assumed to be top, front, left + corner. + img_patch_shape : tuple[int] + Shape of bounding box. + + Returns + ------- + None + + """ + self.img_bbox = img_util.get_bbox(img_patch_origin, img_patch_shape) + + +# --- Graph structure extraction --- def get_irreducibles( swc_dicts, min_size, img_bbox=None, + progress_bar=True, prune_depth=16.0, smooth_bool=True, trim_depth=0.0, @@ -78,11 +212,15 @@ def get_irreducibles( i += 1 # Store results - with tqdm(total=len(processes), desc="Extract Graphs") as pbar: - irreducibles = list() + irreducibles = list() + if progress_bar: + with tqdm(total=len(processes), desc="Extract Graphs") as pbar: + for process in as_completed(processes): + irreducibles.extend(process.result()) + pbar.update(1) + else: for process in as_completed(processes): irreducibles.extend(process.result()) - pbar.update(1) return irreducibles diff --git a/src/deep_neurographs/utils/swc_util.py b/src/deep_neurographs/utils/swc_util.py index 38a2986..e279700 100644 --- a/src/deep_neurographs/utils/swc_util.py +++ b/src/deep_neurographs/utils/swc_util.py @@ -85,7 +85,7 @@ def load(self, swc_pointer): if ".swc" in swc_pointer: return self.load_from_local_path(swc_pointer) if os.path.isdir(swc_pointer): - paths = util.list_paths(swc_pointer, ext=".swc") + paths = util.list_paths(swc_pointer, extension=".swc") return self.load_from_local_paths(paths) raise Exception("SWC Pointer is not Valid!") @@ -307,7 +307,7 @@ def process_content(self, content): Offset of swc file. """ - offset = [1.0, 1.0, 1.0] + offset = [0.0, 0.0, 0.0] for i, line in enumerate(content): if line.startswith("# OFFSET"): offset = self.read_xyz(line.split()[2:5]) @@ -364,7 +364,7 @@ def write(path, content, color=None): elif type(content) is nx.Graph: write_graph(path, content, color=color) else: - raise ExceptionType("Unable to write {} to swc".format(type(content))) + raise Exception("Unable to write {} to swc".format(type(content))) def write_list(path, entry_list, color=None): @@ -558,7 +558,7 @@ def set_radius(graph, i): """ try: radius = graph[i]["radius"] - except: + except ValueError: radius = 1.0 return radius diff --git a/src/deep_neurographs/utils/util.py b/src/deep_neurographs/utils/util.py index c59fbfd..79d820c 100644 --- a/src/deep_neurographs/utils/util.py +++ b/src/deep_neurographs/utils/util.py @@ -8,7 +8,6 @@ """ -import boto3 import json import math import os @@ -18,6 +17,7 @@ from time import time from zipfile import ZipFile +import boto3 import numpy as np import psutil @@ -216,33 +216,32 @@ def rmdir(path): shutil.rmtree(path) -def listdir(path, ext=None): +def listdir(path, extension=None): """ Lists all files in the directory at "path". If an extension "ext" is - provided, then only files containing "ext" are returned. + provided, then only files containing "extension" are returned. Parameters ---------- path : str Path to directory to be searched. - - ext : str, optional + extension : str, optional Extension of file type of interest. The default is None. Returns ------- list - Files in directory at "path" with extension "ext" if provided. + Files in directory at "path" with extension "extension" if provided. Otherwise, list of all files in directory. """ - if ext is None: + if extension is None: return [f for f in os.listdir(path)] else: - return [f for f in os.listdir(path) if ext in f] + return [f for f in os.listdir(path) if f.endswith(extension)] -def list_subdirs(path, keyword=None): +def list_subdirs(path, keyword=None, return_paths=False): """ Creates list of all subdirectories at "path". If "keyword" is provided, then only subdirectories containing "keyword" are contained in list. @@ -251,10 +250,12 @@ def list_subdirs(path, keyword=None): ---------- path : str Path to directory containing subdirectories to be listed. - keyword : str, optional Only subdirectories containing "keyword" are contained in list that is returned. The default is None. + return_paths : bool + Indication of whether to return full path of subdirectories. The + default is False. Returns ------- @@ -263,27 +264,27 @@ def list_subdirs(path, keyword=None): """ subdirs = list() - for d in os.listdir(path): - if os.path.isdir(os.path.join(path, d)): - if keyword is None: - subdirs.append(d) - elif keyword in d: - subdirs.append(d) - subdirs.sort() - return subdirs + for subdir in os.listdir(path): + is_dir = os.path.isdir(os.path.join(path, subdir)) + is_hidden = subdir.startswith('.') + if is_dir and not is_hidden: + subdir = os.path.join(path, subdir) if return_paths else subdir + if (keyword and keyword in subdir) or not keyword: + subdirs.append(subdir) + return sorted(subdirs) -def list_paths(directory, ext=None): +def list_paths(directory, extension=None): """ - Lists all paths within "directory". + Lists all paths within "directory" that end with "extension" if provided. Parameters ---------- directory : str Directory to be searched. - ext : str, optional - If provided, only paths of files with the extension "ext" are - returned. The default is None. + extension : str, optional + If provided, only paths of files with the extension are returned. The + default is None. Returns ------- @@ -292,12 +293,12 @@ def list_paths(directory, ext=None): """ paths = list() - for f in listdir(directory, ext=ext): + for f in listdir(directory, extension=extension): paths.append(os.path.join(directory, f)) return paths -def set_path(dir_name, filename, ext): +def set_path(dir_name, filename, extension): """ Sets the path for a file in a directory. If a file with the same name exists, then this routine finds a suffix to append to the filename. @@ -308,7 +309,7 @@ def set_path(dir_name, filename, ext): Name of directory that path will be generated to point to. filename : str Name of file that path will contain. - ext : str + extension : str Extension of file. Returns @@ -319,10 +320,10 @@ def set_path(dir_name, filename, ext): """ cnt = 0 - ext = ext.replace(".", "") - path = os.path.join(dir_name, f"{filename}.{ext}") + extension = extension.replace(".", "") + path = os.path.join(dir_name, f"{filename}.{extension}") while os.path.exists(path): - path = os.path.join(dir_name, f"{filename}.{cnt}.{ext}") + path = os.path.join(dir_name, f"{filename}.{cnt}.{extension}") cnt += 1 return path From ea7e9e9954ff78220764b8c15583919a0892169c Mon Sep 17 00:00:00 2001 From: anna-grim Date: Thu, 26 Sep 2024 00:08:02 +0000 Subject: [PATCH 03/13] feat: find gcs image path --- src/deep_neurographs/generate_proposals.py | 13 ++++- .../machine_learning/heterograph_trainer.py | 1 + .../machine_learning/inference.py | 3 +- src/deep_neurographs/neurograph.py | 5 ++ .../{run_graphtrace.py => run_pipeline.py} | 0 ...{train_graphtrace.py => train_pipeline.py} | 0 src/deep_neurographs/utils/img_util.py | 29 ++++++++++- src/deep_neurographs/utils/util.py | 50 ++++++++++++++++--- 8 files changed, 90 insertions(+), 11 deletions(-) rename src/deep_neurographs/{run_graphtrace.py => run_pipeline.py} (100%) rename src/deep_neurographs/{train_graphtrace.py => train_pipeline.py} (100%) diff --git a/src/deep_neurographs/generate_proposals.py b/src/deep_neurographs/generate_proposals.py index 1f9c057..3a5bd79 100644 --- a/src/deep_neurographs/generate_proposals.py +++ b/src/deep_neurographs/generate_proposals.py @@ -25,6 +25,7 @@ def run( radius, complex_bool=False, long_range_bool=True, + progress_bar=True, trim_endpoints_bool=True, ): """ @@ -43,6 +44,9 @@ def run( Indication of whether to generate simple proposals within distance of "LONG_RANGE_FACTOR" * radius of leaf from leaf without any proposals. The default is False. + progress_bar : bool, optional + Indication of whether to print out a progress bar while generating + proposals. The default is True. trim_endpoints_bool : bool, optional Indication of whether to endpoints of branches with exactly one proposal. The default is True. @@ -52,10 +56,17 @@ def run( None """ + # Initializations connections = dict() kdtree = init_kdtree(neurograph, complex_bool) radius *= RADIUS_SCALING_FACTOR if trim_endpoints_bool else 1.0 - for leaf in tqdm(neurograph.leafs, desc="Proposals"): + if progress_bar: + iterable = tqdm(neurograph.leafs, desc="Proposals") + else: + iterable = neurograph.leafs + + # Main + for leaf in iterable: # Generate potential proposals candidates = get_candidates( neurograph, diff --git a/src/deep_neurographs/machine_learning/heterograph_trainer.py b/src/deep_neurographs/machine_learning/heterograph_trainer.py index 3554932..f0a2230 100644 --- a/src/deep_neurographs/machine_learning/heterograph_trainer.py +++ b/src/deep_neurographs/machine_learning/heterograph_trainer.py @@ -129,6 +129,7 @@ def run_on_graphs(self, datasets, augment=False): y, hat_y = [], [] self.model.train() for graph_id in train_ids: + print(graph_id) y_i, hat_y_i = self.train( datasets[graph_id], epoch, augment=augment ) diff --git a/src/deep_neurographs/machine_learning/inference.py b/src/deep_neurographs/machine_learning/inference.py index 0ecf87d..3cfbf3e 100644 --- a/src/deep_neurographs/machine_learning/inference.py +++ b/src/deep_neurographs/machine_learning/inference.py @@ -80,8 +80,7 @@ def __init__( self.threshold = confidence_threshold # Load image and model - driver = "n5" if ".n5" in img_path else "zarr" - self.img = img_util.open_tensorstore(img_path, driver=driver) + self.img = img_util.open_tensorstore(img_path, driver="zarr") self.model = ml_util.load_model(model_path) def run(self, neurograph, proposals): diff --git a/src/deep_neurographs/neurograph.py b/src/deep_neurographs/neurograph.py index 8bd1dc7..d9cad37 100644 --- a/src/deep_neurographs/neurograph.py +++ b/src/deep_neurographs/neurograph.py @@ -281,6 +281,7 @@ def generate_proposals( complex_bool=False, groundtruth_graph=None, long_range_bool=False, + progress_bar=True, proposals_per_leaf=3, return_trimmed_proposals=False, trim_endpoints_bool=False, @@ -300,6 +301,9 @@ def generate_proposals( long_range_bool : bool, optional Indication of whether to generate long range proposals. The default is False. + progress_bar : bool, optional + Indication of whether to print out a progress bar while generating + proposals. The default is True. proposals_per_leaf : int, optional Maximum number of proposals generated for each leaf. The default is 3. @@ -322,6 +326,7 @@ def generate_proposals( search_radius, complex_bool=complex_bool, long_range_bool=long_range_bool, + progress_bar=progress_bar, trim_endpoints_bool=trim_endpoints_bool, ) if groundtruth_graph: diff --git a/src/deep_neurographs/run_graphtrace.py b/src/deep_neurographs/run_pipeline.py similarity index 100% rename from src/deep_neurographs/run_graphtrace.py rename to src/deep_neurographs/run_pipeline.py diff --git a/src/deep_neurographs/train_graphtrace.py b/src/deep_neurographs/train_pipeline.py similarity index 100% rename from src/deep_neurographs/train_graphtrace.py rename to src/deep_neurographs/train_pipeline.py diff --git a/src/deep_neurographs/utils/img_util.py b/src/deep_neurographs/utils/img_util.py index 04acb0d..80f660f 100644 --- a/src/deep_neurographs/utils/img_util.py +++ b/src/deep_neurographs/utils/img_util.py @@ -9,11 +9,13 @@ """ from copy import deepcopy +from skimage.color import label2rgb import fastremap import numpy as np import tensorstore as ts -from skimage.color import label2rgb + +from deep_neurographs.utils import util ANISOTROPY = [0.748, 0.748, 1.0] SUPPORTED_DRIVERS = ["neuroglancer_precomputed", "n5", "zarr"] @@ -453,3 +455,28 @@ def get_chunk_labels(path, xyz, shape, from_center=True): img = open_tensorstore(path) img = read_tensorstore(img, xyz, shape, from_center=from_center) return set(fastremap.unique(img).astype(int)) + + +def find_img_path(bucket_name, img_root, dataset_name): + """ + Find the path of a specific dataset in a GCS bucket. + + Parameters: + ---------- + bucket_name : str + Name of the GCS bucket where the images are stored. + img_root : str + Root directory path in the GCS bucket where the images are located. + dataset_name : str + Name of the dataset to be searched for within the subdirectories. + + Returns: + ------- + str + Path of the found dataset subdirectory within the specified GCS bucket. + + """ + for subdir in util.list_gcs_subdirectories(bucket_name, img_root): + if dataset_name in subdir: + return subdir + "fused.zarr/" + raise(f"Dataset not found in {bucket_name} - {img_root}") diff --git a/src/deep_neurographs/utils/util.py b/src/deep_neurographs/utils/util.py index 79d820c..44a9ab3 100644 --- a/src/deep_neurographs/utils/util.py +++ b/src/deep_neurographs/utils/util.py @@ -8,18 +8,19 @@ """ -import json -import math -import os -import shutil +from google.cloud import storage from io import BytesIO from random import sample from time import time from zipfile import ZipFile import boto3 +import json +import math import numpy as np +import os import psutil +import shutil # --- dictionary utils --- @@ -348,7 +349,7 @@ def list_files_in_zip(zip_content): return zip_file.namelist() -def list_gcs_filenames(bucket, cloud_path, extension): +def list_gcs_filenames(bucket, prefix, extension): """ Lists all files in a GCS bucket with the given extension. @@ -356,7 +357,7 @@ def list_gcs_filenames(bucket, cloud_path, extension): ---------- bucket : google.cloud.client Name of bucket to be read from. - cloud_path : str + prefix : str Path to directory in "bucket". extension : str File extension of filenames to be listed. @@ -367,10 +368,45 @@ def list_gcs_filenames(bucket, cloud_path, extension): Filenames stored at "cloud" path with the given extension. """ - blobs = bucket.list_blobs(prefix=cloud_path) + blobs = bucket.list_blobs(prefix=prefix) return [blob.name for blob in blobs if extension in blob.name] +def list_gcs_subdirectories(bucket_name, prefix): + """ + Lists all direct subdirectories of a given prefix in a GCS bucket. + + Parameters + ---------- + bucket : str + Name of bucket to be read from. + prefix : str + Path to directory in "bucket". + + Returns + ------- + list[str] + List of direct subdirectories. + + """ + # Load blobs + storage_client = storage.Client() + blobs = storage_client.list_blobs( + bucket_name, prefix=prefix, delimiter="/" + ) + [blob.name for blob in blobs] + + # Parse directory contents + prefix_depth = len(prefix.split("/")) + subdirs = list() + for prefix in blobs.prefixes: + is_dir = prefix.endswith("/") + is_direct_subdir = len(prefix.split("/")) - 1 == prefix_depth + if is_dir and is_direct_subdir: + subdirs.append(prefix) + return subdirs + + # -- io utils -- def read_json(path): """ From 64179e142d8cf58e82d86f705fe95998d4f0a75f Mon Sep 17 00:00:00 2001 From: anna-grim Date: Thu, 26 Sep 2024 01:43:05 +0000 Subject: [PATCH 04/13] feat: feature generation in trainer --- src/deep_neurographs/train_pipeline.py | 31 ++++++++++++++++++++++---- src/deep_neurographs/utils/img_util.py | 2 +- src/deep_neurographs/utils/util.py | 2 +- 3 files changed, 29 insertions(+), 6 deletions(-) diff --git a/src/deep_neurographs/train_pipeline.py b/src/deep_neurographs/train_pipeline.py index 663ce35..cc670c4 100644 --- a/src/deep_neurographs/train_pipeline.py +++ b/src/deep_neurographs/train_pipeline.py @@ -9,7 +9,7 @@ """ -from deep_neurographs.utils import util +from deep_neurographs.utils import img_util, util from deep_neurographs.utils.graph_util import GraphLoader @@ -77,11 +77,34 @@ def load_example( {"dataset_name": dataset_name, "example_id": example_id} ) - def load_img(self, img_path, dataset_name): - pass + def load_img(self, path, dataset_name): + if dataset_name not in self.imgs: + self.imgs[dataset_name] = img_util.open_tensorstore(path, "zarr") def run(self): - pass + self.generate_proposals() + + def generate_proposals(self): + print("dataset_name - example_id - # proposals - % accepted") + for i in range(self.n_examples()): + # Run + self.pred_graphs[i].generate_proposals( + self.graph_config.search_radius, + complex_bool=self.graph_config.complex_bool, + groundtruth_graph=self.gt_graphs[i], + long_range_bool=self.graph_config.long_range_bool, + progress_bar=False, + proposals_per_leaf=self.graph_config.proposals_per_leaf, + trim_endpoints_bool=self.graph_config.trim_endpoints_bool, + ) + + # Report results + dataset_name = self.idx_to_ids[i]["dataset_name"] + example_id = self.idx_to_ids[i]["example_id"] + n_proposals = self.pred_graphs[i].n_proposals() + n_targets = len(self.pred_graphs[i].target_edges) + p_accepts = round(n_targets / n_proposals, 4) + print(f"{dataset_name} {example_id} {n_proposals} {p_accepts}") def generate_features(self): # check that every example has an image that was loaded! diff --git a/src/deep_neurographs/utils/img_util.py b/src/deep_neurographs/utils/img_util.py index 80f660f..41c10c5 100644 --- a/src/deep_neurographs/utils/img_util.py +++ b/src/deep_neurographs/utils/img_util.py @@ -478,5 +478,5 @@ def find_img_path(bucket_name, img_root, dataset_name): """ for subdir in util.list_gcs_subdirectories(bucket_name, img_root): if dataset_name in subdir: - return subdir + "fused.zarr/" + return subdir + "whole-brain/fused.zarr/" raise(f"Dataset not found in {bucket_name} - {img_root}") diff --git a/src/deep_neurographs/utils/util.py b/src/deep_neurographs/utils/util.py index 44a9ab3..6e165bb 100644 --- a/src/deep_neurographs/utils/util.py +++ b/src/deep_neurographs/utils/util.py @@ -389,7 +389,7 @@ def list_gcs_subdirectories(bucket_name, prefix): List of direct subdirectories. """ - # Load blobs + # Load blobs storage_client = storage.Client() blobs = storage_client.list_blobs( bucket_name, prefix=prefix, delimiter="/" From 9484330b71a9d2bb0cc266da4584f371b80794f7 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Thu, 26 Sep 2024 06:00:55 +0000 Subject: [PATCH 05/13] feat: validation sets in training --- src/deep_neurographs/train_pipeline.py | 110 ++++++++++++++++++++++--- src/deep_neurographs/utils/img_util.py | 2 +- src/deep_neurographs/utils/ml_util.py | 23 ++++++ src/deep_neurographs/utils/util.py | 10 +-- 4 files changed, 126 insertions(+), 19 deletions(-) diff --git a/src/deep_neurographs/train_pipeline.py b/src/deep_neurographs/train_pipeline.py index cc670c4..a3efb63 100644 --- a/src/deep_neurographs/train_pipeline.py +++ b/src/deep_neurographs/train_pipeline.py @@ -9,7 +9,15 @@ """ -from deep_neurographs.utils import img_util, util +import os +from datetime import datetime +from random import sample + +import numpy as np +from torch.nn import BCEWithLogitsLoss + +from deep_neurographs.machine_learning import feature_generation +from deep_neurographs.utils import img_util, ml_util, util from deep_neurographs.utils.graph_util import GraphLoader @@ -20,7 +28,14 @@ class Trainer: """ def __init__( - self, config, model_type, output_dir=None, save_model_bool=True + self, + config, + model_type, + criterion=None, + output_dir=None, + validation_ids=None, + validation_split=0.15, + save_model_bool=True, ): # Check for parameter errors if save_model_bool and not output_dir: @@ -31,11 +46,19 @@ def __init__( self.model_type = model_type self.output_dir = output_dir self.save_model_bool = save_model_bool + self.validation_ids = validation_ids + self.validation_split = validation_split # Set data structures for training examples self.gt_graphs = list() self.pred_graphs = list() self.imgs = dict() + self.train_dataset = list() + self.validation_dataset = list() + + # Train parameters + self.criterion = criterion if criterion else BCEWithLogitsLoss() + self.validation_ids = validation_ids # Extract config settings self.graph_config = config.graph_config @@ -45,15 +68,36 @@ def __init__( progress_bar=False, ) + # --- getters/setters --- def n_examples(self): return len(self.gt_graphs) + def n_train_examples(self): + return len(self.train_dataset) + + def n_validation_samples(self): + return len(self.validation_dataset) + + def set_validation_idxs(self): + if self.validation_ids is None: + k = int(self.validation_split * self.n_examples()) + self.validation_idxs = sample(np.arange(self.n_examples), k) + else: + self.validation_idxs = list() + for ids in self.validation_ids: + for i in range(self.n_examples()): + same = all([ids[k] == self.idx_to_ids[i][k] for k in ids]) + if same: + self.validation_idxs.append(i) + + # --- loaders --- def load_example( self, gt_pointer, pred_pointer, - dataset_name, + sample_id, example_id=None, + pred_id=None, metadata_path=None, ): # Read metadata @@ -74,18 +118,25 @@ def load_example( # Set example ids self.idx_to_ids.append( - {"dataset_name": dataset_name, "example_id": example_id} + { + "sample_id": sample_id, + "example_id": example_id, + "pred_id": pred_id, + } ) - def load_img(self, path, dataset_name): - if dataset_name not in self.imgs: - self.imgs[dataset_name] = img_util.open_tensorstore(path, "zarr") + def load_img(self, path, sample_id): + if sample_id not in self.imgs: + self.imgs[sample_id] = img_util.open_tensorstore(path, "zarr") + # --- main pipeline --- def run(self): self.generate_proposals() + self.generate_features() + self.train_model() def generate_proposals(self): - print("dataset_name - example_id - # proposals - % accepted") + print("sample_id - example_id - # proposals - % accepted") for i in range(self.n_examples()): # Run self.pred_graphs[i].generate_proposals( @@ -99,16 +150,49 @@ def generate_proposals(self): ) # Report results - dataset_name = self.idx_to_ids[i]["dataset_name"] + sample_id = self.idx_to_ids[i]["sample_id"] example_id = self.idx_to_ids[i]["example_id"] n_proposals = self.pred_graphs[i].n_proposals() n_targets = len(self.pred_graphs[i].target_edges) p_accepts = round(n_targets / n_proposals, 4) - print(f"{dataset_name} {example_id} {n_proposals} {p_accepts}") + print(f"{sample_id} {example_id} {n_proposals} {p_accepts}") def generate_features(self): - # check that every example has an image that was loaded! - pass + self.set_validation_idxs() + for i in range(self.n_examples()): + # Get proposals + proposals_dict = { + "proposals": self.pred_graphs[i].list_proposals(), + "graph": self.pred_graphs[i].copy_graph() + } + + # Generate features + sample_id = self.idx_to_ids[i]["sample_id"] + features = feature_generation.run( + self.pred_graphs[i], + self.imgs[sample_id], + self.model_type, + proposals_dict, + self.graph_config.search_radius, + ) - def evaluate(self): + # Initialize train and validation datasets + dataset = ml_util.init_dataset( + self.pred_graphs[i], + features, + self.model_type, + computation_graph=proposals_dict["graph"] + ) + if i in self.validation_ids: + self.validation_dataset.append(dataset) + else: + self.train_dataset.append(dataset) + + def train_model(self): pass + + def save_model(self, model): + name = self.model_type + "-" + datetime.today().strftime('%Y-%m-%d') + extension = ".pth" if "Net" in self.model_type else ".joblib" + path = os.path.join(self.output_dir, name + extension) + util.save_model(path, model) diff --git a/src/deep_neurographs/utils/img_util.py b/src/deep_neurographs/utils/img_util.py index 41c10c5..c6dd8ac 100644 --- a/src/deep_neurographs/utils/img_util.py +++ b/src/deep_neurographs/utils/img_util.py @@ -9,11 +9,11 @@ """ from copy import deepcopy -from skimage.color import label2rgb import fastremap import numpy as np import tensorstore as ts +from skimage.color import label2rgb from deep_neurographs.utils import util diff --git a/src/deep_neurographs/utils/ml_util.py b/src/deep_neurographs/utils/ml_util.py index 64eb50c..347c039 100644 --- a/src/deep_neurographs/utils/ml_util.py +++ b/src/deep_neurographs/utils/ml_util.py @@ -81,6 +81,29 @@ def load_model(path): return joblib.load(path) if ".joblib" in path else torch.load(path) +def save_model(path, model, model_type): + """ + Saves a machine learning model. + + Parameters + ---------- + path : str + Path that model parameters will be written to. + model : object + Model to be saved. + + Returns + ------- + None + + """ + print("Model saved!") + if "Net" in model_type: + torch.save(model, path) + else: + joblib.dump(model, path) + + # --- dataset utils --- def init_dataset( neurograph, features, model_type, computation_graph=None, sample_ids=None diff --git a/src/deep_neurographs/utils/util.py b/src/deep_neurographs/utils/util.py index 6e165bb..9c4809a 100644 --- a/src/deep_neurographs/utils/util.py +++ b/src/deep_neurographs/utils/util.py @@ -8,19 +8,19 @@ """ -from google.cloud import storage +import json +import math +import os +import shutil from io import BytesIO from random import sample from time import time from zipfile import ZipFile import boto3 -import json -import math import numpy as np -import os import psutil -import shutil +from google.cloud import storage # --- dictionary utils --- From 82098ff5e86a8c05f564a5fa657a96293ecca4f5 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Thu, 26 Sep 2024 17:16:56 +0000 Subject: [PATCH 06/13] bug: hgraph forward passes with missing edge types --- src/deep_neurographs/config.py | 4 ++++ .../machine_learning/datasets.py | 4 ++-- .../machine_learning/feature_generation.py | 6 ++--- .../machine_learning/heterograph_datasets.py | 24 ++++++++++++++++++- .../heterograph_feature_generation.py | 17 +++++++++++-- .../machine_learning/heterograph_models.py | 20 ++++++++++------ src/deep_neurographs/train_pipeline.py | 6 ++--- src/deep_neurographs/utils/gnn_util.py | 24 +++++++++---------- 8 files changed, 74 insertions(+), 31 deletions(-) diff --git a/src/deep_neurographs/config.py b/src/deep_neurographs/config.py index fbb403c..3421473 100644 --- a/src/deep_neurographs/config.py +++ b/src/deep_neurographs/config.py @@ -97,8 +97,12 @@ class MLConfig: batch_size: int = 2000 downsample_factor: int = 1 high_threshold: float = 0.9 + lr: float = 1e-4 threshold: float = 0.6 model_type: str = "GraphNeuralNet" + n_epochs: int = 1000 + validation_split: float = 0.15 + weight_decay: float = 1e-3 class Config: diff --git a/src/deep_neurographs/machine_learning/datasets.py b/src/deep_neurographs/machine_learning/datasets.py index eed5bf2..9e05a1a 100644 --- a/src/deep_neurographs/machine_learning/datasets.py +++ b/src/deep_neurographs/machine_learning/datasets.py @@ -295,8 +295,8 @@ def reformat(arr): def init_idxs(idxs): """ - Adds dictionary item called "edge_to_index" which maps an edge in a - neurograph to an that represents the edge's position in the feature + Adds dictionary item called "edge_to_index" which maps a branch/proposal + in a neurograph to an idx that represents it's position in the feature matrix. Parameters diff --git a/src/deep_neurographs/machine_learning/feature_generation.py b/src/deep_neurographs/machine_learning/feature_generation.py index faec027..6ed110e 100644 --- a/src/deep_neurographs/machine_learning/feature_generation.py +++ b/src/deep_neurographs/machine_learning/feature_generation.py @@ -515,15 +515,13 @@ def stack_chunks(neurograph, features, shift=0): # -- util -- -def count_features(model_type): +def count_features(): """ Counts number of features based on the "model_type". Parameters ---------- - model_type : str - Indication of model to be trained. Options include: AdaBoost, - RandomForest, FeedForwardNet, MultiModalNet. + None Returns ------- diff --git a/src/deep_neurographs/machine_learning/heterograph_datasets.py b/src/deep_neurographs/machine_learning/heterograph_datasets.py index 5cb2569..963e2b8 100644 --- a/src/deep_neurographs/machine_learning/heterograph_datasets.py +++ b/src/deep_neurographs/machine_learning/heterograph_datasets.py @@ -137,8 +137,10 @@ def __init__( # Edges self.init_edges() + self.check_missing_edge_type() self.init_edge_attrs(x_nodes) self.n_edge_attrs = n_edge_features(x_nodes) + def init_edges(self): """ @@ -192,6 +194,23 @@ def init_edge_attrs(self, x_nodes): x_nodes, edge_type, self.idxs_branches, self.idxs_proposals ) + def check_missing_edge_type(self): + edge_type = ("branch", "edge", "branch") + if len(self.data[edge_type].edge_index) == 0: + # Add dummy features + dtype = self.data["branch"].x.dtype + zeros = torch.zeros(2, self.n_branch_features(), dtype=dtype) + self.data["branch"].x = torch.cat( + (self.data["branch"].x, zeros), dim=0 + ) + + # Update edge_index + n = self.data["branch"]["x"].size(0) + edge_index = [[n - 1, n - 2], [n - 2, n - 1]] + self.data[edge_type].edge_index = gnn_util.to_tensor(edge_index) + self.idxs_branches["idx_to_edge"][n - 1] = frozenset({-1, -2}) + self.idxs_branches["idx_to_edge"][n - 2] = frozenset({-2, -3}) + # -- Getters -- def n_branch_features(self): """ @@ -342,7 +361,10 @@ def set_edge_attrs(self, x_nodes, edge_type, idx_map): for i in range(self.data[edge_type].edge_index.size(1)): e1, e2 = self.data[edge_type].edge_index[:, i] v = node_intersection(idx_map, e1, e2) - attrs.append(x_nodes[v]) + if v < 0: + attrs.append(torch.zeros(self.n_branch_features() + 1)) + else: + attrs.append(x_nodes[v]) arrs = torch.tensor(np.array(attrs), dtype=DTYPE) self.data[edge_type].edge_attr = arrs diff --git a/src/deep_neurographs/machine_learning/heterograph_feature_generation.py b/src/deep_neurographs/machine_learning/heterograph_feature_generation.py index 1ab1f57..557a9ef 100644 --- a/src/deep_neurographs/machine_learning/heterograph_feature_generation.py +++ b/src/deep_neurographs/machine_learning/heterograph_feature_generation.py @@ -16,10 +16,10 @@ from deep_neurographs.machine_learning import feature_generation as feats from deep_neurographs.utils import img_util -WINDOW = [5, 5, 5] + N_PROFILE_PTS = 16 NODE_PROFILE_DEPTH = 16 - +WINDOW = [5, 5, 5] def generate_hgnn_features( neurograph, img, proposals_dict, radius, downsample_factor @@ -458,3 +458,16 @@ def check_degenerate(voxels): [voxels, voxels[0, :] + np.array([1, 1, 1], dtype=int)] ) return voxels + + +def n_node_features(): + return {'branch': 2, 'proposal': 34} + + +def n_edge_features(): + n_edge_features_dict = { + ('proposal', 'edge', 'proposal'): 3, + ('branch', 'edge', 'branch'): 3, + ('branch', 'edge', 'proposal'): 3 + } + return n_edge_features_dict diff --git a/src/deep_neurographs/machine_learning/heterograph_models.py b/src/deep_neurographs/machine_learning/heterograph_models.py index 38a2df7..a27cd52 100644 --- a/src/deep_neurographs/machine_learning/heterograph_models.py +++ b/src/deep_neurographs/machine_learning/heterograph_models.py @@ -8,6 +8,7 @@ """ +import numpy as np import torch import torch.nn.init as init from torch import nn @@ -15,6 +16,8 @@ from torch_geometric.nn import GATv2Conv as GATConv from torch_geometric.nn import HEATConv, HeteroConv, Linear +from deep_neurographs.machine_learning import heterograph_feature_generation + CONV_TYPES = ["GATConv", "GCNConv"] DROPOUT = 0.3 HEADS_1 = 1 @@ -29,9 +32,7 @@ class HeteroGNN(torch.nn.Module): def __init__( self, - node_dict, - edge_dict, - hidden_dim, + scale_hidden_dim=2, dropout=DROPOUT, heads_1=HEADS_1, heads_2=HEADS_2, @@ -41,6 +42,11 @@ def __init__( """ super().__init__() + # Feature vector sizes + node_dict = heterograph_feature_generation.n_node_features() + edge_dict = heterograph_feature_generation.n_edge_features() + hidden_dim = scale_hidden_dim * np.max(list(node_dict.values())) + # Linear layers output_dim = heads_1 * heads_2 * hidden_dim self.input_nodes = nn.ModuleDict( @@ -161,9 +167,8 @@ def forward(self, x_dict, edge_index_dict, edge_attr_dict): x_dict = self.activation(x_dict) # Input - Edges - edge_attr_dict = { - key: f(edge_attr_dict[key]) for key, f in self.input_edges.items() - } + for key, f in self.input_edges.items(): + edge_attr_dict[key] = f(edge_attr_dict[key]) edge_attr_dict = self.activation(edge_attr_dict) # Convolutional layers @@ -218,7 +223,8 @@ def __init__( metadata=metadata, ) """ - x in_channels (int) – Size of each input sample, or -1 to derive the size from the first input(s) to the forward method. + x in_channels (int) – Size of each input sample, or -1 to + derive the size from the first input(s) to the forward method. x out_channels (int) – Size of each output sample. x num_node_types (int) – The number of node types. x num_edge_types (int) – The number of edge types. diff --git a/src/deep_neurographs/train_pipeline.py b/src/deep_neurographs/train_pipeline.py index a3efb63..7fad91a 100644 --- a/src/deep_neurographs/train_pipeline.py +++ b/src/deep_neurographs/train_pipeline.py @@ -30,11 +30,11 @@ class Trainer: def __init__( self, config, + model, model_type, criterion=None, output_dir=None, validation_ids=None, - validation_split=0.15, save_model_bool=True, ): # Check for parameter errors @@ -43,11 +43,11 @@ def __init__( # Set class attributes self.idx_to_ids = list() + self.model = model self.model_type = model_type self.output_dir = output_dir self.save_model_bool = save_model_bool self.validation_ids = validation_ids - self.validation_split = validation_split # Set data structures for training examples self.gt_graphs = list() @@ -80,7 +80,7 @@ def n_validation_samples(self): def set_validation_idxs(self): if self.validation_ids is None: - k = int(self.validation_split * self.n_examples()) + k = int(self.ml_config.validation_split * self.n_examples()) self.validation_idxs = sample(np.arange(self.n_examples), k) else: self.validation_idxs = list() diff --git a/src/deep_neurographs/utils/gnn_util.py b/src/deep_neurographs/utils/gnn_util.py index 2d8b5d9..2019f63 100644 --- a/src/deep_neurographs/utils/gnn_util.py +++ b/src/deep_neurographs/utils/gnn_util.py @@ -18,6 +18,18 @@ from deep_neurographs.utils import util +def get_inputs(data, model_type): + if "Hetero" in model_type: + x = data.x_dict + edge_index = data.edge_index_dict + edge_attr_dict = data.edge_attr_dict + return x, edge_index, edge_attr_dict + else: + x = data.x + edge_index = data.edge_index + return x, edge_index + + def toCPU(tensor): """ Moves tensor from GPU to CPU. @@ -35,18 +47,6 @@ def toCPU(tensor): return tensor.detach().cpu().tolist() -def get_inputs(data, model_type): - if "Hetero" in model_type: - x = data.x_dict - edge_index = data.edge_index_dict - edge_attr_dict = data.edge_attr_dict - return x, edge_index, edge_attr_dict - else: - x = data.x - edge_index = data.edge_index - return x, edge_index - - def to_tensor(my_list): """ Converts a list to a tensor with contiguous memory. From 10c79f225ead2b9709243b8375a934d0c1388fef Mon Sep 17 00:00:00 2001 From: anna-grim Date: Thu, 26 Sep 2024 22:03:17 +0000 Subject: [PATCH 07/13] refactor: hgnn trainer --- .../machine_learning/heterograph_datasets.py | 1 - .../heterograph_feature_generation.py | 1 + .../machine_learning/heterograph_trainer.py | 149 +++--------------- 3 files changed, 23 insertions(+), 128 deletions(-) diff --git a/src/deep_neurographs/machine_learning/heterograph_datasets.py b/src/deep_neurographs/machine_learning/heterograph_datasets.py index 963e2b8..c585d48 100644 --- a/src/deep_neurographs/machine_learning/heterograph_datasets.py +++ b/src/deep_neurographs/machine_learning/heterograph_datasets.py @@ -140,7 +140,6 @@ def __init__( self.check_missing_edge_type() self.init_edge_attrs(x_nodes) self.n_edge_attrs = n_edge_features(x_nodes) - def init_edges(self): """ diff --git a/src/deep_neurographs/machine_learning/heterograph_feature_generation.py b/src/deep_neurographs/machine_learning/heterograph_feature_generation.py index 557a9ef..32fc446 100644 --- a/src/deep_neurographs/machine_learning/heterograph_feature_generation.py +++ b/src/deep_neurographs/machine_learning/heterograph_feature_generation.py @@ -21,6 +21,7 @@ NODE_PROFILE_DEPTH = 16 WINDOW = [5, 5, 5] + def generate_hgnn_features( neurograph, img, proposals_dict, radius, downsample_factor ): diff --git a/src/deep_neurographs/machine_learning/heterograph_trainer.py b/src/deep_neurographs/machine_learning/heterograph_trainer.py index f0a2230..8dc5866 100644 --- a/src/deep_neurographs/machine_learning/heterograph_trainer.py +++ b/src/deep_neurographs/machine_learning/heterograph_trainer.py @@ -10,7 +10,7 @@ """ from copy import deepcopy -from random import sample, shuffle +from random import shuffle import numpy as np import torch @@ -22,25 +22,16 @@ ) from torch.optim.lr_scheduler import StepLR from torch.utils.tensorboard import SummaryWriter -from torch_geometric.utils import subgraph from deep_neurographs.utils import gnn_util, ml_util from deep_neurographs.utils.gnn_util import toCPU -# Training -FEATURE_DTYPE = torch.float32 -MODEL_TYPE = "HeteroGNN" LR = 1e-3 N_EPOCHS = 200 SCHEDULER_GAMMA = 0.5 SCHEDULER_STEP_SIZE = 1000 -TEST_PERCENT = 0.15 WEIGHT_DECAY = 1e-3 -# Augmentation -MAX_PROPOSAL_DROPOUT = 0.1 -SCALING_FACTOR = 0.05 - class HeteroGraphTrainer: """ @@ -54,8 +45,6 @@ def __init__( criterion, lr=LR, n_epochs=N_EPOCHS, - max_proposal_dropout=MAX_PROPOSAL_DROPOUT, - scaling_factor=SCALING_FACTOR, weight_decay=WEIGHT_DECAY, ): """ @@ -90,10 +79,6 @@ def __init__( self.init_scheduler() self.writer = SummaryWriter() - # Augmentation - self.scaling_factor = scaling_factor - self.max_proposal_dropout = max_proposal_dropout - def init_scheduler(self): self.scheduler = StepLR( self.optimizer, @@ -101,16 +86,14 @@ def init_scheduler(self): gamma=SCHEDULER_GAMMA, ) - def run_on_graphs(self, datasets, augment=False): + def run(self, train_dataset_list, validation_dataset_list): """ Trains a graph neural network in the case where "datasets" is a dictionary of datasets such that each corresponds to a distinct graph. Parameters ---------- - datasets : dict - Dictionary where each key is a graph id and the value is the - corresponding graph dataset. + ... Returns ------- @@ -118,32 +101,36 @@ def run_on_graphs(self, datasets, augment=False): Graph neural network that has been fit onto "datasets". """ - # Initializations best_score = -np.inf best_ckpt = None - - # Main - train_ids, test_ids = train_test_split(list(datasets.keys())) for epoch in range(self.n_epochs): # Train y, hat_y = [], [] self.model.train() - for graph_id in train_ids: - print(graph_id) - y_i, hat_y_i = self.train( - datasets[graph_id], epoch, augment=augment - ) + for graph_dataset in train_dataset_list: + # Forward pass + hat_y_i, y_i = self.predict(graph_dataset.data) + loss = self.criterion(hat_y_i, y_i) + self.writer.add_scalar("loss", loss, epoch) + + # Backward pass + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + # Store predictions y.extend(toCPU(y_i)) hat_y.extend(toCPU(hat_y_i)) + self.compute_metrics(y, hat_y, "train", epoch) self.scheduler.step() - # Test + # Validate if epoch % 10 == 0: y, hat_y = [], [] self.model.eval() - for graph_id in test_ids: - y_i, hat_y_i = self.forward(datasets[graph_id].data) + for graph_dataset in validation_dataset_list: + hat_y_i, y_i = self.predict(graph_dataset.data) y.extend(toCPU(y_i)) hat_y.extend(toCPU(hat_y_i)) test_score = self.compute_metrics(y, hat_y, "val", epoch) @@ -155,52 +142,7 @@ def run_on_graphs(self, datasets, augment=False): self.model.load_state_dict(best_ckpt) return self.model - def run_on_graph(self): - """ - Trains a graph neural network in the case where "dataset" is a - graph that may contain multiple connected components. - - Parameters - ---------- - dataset : dict - Dictionary where each key is a graph id and the value is the - corresponding graph dataset. - - Returns - ------- - None - - """ - pass - - def train(self, dataset, epoch, augment=False): - """ - Performs the forward pass and backpropagation to update the model's - weights. - - Parameters - ---------- - data : GraphDataset - Graph dataset that corresponds to a single connected component. - epoch : int - Current epoch. - augment : bool, optional - Indication of whether to augment data. Default is False. - - Returns - ------- - y : torch.Tensor - Ground truth. - hat_y : torch.Tensor - Prediction. - - """ - # if augment: - y, hat_y = self.forward(dataset.data) - self.backpropagate(y, hat_y, epoch) - return y, hat_y - - def forward(self, data): + def predict(self, data): """ Runs "data" through "self.model" to generate a prediction. @@ -219,37 +161,13 @@ def forward(self, data): """ # Run model x_dict, edge_index_dict, edge_attr_dict = gnn_util.get_inputs( - data, MODEL_TYPE + data, "HeteroGNN" ) - self.optimizer.zero_grad() hat_y = self.model(x_dict, edge_index_dict, edge_attr_dict) # Output y = data["proposal"]["y"] - return y, truncate(hat_y, y) - - def backpropagate(self, y, hat_y, epoch): - """ - Runs backpropagation to update the model's weights. - - Parameters - ---------- - y : torch.Tensor - Ground truth. - hat_y : torch.Tensor - Prediction. - epoch : int - Current epoch. - - Returns - ------- - None - - """ - loss = self.criterion(hat_y, y) - loss.backward() - self.optimizer.step() - self.writer.add_scalar("loss", loss, epoch) + return truncate(hat_y, y), y def compute_metrics(self, y, hat_y, prefix, epoch): """ @@ -312,29 +230,6 @@ def shuffler(my_list): return my_list -def train_test_split(graph_ids): - """ - Split a list of graph IDs into training and testing sets. - - Parameters - ---------- - graph_ids : list[str] - A list containing unique identifiers (IDs) for graphs. - - Returns - ------- - list - A list containing IDs for the training set. - list - A list containing IDs for the testing set. - - """ - n_test_examples = int(len(graph_ids) * TEST_PERCENT) - test_ids = ["block_000", "block_002"] # sample(graph_ids, n_test_examples) - train_ids = list(set(graph_ids) - set(test_ids)) - return train_ids, test_ids - - def truncate(hat_y, y): """ Truncates "hat_y" so that this tensor has the same shape as "y". Note this From abc0ee7570dc693c29a05d03198f2a785b02be38 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Fri, 27 Sep 2024 03:11:48 +0000 Subject: [PATCH 08/13] feat: functional training pipeline --- ...{heterograph_trainer.py => gnn_trainer.py} | 10 +- .../machine_learning/graph_trainer.py | 420 ------------------ .../machine_learning/heterograph_datasets.py | 1 - .../heterograph_feature_generation.py | 1 - src/deep_neurographs/train_pipeline.py | 39 +- 5 files changed, 31 insertions(+), 440 deletions(-) rename src/deep_neurographs/machine_learning/{heterograph_trainer.py => gnn_trainer.py} (96%) delete mode 100644 src/deep_neurographs/machine_learning/graph_trainer.py diff --git a/src/deep_neurographs/machine_learning/heterograph_trainer.py b/src/deep_neurographs/machine_learning/gnn_trainer.py similarity index 96% rename from src/deep_neurographs/machine_learning/heterograph_trainer.py rename to src/deep_neurographs/machine_learning/gnn_trainer.py index 8dc5866..bd32650 100644 --- a/src/deep_neurographs/machine_learning/heterograph_trainer.py +++ b/src/deep_neurographs/machine_learning/gnn_trainer.py @@ -33,7 +33,7 @@ WEIGHT_DECAY = 1e-3 -class HeteroGraphTrainer: +class Trainer: """ Custom class that trains graph neural networks. @@ -107,9 +107,9 @@ def run(self, train_dataset_list, validation_dataset_list): # Train y, hat_y = [], [] self.model.train() - for graph_dataset in train_dataset_list: + for dataset in train_dataset_list: # Forward pass - hat_y_i, y_i = self.predict(graph_dataset.data) + hat_y_i, y_i = self.predict(dataset.data) loss = self.criterion(hat_y_i, y_i) self.writer.add_scalar("loss", loss, epoch) @@ -129,8 +129,8 @@ def run(self, train_dataset_list, validation_dataset_list): if epoch % 10 == 0: y, hat_y = [], [] self.model.eval() - for graph_dataset in validation_dataset_list: - hat_y_i, y_i = self.predict(graph_dataset.data) + for dataset in validation_dataset_list: + hat_y_i, y_i = self.predict(dataset.data) y.extend(toCPU(y_i)) hat_y.extend(toCPU(hat_y_i)) test_score = self.compute_metrics(y, hat_y, "val", epoch) diff --git a/src/deep_neurographs/machine_learning/graph_trainer.py b/src/deep_neurographs/machine_learning/graph_trainer.py deleted file mode 100644 index cc9b9ef..0000000 --- a/src/deep_neurographs/machine_learning/graph_trainer.py +++ /dev/null @@ -1,420 +0,0 @@ -""" -Created on Sat April 12 11:00:00 2024 - -@author: Anna Grim -@email: anna.grim@alleninstitute.org - -Routines for training graph neural networks that classify edge proposals. - -""" - -from copy import deepcopy -from random import sample, shuffle - -import numpy as np -import torch -from sklearn.metrics import ( - accuracy_score, - f1_score, - precision_score, - recall_score, -) -from torch.optim.lr_scheduler import StepLR -from torch.utils.tensorboard import SummaryWriter -from torch_geometric.utils import subgraph - -from deep_neurographs.utils import gnn_util, ml_util - -# Training -LR = 1e-3 -MODEL_TYPE = "GraphNeuralNet" -N_EPOCHS = 200 -SCHEDULER_GAMMA = 0.5 -SCHEDULER_STEP_SIZE = 1000 -TEST_PERCENT = 0.15 -WEIGHT_DECAY = 1e-3 - -# Augmentation -MAX_PROPOSAL_DROPOUT = 0.1 -SCALING_FACTOR = 0.05 - - -class GraphTrainer: - """ - Custom class that trains graph neural networks. - - """ - - def __init__( - self, - model, - criterion, - lr=LR, - n_epochs=N_EPOCHS, - max_proposal_dropout=MAX_PROPOSAL_DROPOUT, - scaling_factor=SCALING_FACTOR, - weight_decay=WEIGHT_DECAY, - ): - """ - Constructs a GraphTrainer object. - - Parameters - ---------- - model : torch.nn.Module - Graph neural network. - criterion : torch.nn.Module._Loss - Loss function. - lr : float, optional - Learning rate. The default is the global variable LR. - n_epochs : int - Number of epochs. The default is the global variable N_EPOCHS. - weight_decay : float - Weight decay used in optimizer. The default is the global variable - WEIGHT_DECAY. - - Returns - ------- - None. - - """ - # Training - self.model = model # .to("cuda:0") - self.criterion = criterion - self.n_epochs = n_epochs - self.optimizer = torch.optim.Adam( - model.parameters(), lr=lr, weight_decay=weight_decay - ) - self.init_scheduler() - self.writer = SummaryWriter() - - # Augmentation - self.scaling_factor = scaling_factor - self.max_proposal_dropout = max_proposal_dropout - - def init_scheduler(self): - self.scheduler = StepLR( - self.optimizer, - step_size=SCHEDULER_STEP_SIZE, - gamma=SCHEDULER_GAMMA, - ) - - def run_on_graphs(self, datasets, augment=False): - """ - Trains a graph neural network in the case where "datasets" is a - dictionary of datasets such that each corresponds to a distinct graph. - - Parameters - ---------- - datasets : dict - Dictionary where each key is a graph id and the value is the - corresponding graph dataset. - - Returns - ------- - torch.nn.Module - Graph neural network that has been fit onto "datasets". - - """ - # Initializations - best_score = -np.inf - best_ckpt = None - - # Main - train_ids, test_ids = train_test_split(list(datasets.keys())) - for epoch in range(self.n_epochs): - # Train - y, hat_y = [], [] - self.model.train() - for graph_id in train_ids: - print(graph_id) - y_i, hat_y_i = self.train( - datasets[graph_id], epoch, augment=augment - ) - y.extend(gnn_util.toCPU(y_i)) - hat_y.extend(gnn_util.toCPU(hat_y_i)) - self.compute_metrics(y, hat_y, "train", epoch) - self.scheduler.step() - - # Test - if epoch % 10 == 0: - y, hat_y = [], [] - self.model.eval() - for graph_id in test_ids: - y_i, hat_y_i = self.forward(datasets[graph_id].data) - y.extend(gnn_util.toCPU(y_i)) - hat_y.extend(gnn_util.toCPU(hat_y_i)) - test_score = self.compute_metrics(y, hat_y, "val", epoch) - - # Check for best - if test_score > best_score: - best_score = test_score - best_ckpt = deepcopy(self.model.state_dict()) - self.model.load_state_dict(best_ckpt) - return self.model - - def run_on_graph(self): - """ - Trains a graph neural network in the case where "dataset" is a - graph that may contain multiple connected components. - - Parameters - ---------- - dataset : dict - Dictionary where each key is a graph id and the value is the - corresponding graph dataset. - - Returns - ------- - None - - """ - pass - - def train(self, dataset, epoch, augment=False): - """ - Performs the forward pass and backpropagation to update the model's - weights. - - Parameters - ---------- - data : GraphDataset - Graph dataset that corresponds to a single connected component. - epoch : int - Current epoch. - augment : bool, optional - Indication of whether to augment data. Default is False. - - Returns - ------- - torch.Tensor - Ground truth. - torch.Tensor - Prediction. - - """ - # Data augmentation (if applicable) - if self.augment: - data = self.augment(dataset) - else: - data = deepcopy(dataset.data) - - # Forward - y, hat_y = self.forward(data) - self.backpropagate(y, hat_y, epoch) - return y, hat_y - - def augment(self, dataset): - augmented_data = rescale_data(dataset, self.scaling_factor) - # augmented_data = proposal_dropout(data, self.max_proposal_dropout) - return augmented_data - - def forward(self, data): - """ - Runs "data" through "self.model" to generate a prediction. - - Parameters - ---------- - data : GraphDataset - Graph dataset that corresponds to a single connected component. - - Returns - ------- - torch.Tensor - Ground truth. - torch.Tensor - Prediction. - - """ - self.optimizer.zero_grad() - x, edge_index = gnn_util.get_inputs(data, MODEL_TYPE) - hat_y = self.model(x, edge_index) - y = data.y # .to("cuda:0", dtype=torch.float32) - return y, truncate(hat_y, y) - - def backpropagate(self, y, hat_y, epoch): - """ - Runs backpropagation to update the model's weights. - - Parameters - ---------- - y : torch.Tensor - Ground truth. - hat_y : torch.Tensor - Prediction. - epoch : int - Current epoch. - - Returns - ------- - None - - """ - loss = self.criterion(hat_y, y) - loss.backward() - self.optimizer.step() - self.writer.add_scalar("loss", loss, epoch) - - def compute_metrics(self, y, hat_y, prefix, epoch): - """ - Computes and logs evaluation metrics for binary classification. - - Parameters - ---------- - y : torch.Tensor - Ground truth. - hat_y : torch.Tensor - Prediction. - prefix : str - Prefix to be added to the metric names when logging. - epoch : int - Current epoch. - - Returns - ------- - float - F1 score. - - """ - # Initializations - y = np.array(y, dtype=int).tolist() - hat_y = get_predictions(hat_y) - - # Compute - accuracy = accuracy_score(y, hat_y) - accuracy_dif = accuracy - np.sum(y) / len(y) - precision = precision_score(y, hat_y) - recall = recall_score(y, hat_y) - f1 = f1_score(y, hat_y) - - # Log - self.writer.add_scalar(prefix + "_accuracy:", accuracy, epoch) - self.writer.add_scalar(prefix + "_accuracy_df:", accuracy_dif, epoch) - self.writer.add_scalar(prefix + "_precision:", precision, epoch) - self.writer.add_scalar(prefix + "_recall:", recall, epoch) - self.writer.add_scalar(prefix + "_f1:", f1, epoch) - return f1 - - -# -- util -- -def shuffler(my_list): - """ - Shuffles a list of items. - - Parameters - ---------- - my_list : list - List to be shuffled. - - Returns - ------- - list - Shuffled list. - - """ - shuffle(my_list) - return my_list - - -def train_test_split(graph_ids): - """ - Split a list of graph IDs into training and testing sets. - - Parameters - ---------- - graph_ids : list[str] - A list containing unique identifiers (IDs) for graphs. - - Returns - ------- - train_ids : list - A list containing IDs for the training set. - test_ids : list - A list containing IDs for the testing set. - - """ - n_test_examples = int(len(graph_ids) * TEST_PERCENT) - test_ids = ["block_000", "block_002"] # sample(graph_ids, n_test_examples) - train_ids = list(set(graph_ids) - set(test_ids)) - return train_ids, test_ids - - -def truncate(hat_y, y): - """ - Truncates "hat_y" so that this tensor has the same shape as "y". Note this - operation removes the predictions corresponding to branches so that loss - is computed over proposals. - - Parameters - ---------- - hat_y : torch.Tensor - Tensor to be truncated. - y : torch.Tensor - Tensor used as a reference. - - Returns - ------- - torch.Tensor - Truncated "hat_y". - - """ - return hat_y[: y.size(0), 0] - - -def get_predictions(hat_y, threshold=0.5): - """ - Generate binary predictions based on the input probabilities. - - Parameters - ---------- - hat_y : torch.Tensor - Predicted probabilities generated by "self.model". - threshold : float, optional - The threshold value for binary classification. The default is 0.5. - - Returns - ------- - list[int] - Binary predictions based on the given threshold. - - """ - return (ml_util.sigmoid(np.array(hat_y)) > threshold).tolist() - - -def connected_components(data): - cc_list = [] - cc_idxs = torch.unique(data.edge_index[0], return_inverse=True)[1] - for i in range(cc_idxs.max().item() + 1): - cc_list.append(torch.nonzero(cc_idxs == i, as_tuple=False).view(-1)) - return cc_list - - -def rescale_data(dataset, scaling_factor): - # Get scaling factor - low = 1.0 - scaling_factor - high = 1.0 + scaling_factor - scaling_factor = torch.tensor(np.random.uniform(low=low, high=high)) - - # Rescale - n = count_proposals(dataset) - data = deepcopy(dataset.data) - data.x[0:n, 1] = scaling_factor * data.x[0:n, 1] - return data - - -def proposal_dropout(data, max_proposal_dropout): - n_dropout_edges = len(data.dropout_edges) // 2 - dropout_prob = np.random.uniform(low=0, high=max_proposal_dropout) - n_remove = int(dropout_prob * n_dropout_edges) - remove_edges = sample(data.dropout_edges, n_remove) - for edge in remove_edges: - reversed_edge = [edge[1], edge[0]] - edges_to_remove = torch.tensor([edge, reversed_edge], dtype=torch.long) - edges_mask = torch.all( - data.data.edge_index.T == edges_to_remove[:, None], dim=2 - ).any(dim=0) - data.data.edge_index = data.data.edge_index[:, ~edges_mask] - return data - - -def count_proposals(dataset): - return dataset.data.y.size(0) diff --git a/src/deep_neurographs/machine_learning/heterograph_datasets.py b/src/deep_neurographs/machine_learning/heterograph_datasets.py index 963e2b8..c585d48 100644 --- a/src/deep_neurographs/machine_learning/heterograph_datasets.py +++ b/src/deep_neurographs/machine_learning/heterograph_datasets.py @@ -140,7 +140,6 @@ def __init__( self.check_missing_edge_type() self.init_edge_attrs(x_nodes) self.n_edge_attrs = n_edge_features(x_nodes) - def init_edges(self): """ diff --git a/src/deep_neurographs/machine_learning/heterograph_feature_generation.py b/src/deep_neurographs/machine_learning/heterograph_feature_generation.py index 32fc446..e79abc3 100644 --- a/src/deep_neurographs/machine_learning/heterograph_feature_generation.py +++ b/src/deep_neurographs/machine_learning/heterograph_feature_generation.py @@ -16,7 +16,6 @@ from deep_neurographs.machine_learning import feature_generation as feats from deep_neurographs.utils import img_util - N_PROFILE_PTS = 16 NODE_PROFILE_DEPTH = 16 WINDOW = [5, 5, 5] diff --git a/src/deep_neurographs/train_pipeline.py b/src/deep_neurographs/train_pipeline.py index 7fad91a..f400840 100644 --- a/src/deep_neurographs/train_pipeline.py +++ b/src/deep_neurographs/train_pipeline.py @@ -17,11 +17,12 @@ from torch.nn import BCEWithLogitsLoss from deep_neurographs.machine_learning import feature_generation +from deep_neurographs.machine_learning.gnn_trainer import Trainer from deep_neurographs.utils import img_util, ml_util, util from deep_neurographs.utils.graph_util import GraphLoader -class Trainer: +class TrainingPipeline: """ Class that is used to train a machine learning model that classifies proposals. @@ -53,8 +54,8 @@ def __init__( self.gt_graphs = list() self.pred_graphs = list() self.imgs = dict() - self.train_dataset = list() - self.validation_dataset = list() + self.train_dataset_list = list() + self.validation_dataset_list = list() # Train parameters self.criterion = criterion if criterion else BCEWithLogitsLoss() @@ -73,10 +74,10 @@ def n_examples(self): return len(self.gt_graphs) def n_train_examples(self): - return len(self.train_dataset) + return len(self.train_dataset_list) def n_validation_samples(self): - return len(self.validation_dataset) + return len(self.validation_dataset_list) def set_validation_idxs(self): if self.validation_ids is None: @@ -131,9 +132,24 @@ def load_img(self, path, sample_id): # --- main pipeline --- def run(self): + # Initialize training data self.generate_proposals() self.generate_features() - self.train_model() + + # Train model + trainer = Trainer( + self.model, + self.criterion, + lr=self.ml_config.lr, + n_epochs=self.ml_config.n_epochs, + ) + self.model = trainer.run( + self.train_dataset_list, self.validation_dataset_list + ) + + # Save model (if applicable) + if self.save_model_bool: + self.save_model() def generate_proposals(self): print("sample_id - example_id - # proposals - % accepted") @@ -184,15 +200,12 @@ def generate_features(self): computation_graph=proposals_dict["graph"] ) if i in self.validation_ids: - self.validation_dataset.append(dataset) + self.validation_dataset_list.append(dataset) else: - self.train_dataset.append(dataset) - - def train_model(self): - pass + self.train_dataset_list.append(dataset) - def save_model(self, model): + def save_model(self): name = self.model_type + "-" + datetime.today().strftime('%Y-%m-%d') extension = ".pth" if "Net" in self.model_type else ".joblib" path = os.path.join(self.output_dir, name + extension) - util.save_model(path, model) + ml_util.save_model(path, self.model, self.model_type) From d647268085bfe3321528f82a0dceb14fad9c1b8d Mon Sep 17 00:00:00 2001 From: anna-grim Date: Fri, 27 Sep 2024 04:44:45 +0000 Subject: [PATCH 09/13] bug: set validation data --- .../graph_artifact_removal.py | 6 +- .../machine_learning/gnn_trainer.py | 17 +- .../groundtruth_generation.py | 10 +- src/deep_neurographs/neurograph.py | 17 +- src/deep_neurographs/train_pipeline.py | 5 +- src/deep_neurographs/utils/graph_util.py | 7 +- src/deep_neurographs/utils/util.py | 209 ++++-------------- 7 files changed, 75 insertions(+), 196 deletions(-) diff --git a/src/deep_neurographs/graph_artifact_removal.py b/src/deep_neurographs/graph_artifact_removal.py index 421adba..627b65d 100644 --- a/src/deep_neurographs/graph_artifact_removal.py +++ b/src/deep_neurographs/graph_artifact_removal.py @@ -8,6 +8,8 @@ other from a NeuroGraph. """ +from collections import defaultdict + import numpy as np from networkx import connected_components from tqdm import tqdm @@ -93,7 +95,7 @@ def compute_projections(neurograph, kdtree, edge): projection distances. """ - hits = dict() + hits = defaultdict(list) query_id = neurograph.edges[edge]["swc_id"] for i, xyz in enumerate(neurograph.edges[edge]["xyz"]): # Compute projections @@ -108,7 +110,7 @@ def compute_projections(neurograph, kdtree, edge): # Store best if best_id: - hits = util.append_dict_value(hits, best_id, best_dist) + hits[best_id].append(best_dist) elif i == 15 and len(hits) == 0: return hits return hits diff --git a/src/deep_neurographs/machine_learning/gnn_trainer.py b/src/deep_neurographs/machine_learning/gnn_trainer.py index bd32650..bcd849b 100644 --- a/src/deep_neurographs/machine_learning/gnn_trainer.py +++ b/src/deep_neurographs/machine_learning/gnn_trainer.py @@ -118,7 +118,7 @@ def run(self, train_dataset_list, validation_dataset_list): loss.backward() self.optimizer.step() - # Store predictions + # Store prediction y.extend(toCPU(y_i)) hat_y.extend(toCPU(hat_y_i)) @@ -159,13 +159,8 @@ def predict(self, data): Prediction. """ - # Run model - x_dict, edge_index_dict, edge_attr_dict = gnn_util.get_inputs( - data, "HeteroGNN" - ) - hat_y = self.model(x_dict, edge_index_dict, edge_attr_dict) - - # Output + x, edge_index, edge_attr = gnn_util.get_inputs(data, "HeteroGNN") + hat_y = self.model(x, edge_index, edge_attr) y = data["proposal"]["y"] return truncate(hat_y, y), y @@ -197,9 +192,9 @@ def compute_metrics(self, y, hat_y, prefix, epoch): # Compute accuracy = accuracy_score(y, hat_y) accuracy_dif = accuracy - np.sum(y) / len(y) - precision = precision_score(y, hat_y) - recall = recall_score(y, hat_y) - f1 = f1_score(y, hat_y) + precision = precision_score(y, hat_y, zero_division=1.0) + recall = recall_score(y, hat_y, zero_division=1.0) + f1 = f1_score(y, hat_y, zero_division=1.0) # Log self.writer.add_scalar(prefix + "_accuracy:", accuracy, epoch) diff --git a/src/deep_neurographs/machine_learning/groundtruth_generation.py b/src/deep_neurographs/machine_learning/groundtruth_generation.py index c272143..0095630 100644 --- a/src/deep_neurographs/machine_learning/groundtruth_generation.py +++ b/src/deep_neurographs/machine_learning/groundtruth_generation.py @@ -9,6 +9,8 @@ """ +from collections import defaultdict + import networkx as nx import numpy as np @@ -145,13 +147,13 @@ def is_component_aligned(target_graph, pred_graph, component, kdtree): """ # Compute distances - dists = dict() + dists = defaultdict(list) for edge in pred_graph.subgraph(component).edges: for xyz in pred_graph.edges[edge]["xyz"]: hat_xyz = geometry.kdtree_query(kdtree, xyz) hat_swc_id = target_graph.xyz_to_swc(hat_xyz) d = get_dist(hat_xyz, xyz) - dists = util.append_dict_value(dists, hat_swc_id, d) + dists[hat_swc_id].append(d) # Deterine whether aligned hat_swc_id = util.find_best(dists) @@ -212,14 +214,14 @@ def is_valid(target_graph, pred_graph, kdtree, target_id, edge): def proj_branch(target_graph, pred_graph, kdtree, target_id, i): # Compute projections - hits = dict() + hits = defaultdict(list) for branch in pred_graph.get_branches(i): for xyz in branch: hat_xyz = geometry.kdtree_query(kdtree, xyz) swc_id = target_graph.xyz_to_swc(hat_xyz) if swc_id == target_id: hat_edge = target_graph.xyz_to_edge[hat_xyz] - hits = util.append_dict_value(hits, hat_edge, hat_xyz) + hits[hat_edge].append(hat_xyz) # Determine closest edge min_dist = np.inf diff --git a/src/deep_neurographs/neurograph.py b/src/deep_neurographs/neurograph.py index d9cad37..e131fe6 100644 --- a/src/deep_neurographs/neurograph.py +++ b/src/deep_neurographs/neurograph.py @@ -892,12 +892,6 @@ def leaf_neighbor(self, i): assert self.is_leaf(i) return list(self.neighbors(i))[0] - """ - def get_edge_attr(self, edge, key): - xyz_arr = gutil.get_edge_attr(self, edge, key) - return xyz_arr[0], xyz_arr[-1] - """ - def to_patch_coords(self, edge, midpoint, chunk_size): patch_coords = list() for xyz in self.edges[edge]["xyz"]: @@ -917,6 +911,7 @@ def xyz_to_swc(self, xyz, return_node=False): else: return None + """ def component_cardinality(self, root): cardinality = 0 queue = [(-1, root)] @@ -933,6 +928,7 @@ def component_cardinality(self, root): if frozenset((j, k)) not in visited: queue.append((j, k)) return cardinality + """ # --- write graph to swcs --- def to_zipped_swcs(self, zip_path, color=None): @@ -956,8 +952,7 @@ def to_zipped_swc(self, zip_writer, nodes, color): swc_id = self.nodes[i]["swc_id"] x, y, z = tuple(self.nodes[i]["xyz"]) r = self.nodes[i]["radius"] - if color != "1.0 0.0 0.0": - r += 1.5 + text_buffer.write("\n" + f"1 2 {x} {y} {z} {r} -1") node_to_idx[i] = 1 n_entries += 1 @@ -1056,11 +1051,11 @@ def branch_to_zip(self, text_buffer, n_entries, i, j, parent, color): branch_radius = np.flip(branch_radius, axis=0) # Make entries - for k in range(1, len(branch_xyz)): + idxs = np.arange(1, len(branch_xyz)) + for k in util.spaced_idxs(idxs, 4): x, y, z = tuple(branch_xyz[k]) r = branch_radius[k] - if color != "1.0 0.0 0.0": - r += 1 + node_id = n_entries + 1 parent = n_entries if k > 1 else parent text_buffer.write("\n" + f"{node_id} 2 {x} {y} {z} {r} {parent}") diff --git a/src/deep_neurographs/train_pipeline.py b/src/deep_neurographs/train_pipeline.py index f400840..c25bb77 100644 --- a/src/deep_neurographs/train_pipeline.py +++ b/src/deep_neurographs/train_pipeline.py @@ -135,6 +135,8 @@ def run(self): # Initialize training data self.generate_proposals() self.generate_features() + self.set_validation_idxs() + assert len(self.validation_dataset_list) > 0, "No validation data!" # Train model trainer = Trainer( @@ -174,7 +176,6 @@ def generate_proposals(self): print(f"{sample_id} {example_id} {n_proposals} {p_accepts}") def generate_features(self): - self.set_validation_idxs() for i in range(self.n_examples()): # Get proposals proposals_dict = { @@ -199,7 +200,7 @@ def generate_features(self): self.model_type, computation_graph=proposals_dict["graph"] ) - if i in self.validation_ids: + if i in self.validation_idxs: self.validation_dataset_list.append(dataset) else: self.train_dataset_list.append(dataset) diff --git a/src/deep_neurographs/utils/graph_util.py b/src/deep_neurographs/utils/graph_util.py index b01d9ce..fa2b8ce 100644 --- a/src/deep_neurographs/utils/graph_util.py +++ b/src/deep_neurographs/utils/graph_util.py @@ -22,6 +22,7 @@ """ +from collections import defaultdict from concurrent.futures import ProcessPoolExecutor, as_completed from random import sample @@ -372,7 +373,7 @@ def get_subcomponent_irreducibles(graph, swc_dict, smooth_bool): # Extract edges edges = dict() - nbs = dict() + nbs = defaultdict(list) root = None for (i, j) in nx.dfs_edges(graph, source=source): # Check if start of path is valid @@ -390,8 +391,8 @@ def get_subcomponent_irreducibles(graph, swc_dict, smooth_bool): ) else: edges[(root, j)] = attrs - nbs = util.append_dict_value(nbs, root, j) - nbs = util.append_dict_value(nbs, j, root) + nbs[root].append(j) # = util.append_dict_value(nbs, root, j) + nbs[j].append(root) # = util.append_dict_value(nbs, j, root) root = None # Output diff --git a/src/deep_neurographs/utils/util.py b/src/deep_neurographs/utils/util.py index 9c4809a..e8c9a68 100644 --- a/src/deep_neurographs/utils/util.py +++ b/src/deep_neurographs/utils/util.py @@ -14,7 +14,6 @@ import shutil from io import BytesIO from random import sample -from time import time from zipfile import ZipFile import boto3 @@ -23,158 +22,6 @@ from google.cloud import storage -# --- dictionary utils --- -def remove_item(my_set, item): - """ - Removes item from a set. - - Parameters - ---------- - my_set : set - Set to be queried. - item : - Value to query. - - Returns - ------- - set - Set "my_set" with "item" removed if it existed. - - """ - if item in my_set: - my_set.remove(item) - return my_set - - -def check_key(my_dict, key): - """ - Checks whether "key" is contained in "my_dict". If so, returns the - corresponding value. - - Parameters - ---------- - my_dict : dict - Dictionary to be checked - key : hashable data type - - Returns - ------- - dict value or bool - If "key" is a key in "my_dict", then the associated value is returned. - Otherwise, the bool "False" is returned. - - """ - if key in my_dict.keys(): - return my_dict[key] - else: - return False - - -def remove_key(my_dict, key): - """ - Removes "key" from "my_dict" in the case when key may need to be reversed. - - Parameters - ---------- - my_dict : dict - Dictionary to be queried - key : hashable data type - Key to query. - - Returns - ------- - dict - Updated dictionary. - - """ - if check_key(my_dict, key): - my_dict.pop(key) - elif check_key(my_dict, (key[1], key[0])): - my_dict.pop((key[1], key[0])) - return my_dict - - -def remove_items(my_dict, keys): - """ - Removes dictionary items corresponding to "keys". - - Parameters - ---------- - my_dict : dict - Dictionary to be edited. - keys : list - List of keys to be deleted from "my_dict". - - Returns - ------- - dict - Updated dictionary. - - """ - for key in keys: - if key in my_dict.keys(): - del my_dict[key] - return my_dict - - -def append_dict_value(my_dict, key, value): - """ - Appends "value" to the list stored at "key". - - Parameters - ---------- - my_dict : dict - Dictionary to be queried. - key : hashable data type - Key to be query. - value : list item type - Value to append to list stored at "key". - - Returns - ------- - dict - Updated dictionary. - - """ - if key in my_dict.keys(): - my_dict[key].append(value) - else: - my_dict[key] = [value] - return my_dict - - -def find_best(my_dict, maximize=True): - """ - Given a dictionary where each value is either a list or int (i.e. cnt), - finds the key associated with the longest list or largest integer. - - Parameters - ---------- - my_dict : dict - Dictionary to be searched. - maximize : bool, optional - Indication of whether to find the largest value or highest vote cnt. - - Returns - ------- - hashable data type - Key associated with the longest list or largest integer in "my_dict". - - """ - best_key = None - best_vote_cnt = 0 if maximize else np.inf - for key in my_dict.keys(): - val_type = type(my_dict[key]) - vote_cnt = my_dict[key] if val_type == float else len(my_dict[key]) - if vote_cnt > best_vote_cnt and maximize: - best_key = key - best_vote_cnt = vote_cnt - elif vote_cnt < best_vote_cnt and not maximize: - best_key = key - best_vote_cnt = vote_cnt - return best_key - - # --- os utils --- def mkdir(path, delete=False): """ @@ -633,26 +480,63 @@ def sample_once(my_container): return sample(my_container, 1)[0] -# --- runtime --- -def init_timers(): +# --- dictionary utils --- +def remove_items(my_dict, keys): """ - Initializes two timers. + Removes dictionary items corresponding to "keys". Parameters ---------- - None + my_dict : dict + Dictionary to be edited. + keys : list + List of keys to be deleted from "my_dict". Returns ------- - time.time - Timer. - time.time - Timer. + dict + Updated dictionary. + + """ + for key in keys: + if key in my_dict: + del my_dict[key] + return my_dict + +def find_best(my_dict, maximize=True): """ - return time(), time() + Given a dictionary where each value is either a list or int (i.e. cnt), + finds the key associated with the longest list or largest integer. + Parameters + ---------- + my_dict : dict + Dictionary to be searched. + maximize : bool, optional + Indication of whether to find the largest value or highest vote cnt. + Returns + ------- + hashable data type + Key associated with the longest list or largest integer in "my_dict". + + """ + best_key = None + best_vote_cnt = 0 if maximize else np.inf + for key in my_dict.keys(): + val_type = type(my_dict[key]) + vote_cnt = my_dict[key] if val_type == float else len(my_dict[key]) + if vote_cnt > best_vote_cnt and maximize: + best_key = key + best_vote_cnt = vote_cnt + elif vote_cnt < best_vote_cnt and not maximize: + best_key = key + best_vote_cnt = vote_cnt + return best_key + + +# --- miscellaneous --- def time_writer(t, unit="seconds"): """ Converts a runtime "t" to a larger unit of time if applicable. @@ -683,7 +567,6 @@ def time_writer(t, unit="seconds"): return t, unit -# --- miscellaneous --- def get_swc_id(path): """ Gets segment id of the swc file at "path". From 851aabe87c2b119fbf3785d6928918102f38ea53 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Fri, 27 Sep 2024 18:06:14 +0000 Subject: [PATCH 10/13] refactor: combined train engine and pipeline --- ...{run_pipeline.py => inference_pipeline.py} | 2 +- .../groundtruth_generation.py | 2 +- .../machine_learning/heterograph_models.py | 6 +- .../{gnn_trainer.py => train.py} | 209 ++++++++++++++++- .../machine_learning/trainer.py | 157 ------------- src/deep_neurographs/train_pipeline.py | 212 ------------------ 6 files changed, 209 insertions(+), 379 deletions(-) rename src/deep_neurographs/{run_pipeline.py => inference_pipeline.py} (99%) rename src/deep_neurographs/machine_learning/{gnn_trainer.py => train.py} (50%) delete mode 100644 src/deep_neurographs/machine_learning/trainer.py delete mode 100644 src/deep_neurographs/train_pipeline.py diff --git a/src/deep_neurographs/run_pipeline.py b/src/deep_neurographs/inference_pipeline.py similarity index 99% rename from src/deep_neurographs/run_pipeline.py rename to src/deep_neurographs/inference_pipeline.py index 4b4db4b..b7628bc 100644 --- a/src/deep_neurographs/run_pipeline.py +++ b/src/deep_neurographs/inference_pipeline.py @@ -38,7 +38,7 @@ from deep_neurographs.utils.graph_util import GraphLoader -class GraphTracePipeline: +class InferencePipeline: """ Class that executes the full GraphTrace inference pipeline. diff --git a/src/deep_neurographs/machine_learning/groundtruth_generation.py b/src/deep_neurographs/machine_learning/groundtruth_generation.py index 0095630..ad12616 100644 --- a/src/deep_neurographs/machine_learning/groundtruth_generation.py +++ b/src/deep_neurographs/machine_learning/groundtruth_generation.py @@ -19,7 +19,7 @@ from deep_neurographs.utils import graph_util as gutil from deep_neurographs.utils import util -ALIGNED_THRESHOLD = 3.5 +ALIGNED_THRESHOLD = 4 MIN_INTERSECTION = 10 diff --git a/src/deep_neurographs/machine_learning/heterograph_models.py b/src/deep_neurographs/machine_learning/heterograph_models.py index a27cd52..fd9794e 100644 --- a/src/deep_neurographs/machine_learning/heterograph_models.py +++ b/src/deep_neurographs/machine_learning/heterograph_models.py @@ -16,7 +16,7 @@ from torch_geometric.nn import GATv2Conv as GATConv from torch_geometric.nn import HEATConv, HeteroConv, Linear -from deep_neurographs.machine_learning import heterograph_feature_generation +from deep_neurographs import machine_learning as ml CONV_TYPES = ["GATConv", "GCNConv"] DROPOUT = 0.3 @@ -43,8 +43,8 @@ def __init__( """ super().__init__() # Feature vector sizes - node_dict = heterograph_feature_generation.n_node_features() - edge_dict = heterograph_feature_generation.n_edge_features() + node_dict = ml.heterograph_feature_generation.n_node_features() + edge_dict = ml.heterograph_feature_generation.n_edge_features() hidden_dim = scale_hidden_dim * np.max(list(node_dict.values())) # Linear layers diff --git a/src/deep_neurographs/machine_learning/gnn_trainer.py b/src/deep_neurographs/machine_learning/train.py similarity index 50% rename from src/deep_neurographs/machine_learning/gnn_trainer.py rename to src/deep_neurographs/machine_learning/train.py index 2c17c2a..99ed883 100644 --- a/src/deep_neurographs/machine_learning/gnn_trainer.py +++ b/src/deep_neurographs/machine_learning/train.py @@ -4,13 +4,14 @@ @author: Anna Grim @email: anna.grim@alleninstitute.org -Routines for training heterogeneous graph neural networks that classify -edge proposals. +Routines for training machine learning models that classify proposals. """ +import os from copy import deepcopy -from random import shuffle +from datetime import datetime +from random import sample, shuffle import numpy as np import torch @@ -20,11 +21,14 @@ precision_score, recall_score, ) +from torch.nn import BCEWithLogitsLoss from torch.optim.lr_scheduler import StepLR from torch.utils.tensorboard import SummaryWriter -from deep_neurographs.utils import gnn_util, ml_util +from deep_neurographs.machine_learning import feature_generation +from deep_neurographs.utils import gnn_util, img_util, ml_util, util from deep_neurographs.utils.gnn_util import toCPU +from deep_neurographs.utils.graph_util import GraphLoader LR = 1e-3 N_EPOCHS = 200 @@ -33,7 +37,197 @@ WEIGHT_DECAY = 1e-3 -class Trainer: +class TrainPipeline: + """ + Class that is used to train a machine learning model that classifies + proposals. + + """ + def __init__( + self, + config, + model, + model_type, + criterion=None, + output_dir=None, + validation_ids=None, + save_model_bool=True, + ): + # Check for parameter errors + if save_model_bool and not output_dir: + raise ValueError("Must provide output_dir to save model.") + + # Set class attributes + self.idx_to_ids = list() + self.model = model + self.model_type = model_type + self.output_dir = output_dir + self.save_model_bool = save_model_bool + self.validation_ids = validation_ids + + # Set data structures for training examples + self.gt_graphs = list() + self.pred_graphs = list() + self.imgs = dict() + self.train_dataset_list = list() + self.validation_dataset_list = list() + + # Train parameters + self.criterion = criterion if criterion else BCEWithLogitsLoss() + self.validation_ids = validation_ids + + # Extract config settings + self.graph_config = config.graph_config + self.ml_config = config.ml_config + self.graph_loader = GraphLoader( + min_size=self.graph_config.min_size, + progress_bar=False, + ) + + # --- getters/setters --- + def n_examples(self): + return len(self.gt_graphs) + + def n_train_examples(self): + return len(self.train_dataset_list) + + def n_validation_samples(self): + return len(self.validation_dataset_list) + + def set_validation_idxs(self): + if self.validation_ids is None: + k = int(self.ml_config.validation_split * self.n_examples()) + self.validation_idxs = sample(np.arange(self.n_examples), k) + else: + self.validation_idxs = list() + for ids in self.validation_ids: + for i in range(self.n_examples()): + same = all([ids[k] == self.idx_to_ids[i][k] for k in ids]) + if same: + self.validation_idxs.append(i) + assert len(self.validation_idxs) > 0, "No validation data!" + + # --- loaders --- + def load_example( + self, + gt_pointer, + pred_pointer, + sample_id, + example_id=None, + pred_id=None, + metadata_path=None, + ): + # Read metadata + if metadata_path: + origin, shape = util.read_metadata(metadata_path) + else: + origin, shape = None, None + + # Load graphs + self.gt_graphs.append(self.graph_loader.run(gt_pointer)) + self.pred_graphs.append( + self.graph_loader.run( + pred_pointer, + img_patch_origin=origin, + img_patch_shape=shape, + ) + ) + + # Set example ids + self.idx_to_ids.append( + { + "sample_id": sample_id, + "example_id": example_id, + "pred_id": pred_id, + } + ) + + def load_img(self, path, sample_id): + if sample_id not in self.imgs: + self.imgs[sample_id] = img_util.open_tensorstore(path, "zarr") + + # --- main pipeline --- + def run(self): + # Initialize training data + self.set_validation_idxs() + self.generate_proposals() + self.generate_features() + + # Train model + train_engine = TrainEngine( + self.model, + self.criterion, + lr=self.ml_config.lr, + n_epochs=self.ml_config.n_epochs, + ) + self.model = train_engine.run( + self.train_dataset_list, self.validation_dataset_list + ) + + # Save model (if applicable) + if self.save_model_bool: + self.save_model() + + def generate_proposals(self): + print("sample_id - example_id - # proposals - % accepted") + for i in range(self.n_examples()): + # Run + self.pred_graphs[i].generate_proposals( + self.graph_config.search_radius, + complex_bool=self.graph_config.complex_bool, + groundtruth_graph=self.gt_graphs[i], + long_range_bool=self.graph_config.long_range_bool, + progress_bar=False, + proposals_per_leaf=self.graph_config.proposals_per_leaf, + trim_endpoints_bool=self.graph_config.trim_endpoints_bool, + ) + + # Report results + sample_id = self.idx_to_ids[i]["sample_id"] + example_id = self.idx_to_ids[i]["example_id"] + n_proposals = self.pred_graphs[i].n_proposals() + n_targets = len(self.pred_graphs[i].target_edges) + p_accepts = round(n_targets / n_proposals, 4) + print(f"{sample_id} {example_id} {n_proposals} {p_accepts}") + + def generate_features(self): + for i in range(self.n_examples()): + # Get proposals + proposals_dict = { + "proposals": self.pred_graphs[i].list_proposals(), + "graph": self.pred_graphs[i].copy_graph() + } + + # Generate features + sample_id = self.idx_to_ids[i]["sample_id"] + features = feature_generation.run( + self.pred_graphs[i], + self.imgs[sample_id], + self.model_type, + proposals_dict, + self.graph_config.search_radius, + ) + + # Initialize train and validation datasets + dataset = ml_util.init_dataset( + self.pred_graphs[i], + features, + self.model_type, + computation_graph=proposals_dict["graph"] + ) + if i in self.validation_idxs: + self.validation_dataset_list.append(dataset) + else: + self.train_dataset_list.append(dataset) + + def save_model(self): + name = self.model_type + "-" + datetime.today().strftime('%Y-%m-%d') + extension = ".pth" if "Net" in self.model_type else ".joblib" + path = os.path.join(self.output_dir, name + extension) + ml_util.save_model(path, self.model, self.model_type) + + +class TrainEngine: """ Custom class that trains graph neural networks. @@ -205,6 +399,11 @@ def compute_metrics(self, y, hat_y, prefix, epoch): return f1 +def fit_random_forest(model, dataset): + model.fit(dataset.data.x, dataset.data.y) + return model + + # -- util -- def shuffler(my_list): """ diff --git a/src/deep_neurographs/machine_learning/trainer.py b/src/deep_neurographs/machine_learning/trainer.py deleted file mode 100644 index 954f125..0000000 --- a/src/deep_neurographs/machine_learning/trainer.py +++ /dev/null @@ -1,157 +0,0 @@ -""" -Created on Sat November 04 15:30:00 2023 - -@author: Anna Grim -@email: anna.grim@alleninstitute.org - -Routines for training models that classify edge proposals. - -""" - -import logging - -import lightning.pytorch as pl -import torch -import torch.nn as nn -import torch.utils.data as torch_data -from lightning.pytorch.callbacks import ModelCheckpoint -from torch.nn.functional import sigmoid -from torch.utils.data import DataLoader -from torcheval.metrics.functional import ( - binary_f1_score, - binary_precision, - binary_recall, -) - -logging.getLogger("pytorch_lightning").setLevel(logging.ERROR) - -BATCH_SIZE = 32 -SHUFFLE = True -SUPPORTED_MODELS = [ - "AdaBoost", - "RandomForest", - "FeedForwardNet", - "ConvNet", - "MultiModalNet", -] - - -def fit_model(model, dataset): - model.fit(dataset.data.x, dataset.data.y) - return model - - -def fit_deep_model( - model, - dataset, - batch_size=BATCH_SIZE, - criterion=None, - logger=False, - lr=1e-3, - max_epochs=1000, -): - """ - Fits a neural network to a dataset. - - Parameters - ---------- - model : ... - ... - dataset : ... - ... - lr : float, optional - Learning rate to be used if model is a neural network. The default is - 1e-3. - logger : bool, optional - Indication of whether to log performance stats while neural network - trains. The default is False. - max_epochs : int, optional - Maximum number of epochs used to train neural network. The default is - 50. - - Returns - ------- - ... - """ - # Load data - train_set, valid_set = random_split(dataset.data) - train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True) - valid_loader = DataLoader(valid_set, batch_size=batch_size) - - # Configure trainer - lit_model = LitModel(criterion=criterion, model=model, lr=lr) - ckpt_callback = ModelCheckpoint(save_top_k=1, monitor="val_f1", mode="max") - - # Fit model - pylightning_trainer = pl.Trainer( - accelerator="gpu", - callbacks=[ckpt_callback], - devices=1, - enable_model_summary=False, - enable_progress_bar=False, - logger=logger, - log_every_n_steps=1, - max_epochs=max_epochs, - ) - pylightning_trainer.fit(lit_model, train_loader, valid_loader) - - # Return best model - ckpt = torch.load(ckpt_callback.best_model_path) - lit_model.model.load_state_dict(ckpt["state_dict"]) - return lit_model.model - - -def random_split(train_set, train_ratio=0.8): - train_set_size = int(len(train_set) * train_ratio) - valid_set_size = len(train_set) - train_set_size - return torch_data.random_split(train_set, [train_set_size, valid_set_size]) - - -# -- Lightning Module -- -class LitModel(pl.LightningModule): - def __init__(self, criterion=None, model=None, lr=1e-3): - super().__init__() - self.model = model - self.lr = lr - if criterion: - self.criterion = criterion - else: - pos_weight = torch.tensor([1.0], device=0) - self.criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight) - - def forward(self, batch): - x = self.get_example(batch, "inputs") - return self.model(x) - - def configure_optimizers(self): - optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) - return optimizer - - def training_step(self, batch, batch_idx): - X = self.get_example(batch, "inputs") - y = self.get_example(batch, "targets") - y_hat = self.model(X) - - loss = self.criterion(y_hat, y) - self.log("train_loss", loss) - self.compute_stats(y_hat, y, prefix="train_") - return loss - - def validation_step(self, batch, batch_idx): - X = self.get_example(batch, "inputs") - y = self.get_example(batch, "targets") - y_hat = self.model(X) - self.compute_stats(y_hat, y, prefix="val_") - - def compute_stats(self, y_hat, y, prefix=""): - y_hat = torch.flatten(sigmoid(y_hat)) - y = torch.flatten(y).to(torch.int) - self.log(prefix + "precision", binary_precision(y_hat, y)) - self.log(prefix + "recall", binary_recall(y_hat, y)) - self.log(prefix + "f1", binary_f1_score(y_hat, y)) - - def get_example(self, batch, key): - return batch[key] - - def state_dict(self, destination=None, prefix="", keep_vars=False): - return self.model.state_dict(destination, prefix + "", keep_vars) diff --git a/src/deep_neurographs/train_pipeline.py b/src/deep_neurographs/train_pipeline.py deleted file mode 100644 index c25bb77..0000000 --- a/src/deep_neurographs/train_pipeline.py +++ /dev/null @@ -1,212 +0,0 @@ -""" -Created on Sat Sept 16 11:30:00 2024 - -@author: Anna Grim -@email: anna.grim@alleninstitute.org - - -This script trains the GraphTrace inference pipeline. - -""" - -import os -from datetime import datetime -from random import sample - -import numpy as np -from torch.nn import BCEWithLogitsLoss - -from deep_neurographs.machine_learning import feature_generation -from deep_neurographs.machine_learning.gnn_trainer import Trainer -from deep_neurographs.utils import img_util, ml_util, util -from deep_neurographs.utils.graph_util import GraphLoader - - -class TrainingPipeline: - """ - Class that is used to train a machine learning model that classifies - proposals. - - """ - def __init__( - self, - config, - model, - model_type, - criterion=None, - output_dir=None, - validation_ids=None, - save_model_bool=True, - ): - # Check for parameter errors - if save_model_bool and not output_dir: - raise ValueError("Must provide output_dir to save model.") - - # Set class attributes - self.idx_to_ids = list() - self.model = model - self.model_type = model_type - self.output_dir = output_dir - self.save_model_bool = save_model_bool - self.validation_ids = validation_ids - - # Set data structures for training examples - self.gt_graphs = list() - self.pred_graphs = list() - self.imgs = dict() - self.train_dataset_list = list() - self.validation_dataset_list = list() - - # Train parameters - self.criterion = criterion if criterion else BCEWithLogitsLoss() - self.validation_ids = validation_ids - - # Extract config settings - self.graph_config = config.graph_config - self.ml_config = config.ml_config - self.graph_loader = GraphLoader( - min_size=self.graph_config.min_size, - progress_bar=False, - ) - - # --- getters/setters --- - def n_examples(self): - return len(self.gt_graphs) - - def n_train_examples(self): - return len(self.train_dataset_list) - - def n_validation_samples(self): - return len(self.validation_dataset_list) - - def set_validation_idxs(self): - if self.validation_ids is None: - k = int(self.ml_config.validation_split * self.n_examples()) - self.validation_idxs = sample(np.arange(self.n_examples), k) - else: - self.validation_idxs = list() - for ids in self.validation_ids: - for i in range(self.n_examples()): - same = all([ids[k] == self.idx_to_ids[i][k] for k in ids]) - if same: - self.validation_idxs.append(i) - - # --- loaders --- - def load_example( - self, - gt_pointer, - pred_pointer, - sample_id, - example_id=None, - pred_id=None, - metadata_path=None, - ): - # Read metadata - if metadata_path: - origin, shape = util.read_metadata(metadata_path) - else: - origin, shape = None, None - - # Load graphs - self.gt_graphs.append(self.graph_loader.run(gt_pointer)) - self.pred_graphs.append( - self.graph_loader.run( - pred_pointer, - img_patch_origin=origin, - img_patch_shape=shape, - ) - ) - - # Set example ids - self.idx_to_ids.append( - { - "sample_id": sample_id, - "example_id": example_id, - "pred_id": pred_id, - } - ) - - def load_img(self, path, sample_id): - if sample_id not in self.imgs: - self.imgs[sample_id] = img_util.open_tensorstore(path, "zarr") - - # --- main pipeline --- - def run(self): - # Initialize training data - self.generate_proposals() - self.generate_features() - self.set_validation_idxs() - assert len(self.validation_dataset_list) > 0, "No validation data!" - - # Train model - trainer = Trainer( - self.model, - self.criterion, - lr=self.ml_config.lr, - n_epochs=self.ml_config.n_epochs, - ) - self.model = trainer.run( - self.train_dataset_list, self.validation_dataset_list - ) - - # Save model (if applicable) - if self.save_model_bool: - self.save_model() - - def generate_proposals(self): - print("sample_id - example_id - # proposals - % accepted") - for i in range(self.n_examples()): - # Run - self.pred_graphs[i].generate_proposals( - self.graph_config.search_radius, - complex_bool=self.graph_config.complex_bool, - groundtruth_graph=self.gt_graphs[i], - long_range_bool=self.graph_config.long_range_bool, - progress_bar=False, - proposals_per_leaf=self.graph_config.proposals_per_leaf, - trim_endpoints_bool=self.graph_config.trim_endpoints_bool, - ) - - # Report results - sample_id = self.idx_to_ids[i]["sample_id"] - example_id = self.idx_to_ids[i]["example_id"] - n_proposals = self.pred_graphs[i].n_proposals() - n_targets = len(self.pred_graphs[i].target_edges) - p_accepts = round(n_targets / n_proposals, 4) - print(f"{sample_id} {example_id} {n_proposals} {p_accepts}") - - def generate_features(self): - for i in range(self.n_examples()): - # Get proposals - proposals_dict = { - "proposals": self.pred_graphs[i].list_proposals(), - "graph": self.pred_graphs[i].copy_graph() - } - - # Generate features - sample_id = self.idx_to_ids[i]["sample_id"] - features = feature_generation.run( - self.pred_graphs[i], - self.imgs[sample_id], - self.model_type, - proposals_dict, - self.graph_config.search_radius, - ) - - # Initialize train and validation datasets - dataset = ml_util.init_dataset( - self.pred_graphs[i], - features, - self.model_type, - computation_graph=proposals_dict["graph"] - ) - if i in self.validation_idxs: - self.validation_dataset_list.append(dataset) - else: - self.train_dataset_list.append(dataset) - - def save_model(self): - name = self.model_type + "-" + datetime.today().strftime('%Y-%m-%d') - extension = ".pth" if "Net" in self.model_type else ".joblib" - path = os.path.join(self.output_dir, name + extension) - ml_util.save_model(path, self.model, self.model_type) From c57bc77b035c3a8b4de7b9a7e702bb0495e286f2 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Fri, 27 Sep 2024 20:18:23 +0000 Subject: [PATCH 11/13] refactor: infernce pipeline, evaluation --- src/deep_neurographs/geometry.py | 2 +- .../machine_learning/evaluation.py | 65 ------------------- .../machine_learning/feature_generation.py | 2 +- .../machine_learning/inference.py | 17 +++-- src/deep_neurographs/neurograph.py | 2 +- src/deep_neurographs/utils/img_util.py | 2 +- src/deep_neurographs/utils/swc_util.py | 2 +- 7 files changed, 15 insertions(+), 77 deletions(-) diff --git a/src/deep_neurographs/geometry.py b/src/deep_neurographs/geometry.py index 989806a..21c03c2 100644 --- a/src/deep_neurographs/geometry.py +++ b/src/deep_neurographs/geometry.py @@ -280,7 +280,7 @@ def get_profile(img, xyz_arr, process_id=None, window=[5, 5, 5]): """ profile = [] for xyz in xyz_arr: - if type(img) == ts.TensorStore: + if type(img) is ts.TensorStore: profile.append(np.max(util.read_tensorstore(img, xyz, window))) else: profile.append(np.max(util.get_chunk(img, xyz, window))) diff --git a/src/deep_neurographs/machine_learning/evaluation.py b/src/deep_neurographs/machine_learning/evaluation.py index 2abe7d2..052f88d 100644 --- a/src/deep_neurographs/machine_learning/evaluation.py +++ b/src/deep_neurographs/machine_learning/evaluation.py @@ -82,74 +82,9 @@ def run_evaluation(neurograph, accepts, proposals): stats["Overall"][metric].append(overall_stats[metric]) stats["Simple"][metric].append(simple_stats[metric]) stats["Complex"][metric].append(complex_stats[metric]) - return stats -def run_evaluation_blocks(neurographs, blocks, accepts): - """ - Runs an evaluation on the accuracy of the predictions generated by an edge - classication model for a given list of blocks. - - Parameters - ---------- - neurographs : list[NeuroGraph] - Predicted neurographs. - blocks : list[str], optional - List of block_ids that indicate which predictions to evaluate. - accepts : list - Accepted proposals. - - Returns - ------- - dict[dict] - Acuracy of the edge classification model on all edges, simple edges, - and complex edges. The metrics contained in a sub-dictionary where the - keys are identical to "METRICS_LIST"]. - - """ - avg_wgts = {"Overall": [], "Simple": [], "Complex": []} - stats = { - "Overall": init_stats(), - "Simple": init_stats(), - "Complex": init_stats(), - } - for block_id in blocks: - # Compute accuracy - overall_stats_i = get_stats( - neurographs[block_id], - neurographs[block_id].proposals, - accepts[block_id], - ) - - simple_stats_i = get_stats( - neurographs[block_id], - neurographs[block_id].simple_proposals(), - accepts[block_id], - ) - - complex_stats_i = get_stats( - neurographs[block_id], - neurographs[block_id].complex_proposals(), - accepts[block_id], - ) - - # Store results - avg_wgts["Overall"].append(len(neurographs[block_id].proposals)) - avg_wgts["Simple"].append( - len(neurographs[block_id].simple_proposals()) - ) - avg_wgts["Complex"].append( - len(neurographs[block_id].complex_proposals()) - ) - for metric in METRICS_LIST: - stats["Overall"][metric].append(overall_stats_i[metric]) - stats["Simple"][metric].append(simple_stats_i[metric]) - stats["Complex"][metric].append(complex_stats_i[metric]) - - return stats, avg_wgts - - def get_stats(neurograph, proposals, accepts): """ Accuracy of the predictions generated by an edge classication model on a diff --git a/src/deep_neurographs/machine_learning/feature_generation.py b/src/deep_neurographs/machine_learning/feature_generation.py index 6ed110e..f161491 100644 --- a/src/deep_neurographs/machine_learning/feature_generation.py +++ b/src/deep_neurographs/machine_learning/feature_generation.py @@ -593,7 +593,7 @@ def generate_chunks(neurograph, proposals, img, labels): def get_chunk(img, labels, voxel_1, voxel_2, thread_id=None): # Extract chunks midpoint = geometry.get_midpoint(voxel_1, voxel_2).astype(int) - if type(img) == ts.TensorStore: + if type(img) is ts.TensorStore: chunk = util.read_tensorstore(img, midpoint, CHUNK_SIZE) labels_chunk = util.read_tensorstore(labels, midpoint, CHUNK_SIZE) else: diff --git a/src/deep_neurographs/machine_learning/inference.py b/src/deep_neurographs/machine_learning/inference.py index 3cfbf3e..cf962db 100644 --- a/src/deep_neurographs/machine_learning/inference.py +++ b/src/deep_neurographs/machine_learning/inference.py @@ -40,7 +40,7 @@ def __init__( search_radius, batch_size=BATCH_SIZE, confidence_threshold=CONFIDENCE_THRESHOLD, - downsample_factor=0, + downsample_factor=1, ): """ Initializes an inference engine by loading images and setting class @@ -122,8 +122,10 @@ def run(self, neurograph, proposals): preds = self.run_model(dataset) # Update graph - batch_accepts = get_accepted_proposals(neurograph, preds) - for proposal in map(frozenset, batch_accepts): + batch_accepts = get_accepted_proposals( + neurograph, preds, self.threshold + ) + for proposal in batch_accepts: neurograph.merge_proposal(proposal) # Finish @@ -222,7 +224,7 @@ def run_model(self, dataset): # Filter preds idxs = dataset.idxs_proposals["idx_to_edge"] - return {idxs[i]: p for i, p in enumerate(preds) if p > self.threshold} + return {idxs[i]: p for i, p in enumerate(preds)} # --- run machine learning model --- @@ -257,7 +259,7 @@ def run_gnn_model(data, model, model_type): # --- Accepting proposals --- -def get_accepted_proposals(neurograph, preds, high_threshold=0.9): +def get_accepted_proposals(neurograph, preds, threshold, high_threshold=0.9): """ Determines which proposals to accept based on prediction scores and the specified threshold. @@ -280,6 +282,7 @@ def get_accepted_proposals(neurograph, preds, high_threshold=0.9): """ # Partition proposals into best and the rest + preds = {k: v for k, v in preds.items() if v > threshold} best_proposals, proposals = separate_best( preds, neurograph.simple_proposals(), high_threshold ) @@ -359,8 +362,8 @@ def filter_proposals(graph, proposals): created_cycle, _ = gutil.creates_cycle(subgraph, (i, j)) if not created_cycle: graph.add_edge(i, j) - accepts.append((i, j)) - graph.remove_edges_from(accepts) + accepts.append(frozenset({i, j})) + graph.remove_edges_from(map(tuple, accepts)) return accepts diff --git a/src/deep_neurographs/neurograph.py b/src/deep_neurographs/neurograph.py index e131fe6..cceebdb 100644 --- a/src/deep_neurographs/neurograph.py +++ b/src/deep_neurographs/neurograph.py @@ -852,7 +852,7 @@ def branch_contained(self, xyz_list): def to_voxels(self, node_or_xyz, shift=False): shift = self.origin if shift else np.zeros((3)) - if type(node_or_xyz) == int: + if type(node_or_xyz) is int: coord = img_util.to_voxels(self.nodes[node_or_xyz]["xyz"]) else: coord = img_util.to_voxels(node_or_xyz) diff --git a/src/deep_neurographs/utils/img_util.py b/src/deep_neurographs/utils/img_util.py index c6dd8ac..da2bed2 100644 --- a/src/deep_neurographs/utils/img_util.py +++ b/src/deep_neurographs/utils/img_util.py @@ -479,4 +479,4 @@ def find_img_path(bucket_name, img_root, dataset_name): for subdir in util.list_gcs_subdirectories(bucket_name, img_root): if dataset_name in subdir: return subdir + "whole-brain/fused.zarr/" - raise(f"Dataset not found in {bucket_name} - {img_root}") + raise f"Dataset not found in {bucket_name} - {img_root}" diff --git a/src/deep_neurographs/utils/swc_util.py b/src/deep_neurographs/utils/swc_util.py index e279700..bc5af3d 100644 --- a/src/deep_neurographs/utils/swc_util.py +++ b/src/deep_neurographs/utils/swc_util.py @@ -609,7 +609,7 @@ def to_graph(swc_dict, swc_id=None, set_attrs=False): graph.add_edges_from(zip(swc_dict["id"][1:], swc_dict["pid"][1:])) if set_attrs: xyz = swc_dict["xyz"] - if type(swc_dict["xyz"]) == np.ndarray: + if type(swc_dict["xyz"]) is np.ndarray: xyz = util.numpy_to_hashable(swc_dict["xyz"]) graph = __add_attributes(swc_dict, graph) xyz_to_node = dict(zip(xyz, swc_dict["id"])) From 1004f8562c8d6e2a43090315c27465ff7039fc8d Mon Sep 17 00:00:00 2001 From: anna-grim Date: Fri, 27 Sep 2024 23:49:35 +0000 Subject: [PATCH 12/13] moved files --- src/deep_neurographs/delete_merges_gt.py | 330 ---------------- src/deep_neurographs/densegraph.py | 191 ---------- .../{machine_learning => }/evaluation.py | 0 .../groundtruth_generation.py | 0 .../{machine_learning => }/inference.py | 353 ++++++++++++++++- src/deep_neurographs/inference_pipeline.py | 354 ------------------ src/deep_neurographs/neurograph.py | 2 +- .../{machine_learning => }/train.py | 0 8 files changed, 349 insertions(+), 881 deletions(-) delete mode 100644 src/deep_neurographs/delete_merges_gt.py delete mode 100644 src/deep_neurographs/densegraph.py rename src/deep_neurographs/{machine_learning => }/evaluation.py (100%) rename src/deep_neurographs/{machine_learning => }/groundtruth_generation.py (100%) rename src/deep_neurographs/{machine_learning => }/inference.py (51%) delete mode 100644 src/deep_neurographs/inference_pipeline.py rename src/deep_neurographs/{machine_learning => }/train.py (100%) diff --git a/src/deep_neurographs/delete_merges_gt.py b/src/deep_neurographs/delete_merges_gt.py deleted file mode 100644 index 118a34a..0000000 --- a/src/deep_neurographs/delete_merges_gt.py +++ /dev/null @@ -1,330 +0,0 @@ -""" -Created on Sat March 26 17:30:00 2024 - -@author: Anna Grim -@email: anna.grim@alleninstitute.org - -Deletes merges from predicted swc files in the case when there are ground -truth swc files. - -""" - -import os - -import networkx as nx -import numpy as np - -from deep_neurographs import geometry -from deep_neurographs.densegraph import DenseGraph -from deep_neurographs.utils import swc_util, util - -CLOSE_DISTANCE_THRESHOLD = 3.5 -DELETION_RADIUS = 3 -MERGE_DIST_THRESHOLD = 30 -MIN_INTERSECTION = 10 - - -def delete_merges( - target_swc_paths, - pred_swc_paths, - output_dir, - img_patch_origin=None, - img_patch_shape=None, - radius=DELETION_RADIUS, - save_sites=False, -): - """ - Deletes merges from predicted swc files in the case when there are ground - truth swc files. - - Parameters - ---------- - target_swc_paths : list[str] - List of paths to ground truth swc files. - pred_swc_paths : list[str] - List of paths to predicted swc files. - output_dir : str - Directory that updated graphs and merge sites are written to. - img_patch_origin : list[float], optional - An xyz coordinate in the image which is the upper, left, front corner - of am image patch that contains the swc files. The default is None. - img_patch_shape : list[float], optional - The xyz dimensions of the bounding box which contains the swc files. - The default is None. - radius : int, optional - Each node within "radius" is deleted. The default is the global - variable "DELETION_RADIUS". - save_sites : bool, optional - Indication of whether to save merge sites. The default is False. - - Returns - ------- - None - - """ - # Initializations - target_densegraph = DenseGraph(target_swc_paths) - pred_densegraph = DenseGraph( - pred_swc_paths, - img_patch_origin=img_patch_origin, - img_patch_shape=img_patch_shape, - ) - if save_sites: - util.mkdir(os.path.join(output_dir, "merge_sites")) - - # Run merge deletion - for swc_id in pred_densegraph.graphs.keys(): - # Detection - graph = pred_densegraph.graphs[swc_id] - delete_nodes = detect_merges_neuron( - target_densegraph, - graph, - radius, - output_dir=output_dir, - save=save_sites, - ) - - # Finish - if len(delete_nodes) > 0: - graph.remove_nodes_from(delete_nodes) - print("Merge Detected:", swc_id) - print("# Nodes Deleted:", len(delete_nodes)) - print("") - pred_densegraph.graphs[swc_id] = graph - - # Save - pred_densegraph.save(output_dir) - - -def detect_merges_neuron( - target_densegraph, graph, radius, output_dir=None, save=False -): - """ - Determines whether the "graph" contains merge mistakes. This routine - projects each node in "graph" onto "target_neurograph", then computes - the projection distance. ... - - Parameters - ---------- - target_densegraph : DenseGraph - Graph built from ground truth swc files. - graph : networkx.Graph - Graph build from a predicted swc file. - radius : int - Each node within "radius" is deleted. - output_dir : str, optional - Directory that merge sites are saved in swc files. The default is - None. - save : bool, optional - Indication of whether to save merge sites. The default is False. - - Returns - ------- - delete_nodes : set - Nodes that are part of a merge mistake. - - """ - delete_nodes = set() - for component in nx.connected_components(graph): - hits = detect_intersections(target_densegraph, graph, component) - sites = detect_merges( - target_densegraph, graph, hits, radius, output_dir, save - ) - delete_nodes = delete_nodes.union(sites) - return delete_nodes - - -def detect_intersections(target_densegraph, graph, component): - """ - Projects each node in "component" onto the closest node in - "target_densegraph". If the projection distance for a given node is less - than "CLOSE_DISTANCE_THRESHOLD", then this node is said to 'intersect' - with ground truth neuron corresponding to "hat_swc_id". - - Parameters - ---------- - target_densegraph : DenseGraph - Graph built from ground truth swc files. - graph : networkx.Graph - Graph build from a predicted swc file. - component : iterator - Nodes that comprise a connected component. - - Returns - ------- - dict - Dictionary that records intersections between "component" and ground - truth graphs stored in "target_densegraph". Each item consists of the - swc_id of a neuron from the ground truth and the nodes from - "component" that intersect that neuron. - - """ - # Compute projections - hits = dict() - for i in component: - xyz = tuple(graph.nodes[i]["xyz"]) - hat_xyz = target_densegraph.get_projection(xyz) - hat_swc_id = target_densegraph.xyz_to_swc[hat_xyz] - if geometry.dist(hat_xyz, xyz) < CLOSE_DISTANCE_THRESHOLD: - hits = util.append_dict_value(hits, hat_swc_id, i) - - # Remove spurious intersections - keys = [key for key in hits.keys() if len(hits[key]) < MIN_INTERSECTION] - return util.remove_items(hits, keys) - - -def detect_merges(target_densegraph, graph, hits, radius, output_dir, save): - """ - Detects merge mistakes in "graph" (i.e. whether "graph" is closely aligned - with two distinct connected components in "target_densegraph". - - Parameters - ---------- - target_densegraph : DenseGraph - Graph built from ground truth swc files. - graph : networkx.Graph - Graph build from a predicted swc file. - hits : dict - Dictionary that stores intersections between "target_densegraph" and - "graph", where the keys are swc ids from "target_densegraph" and - values are nodes from "graph". - radius : int - Each node within "radius" is deleted. - output_dir : str, optional - Directory that merge sites are saved in swc files. The default is - None. - save : bool, optional - Indication of whether to save merge sites. - - Returns - ------- - merge_sites : set - Nodes that are part of a merge site. - - """ - merge_sites = set() - if len(hits.keys()) > 0: - visited = set() - for id_1 in hits.keys(): - for id_2 in hits.keys(): - # Determine whether to visit - pair = frozenset((id_1, id_2)) - if id_1 == id_2 or pair in visited: - continue - - # Check for merge site - min_dist, sites = locate_site(graph, hits[id_1], hits[id_2]) - visited.add(pair) - if min_dist < MERGE_DIST_THRESHOLD: - merge_nbhd = get_merged_nodes(graph, sites, radius) - merge_sites = merge_sites.union(merge_nbhd) - if save: - dir_name = f"{output_dir}/merge_sites/" - filename = "merge-" + graph.nodes[sites[0]]["swc_id"] - path = util.set_path(dir_name, filename, "swc") - xyz = get_point(graph, sites) - swc_util.save_point(path, xyz) - return merge_sites - - -def locate_site(graph, merged_1, merged_2): - """ - Locates the approximate site of where a merge between two neurons occurs. - - Parameters - ---------- - graph : networkx.Graph - Graph to be searched. - merged_1 : list - List of nodes part of merge. - merged_2 : list - List of nodes part of merge. - - Returns - ------- - node_pair : tuple - Closest nodes from "merged_1" and "merged_2" - min_dist : float - Euclidean distance between nodes in "node_pair". - - """ - min_dist = np.inf - node_pair = (None, None) - for i in merged_1: - for j in merged_2: - xyz_i = graph.nodes[i]["xyz"] - xyz_j = graph.nodes[j]["xyz"] - if geometry.dist(xyz_i, xyz_j) < min_dist: - min_dist = geometry.dist(xyz_i, xyz_j) - node_pair = [i, j] - return min_dist, node_pair - - -def get_merged_nodes(graph, sites, radius): - """ - Gets nodes that are falsely merged. - - Parameters - ---------- - graph : networkx.Graph - Graph that contains a merge at "sites". - sites : list - Nodes in "graph" that are part of a merge mistake. - radius : int - Radius about node to be searched. - - Returns - ------- - merged_nodes : set - Nodes that are falsely merged. - - """ - i, j = tuple(sites) - merged_nodes = set(nx.shortest_path(graph, source=i, target=j)) - merged_nodes = merged_nodes.union(get_nbhd(graph, i, radius)) - merged_nodes = merged_nodes.union(get_nbhd(graph, j, radius)) - return merged_nodes - - -def get_nbhd(graph, i, radius): - """ - Gets all nodes within a path length of "radius" from node "i". - - Parameters - ---------- - graph : networkx.Graph - Graph to searched. - i : node - Node that is root of neighborhood to be returned. - radius : int - Radius about node to be searched. - - Returns - ------- - set - Nodes within a path length of "radius" from node "i". - - """ - return set(nx.dfs_tree(graph, source=i, depth_limit=radius)) - - -def get_point(graph, sites): - """ - Gets midpoint of merge site defined by the pair contained in "sites". - - Parameters - ---------- - graph : networkx.Graph - Graph that contains a merge at "sites". - sites : list - Nodes in "graph" that are part of a merge mistake. - - Returns - ------- - numpy.ndarray - Midpoint between pair of xyz coordinates in "sites". - - """ - xyz_0 = graph.nodes[sites[0]]["xyz"] - xyz_1 = graph.nodes[sites[1]]["xyz"] - return geometry.get_midpoint(xyz_0, xyz_1) diff --git a/src/deep_neurographs/densegraph.py b/src/deep_neurographs/densegraph.py deleted file mode 100644 index c3f6537..0000000 --- a/src/deep_neurographs/densegraph.py +++ /dev/null @@ -1,191 +0,0 @@ -""" -Created on Sat November 04 15:30:00 2023 - -@author: Anna Grim -@email: anna.grim@alleninstitute.org - -Class of graphs built from swc files where each entry in the swc file -corresponds to a node in the graph. - -""" - -import os - -import networkx as nx -from scipy.spatial import KDTree - -from deep_neurographs.utils import graph_util as gutil -from deep_neurographs.utils import img_util, swc_util - -DELETION_RADIUS = 10 - - -class DenseGraph: - """ - Class of graphs built from swc files. Each swc file is stored as a - distinct graph and each node in this graph. - - """ - - def __init__(self, swc_paths, img_patch_origin=None, img_patch_shape=None): - """ - Constructs a DenseGraph object from a directory of swc files. - - Parameters - ---------- - swc_paths : list[str] - List of paths to swc files which are used to construct a hash - table in which the entries are filename-graph pairs. - ... - - Returns - ------- - None - - """ - self.bbox = img_util.get_bbox(img_patch_origin, img_patch_shape) - self.init_graphs(swc_paths) - self.init_kdtree() - - def init_graphs(self, paths): - """ - Initializes graphs by reading swc files in "paths". Graphs are - stored in a hash table where the entries are filename-graph pairs. - - Parameters - ---------- - paths : list[str] - List of paths to swc files that are used to construct a dictionary - in which the items are filename-graph pairs. - - Returns - ------- - None - - """ - self.graphs = dict() - self.xyz_to_swc = dict() - swc_dicts = swc_util.Reader().load(paths) - for i, swc_dict in enumerate(swc_dicts): - # Build graph - swc_id = swc_dict["swc_id"] - graph, _ = swc_util.to_graph(swc_dict, set_attrs=True) - if self.bbox: - graph = gutil.trim_branches(graph, self.bbox) - - # Store graph - self.store_xyz_swc(graph, swc_id) - self.graphs[swc_id] = graph - - def store_xyz_swc(self, graph, swc_id): - """ - Stores (xyz, swc_id) as an item in the dictionary "self.xyz_to_swc". - - Parameters - ---------- - graph : netowrkx.Graph - Graph to parsed. - swc_id : str - swc_id corresponding to "graph". - - Returns - ------- - None - - """ - for i in graph.nodes: - self.xyz_to_swc[tuple(graph.nodes[i]["xyz"])] = swc_id - - def init_kdtree(self): - """ - Builds a KD-Tree from the xyz coordinates from every node stored in - self.graphs. - - Parameters - ---------- - None - - Returns - ------- - None - - """ - self.kdtree = KDTree(list(self.xyz_to_swc.keys())) - - def get_projection(self, xyz): - """ - Projects "xyz" onto "self by finding the closest point. - - Parameters - ---------- - xyz : numpy.ndarray - xyz coordinate to be queried. - - Returns - ------- - numpy.ndarray - Projection of "xyz". - - """ - _, idx = self.kdtree.query(xyz, k=1) - return tuple(self.kdtree.data[idx]) - - def save(self, output_dir): - """ - Saves "self" to an swc file. - - Parameters - ---------- - output_dir : str - Path to directory that swc files are written to. - - Returns - ------- - None - - """ - for swc_id, graph in self.graphs.items(): - cnt = 0 - for component in nx.connected_components(graph): - entry_list = self.make_entries(graph, component) - path = os.path.join(output_dir, f"{swc_id}.swc") - while os.path.exists(path): - path = os.path.join(output_dir, f"{swc_id}.{cnt}.swc") - cnt += 1 - swc_util.write(path, entry_list) - - def make_entries(self, graph, component): - """ - Makes swc entries corresponding to nodes in "component". - - Parameters - ---------- - graph : networkx.Graph - Graph that "component" is a connected component of. - component : set - Connected component of "graph". - - Returns - ------- - entry_list - List of swc entries generated from nodes in "component". - - """ - node_to_idx = dict() - entry_list = [] - for i, j in nx.dfs_edges(graph.subgraph(component)): - # Initialize - if len(entry_list) == 0: - node_to_idx[i] = 1 - x, y, z = tuple(graph.nodes[i]["xyz"]) - r = graph.nodes[i]["radius"] - entry_list.append(f"1 2 {x} {y} {z} {r} -1") - - # Create entry - node_to_idx[j] = len(entry_list) + 1 - x, y, z = tuple(graph.nodes[j]["xyz"]) - r = graph.nodes[j]["radius"] - entry_list.append( - f"{node_to_idx[j]} 2 {x} {y} {z} {r} {node_to_idx[i]}" - ) - return entry_list diff --git a/src/deep_neurographs/machine_learning/evaluation.py b/src/deep_neurographs/evaluation.py similarity index 100% rename from src/deep_neurographs/machine_learning/evaluation.py rename to src/deep_neurographs/evaluation.py diff --git a/src/deep_neurographs/machine_learning/groundtruth_generation.py b/src/deep_neurographs/groundtruth_generation.py similarity index 100% rename from src/deep_neurographs/machine_learning/groundtruth_generation.py rename to src/deep_neurographs/groundtruth_generation.py diff --git a/src/deep_neurographs/machine_learning/inference.py b/src/deep_neurographs/inference.py similarity index 51% rename from src/deep_neurographs/machine_learning/inference.py rename to src/deep_neurographs/inference.py index cf962db..406bd36 100644 --- a/src/deep_neurographs/machine_learning/inference.py +++ b/src/deep_neurographs/inference.py @@ -4,27 +4,369 @@ @author: Anna Grim @email: anna.grim@alleninstitute.org -Routines for running inference with a model that classifies edge proposals. +Routines for running inference with a machine model that classifies edge proposals. """ -import networkx as nx -import numpy as np -import torch +from datetime import datetime +from time import time from torch.nn.functional import sigmoid from torch.utils.data import DataLoader from tqdm import tqdm +import networkx as nx +import numpy as np +import os +import torch + +from deep_neurographs.graph_artifact_removal import remove_doubles from deep_neurographs.machine_learning import feature_generation from deep_neurographs.utils import gnn_util from deep_neurographs.utils import graph_util as gutil from deep_neurographs.utils import img_util, ml_util +from deep_neurographs.utils import util +from deep_neurographs.utils.graph_util import GraphLoader from deep_neurographs.utils.gnn_util import toCPU BATCH_SIZE = 2000 CONFIDENCE_THRESHOLD = 0.7 +class InferencePipeline: + """ + Class that executes the full GraphTrace inference pipeline that performs + the following steps: + + 1. Graph Construction + Builds a graph representation from fragmented neuron segments. + + 2. Connection Proposals + Generates proposals for potential connections between fragments. + + 3. Feature Generation + Extracts relevant features from the proposals and graph to be used by + a machine learning model. + + 4. Inference + Applies a machine learning model classify proposals as accept/reject + based on the learned features. + + 5. Graph Update + Integrates the inference results to refine and merge the fragments + into a cohesive structure. + + """ + + def __init__( + self, dataset, pred_id, img_path, model_path, output_dir, config + ): + """ + Initializes an object that executes the full GraphTrace inference + pipeline. + + Parameters + ---------- + dataset : int + Identifier for the dataset to be used in the inference pipeline. + pred_id : str + Identifier for the predicted segmentation to be processed by the + inference pipeline. + img_path : str + Path to the raw image of whole brain stored on a GCS bucket. + model_path : str + Path to machine learning model parameters. + output_dir : str + Directory where the results of the inference will be saved. + config : Config + Configuration object containing parameters and settings required + for the inference pipeline. + + Returns + ------- + None + + """ + # Class attributes + self.accepted_proposals = list() + self.dataset = dataset + self.pred_id = pred_id + self.img_path = img_path + self.model_path = model_path + + # Extract config settings + self.graph_config = config.graph_config + self.ml_config = config.ml_config + + # Set output directory + date = datetime.today().strftime("%Y-%m-%d") + self.output_dir = f"{output_dir}/{pred_id}-{date}" + util.mkdir(self.output_dir, delete=True) + + # --- Core --- + def run(self, fragments_pointer): + """ + Executes the full inference pipeline. + + Parameters + ---------- + fragments_pointer : dict, list, str + Pointer to swc files used to build an instance of FragmentGraph, + see "swc_util.Reader" for further documentation. + + Returns + ------- + None + + """ + # Initializations + print("\nExperiment Details") + print("-----------------------------------------------") + print("Dataset:", self.dataset) + print("Pred_ID:", self.pred_id) + print("") + self.write_metadata() + t0 = time() + + # Main + self.build_graph(fragments_pointer) + self.generate_proposals() + self.run_inference() + self.save_results() + + t, unit = util.time_writer(time() - t0) + print(f"Total Runtime: {round(t, 4)} {unit}\n") + + def run_schedule(self, fragments_pointer, search_radius_schedule): + # Initializations + print("\nExperiment Details") + print("-----------------------------------------------") + print("Dataset:", self.dataset) + print("Pred_ID:", self.pred_id) + print("") + t0 = time() + + # Main + self.build_graph(fragments_pointer) + for round_id, search_radius in enumerate(search_radius_schedule): + print(f"--- Round {round_id + 1}: Radius = {search_radius} ---") + round_id += 1 + self.generate_proposals(search_radius=search_radius) + self.run_inference() + self.save_results(round_id=round_id) + t, unit = util.time_writer(time() - t0) + print(f"Total Runtime: {round(t, 4)} {unit}\n") + + def build_graph(self, fragments_pointer): + """ + Initializes and constructs the fragments graph based on the provided + fragment data. + + Parameters + ---------- + fragment_pointer : dict, list, str + Pointer to swc files used to build an instance of FragmentGraph, + see "swc_util.Reader" for further documentation. + + Returns + ------- + None + + """ + print("(1) Building FragmentGraph") + t0 = time() + + # Initialize Graph + graph_builder = GraphLoader( + anisotropy=self.graph_config.anisotropy, + min_size=self.graph_config.min_size, + node_spacing=self.graph_config.node_spacing, + trim_depth=self.graph_config.trim_depth, + ) + self.graph = graph_builder.run(fragments_pointer) + + # Remove doubles (if applicable) + if self.graph_config.remove_doubles_bool: + remove_doubles(self.graph, 200, self.graph_config.node_spacing) + + # Save valid labels and current graph + swcs_path = os.path.join(self.output_dir, "processed-swcs.zip") + labels_path = os.path.join(self.output_dir, "valid_labels.txt") + self.graph.to_zipped_swcs(swcs_path) + self.graph.save_labels(labels_path) + + t, unit = util.time_writer(time() - t0) + print(f"Module Runtime: {round(t, 4)} {unit}\n") + self.print_graph_overview() + + def generate_proposals(self, search_radius=None): + """ + Generates proposals for the fragment graph based on the specified + configuration. + + Parameters + ---------- + None + + Returns + ------- + None + + """ + # Initializations + print("(2) Generate Proposals") + if not search_radius: + search_radius = self.graph_config.search_radius, + + # Main + t0 = time() + self.graph.generate_proposals( + search_radius, + complex_bool=self.graph_config.complex_bool, + long_range_bool=self.graph_config.long_range_bool, + proposals_per_leaf=self.graph_config.proposals_per_leaf, + trim_endpoints_bool=self.graph_config.trim_endpoints_bool, + ) + n_proposals = util.reformat_number(self.graph.n_proposals()) + + # Report results + t, unit = util.time_writer(time() - t0) + print("# Proposals:", n_proposals) + print(f"Module Runtime: {round(t, 4)} {unit}\n") + + def run_inference(self): + """ + Executes the inference process using the configured inference engine + and updates the graph. + + Parameters + ---------- + None + + Returns + ------- + None + + """ + print("(3) Run Inference") + t0 = time() + n_proposals = self.graph.n_proposals() + inference_engine = InferenceEngine( + self.img_path, + self.model_path, + self.ml_config.model_type, + self.graph_config.search_radius, + confidence_threshold=self.ml_config.threshold, + downsample_factor=self.ml_config.downsample_factor, + ) + self.graph, accepts = inference_engine.run( + self.graph, self.graph.list_proposals() + ) + self.accepted_proposals.extend(accepts) + print("# Accepted:", util.reformat_number(len(accepts))) + print("% Accepted:", len(accepts) / n_proposals) + + t, unit = util.time_writer(time() - t0) + print(f"Module Runtime: {round(t, 4)} {unit}\n") + + def save_results(self, round_id=None): + """ + Saves the processed results from running the inference pipeline, + namely the corrected swc files and a list of the merged swc ids. + + Parameters + ---------- + None + + Returns + ------- + None + + """ + suffix = f"-{round_id}" if round_id else "" + filename = f"corrected-processed-swcs{suffix}.zip" + path = os.path.join(self.output_dir, filename) + self.graph.to_zipped_swcs(path) + self.save_connections(round_id=round_id) + self.write_metadata() + + # --- io --- + def save_connections(self, round_id=None): + """ + Saves predicted connections between connected components in a txt file. + + Parameters + ---------- + None + + Returns + ------- + None + + """ + suffix = f"-{round_id}" if round_id else "" + path = os.path.join(self.output_dir, f"connections{suffix}.txt") + with open(path, "w") as f: + for id_1, id_2 in self.graph.merged_ids: + f.write(f"{id_1}, {id_2}" + "\n") + + def write_metadata(self): + """ + Writes metadata about the current pipeline run to a JSON file. + + Parameters + ---------- + None + + Returns + ------- + None + + """ + metadata = { + "date": datetime.today().strftime("%Y-%m-%d"), + "dataset": self.dataset, + "pred_id": self.pred_id, + "min_fragment_size": f"{self.graph_config.min_size}um", + "model_type": self.ml_config.model_type, + "model_name": os.path.basename(self.model_path), + "complex_proposals": self.graph_config.complex_bool, + "long_range_bool": self.graph_config.long_range_bool, + "proposals_per_leaf": self.graph_config.proposals_per_leaf, + "search_radius": f"{self.graph_config.search_radius}um", + "confidence_threshold": self.ml_config.threshold, + "node_spacing": self.graph_config.node_spacing, + "remove_doubles": self.graph_config.remove_doubles_bool, + "trim_depth": self.graph_config.trim_depth, + } + path = os.path.join(self.output_dir, "metadata.json") + util.write_json(path, metadata) + + # --- Summaries --- + def print_graph_overview(self): + """ + Prints an overview of the graph's structure and memory usage. + + Parameters + ---------- + None + + Returns + ------- + None + + """ + # Compute values + n_components = nx.number_connected_components(self.graph) + usage = round(util.get_memory_usage(), 2) + + # Print overview + print("Graph Overview...") + print("# Connected Components:", util.reformat_number(n_components)) + print("# Nodes:", util.reformat_number(self.graph.number_of_nodes())) + print("# Edges:", util.reformat_number(self.graph.number_of_edges())) + print(f"Memory Consumption: {usage} GBs\n") + + class InferenceEngine: """ Class that runs inference with a machine learning model that has been @@ -80,7 +422,8 @@ def __init__( self.threshold = confidence_threshold # Load image and model - self.img = img_util.open_tensorstore(img_path, driver="zarr") + driver = "n5" if ".n5" in img_path else "zarr" + self.img = img_util.open_tensorstore(img_path, driver=driver) self.model = ml_util.load_model(model_path) def run(self, neurograph, proposals): diff --git a/src/deep_neurographs/inference_pipeline.py b/src/deep_neurographs/inference_pipeline.py deleted file mode 100644 index b7628bc..0000000 --- a/src/deep_neurographs/inference_pipeline.py +++ /dev/null @@ -1,354 +0,0 @@ -""" -Created on Sat Sept 16 11:30:00 2024 - -@author: Anna Grim -@email: anna.grim@alleninstitute.org - -This script executes the full GraphTrace inference pipeline for processing -neuron segmentation data. It performs the following steps: - - 1. Graph Construction - Builds a graph representation from fragmented neuron segments. - - 2. Connection Proposals - Generates proposals for potential connections between fragments. - - 3. Feature Generation - Extracts relevant features from the proposals and graph to be used by - a machine learning model. - - 4. Inference - Applies a machine learning model classify proposals as accept/reject - based on the learned features. - - 5. Graph Update - Integrates the inference results to refine and merge the fragments - into a cohesive structure. - -""" -import os -from datetime import datetime -from time import time - -import networkx as nx - -from deep_neurographs.graph_artifact_removal import remove_doubles -from deep_neurographs.machine_learning.inference import InferenceEngine -from deep_neurographs.utils import util -from deep_neurographs.utils.graph_util import GraphLoader - - -class InferencePipeline: - """ - Class that executes the full GraphTrace inference pipeline. - - """ - - def __init__( - self, dataset, pred_id, img_path, model_path, output_dir, config - ): - """ - Initializes an object that executes the full GraphTrace inference - pipeline. - - Parameters - ---------- - dataset : int - Identifier for the dataset to be used in the inference pipeline. - pred_id : str - Identifier for the predicted segmentation to be processed by the - inference pipeline. - img_path : str - Path to the raw image of whole brain stored on a GCS bucket. - model_path : str - Path to machine learning model parameters. - output_dir : str - Directory where the results of the inference will be saved. - config : Config - Configuration object containing parameters and settings required - for the inference pipeline. - - Returns - ------- - None - - """ - # Class attributes - self.accepted_proposals = list() - self.dataset = dataset - self.pred_id = pred_id - self.img_path = img_path - self.model_path = model_path - - # Extract config settings - self.graph_config = config.graph_config - self.ml_config = config.ml_config - - # Set output directory - date = datetime.today().strftime("%Y-%m-%d") - self.output_dir = f"{output_dir}/{pred_id}-{date}" - util.mkdir(self.output_dir, delete=True) - - # --- Core --- - def run(self, fragments_pointer): - """ - Executes the full inference pipeline. - - Parameters - ---------- - fragments_pointer : dict, list, str - Pointer to swc files used to build an instance of FragmentGraph, - see "swc_util.Reader" for further documentation. - - Returns - ------- - None - - """ - # Initializations - print("\nExperiment Details") - print("-----------------------------------------------") - print("Dataset:", self.dataset) - print("Pred_ID:", self.pred_id) - print("") - self.write_metadata() - t0 = time() - - # Main - self.build_graph(fragments_pointer) - self.generate_proposals() - self.run_inference() - self.save_results() - - t, unit = util.time_writer(time() - t0) - print(f"Total Runtime: {round(t, 4)} {unit}\n") - - def run_schedule(self, fragments_pointer, search_radius_schedule): - # Initializations - print("\nExperiment Details") - print("-----------------------------------------------") - print("Dataset:", self.dataset) - print("Pred_ID:", self.pred_id) - print("") - t0 = time() - - # Main - self.build_graph(fragments_pointer) - for round_id, search_radius in enumerate(search_radius_schedule): - print(f"--- Round {round_id + 1}: Radius = {search_radius} ---") - round_id += 1 - self.generate_proposals(search_radius=search_radius) - self.run_inference() - self.save_results(round_id=round_id) - t, unit = util.time_writer(time() - t0) - print(f"Total Runtime: {round(t, 4)} {unit}\n") - - def build_graph(self, fragments_pointer): - """ - Initializes and constructs the fragments graph based on the provided - fragment data. - - Parameters - ---------- - fragment_pointer : dict, list, str - Pointer to swc files used to build an instance of FragmentGraph, - see "swc_util.Reader" for further documentation. - - Returns - ------- - None - - """ - print("(1) Building FragmentGraph") - t0 = time() - - # Initialize Graph - graph_builder = GraphLoader( - anisotropy=self.graph_config.anisotropy, - min_size=self.graph_config.min_size, - node_spacing=self.graph_config.node_spacing, - trim_depth=self.graph_config.trim_depth, - ) - self.graph = graph_builder.run(fragments_pointer) - - # Remove doubles (if applicable) - if self.graph_config.remove_doubles_bool: - remove_doubles(self.graph, 200, self.graph_config.node_spacing) - - # Save valid labels and current graph - swcs_path = os.path.join(self.output_dir, "processed-swcs.zip") - labels_path = os.path.join(self.output_dir, "valid_labels.txt") - self.graph.to_zipped_swcs(swcs_path) - self.graph.save_labels(labels_path) - - t, unit = util.time_writer(time() - t0) - print(f"Module Runtime: {round(t, 4)} {unit}\n") - self.print_graph_overview() - - def generate_proposals(self, search_radius=None): - """ - Generates proposals for the fragment graph based on the specified - configuration. - - Parameters - ---------- - None - - Returns - ------- - None - - """ - # Initializations - print("(2) Generate Proposals") - if not search_radius: - search_radius = self.graph_config.search_radius, - - # Main - t0 = time() - self.graph.generate_proposals( - search_radius, - complex_bool=self.graph_config.complex_bool, - long_range_bool=self.graph_config.long_range_bool, - proposals_per_leaf=self.graph_config.proposals_per_leaf, - trim_endpoints_bool=self.graph_config.trim_endpoints_bool, - ) - n_proposals = util.reformat_number(self.graph.n_proposals()) - - # Report results - t, unit = util.time_writer(time() - t0) - print("# Proposals:", n_proposals) - print(f"Module Runtime: {round(t, 4)} {unit}\n") - - def run_inference(self): - """ - Executes the inference process using the configured inference engine - and updates the graph. - - Parameters - ---------- - None - - Returns - ------- - None - - """ - print("(3) Run Inference") - t0 = time() - n_proposals = self.graph.n_proposals() - inference_engine = InferenceEngine( - self.img_path, - self.model_path, - self.ml_config.model_type, - self.graph_config.search_radius, - confidence_threshold=self.ml_config.threshold, - downsample_factor=self.ml_config.downsample_factor, - ) - self.graph, accepts = inference_engine.run( - self.graph, self.graph.list_proposals() - ) - self.accepted_proposals.extend(accepts) - print("# Accepted:", util.reformat_number(len(accepts))) - print("% Accepted:", len(accepts) / n_proposals) - - t, unit = util.time_writer(time() - t0) - print(f"Module Runtime: {round(t, 4)} {unit}\n") - - def save_results(self, round_id=None): - """ - Saves the processed results from running the inference pipeline, - namely the corrected swc files and a list of the merged swc ids. - - Parameters - ---------- - None - - Returns - ------- - None - - """ - suffix = f"-{round_id}" if round_id else "" - filename = f"corrected-processed-swcs{suffix}.zip" - path = os.path.join(self.output_dir, filename) - self.graph.to_zipped_swcs(path) - self.save_connections(round_id=round_id) - self.write_metadata() - - # --- io --- - def save_connections(self, round_id=None): - """ - Saves predicted connections between connected components in a txt file. - - Parameters - ---------- - None - - Returns - ------- - None - - """ - suffix = f"-{round_id}" if round_id else "" - path = os.path.join(self.output_dir, f"connections{suffix}.txt") - with open(path, "w") as f: - for id_1, id_2 in self.graph.merged_ids: - f.write(f"{id_1}, {id_2}" + "\n") - - def write_metadata(self): - """ - Writes metadata about the current pipeline run to a JSON file. - - Parameters - ---------- - None - - Returns - ------- - None - - """ - metadata = { - "date": datetime.today().strftime("%Y-%m-%d"), - "dataset": self.dataset, - "pred_id": self.pred_id, - "min_fragment_size": f"{self.graph_config.min_size}um", - "model_type": self.ml_config.model_type, - "model_name": os.path.basename(self.model_path), - "complex_proposals": self.graph_config.complex_bool, - "long_range_bool": self.graph_config.long_range_bool, - "proposals_per_leaf": self.graph_config.proposals_per_leaf, - "search_radius": f"{self.graph_config.search_radius}um", - "confidence_threshold": self.ml_config.threshold, - "node_spacing": self.graph_config.node_spacing, - "remove_doubles": self.graph_config.remove_doubles_bool, - "trim_depth": self.graph_config.trim_depth, - } - path = os.path.join(self.output_dir, "metadata.json") - util.write_json(path, metadata) - - # --- Summaries --- - def print_graph_overview(self): - """ - Prints an overview of the graph's structure and memory usage. - - Parameters - ---------- - None - - Returns - ------- - None - - """ - # Compute values - n_components = nx.number_connected_components(self.graph) - usage = round(util.get_memory_usage(), 2) - - # Print overview - print("Graph Overview...") - print("# Connected Components:", util.reformat_number(n_components)) - print("# Nodes:", util.reformat_number(self.graph.number_of_nodes())) - print("# Edges:", util.reformat_number(self.graph.number_of_edges())) - print(f"Memory Consumption: {usage} GBs\n") diff --git a/src/deep_neurographs/neurograph.py b/src/deep_neurographs/neurograph.py index cceebdb..ab1416c 100644 --- a/src/deep_neurographs/neurograph.py +++ b/src/deep_neurographs/neurograph.py @@ -20,7 +20,7 @@ from deep_neurographs import generate_proposals, geometry from deep_neurographs.geometry import dist as get_dist -from deep_neurographs.machine_learning.groundtruth_generation import ( +from deep_neurographs.groundtruth_generation import ( init_targets, ) from deep_neurographs.utils import graph_util as gutil diff --git a/src/deep_neurographs/machine_learning/train.py b/src/deep_neurographs/train.py similarity index 100% rename from src/deep_neurographs/machine_learning/train.py rename to src/deep_neurographs/train.py From 2f4e386e61c4f602a6fcf96efd5bc10e282a97d0 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Tue, 1 Oct 2024 01:04:51 +0000 Subject: [PATCH 13/13] upds --- src/deep_neurographs/utils/util.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/deep_neurographs/utils/util.py b/src/deep_neurographs/utils/util.py index e8c9a68..023b6a6 100644 --- a/src/deep_neurographs/utils/util.py +++ b/src/deep_neurographs/utils/util.py @@ -372,7 +372,7 @@ def write_txt(path, contents): f.close() -def write_to_s3(local_path, bucket_name, s3_key): +def write_to_s3(local_path, bucket_name, prefix): """ Writes a single file on local machine to an s3 bucket. @@ -382,7 +382,7 @@ def write_to_s3(local_path, bucket_name, s3_key): Path to file to be written to s3. bucket_name : str Name of s3 bucket. - s3_key : str + prefix : str Path within s3 bucket. Returns @@ -391,7 +391,7 @@ def write_to_s3(local_path, bucket_name, s3_key): """ s3 = boto3.client('s3') - s3.upload_file(local_path, bucket_name, s3_key) + s3.upload_file(local_path, bucket_name, prefix) # --- math utils ---