Skip to content

Commit

Permalink
feat: gnn-based inference
Browse files Browse the repository at this point in the history
  • Loading branch information
anna-grim committed Apr 16, 2024
1 parent 1c34588 commit e064313
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,6 @@ def generate_skel_features(neurograph, proposals):
n_nearby_leafs(neurograph, proposal),
get_radii(neurograph, proposal),
get_avg_radii(neurograph, proposal),
get_avg_branch_lens(neurograph, proposal),
get_directionals(neurograph, proposal, 8),
get_directionals(neurograph, proposal, 16),
get_directionals(neurograph, proposal, 32),
Expand Down
12 changes: 10 additions & 2 deletions src/deep_neurographs/machine_learning/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def predict(
confidence_threshold=0.7,
):
# Generate features
features = feature_generation.run_on_proposals(
features = feature_generation.run(
neurograph,
model_type,
search_radius,
Expand All @@ -167,18 +167,26 @@ def predict(
dataset = ml_utils.init_dataset(neurograph, features, model_type)

# Run model
idx_to_edge = get_idxs(dataset, model_type)
proposal_probs = run_model(dataset, model, model_type)
accepts, graph = build.get_accepted_proposals(
neurograph,
graph,
proposal_probs,
dataset["idx_to_edge"],
idx_to_edge,
high_threshold=0.95,
low_threshold=confidence_threshold,
)
return accepts, graph


def get_idxs(dataset, model_type):
if "Graph" in model_type:
return dataset.idxs_proposals["idx_to_edge"]
else:
return dataset["idx_to_edge"]


# -- Whole Brain Seed-Based Inference --
def build_from_soma(
neurograph, labels_path, chunk_origin, chunk_shape=CHUNK_SHAPE, n_hops=1
Expand Down
18 changes: 17 additions & 1 deletion src/deep_neurographs/machine_learning/ml_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import torch
from sklearn.ensemble import AdaBoostClassifier, RandomForestClassifier

from deep_neurographs.machine_learning import feature_generation
from deep_neurographs.machine_learning import feature_generation, graph_datasets
from deep_neurographs.machine_learning.datasets import (
ImgProposalDataset,
MultiModalDataset,
Expand Down Expand Up @@ -152,6 +152,22 @@ def get_dataset(inputs, targets, model_type, transform, lengths):

def init_dataset(
neurographs, features, model_type, block_ids=None, transform=False
):
if "Graph" in model_type:
dataset = graph_datasets.init(neurographs, features)
else:
dataset = init_proposal_dataset(
neurographs,
features,
model_type,
block_ids=block_ids,
transform=transform
)
return dataset


def init_proposal_dataset(
neurographs, features, model_type, block_ids=None, transform=False
):
# Extract features
inputs, targets, idx_transforms = feature_generation.get_matrix(
Expand Down

0 comments on commit e064313

Please sign in to comment.