From e842b59c31fe08a5708096e42bface55d705bae9 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Wed, 25 Sep 2024 00:55:50 +0000 Subject: [PATCH 1/6] minor upds --- pyproject.toml | 1 + src/deep_neurographs/delete_merges_gt.py | 1 + src/deep_neurographs/densegraph.py | 12 +++++++----- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 004c70c..76b984c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ dependencies = [ 'torcheval', 'torchio', 'torch_geometric', + 'tqdm', 'zarr', ] diff --git a/src/deep_neurographs/delete_merges_gt.py b/src/deep_neurographs/delete_merges_gt.py index 936052c..118a34a 100644 --- a/src/deep_neurographs/delete_merges_gt.py +++ b/src/deep_neurographs/delete_merges_gt.py @@ -87,6 +87,7 @@ def delete_merges( # Finish if len(delete_nodes) > 0: graph.remove_nodes_from(delete_nodes) + print("Merge Detected:", swc_id) print("# Nodes Deleted:", len(delete_nodes)) print("") pred_densegraph.graphs[swc_id] = graph diff --git a/src/deep_neurographs/densegraph.py b/src/deep_neurographs/densegraph.py index 75f100f..273a7ef 100644 --- a/src/deep_neurographs/densegraph.py +++ b/src/deep_neurographs/densegraph.py @@ -15,7 +15,7 @@ from scipy.spatial import KDTree from deep_neurographs.utils import graph_util as gutil -from deep_neurographs.utils import swc_util, util +from deep_neurographs.utils import img_util, swc_util, util DELETION_RADIUS = 10 @@ -43,7 +43,7 @@ def __init__(self, swc_paths, img_patch_origin=None, img_patch_shape=None): None """ - self.bbox = util.get_img_bbox(img_patch_origin, img_patch_shape) + self.bbox = img_util.get_bbox(img_patch_origin, img_patch_shape) self.init_graphs(swc_paths) self.init_kdtree() @@ -65,7 +65,7 @@ def init_graphs(self, paths): """ self.graphs = dict() self.xyz_to_swc = dict() - swc_dicts, _ = swc_util.process_local_paths(paths) + swc_dicts = swc_util.Reader().load(paths) for i, swc_dict in enumerate(swc_dicts): # Build graph swc_id = swc_dict["swc_id"] @@ -179,11 +179,13 @@ def make_entries(self, graph, component): node_to_idx[i] = 1 x, y, z = tuple(graph.nodes[i]["xyz"]) r = graph.nodes[i]["radius"] - entry_list.append([1, 2, x, y, z, r, -1]) + entry_list.append(f"1 2 {x} {y} {z} {r} -1") # Create entry node_to_idx[j] = len(entry_list) + 1 x, y, z = tuple(graph.nodes[j]["xyz"]) r = graph.nodes[j]["radius"] - entry_list.append([node_to_idx[j], 2, x, y, z, r, node_to_idx[i]]) + entry_list.append( + f"{node_to_idx[j]} 2 {x} {y} {z} {r} {node_to_idx[i]}" + ) return entry_list From 1f36be9b5bc4a7e35f0699a5f949e8ddd8754b9e Mon Sep 17 00:00:00 2001 From: anna-grim Date: Wed, 25 Sep 2024 21:56:45 +0000 Subject: [PATCH 2/6] refactor: training pipeline --- src/deep_neurographs/densegraph.py | 2 +- src/deep_neurographs/intake.py | 138 ---------------- ...aphtrace_pipeline.py => run_graphtrace.py} | 4 +- src/deep_neurographs/train_graphtrace.py | 81 ++++++++++ src/deep_neurographs/utils/gnn_util.py | 1 - src/deep_neurographs/utils/graph_util.py | 148 +++++++++++++++++- src/deep_neurographs/utils/swc_util.py | 8 +- src/deep_neurographs/utils/util.py | 59 +++---- 8 files changed, 261 insertions(+), 180 deletions(-) delete mode 100644 src/deep_neurographs/intake.py rename src/deep_neurographs/{run_graphtrace_pipeline.py => run_graphtrace.py} (99%) diff --git a/src/deep_neurographs/densegraph.py b/src/deep_neurographs/densegraph.py index 273a7ef..c3f6537 100644 --- a/src/deep_neurographs/densegraph.py +++ b/src/deep_neurographs/densegraph.py @@ -15,7 +15,7 @@ from scipy.spatial import KDTree from deep_neurographs.utils import graph_util as gutil -from deep_neurographs.utils import img_util, swc_util, util +from deep_neurographs.utils import img_util, swc_util DELETION_RADIUS = 10 diff --git a/src/deep_neurographs/intake.py b/src/deep_neurographs/intake.py deleted file mode 100644 index ea4d5bc..0000000 --- a/src/deep_neurographs/intake.py +++ /dev/null @@ -1,138 +0,0 @@ -""" -Created on Sat July 15 9:00:00 2023 - -@author: Anna Grim -@email: anna.grim@alleninstitute.org - -Builds a neurograph for neuron reconstruction. - -""" - -from deep_neurographs.neurograph import NeuroGraph -from deep_neurographs.utils import graph_util as gutil -from deep_neurographs.utils import img_util, swc_util - -MIN_SIZE = 30 -NODE_SPACING = 1 -SMOOTH_BOOL = True -PRUNE_DEPTH = 16 -TRIM_DEPTH = 0 - - -class GraphBuilder: - """ - Class that is used to build an instance of FragmentsGraph. - - """ - - def __init__( - self, - anisotropy=[1.0, 1.0, 1.0], - min_size=MIN_SIZE, - node_spacing=NODE_SPACING, - prune_depth=PRUNE_DEPTH, - smooth_bool=SMOOTH_BOOL, - trim_depth=TRIM_DEPTH, - ): - """ - Builds a FragmentsGraph by reading swc files stored either on the - cloud or local machine, then extracting the irreducible components. - - Parameters - ---------- - anisotropy : list[float], optional - Scaling factors applied to xyz coordinates to account for - anisotropy of microscope. The default is [1.0, 1.0, 1.0]. - min_size : float, optional - Minimum path length of swc files which are stored as connected - components in the FragmentsGraph. The default is 30ums. - node_spacing : int, optional - Spacing (in microns) between nodes. The default is the global - variable "NODE_SPACING". - prune_depth : int, optional - Branches less than "prune_depth" microns are pruned if "prune" is - True. The default is the global variable "PRUNE_DEPTH". - smooth_bool : bool, optional - Indication of whether to smooth branches from swc files. The - default is the global variable "SMOOTH". - trim_depth : float, optional - Maximum path length (in microns) to trim from "branch". The default - is the global variable "TRIM_DEPTH". - - Returns - ------- - FragmentsGraph - FragmentsGraph generated from swc files. - - """ - self.anisotropy = anisotropy - self.min_size = min_size - self.node_spacing = node_spacing - self.prune_depth = prune_depth - self.smooth_bool = smooth_bool - self.trim_depth = trim_depth - - self.reader = swc_util.Reader(anisotropy, min_size) - - def run( - self, fragments_pointer, img_patch_origin=None, img_patch_shape=None - ): - """ - Builds a FragmentsGraph by reading swc files stored either on the - cloud or local machine, then extracting the irreducible components. - - Parameters - ---------- - fragments_pointer : dict, list, str - Pointer to swc files used to build an instance of FragmentsGraph, - see "swc_util.Reader" for further documentation. - img_patch_origin : list[int], optional - An xyz coordinate which is the upper, left, front corner of the - image patch that contains the swc files. The default is None. - img_patch_shape : list[int], optional - Shape of the image patch which contains the swc files. The default - is None. - - Returns - ------- - FragmentsGraph - FragmentsGraph generated from swc files. - - """ - # Load fragments and extract irreducibles - self.set_img_bbox(img_patch_origin, img_patch_shape) - swc_dicts = self.reader.load(fragments_pointer) - irreducibles = gutil.get_irreducibles( - swc_dicts, - self.min_size, - self.img_bbox, - self.prune_depth, - self.smooth_bool, - self.trim_depth, - ) - - # Build FragmentsGraph - neurograph = NeuroGraph(node_spacing=self.node_spacing) - while len(irreducibles): - irreducible_set = irreducibles.pop() - neurograph.add_component(irreducible_set) - return neurograph - - def set_img_bbox(self, img_patch_origin, img_patch_shape): - """ - Sets the bounding box of an image patch as a class attriubte. - - Parameters - ---------- - img_patch_origin : tuple[int] - Origin of bounding box which is assumed to be top, front, left - corner. - img_patch_shape : tuple[int] - Shape of bounding box. - - Returns - ------- - None - - """ - self.img_bbox = img_util.get_bbox(img_patch_origin, img_patch_shape) diff --git a/src/deep_neurographs/run_graphtrace_pipeline.py b/src/deep_neurographs/run_graphtrace.py similarity index 99% rename from src/deep_neurographs/run_graphtrace_pipeline.py rename to src/deep_neurographs/run_graphtrace.py index 3e5e53d..4b4db4b 100644 --- a/src/deep_neurographs/run_graphtrace_pipeline.py +++ b/src/deep_neurographs/run_graphtrace.py @@ -33,9 +33,9 @@ import networkx as nx from deep_neurographs.graph_artifact_removal import remove_doubles -from deep_neurographs.intake import GraphBuilder from deep_neurographs.machine_learning.inference import InferenceEngine from deep_neurographs.utils import util +from deep_neurographs.utils.graph_util import GraphLoader class GraphTracePipeline: @@ -163,7 +163,7 @@ def build_graph(self, fragments_pointer): t0 = time() # Initialize Graph - graph_builder = GraphBuilder( + graph_builder = GraphLoader( anisotropy=self.graph_config.anisotropy, min_size=self.graph_config.min_size, node_spacing=self.graph_config.node_spacing, diff --git a/src/deep_neurographs/train_graphtrace.py b/src/deep_neurographs/train_graphtrace.py index 128160f..663ce35 100644 --- a/src/deep_neurographs/train_graphtrace.py +++ b/src/deep_neurographs/train_graphtrace.py @@ -8,3 +8,84 @@ This script trains the GraphTrace inference pipeline. """ + +from deep_neurographs.utils import util +from deep_neurographs.utils.graph_util import GraphLoader + + +class Trainer: + """ + Class that is used to train a machine learning model that classifies + proposals. + + """ + def __init__( + self, config, model_type, output_dir=None, save_model_bool=True + ): + # Check for parameter errors + if save_model_bool and not output_dir: + raise ValueError("Must provide output_dir to save model.") + + # Set class attributes + self.idx_to_ids = list() + self.model_type = model_type + self.output_dir = output_dir + self.save_model_bool = save_model_bool + + # Set data structures for training examples + self.gt_graphs = list() + self.pred_graphs = list() + self.imgs = dict() + + # Extract config settings + self.graph_config = config.graph_config + self.ml_config = config.ml_config + self.graph_loader = GraphLoader( + min_size=self.graph_config.min_size, + progress_bar=False, + ) + + def n_examples(self): + return len(self.gt_graphs) + + def load_example( + self, + gt_pointer, + pred_pointer, + dataset_name, + example_id=None, + metadata_path=None, + ): + # Read metadata + if metadata_path: + origin, shape = util.read_metadata(metadata_path) + else: + origin, shape = None, None + + # Load graphs + self.gt_graphs.append(self.graph_loader.run(gt_pointer)) + self.pred_graphs.append( + self.graph_loader.run( + pred_pointer, + img_patch_origin=origin, + img_patch_shape=shape, + ) + ) + + # Set example ids + self.idx_to_ids.append( + {"dataset_name": dataset_name, "example_id": example_id} + ) + + def load_img(self, img_path, dataset_name): + pass + + def run(self): + pass + + def generate_features(self): + # check that every example has an image that was loaded! + pass + + def evaluate(self): + pass diff --git a/src/deep_neurographs/utils/gnn_util.py b/src/deep_neurographs/utils/gnn_util.py index 8b41092..2d8b5d9 100644 --- a/src/deep_neurographs/utils/gnn_util.py +++ b/src/deep_neurographs/utils/gnn_util.py @@ -108,7 +108,6 @@ def get_batch(graph, proposals, batch_size): """ batch = reset_batch() - cur_proposal_cnt = 0 visited = set() while len(proposals) > 0 and len(batch["proposals"]) < batch_size: root = tuple(util.sample_once(proposals)) diff --git a/src/deep_neurographs/utils/graph_util.py b/src/deep_neurographs/utils/graph_util.py index 0bc17fa..b01d9ce 100644 --- a/src/deep_neurographs/utils/graph_util.py +++ b/src/deep_neurographs/utils/graph_util.py @@ -5,7 +5,7 @@ @email: anna.grim@alleninstitute.org -Routines that extract the irreducible components of a graph. +Routines for loading fragments and building a neurograph. Terminology @@ -30,13 +30,147 @@ from tqdm import tqdm from deep_neurographs import geometry +from deep_neurographs.neurograph import NeuroGraph from deep_neurographs.utils import img_util, swc_util, util - +MIN_SIZE = 30 +NODE_SPACING = 1 +SMOOTH_BOOL = True +PRUNE_DEPTH = 16 +TRIM_DEPTH = 0 + + +class GraphLoader: + """ + Class that is used to build an instance of FragmentsGraph. + + """ + + def __init__( + self, + anisotropy=[1.0, 1.0, 1.0], + min_size=MIN_SIZE, + node_spacing=NODE_SPACING, + progress_bar=False, + prune_depth=PRUNE_DEPTH, + smooth_bool=SMOOTH_BOOL, + trim_depth=TRIM_DEPTH, + ): + """ + Builds a FragmentsGraph by reading swc files stored either on the + cloud or local machine, then extracting the irreducible components. + + Parameters + ---------- + anisotropy : list[float], optional + Scaling factors applied to xyz coordinates to account for + anisotropy of microscope. The default is [1.0, 1.0, 1.0]. + min_size : float, optional + Minimum path length of swc files which are stored as connected + components in the FragmentsGraph. The default is 30ums. + node_spacing : int, optional + Spacing (in microns) between nodes. The default is the global + variable "NODE_SPACING". + progress_bar : bool, optional + Indication of whether to print out a progress bar while building + graph. The default is True. + prune_depth : int, optional + Branches less than "prune_depth" microns are pruned if "prune" is + True. The default is the global variable "PRUNE_DEPTH". + smooth_bool : bool, optional + Indication of whether to smooth branches from swc files. The + default is the global variable "SMOOTH". + trim_depth : float, optional + Maximum path length (in microns) to trim from "branch". The default + is the global variable "TRIM_DEPTH". + + Returns + ------- + FragmentsGraph + FragmentsGraph generated from swc files. + + """ + self.anisotropy = anisotropy + self.min_size = min_size + self.node_spacing = node_spacing + self.progress_bar = progress_bar + self.prune_depth = prune_depth + self.smooth_bool = smooth_bool + self.trim_depth = trim_depth + + self.reader = swc_util.Reader(anisotropy, min_size) + + def run( + self, fragments_pointer, img_patch_origin=None, img_patch_shape=None + ): + """ + Builds a FragmentsGraph by reading swc files stored either on the + cloud or local machine, then extracting the irreducible components. + + Parameters + ---------- + fragments_pointer : dict, list, str + Pointer to swc files used to build an instance of FragmentsGraph, + see "swc_util.Reader" for further documentation. + img_patch_origin : list[int], optional + An xyz coordinate which is the upper, left, front corner of the + image patch that contains the swc files. The default is None. + img_patch_shape : list[int], optional + Shape of the image patch which contains the swc files. The default + is None. + + Returns + ------- + FragmentsGraph + FragmentsGraph generated from swc files. + + """ + # Load fragments and extract irreducibles + self.set_img_bbox(img_patch_origin, img_patch_shape) + swc_dicts = self.reader.load(fragments_pointer) + irreducibles = get_irreducibles( + swc_dicts, + self.min_size, + self.img_bbox, + self.progress_bar, + self.prune_depth, + self.smooth_bool, + self.trim_depth, + ) + + # Build FragmentsGraph + neurograph = NeuroGraph(node_spacing=self.node_spacing) + while len(irreducibles): + irreducible_set = irreducibles.pop() + neurograph.add_component(irreducible_set) + return neurograph + + def set_img_bbox(self, img_patch_origin, img_patch_shape): + """ + Sets the bounding box of an image patch as a class attriubte. + + Parameters + ---------- + img_patch_origin : tuple[int] + Origin of bounding box which is assumed to be top, front, left + corner. + img_patch_shape : tuple[int] + Shape of bounding box. + + Returns + ------- + None + + """ + self.img_bbox = img_util.get_bbox(img_patch_origin, img_patch_shape) + + +# --- Graph structure extraction --- def get_irreducibles( swc_dicts, min_size, img_bbox=None, + progress_bar=True, prune_depth=16.0, smooth_bool=True, trim_depth=0.0, @@ -78,11 +212,15 @@ def get_irreducibles( i += 1 # Store results - with tqdm(total=len(processes), desc="Extract Graphs") as pbar: - irreducibles = list() + irreducibles = list() + if progress_bar: + with tqdm(total=len(processes), desc="Extract Graphs") as pbar: + for process in as_completed(processes): + irreducibles.extend(process.result()) + pbar.update(1) + else: for process in as_completed(processes): irreducibles.extend(process.result()) - pbar.update(1) return irreducibles diff --git a/src/deep_neurographs/utils/swc_util.py b/src/deep_neurographs/utils/swc_util.py index 38a2986..e279700 100644 --- a/src/deep_neurographs/utils/swc_util.py +++ b/src/deep_neurographs/utils/swc_util.py @@ -85,7 +85,7 @@ def load(self, swc_pointer): if ".swc" in swc_pointer: return self.load_from_local_path(swc_pointer) if os.path.isdir(swc_pointer): - paths = util.list_paths(swc_pointer, ext=".swc") + paths = util.list_paths(swc_pointer, extension=".swc") return self.load_from_local_paths(paths) raise Exception("SWC Pointer is not Valid!") @@ -307,7 +307,7 @@ def process_content(self, content): Offset of swc file. """ - offset = [1.0, 1.0, 1.0] + offset = [0.0, 0.0, 0.0] for i, line in enumerate(content): if line.startswith("# OFFSET"): offset = self.read_xyz(line.split()[2:5]) @@ -364,7 +364,7 @@ def write(path, content, color=None): elif type(content) is nx.Graph: write_graph(path, content, color=color) else: - raise ExceptionType("Unable to write {} to swc".format(type(content))) + raise Exception("Unable to write {} to swc".format(type(content))) def write_list(path, entry_list, color=None): @@ -558,7 +558,7 @@ def set_radius(graph, i): """ try: radius = graph[i]["radius"] - except: + except ValueError: radius = 1.0 return radius diff --git a/src/deep_neurographs/utils/util.py b/src/deep_neurographs/utils/util.py index c59fbfd..79d820c 100644 --- a/src/deep_neurographs/utils/util.py +++ b/src/deep_neurographs/utils/util.py @@ -8,7 +8,6 @@ """ -import boto3 import json import math import os @@ -18,6 +17,7 @@ from time import time from zipfile import ZipFile +import boto3 import numpy as np import psutil @@ -216,33 +216,32 @@ def rmdir(path): shutil.rmtree(path) -def listdir(path, ext=None): +def listdir(path, extension=None): """ Lists all files in the directory at "path". If an extension "ext" is - provided, then only files containing "ext" are returned. + provided, then only files containing "extension" are returned. Parameters ---------- path : str Path to directory to be searched. - - ext : str, optional + extension : str, optional Extension of file type of interest. The default is None. Returns ------- list - Files in directory at "path" with extension "ext" if provided. + Files in directory at "path" with extension "extension" if provided. Otherwise, list of all files in directory. """ - if ext is None: + if extension is None: return [f for f in os.listdir(path)] else: - return [f for f in os.listdir(path) if ext in f] + return [f for f in os.listdir(path) if f.endswith(extension)] -def list_subdirs(path, keyword=None): +def list_subdirs(path, keyword=None, return_paths=False): """ Creates list of all subdirectories at "path". If "keyword" is provided, then only subdirectories containing "keyword" are contained in list. @@ -251,10 +250,12 @@ def list_subdirs(path, keyword=None): ---------- path : str Path to directory containing subdirectories to be listed. - keyword : str, optional Only subdirectories containing "keyword" are contained in list that is returned. The default is None. + return_paths : bool + Indication of whether to return full path of subdirectories. The + default is False. Returns ------- @@ -263,27 +264,27 @@ def list_subdirs(path, keyword=None): """ subdirs = list() - for d in os.listdir(path): - if os.path.isdir(os.path.join(path, d)): - if keyword is None: - subdirs.append(d) - elif keyword in d: - subdirs.append(d) - subdirs.sort() - return subdirs + for subdir in os.listdir(path): + is_dir = os.path.isdir(os.path.join(path, subdir)) + is_hidden = subdir.startswith('.') + if is_dir and not is_hidden: + subdir = os.path.join(path, subdir) if return_paths else subdir + if (keyword and keyword in subdir) or not keyword: + subdirs.append(subdir) + return sorted(subdirs) -def list_paths(directory, ext=None): +def list_paths(directory, extension=None): """ - Lists all paths within "directory". + Lists all paths within "directory" that end with "extension" if provided. Parameters ---------- directory : str Directory to be searched. - ext : str, optional - If provided, only paths of files with the extension "ext" are - returned. The default is None. + extension : str, optional + If provided, only paths of files with the extension are returned. The + default is None. Returns ------- @@ -292,12 +293,12 @@ def list_paths(directory, ext=None): """ paths = list() - for f in listdir(directory, ext=ext): + for f in listdir(directory, extension=extension): paths.append(os.path.join(directory, f)) return paths -def set_path(dir_name, filename, ext): +def set_path(dir_name, filename, extension): """ Sets the path for a file in a directory. If a file with the same name exists, then this routine finds a suffix to append to the filename. @@ -308,7 +309,7 @@ def set_path(dir_name, filename, ext): Name of directory that path will be generated to point to. filename : str Name of file that path will contain. - ext : str + extension : str Extension of file. Returns @@ -319,10 +320,10 @@ def set_path(dir_name, filename, ext): """ cnt = 0 - ext = ext.replace(".", "") - path = os.path.join(dir_name, f"{filename}.{ext}") + extension = extension.replace(".", "") + path = os.path.join(dir_name, f"{filename}.{extension}") while os.path.exists(path): - path = os.path.join(dir_name, f"{filename}.{cnt}.{ext}") + path = os.path.join(dir_name, f"{filename}.{cnt}.{extension}") cnt += 1 return path From ea7e9e9954ff78220764b8c15583919a0892169c Mon Sep 17 00:00:00 2001 From: anna-grim Date: Thu, 26 Sep 2024 00:08:02 +0000 Subject: [PATCH 3/6] feat: find gcs image path --- src/deep_neurographs/generate_proposals.py | 13 ++++- .../machine_learning/heterograph_trainer.py | 1 + .../machine_learning/inference.py | 3 +- src/deep_neurographs/neurograph.py | 5 ++ .../{run_graphtrace.py => run_pipeline.py} | 0 ...{train_graphtrace.py => train_pipeline.py} | 0 src/deep_neurographs/utils/img_util.py | 29 ++++++++++- src/deep_neurographs/utils/util.py | 50 ++++++++++++++++--- 8 files changed, 90 insertions(+), 11 deletions(-) rename src/deep_neurographs/{run_graphtrace.py => run_pipeline.py} (100%) rename src/deep_neurographs/{train_graphtrace.py => train_pipeline.py} (100%) diff --git a/src/deep_neurographs/generate_proposals.py b/src/deep_neurographs/generate_proposals.py index 1f9c057..3a5bd79 100644 --- a/src/deep_neurographs/generate_proposals.py +++ b/src/deep_neurographs/generate_proposals.py @@ -25,6 +25,7 @@ def run( radius, complex_bool=False, long_range_bool=True, + progress_bar=True, trim_endpoints_bool=True, ): """ @@ -43,6 +44,9 @@ def run( Indication of whether to generate simple proposals within distance of "LONG_RANGE_FACTOR" * radius of leaf from leaf without any proposals. The default is False. + progress_bar : bool, optional + Indication of whether to print out a progress bar while generating + proposals. The default is True. trim_endpoints_bool : bool, optional Indication of whether to endpoints of branches with exactly one proposal. The default is True. @@ -52,10 +56,17 @@ def run( None """ + # Initializations connections = dict() kdtree = init_kdtree(neurograph, complex_bool) radius *= RADIUS_SCALING_FACTOR if trim_endpoints_bool else 1.0 - for leaf in tqdm(neurograph.leafs, desc="Proposals"): + if progress_bar: + iterable = tqdm(neurograph.leafs, desc="Proposals") + else: + iterable = neurograph.leafs + + # Main + for leaf in iterable: # Generate potential proposals candidates = get_candidates( neurograph, diff --git a/src/deep_neurographs/machine_learning/heterograph_trainer.py b/src/deep_neurographs/machine_learning/heterograph_trainer.py index 3554932..f0a2230 100644 --- a/src/deep_neurographs/machine_learning/heterograph_trainer.py +++ b/src/deep_neurographs/machine_learning/heterograph_trainer.py @@ -129,6 +129,7 @@ def run_on_graphs(self, datasets, augment=False): y, hat_y = [], [] self.model.train() for graph_id in train_ids: + print(graph_id) y_i, hat_y_i = self.train( datasets[graph_id], epoch, augment=augment ) diff --git a/src/deep_neurographs/machine_learning/inference.py b/src/deep_neurographs/machine_learning/inference.py index 0ecf87d..3cfbf3e 100644 --- a/src/deep_neurographs/machine_learning/inference.py +++ b/src/deep_neurographs/machine_learning/inference.py @@ -80,8 +80,7 @@ def __init__( self.threshold = confidence_threshold # Load image and model - driver = "n5" if ".n5" in img_path else "zarr" - self.img = img_util.open_tensorstore(img_path, driver=driver) + self.img = img_util.open_tensorstore(img_path, driver="zarr") self.model = ml_util.load_model(model_path) def run(self, neurograph, proposals): diff --git a/src/deep_neurographs/neurograph.py b/src/deep_neurographs/neurograph.py index 8bd1dc7..d9cad37 100644 --- a/src/deep_neurographs/neurograph.py +++ b/src/deep_neurographs/neurograph.py @@ -281,6 +281,7 @@ def generate_proposals( complex_bool=False, groundtruth_graph=None, long_range_bool=False, + progress_bar=True, proposals_per_leaf=3, return_trimmed_proposals=False, trim_endpoints_bool=False, @@ -300,6 +301,9 @@ def generate_proposals( long_range_bool : bool, optional Indication of whether to generate long range proposals. The default is False. + progress_bar : bool, optional + Indication of whether to print out a progress bar while generating + proposals. The default is True. proposals_per_leaf : int, optional Maximum number of proposals generated for each leaf. The default is 3. @@ -322,6 +326,7 @@ def generate_proposals( search_radius, complex_bool=complex_bool, long_range_bool=long_range_bool, + progress_bar=progress_bar, trim_endpoints_bool=trim_endpoints_bool, ) if groundtruth_graph: diff --git a/src/deep_neurographs/run_graphtrace.py b/src/deep_neurographs/run_pipeline.py similarity index 100% rename from src/deep_neurographs/run_graphtrace.py rename to src/deep_neurographs/run_pipeline.py diff --git a/src/deep_neurographs/train_graphtrace.py b/src/deep_neurographs/train_pipeline.py similarity index 100% rename from src/deep_neurographs/train_graphtrace.py rename to src/deep_neurographs/train_pipeline.py diff --git a/src/deep_neurographs/utils/img_util.py b/src/deep_neurographs/utils/img_util.py index 04acb0d..80f660f 100644 --- a/src/deep_neurographs/utils/img_util.py +++ b/src/deep_neurographs/utils/img_util.py @@ -9,11 +9,13 @@ """ from copy import deepcopy +from skimage.color import label2rgb import fastremap import numpy as np import tensorstore as ts -from skimage.color import label2rgb + +from deep_neurographs.utils import util ANISOTROPY = [0.748, 0.748, 1.0] SUPPORTED_DRIVERS = ["neuroglancer_precomputed", "n5", "zarr"] @@ -453,3 +455,28 @@ def get_chunk_labels(path, xyz, shape, from_center=True): img = open_tensorstore(path) img = read_tensorstore(img, xyz, shape, from_center=from_center) return set(fastremap.unique(img).astype(int)) + + +def find_img_path(bucket_name, img_root, dataset_name): + """ + Find the path of a specific dataset in a GCS bucket. + + Parameters: + ---------- + bucket_name : str + Name of the GCS bucket where the images are stored. + img_root : str + Root directory path in the GCS bucket where the images are located. + dataset_name : str + Name of the dataset to be searched for within the subdirectories. + + Returns: + ------- + str + Path of the found dataset subdirectory within the specified GCS bucket. + + """ + for subdir in util.list_gcs_subdirectories(bucket_name, img_root): + if dataset_name in subdir: + return subdir + "fused.zarr/" + raise(f"Dataset not found in {bucket_name} - {img_root}") diff --git a/src/deep_neurographs/utils/util.py b/src/deep_neurographs/utils/util.py index 79d820c..44a9ab3 100644 --- a/src/deep_neurographs/utils/util.py +++ b/src/deep_neurographs/utils/util.py @@ -8,18 +8,19 @@ """ -import json -import math -import os -import shutil +from google.cloud import storage from io import BytesIO from random import sample from time import time from zipfile import ZipFile import boto3 +import json +import math import numpy as np +import os import psutil +import shutil # --- dictionary utils --- @@ -348,7 +349,7 @@ def list_files_in_zip(zip_content): return zip_file.namelist() -def list_gcs_filenames(bucket, cloud_path, extension): +def list_gcs_filenames(bucket, prefix, extension): """ Lists all files in a GCS bucket with the given extension. @@ -356,7 +357,7 @@ def list_gcs_filenames(bucket, cloud_path, extension): ---------- bucket : google.cloud.client Name of bucket to be read from. - cloud_path : str + prefix : str Path to directory in "bucket". extension : str File extension of filenames to be listed. @@ -367,10 +368,45 @@ def list_gcs_filenames(bucket, cloud_path, extension): Filenames stored at "cloud" path with the given extension. """ - blobs = bucket.list_blobs(prefix=cloud_path) + blobs = bucket.list_blobs(prefix=prefix) return [blob.name for blob in blobs if extension in blob.name] +def list_gcs_subdirectories(bucket_name, prefix): + """ + Lists all direct subdirectories of a given prefix in a GCS bucket. + + Parameters + ---------- + bucket : str + Name of bucket to be read from. + prefix : str + Path to directory in "bucket". + + Returns + ------- + list[str] + List of direct subdirectories. + + """ + # Load blobs + storage_client = storage.Client() + blobs = storage_client.list_blobs( + bucket_name, prefix=prefix, delimiter="/" + ) + [blob.name for blob in blobs] + + # Parse directory contents + prefix_depth = len(prefix.split("/")) + subdirs = list() + for prefix in blobs.prefixes: + is_dir = prefix.endswith("/") + is_direct_subdir = len(prefix.split("/")) - 1 == prefix_depth + if is_dir and is_direct_subdir: + subdirs.append(prefix) + return subdirs + + # -- io utils -- def read_json(path): """ From 64179e142d8cf58e82d86f705fe95998d4f0a75f Mon Sep 17 00:00:00 2001 From: anna-grim Date: Thu, 26 Sep 2024 01:43:05 +0000 Subject: [PATCH 4/6] feat: feature generation in trainer --- src/deep_neurographs/train_pipeline.py | 31 ++++++++++++++++++++++---- src/deep_neurographs/utils/img_util.py | 2 +- src/deep_neurographs/utils/util.py | 2 +- 3 files changed, 29 insertions(+), 6 deletions(-) diff --git a/src/deep_neurographs/train_pipeline.py b/src/deep_neurographs/train_pipeline.py index 663ce35..cc670c4 100644 --- a/src/deep_neurographs/train_pipeline.py +++ b/src/deep_neurographs/train_pipeline.py @@ -9,7 +9,7 @@ """ -from deep_neurographs.utils import util +from deep_neurographs.utils import img_util, util from deep_neurographs.utils.graph_util import GraphLoader @@ -77,11 +77,34 @@ def load_example( {"dataset_name": dataset_name, "example_id": example_id} ) - def load_img(self, img_path, dataset_name): - pass + def load_img(self, path, dataset_name): + if dataset_name not in self.imgs: + self.imgs[dataset_name] = img_util.open_tensorstore(path, "zarr") def run(self): - pass + self.generate_proposals() + + def generate_proposals(self): + print("dataset_name - example_id - # proposals - % accepted") + for i in range(self.n_examples()): + # Run + self.pred_graphs[i].generate_proposals( + self.graph_config.search_radius, + complex_bool=self.graph_config.complex_bool, + groundtruth_graph=self.gt_graphs[i], + long_range_bool=self.graph_config.long_range_bool, + progress_bar=False, + proposals_per_leaf=self.graph_config.proposals_per_leaf, + trim_endpoints_bool=self.graph_config.trim_endpoints_bool, + ) + + # Report results + dataset_name = self.idx_to_ids[i]["dataset_name"] + example_id = self.idx_to_ids[i]["example_id"] + n_proposals = self.pred_graphs[i].n_proposals() + n_targets = len(self.pred_graphs[i].target_edges) + p_accepts = round(n_targets / n_proposals, 4) + print(f"{dataset_name} {example_id} {n_proposals} {p_accepts}") def generate_features(self): # check that every example has an image that was loaded! diff --git a/src/deep_neurographs/utils/img_util.py b/src/deep_neurographs/utils/img_util.py index 80f660f..41c10c5 100644 --- a/src/deep_neurographs/utils/img_util.py +++ b/src/deep_neurographs/utils/img_util.py @@ -478,5 +478,5 @@ def find_img_path(bucket_name, img_root, dataset_name): """ for subdir in util.list_gcs_subdirectories(bucket_name, img_root): if dataset_name in subdir: - return subdir + "fused.zarr/" + return subdir + "whole-brain/fused.zarr/" raise(f"Dataset not found in {bucket_name} - {img_root}") diff --git a/src/deep_neurographs/utils/util.py b/src/deep_neurographs/utils/util.py index 44a9ab3..6e165bb 100644 --- a/src/deep_neurographs/utils/util.py +++ b/src/deep_neurographs/utils/util.py @@ -389,7 +389,7 @@ def list_gcs_subdirectories(bucket_name, prefix): List of direct subdirectories. """ - # Load blobs + # Load blobs storage_client = storage.Client() blobs = storage_client.list_blobs( bucket_name, prefix=prefix, delimiter="/" From 9484330b71a9d2bb0cc266da4584f371b80794f7 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Thu, 26 Sep 2024 06:00:55 +0000 Subject: [PATCH 5/6] feat: validation sets in training --- src/deep_neurographs/train_pipeline.py | 110 ++++++++++++++++++++++--- src/deep_neurographs/utils/img_util.py | 2 +- src/deep_neurographs/utils/ml_util.py | 23 ++++++ src/deep_neurographs/utils/util.py | 10 +-- 4 files changed, 126 insertions(+), 19 deletions(-) diff --git a/src/deep_neurographs/train_pipeline.py b/src/deep_neurographs/train_pipeline.py index cc670c4..a3efb63 100644 --- a/src/deep_neurographs/train_pipeline.py +++ b/src/deep_neurographs/train_pipeline.py @@ -9,7 +9,15 @@ """ -from deep_neurographs.utils import img_util, util +import os +from datetime import datetime +from random import sample + +import numpy as np +from torch.nn import BCEWithLogitsLoss + +from deep_neurographs.machine_learning import feature_generation +from deep_neurographs.utils import img_util, ml_util, util from deep_neurographs.utils.graph_util import GraphLoader @@ -20,7 +28,14 @@ class Trainer: """ def __init__( - self, config, model_type, output_dir=None, save_model_bool=True + self, + config, + model_type, + criterion=None, + output_dir=None, + validation_ids=None, + validation_split=0.15, + save_model_bool=True, ): # Check for parameter errors if save_model_bool and not output_dir: @@ -31,11 +46,19 @@ def __init__( self.model_type = model_type self.output_dir = output_dir self.save_model_bool = save_model_bool + self.validation_ids = validation_ids + self.validation_split = validation_split # Set data structures for training examples self.gt_graphs = list() self.pred_graphs = list() self.imgs = dict() + self.train_dataset = list() + self.validation_dataset = list() + + # Train parameters + self.criterion = criterion if criterion else BCEWithLogitsLoss() + self.validation_ids = validation_ids # Extract config settings self.graph_config = config.graph_config @@ -45,15 +68,36 @@ def __init__( progress_bar=False, ) + # --- getters/setters --- def n_examples(self): return len(self.gt_graphs) + def n_train_examples(self): + return len(self.train_dataset) + + def n_validation_samples(self): + return len(self.validation_dataset) + + def set_validation_idxs(self): + if self.validation_ids is None: + k = int(self.validation_split * self.n_examples()) + self.validation_idxs = sample(np.arange(self.n_examples), k) + else: + self.validation_idxs = list() + for ids in self.validation_ids: + for i in range(self.n_examples()): + same = all([ids[k] == self.idx_to_ids[i][k] for k in ids]) + if same: + self.validation_idxs.append(i) + + # --- loaders --- def load_example( self, gt_pointer, pred_pointer, - dataset_name, + sample_id, example_id=None, + pred_id=None, metadata_path=None, ): # Read metadata @@ -74,18 +118,25 @@ def load_example( # Set example ids self.idx_to_ids.append( - {"dataset_name": dataset_name, "example_id": example_id} + { + "sample_id": sample_id, + "example_id": example_id, + "pred_id": pred_id, + } ) - def load_img(self, path, dataset_name): - if dataset_name not in self.imgs: - self.imgs[dataset_name] = img_util.open_tensorstore(path, "zarr") + def load_img(self, path, sample_id): + if sample_id not in self.imgs: + self.imgs[sample_id] = img_util.open_tensorstore(path, "zarr") + # --- main pipeline --- def run(self): self.generate_proposals() + self.generate_features() + self.train_model() def generate_proposals(self): - print("dataset_name - example_id - # proposals - % accepted") + print("sample_id - example_id - # proposals - % accepted") for i in range(self.n_examples()): # Run self.pred_graphs[i].generate_proposals( @@ -99,16 +150,49 @@ def generate_proposals(self): ) # Report results - dataset_name = self.idx_to_ids[i]["dataset_name"] + sample_id = self.idx_to_ids[i]["sample_id"] example_id = self.idx_to_ids[i]["example_id"] n_proposals = self.pred_graphs[i].n_proposals() n_targets = len(self.pred_graphs[i].target_edges) p_accepts = round(n_targets / n_proposals, 4) - print(f"{dataset_name} {example_id} {n_proposals} {p_accepts}") + print(f"{sample_id} {example_id} {n_proposals} {p_accepts}") def generate_features(self): - # check that every example has an image that was loaded! - pass + self.set_validation_idxs() + for i in range(self.n_examples()): + # Get proposals + proposals_dict = { + "proposals": self.pred_graphs[i].list_proposals(), + "graph": self.pred_graphs[i].copy_graph() + } + + # Generate features + sample_id = self.idx_to_ids[i]["sample_id"] + features = feature_generation.run( + self.pred_graphs[i], + self.imgs[sample_id], + self.model_type, + proposals_dict, + self.graph_config.search_radius, + ) - def evaluate(self): + # Initialize train and validation datasets + dataset = ml_util.init_dataset( + self.pred_graphs[i], + features, + self.model_type, + computation_graph=proposals_dict["graph"] + ) + if i in self.validation_ids: + self.validation_dataset.append(dataset) + else: + self.train_dataset.append(dataset) + + def train_model(self): pass + + def save_model(self, model): + name = self.model_type + "-" + datetime.today().strftime('%Y-%m-%d') + extension = ".pth" if "Net" in self.model_type else ".joblib" + path = os.path.join(self.output_dir, name + extension) + util.save_model(path, model) diff --git a/src/deep_neurographs/utils/img_util.py b/src/deep_neurographs/utils/img_util.py index 41c10c5..c6dd8ac 100644 --- a/src/deep_neurographs/utils/img_util.py +++ b/src/deep_neurographs/utils/img_util.py @@ -9,11 +9,11 @@ """ from copy import deepcopy -from skimage.color import label2rgb import fastremap import numpy as np import tensorstore as ts +from skimage.color import label2rgb from deep_neurographs.utils import util diff --git a/src/deep_neurographs/utils/ml_util.py b/src/deep_neurographs/utils/ml_util.py index 64eb50c..347c039 100644 --- a/src/deep_neurographs/utils/ml_util.py +++ b/src/deep_neurographs/utils/ml_util.py @@ -81,6 +81,29 @@ def load_model(path): return joblib.load(path) if ".joblib" in path else torch.load(path) +def save_model(path, model, model_type): + """ + Saves a machine learning model. + + Parameters + ---------- + path : str + Path that model parameters will be written to. + model : object + Model to be saved. + + Returns + ------- + None + + """ + print("Model saved!") + if "Net" in model_type: + torch.save(model, path) + else: + joblib.dump(model, path) + + # --- dataset utils --- def init_dataset( neurograph, features, model_type, computation_graph=None, sample_ids=None diff --git a/src/deep_neurographs/utils/util.py b/src/deep_neurographs/utils/util.py index 6e165bb..9c4809a 100644 --- a/src/deep_neurographs/utils/util.py +++ b/src/deep_neurographs/utils/util.py @@ -8,19 +8,19 @@ """ -from google.cloud import storage +import json +import math +import os +import shutil from io import BytesIO from random import sample from time import time from zipfile import ZipFile import boto3 -import json -import math import numpy as np -import os import psutil -import shutil +from google.cloud import storage # --- dictionary utils --- From 82098ff5e86a8c05f564a5fa657a96293ecca4f5 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Thu, 26 Sep 2024 17:16:56 +0000 Subject: [PATCH 6/6] bug: hgraph forward passes with missing edge types --- src/deep_neurographs/config.py | 4 ++++ .../machine_learning/datasets.py | 4 ++-- .../machine_learning/feature_generation.py | 6 ++--- .../machine_learning/heterograph_datasets.py | 24 ++++++++++++++++++- .../heterograph_feature_generation.py | 17 +++++++++++-- .../machine_learning/heterograph_models.py | 20 ++++++++++------ src/deep_neurographs/train_pipeline.py | 6 ++--- src/deep_neurographs/utils/gnn_util.py | 24 +++++++++---------- 8 files changed, 74 insertions(+), 31 deletions(-) diff --git a/src/deep_neurographs/config.py b/src/deep_neurographs/config.py index fbb403c..3421473 100644 --- a/src/deep_neurographs/config.py +++ b/src/deep_neurographs/config.py @@ -97,8 +97,12 @@ class MLConfig: batch_size: int = 2000 downsample_factor: int = 1 high_threshold: float = 0.9 + lr: float = 1e-4 threshold: float = 0.6 model_type: str = "GraphNeuralNet" + n_epochs: int = 1000 + validation_split: float = 0.15 + weight_decay: float = 1e-3 class Config: diff --git a/src/deep_neurographs/machine_learning/datasets.py b/src/deep_neurographs/machine_learning/datasets.py index eed5bf2..9e05a1a 100644 --- a/src/deep_neurographs/machine_learning/datasets.py +++ b/src/deep_neurographs/machine_learning/datasets.py @@ -295,8 +295,8 @@ def reformat(arr): def init_idxs(idxs): """ - Adds dictionary item called "edge_to_index" which maps an edge in a - neurograph to an that represents the edge's position in the feature + Adds dictionary item called "edge_to_index" which maps a branch/proposal + in a neurograph to an idx that represents it's position in the feature matrix. Parameters diff --git a/src/deep_neurographs/machine_learning/feature_generation.py b/src/deep_neurographs/machine_learning/feature_generation.py index faec027..6ed110e 100644 --- a/src/deep_neurographs/machine_learning/feature_generation.py +++ b/src/deep_neurographs/machine_learning/feature_generation.py @@ -515,15 +515,13 @@ def stack_chunks(neurograph, features, shift=0): # -- util -- -def count_features(model_type): +def count_features(): """ Counts number of features based on the "model_type". Parameters ---------- - model_type : str - Indication of model to be trained. Options include: AdaBoost, - RandomForest, FeedForwardNet, MultiModalNet. + None Returns ------- diff --git a/src/deep_neurographs/machine_learning/heterograph_datasets.py b/src/deep_neurographs/machine_learning/heterograph_datasets.py index 5cb2569..963e2b8 100644 --- a/src/deep_neurographs/machine_learning/heterograph_datasets.py +++ b/src/deep_neurographs/machine_learning/heterograph_datasets.py @@ -137,8 +137,10 @@ def __init__( # Edges self.init_edges() + self.check_missing_edge_type() self.init_edge_attrs(x_nodes) self.n_edge_attrs = n_edge_features(x_nodes) + def init_edges(self): """ @@ -192,6 +194,23 @@ def init_edge_attrs(self, x_nodes): x_nodes, edge_type, self.idxs_branches, self.idxs_proposals ) + def check_missing_edge_type(self): + edge_type = ("branch", "edge", "branch") + if len(self.data[edge_type].edge_index) == 0: + # Add dummy features + dtype = self.data["branch"].x.dtype + zeros = torch.zeros(2, self.n_branch_features(), dtype=dtype) + self.data["branch"].x = torch.cat( + (self.data["branch"].x, zeros), dim=0 + ) + + # Update edge_index + n = self.data["branch"]["x"].size(0) + edge_index = [[n - 1, n - 2], [n - 2, n - 1]] + self.data[edge_type].edge_index = gnn_util.to_tensor(edge_index) + self.idxs_branches["idx_to_edge"][n - 1] = frozenset({-1, -2}) + self.idxs_branches["idx_to_edge"][n - 2] = frozenset({-2, -3}) + # -- Getters -- def n_branch_features(self): """ @@ -342,7 +361,10 @@ def set_edge_attrs(self, x_nodes, edge_type, idx_map): for i in range(self.data[edge_type].edge_index.size(1)): e1, e2 = self.data[edge_type].edge_index[:, i] v = node_intersection(idx_map, e1, e2) - attrs.append(x_nodes[v]) + if v < 0: + attrs.append(torch.zeros(self.n_branch_features() + 1)) + else: + attrs.append(x_nodes[v]) arrs = torch.tensor(np.array(attrs), dtype=DTYPE) self.data[edge_type].edge_attr = arrs diff --git a/src/deep_neurographs/machine_learning/heterograph_feature_generation.py b/src/deep_neurographs/machine_learning/heterograph_feature_generation.py index 1ab1f57..557a9ef 100644 --- a/src/deep_neurographs/machine_learning/heterograph_feature_generation.py +++ b/src/deep_neurographs/machine_learning/heterograph_feature_generation.py @@ -16,10 +16,10 @@ from deep_neurographs.machine_learning import feature_generation as feats from deep_neurographs.utils import img_util -WINDOW = [5, 5, 5] + N_PROFILE_PTS = 16 NODE_PROFILE_DEPTH = 16 - +WINDOW = [5, 5, 5] def generate_hgnn_features( neurograph, img, proposals_dict, radius, downsample_factor @@ -458,3 +458,16 @@ def check_degenerate(voxels): [voxels, voxels[0, :] + np.array([1, 1, 1], dtype=int)] ) return voxels + + +def n_node_features(): + return {'branch': 2, 'proposal': 34} + + +def n_edge_features(): + n_edge_features_dict = { + ('proposal', 'edge', 'proposal'): 3, + ('branch', 'edge', 'branch'): 3, + ('branch', 'edge', 'proposal'): 3 + } + return n_edge_features_dict diff --git a/src/deep_neurographs/machine_learning/heterograph_models.py b/src/deep_neurographs/machine_learning/heterograph_models.py index 38a2df7..a27cd52 100644 --- a/src/deep_neurographs/machine_learning/heterograph_models.py +++ b/src/deep_neurographs/machine_learning/heterograph_models.py @@ -8,6 +8,7 @@ """ +import numpy as np import torch import torch.nn.init as init from torch import nn @@ -15,6 +16,8 @@ from torch_geometric.nn import GATv2Conv as GATConv from torch_geometric.nn import HEATConv, HeteroConv, Linear +from deep_neurographs.machine_learning import heterograph_feature_generation + CONV_TYPES = ["GATConv", "GCNConv"] DROPOUT = 0.3 HEADS_1 = 1 @@ -29,9 +32,7 @@ class HeteroGNN(torch.nn.Module): def __init__( self, - node_dict, - edge_dict, - hidden_dim, + scale_hidden_dim=2, dropout=DROPOUT, heads_1=HEADS_1, heads_2=HEADS_2, @@ -41,6 +42,11 @@ def __init__( """ super().__init__() + # Feature vector sizes + node_dict = heterograph_feature_generation.n_node_features() + edge_dict = heterograph_feature_generation.n_edge_features() + hidden_dim = scale_hidden_dim * np.max(list(node_dict.values())) + # Linear layers output_dim = heads_1 * heads_2 * hidden_dim self.input_nodes = nn.ModuleDict( @@ -161,9 +167,8 @@ def forward(self, x_dict, edge_index_dict, edge_attr_dict): x_dict = self.activation(x_dict) # Input - Edges - edge_attr_dict = { - key: f(edge_attr_dict[key]) for key, f in self.input_edges.items() - } + for key, f in self.input_edges.items(): + edge_attr_dict[key] = f(edge_attr_dict[key]) edge_attr_dict = self.activation(edge_attr_dict) # Convolutional layers @@ -218,7 +223,8 @@ def __init__( metadata=metadata, ) """ - x in_channels (int) – Size of each input sample, or -1 to derive the size from the first input(s) to the forward method. + x in_channels (int) – Size of each input sample, or -1 to + derive the size from the first input(s) to the forward method. x out_channels (int) – Size of each output sample. x num_node_types (int) – The number of node types. x num_edge_types (int) – The number of edge types. diff --git a/src/deep_neurographs/train_pipeline.py b/src/deep_neurographs/train_pipeline.py index a3efb63..7fad91a 100644 --- a/src/deep_neurographs/train_pipeline.py +++ b/src/deep_neurographs/train_pipeline.py @@ -30,11 +30,11 @@ class Trainer: def __init__( self, config, + model, model_type, criterion=None, output_dir=None, validation_ids=None, - validation_split=0.15, save_model_bool=True, ): # Check for parameter errors @@ -43,11 +43,11 @@ def __init__( # Set class attributes self.idx_to_ids = list() + self.model = model self.model_type = model_type self.output_dir = output_dir self.save_model_bool = save_model_bool self.validation_ids = validation_ids - self.validation_split = validation_split # Set data structures for training examples self.gt_graphs = list() @@ -80,7 +80,7 @@ def n_validation_samples(self): def set_validation_idxs(self): if self.validation_ids is None: - k = int(self.validation_split * self.n_examples()) + k = int(self.ml_config.validation_split * self.n_examples()) self.validation_idxs = sample(np.arange(self.n_examples), k) else: self.validation_idxs = list() diff --git a/src/deep_neurographs/utils/gnn_util.py b/src/deep_neurographs/utils/gnn_util.py index 2d8b5d9..2019f63 100644 --- a/src/deep_neurographs/utils/gnn_util.py +++ b/src/deep_neurographs/utils/gnn_util.py @@ -18,6 +18,18 @@ from deep_neurographs.utils import util +def get_inputs(data, model_type): + if "Hetero" in model_type: + x = data.x_dict + edge_index = data.edge_index_dict + edge_attr_dict = data.edge_attr_dict + return x, edge_index, edge_attr_dict + else: + x = data.x + edge_index = data.edge_index + return x, edge_index + + def toCPU(tensor): """ Moves tensor from GPU to CPU. @@ -35,18 +47,6 @@ def toCPU(tensor): return tensor.detach().cpu().tolist() -def get_inputs(data, model_type): - if "Hetero" in model_type: - x = data.x_dict - edge_index = data.edge_index_dict - edge_attr_dict = data.edge_attr_dict - return x, edge_index, edge_attr_dict - else: - x = data.x - edge_index = data.edge_index - return x, edge_index - - def to_tensor(my_list): """ Converts a list to a tensor with contiguous memory.