Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,6 @@ def generate_skel_features(neurograph, proposals):
n_nearby_leafs(neurograph, proposal),
get_radii(neurograph, proposal),
get_avg_radii(neurograph, proposal),
get_avg_branch_lens(neurograph, proposal),
get_directionals(neurograph, proposal, 8),
get_directionals(neurograph, proposal, 16),
get_directionals(neurograph, proposal, 32),
Expand Down
12 changes: 10 additions & 2 deletions src/deep_neurographs/machine_learning/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def predict(
confidence_threshold=0.7,
):
# Generate features
features = feature_generation.run_on_proposals(
features = feature_generation.run(
neurograph,
model_type,
search_radius,
Expand All @@ -167,18 +167,26 @@ def predict(
dataset = ml_utils.init_dataset(neurograph, features, model_type)

# Run model
idx_to_edge = get_idxs(dataset, model_type)
proposal_probs = run_model(dataset, model, model_type)
accepts, graph = build.get_accepted_proposals(
neurograph,
graph,
proposal_probs,
dataset["idx_to_edge"],
idx_to_edge,
high_threshold=0.95,
low_threshold=confidence_threshold,
)
return accepts, graph


def get_idxs(dataset, model_type):
if "Graph" in model_type:
return dataset.idxs_proposals["idx_to_edge"]
else:
return dataset["idx_to_edge"]


# -- Whole Brain Seed-Based Inference --
def build_from_soma(
neurograph, labels_path, chunk_origin, chunk_shape=CHUNK_SHAPE, n_hops=1
Expand Down
18 changes: 17 additions & 1 deletion src/deep_neurographs/machine_learning/ml_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import torch
from sklearn.ensemble import AdaBoostClassifier, RandomForestClassifier

from deep_neurographs.machine_learning import feature_generation
from deep_neurographs.machine_learning import feature_generation, graph_datasets
from deep_neurographs.machine_learning.datasets import (
ImgProposalDataset,
MultiModalDataset,
Expand Down Expand Up @@ -152,6 +152,22 @@ def get_dataset(inputs, targets, model_type, transform, lengths):

def init_dataset(
neurographs, features, model_type, block_ids=None, transform=False
):
if "Graph" in model_type:
dataset = graph_datasets.init(neurographs, features)
else:
dataset = init_proposal_dataset(
neurographs,
features,
model_type,
block_ids=block_ids,
transform=transform
)
return dataset


def init_proposal_dataset(
neurographs, features, model_type, block_ids=None, transform=False
):
# Extract features
inputs, targets, idx_transforms = feature_generation.get_matrix(
Expand Down