From e064313b4cf10ac317dc7cdc7071898d04e889bd Mon Sep 17 00:00:00 2001 From: anna-grim Date: Tue, 16 Apr 2024 22:40:58 +0000 Subject: [PATCH] feat: gnn-based inference --- .../machine_learning/feature_generation.py | 1 - .../machine_learning/inference.py | 12 ++++++++++-- .../machine_learning/ml_utils.py | 18 +++++++++++++++++- 3 files changed, 27 insertions(+), 4 deletions(-) diff --git a/src/deep_neurographs/machine_learning/feature_generation.py b/src/deep_neurographs/machine_learning/feature_generation.py index 14692bd..4523995 100644 --- a/src/deep_neurographs/machine_learning/feature_generation.py +++ b/src/deep_neurographs/machine_learning/feature_generation.py @@ -335,7 +335,6 @@ def generate_skel_features(neurograph, proposals): n_nearby_leafs(neurograph, proposal), get_radii(neurograph, proposal), get_avg_radii(neurograph, proposal), - get_avg_branch_lens(neurograph, proposal), get_directionals(neurograph, proposal, 8), get_directionals(neurograph, proposal, 16), get_directionals(neurograph, proposal, 32), diff --git a/src/deep_neurographs/machine_learning/inference.py b/src/deep_neurographs/machine_learning/inference.py index bf6efb5..5c1776b 100644 --- a/src/deep_neurographs/machine_learning/inference.py +++ b/src/deep_neurographs/machine_learning/inference.py @@ -156,7 +156,7 @@ def predict( confidence_threshold=0.7, ): # Generate features - features = feature_generation.run_on_proposals( + features = feature_generation.run( neurograph, model_type, search_radius, @@ -167,18 +167,26 @@ def predict( dataset = ml_utils.init_dataset(neurograph, features, model_type) # Run model + idx_to_edge = get_idxs(dataset, model_type) proposal_probs = run_model(dataset, model, model_type) accepts, graph = build.get_accepted_proposals( neurograph, graph, proposal_probs, - dataset["idx_to_edge"], + idx_to_edge, high_threshold=0.95, low_threshold=confidence_threshold, ) return accepts, graph +def get_idxs(dataset, model_type): + if "Graph" in model_type: + return dataset.idxs_proposals["idx_to_edge"] + else: + return dataset["idx_to_edge"] + + # -- Whole Brain Seed-Based Inference -- def build_from_soma( neurograph, labels_path, chunk_origin, chunk_shape=CHUNK_SHAPE, n_hops=1 diff --git a/src/deep_neurographs/machine_learning/ml_utils.py b/src/deep_neurographs/machine_learning/ml_utils.py index 7c2f1bc..76b7332 100644 --- a/src/deep_neurographs/machine_learning/ml_utils.py +++ b/src/deep_neurographs/machine_learning/ml_utils.py @@ -15,7 +15,7 @@ import torch from sklearn.ensemble import AdaBoostClassifier, RandomForestClassifier -from deep_neurographs.machine_learning import feature_generation +from deep_neurographs.machine_learning import feature_generation, graph_datasets from deep_neurographs.machine_learning.datasets import ( ImgProposalDataset, MultiModalDataset, @@ -152,6 +152,22 @@ def get_dataset(inputs, targets, model_type, transform, lengths): def init_dataset( neurographs, features, model_type, block_ids=None, transform=False +): + if "Graph" in model_type: + dataset = graph_datasets.init(neurographs, features) + else: + dataset = init_proposal_dataset( + neurographs, + features, + model_type, + block_ids=block_ids, + transform=transform + ) + return dataset + + +def init_proposal_dataset( + neurographs, features, model_type, block_ids=None, transform=False ): # Extract features inputs, targets, idx_transforms = feature_generation.get_matrix(