Skip to content

Commit

Permalink
Refactor feature generation (#255)
Browse files Browse the repository at this point in the history
* refactor removed old gnn option

* refactor: improved heterognn feature generation

---------

Co-authored-by: anna-grim <anna.grim@alleninstitute.org>
  • Loading branch information
anna-grim and anna-grim authored Sep 30, 2024
1 parent b08c9e5 commit 4f04abc
Show file tree
Hide file tree
Showing 10 changed files with 199 additions and 232 deletions.
98 changes: 43 additions & 55 deletions src/deep_neurographs/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from datetime import datetime
from time import time
from torch.nn.functional import sigmoid
from torch.utils.data import DataLoader
from tqdm import tqdm

import networkx as nx
Expand Down Expand Up @@ -66,6 +65,7 @@ def __init__(
model_path,
output_dir,
config,
device=None,
):
"""
Initializes an object that executes the full GraphTrace inference
Expand All @@ -88,6 +88,8 @@ def __init__(
config : Config
Configuration object containing parameters and settings required
for the inference pipeline.
device : str, optional
...
Returns
-------
Expand All @@ -105,6 +107,17 @@ def __init__(
self.graph_config = config.graph_config
self.ml_config = config.ml_config

# Inference engine
self.inference_engine = InferenceEngine(
self.img_path,
self.model_path,
self.ml_config.model_type,
self.graph_config.search_radius,
confidence_threshold=self.ml_config.threshold,
device=device,
downsample_factor=self.ml_config.downsample_factor,
)

# Set output directory
date = datetime.today().strftime("%Y-%m-%d")
self.output_dir = f"{output_dir}/{segmentation_id}-{date}"
Expand All @@ -127,11 +140,7 @@ def run(self, fragments_pointer):
"""
# Initializations
print("\nExperiment Details")
print("-----------------------------------------------")
print("Sample_ID:", self.sample_id)
print("Segmentation_ID:", self.segmentation_id)
print("")
self.report_experiment()
self.write_metadata()
t0 = time()

Expand All @@ -145,15 +154,8 @@ def run(self, fragments_pointer):
print(f"Total Runtime: {round(t, 4)} {unit}\n")

def run_schedule(self, fragments_pointer, search_radius_schedule):
# Initializations
print("\nExperiment Details")
print("-----------------------------------------------")
print("Sample_ID:", self.sample_id)
print("Segmentation_ID:", self.segmentation_id)
print("")
t0 = time()

# Main
self.report_experiment()
self.build_graph(fragments_pointer)
for round_id, search_radius in enumerate(search_radius_schedule):
print(f"--- Round {round_id + 1}: Radius = {search_radius} ---")
Expand Down Expand Up @@ -258,15 +260,7 @@ def run_inference(self):
print("(3) Run Inference")
t0 = time()
n_proposals = self.graph.n_proposals()
inference_engine = InferenceEngine(
self.img_path,
self.model_path,
self.ml_config.model_type,
self.graph_config.search_radius,
confidence_threshold=self.ml_config.threshold,
downsample_factor=self.ml_config.downsample_factor,
)
self.graph, accepts = inference_engine.run(
self.graph, accepts = self.inference_engine.run(
self.graph, self.graph.list_proposals()
)
self.accepted_proposals.extend(accepts)
Expand Down Expand Up @@ -297,6 +291,13 @@ def save_results(self, round_id=None):
self.save_connections(round_id=round_id)
self.write_metadata()

def report_experiment(self):
print("\nExperiment Overview")
print("-----------------------------------------------")
print("Sample_ID:", self.sample_id)
print("Segmentation_ID:", self.segmentation_id)
print("")

# --- io ---
def save_connections(self, round_id=None):
"""
Expand Down Expand Up @@ -390,6 +391,7 @@ def __init__(
search_radius,
batch_size=BATCH_SIZE,
confidence_threshold=CONFIDENCE_THRESHOLD,
device=None,
downsample_factor=1,
):
"""
Expand Down Expand Up @@ -424,6 +426,7 @@ def __init__(
# Set class attributes
self.batch_size = batch_size
self.downsample_factor = downsample_factor
self.device = "cpu" if device is None else device
self.is_gnn = True if "Graph" in model_type else False
self.model_type = model_type
self.search_radius = search_radius
Expand All @@ -433,6 +436,9 @@ def __init__(
driver = "n5" if ".n5" in img_path else "zarr"
self.img = img_util.open_tensorstore(img_path, driver=driver)
self.model = ml_util.load_model(model_path)
if self.is_gnn:
self.model = self.model.to(self.device)
self.model.eval()

def run(self, neurograph, proposals):
"""
Expand Down Expand Up @@ -470,7 +476,7 @@ def run(self, neurograph, proposals):
# Predict
batch = self.get_batch(neurograph, proposals)
dataset = self.get_batch_dataset(neurograph, batch)
preds = self.run_model(dataset)
preds = self.predict(dataset)

# Update graph
batch_accepts = get_accepted_proposals(
Expand Down Expand Up @@ -547,7 +553,7 @@ def get_batch_dataset(self, neurograph, batch):
)
return dataset

def run_model(self, dataset):
def predict(self, dataset):
"""
Runs the model on the given dataset to generate and filter
predictions.
Expand All @@ -561,47 +567,29 @@ def run_model(self, dataset):
-------
dict
A dictionary that maps a proposal to the model's prediction (i.e.
probability). Note that this dictionary only contains proposals
whose predicted probability is greater the threshold.
probability).
"""
# Get predictions
if self.model_type == "GraphNeuralNet":
preds = run_gnn_model(dataset.data, self.model)
elif "Net" in self.model_type:
preds = run_nn_model(dataset.data, self.model)
with torch.no_grad():
# Get inputs
n = len(dataset.data["proposal"]["y"])
x, edge_index, edge_attr = gnn_util.get_inputs(
dataset.data, device=self.device
)

# Run model
preds = sigmoid(self.model(x, edge_index, edge_attr))
preds = toCPU(preds[0:n, 0])
else:
preds = np.array(self.model.predict_proba(dataset.data.x)[:, 1])

# Filter preds
# Reformat prediction
idxs = dataset.idxs_proposals["idx_to_edge"]
return {idxs[i]: p for i, p in enumerate(preds)}


# --- run machine learning model ---
def run_nn_model(data, model):
hat_y = list()
model.eval()
with torch.no_grad():
for batch in DataLoader(data, batch_size=32):
# Run model
hat_y_i = sigmoid(model(batch["inputs"]))

# Postprocess
hat_y_i = np.array(hat_y_i)
hat_y.extend(hat_y_i[:, 0].tolist())
return np.array(hat_y)


def run_gnn_model(data, model):
model.eval()
with torch.no_grad():
x, edge_index, edge_attr = gnn_util.get_inputs(data)
hat_y = sigmoid(model(x, edge_index, edge_attr))
idx = len(data["proposal"]["y"])
return toCPU(hat_y[0:idx, 0])


# --- Accepting Proposals ---
def get_accepted_proposals(neurograph, preds, threshold, high_threshold=0.9):
"""
Expand Down
4 changes: 2 additions & 2 deletions src/deep_neurographs/machine_learning/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


# Wrapper
def init(neurograph, features, model_type, sample_ids=None):
def init(neurograph, features, sample_ids=None):
"""
Initializes a dataset that can be used to train a machine learning model.
Expand All @@ -41,7 +41,7 @@ def init(neurograph, features, model_type, sample_ids=None):
"""
# Extract features
x_proposals, y_proposals, idxs_proposals = feature_generation.get_matrix(
neurograph, features["proposals"], model_type, sample_ids=sample_ids
neurograph, features["proposals"], sample_ids=sample_ids
)

# Initialize dataset
Expand Down
54 changes: 9 additions & 45 deletions src/deep_neurographs/machine_learning/feature_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""

from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed
from copy import deepcopy
from random import sample
Expand All @@ -29,7 +30,7 @@

CHUNK_SIZE = [64, 64, 64]
N_BRANCH_PTS = 50
N_PROFILE_PTS = 16
N_PROFILE_PTS = 16 # 10
N_SKEL_FEATURES = 22


Expand Down Expand Up @@ -115,55 +116,18 @@ def generate_features(
Feature vectors.
"""
features = {
"proposals": run_on_proposals(
neurograph,
img,
proposals_dict["proposals"],
radius,
downsample_factor,
)
}
return features


# -- feature generation by graphical structure type --
def run_on_proposals(neurograph, img, proposals, radius, downsample_factor):
"""
Generates feature vectors for a set of proposals in a neurograph.
Parameters
----------
neurograph : NeuroGraph
Graph that "proposals" belong to.
img : tensorstore.Tensorstore
Image stored in a GCS bucket.
proposals : list[frozenset]
List of proposals for which features will be generated.
radius : float
Search radius used to generate proposals.
downsample_factor : int
Downsampling factor that accounts for which level in the image pyramid
the voxel coordinates must index into.
Returns
-------
dict
Dictionary whose keys are feature types (i.e. skeletal and profiles)
and values are a dictionary that maps a proposal id to the
corresponding feature vector.
"""
proposal_features = {
"skel": proposal_skeletal(neurograph, proposals, radius),
features = defaultdict(bool)
features["proposals"] = {
"skel": proposal_skeletal(
neurograph, proposals_dict["proposals"], radius
),
"profiles": proposal_profiles(
neurograph, img, proposals, downsample_factor
neurograph, img, proposals_dict["proposals"], downsample_factor
),
}
return proposal_features
return features


# -- part 1: proposal feature generation --
def proposal_profiles(neurograph, img, proposals, downsample_factor):
"""
Generates an image intensity profile along each proposal by reading from
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,12 +195,11 @@ def edge_skeletal(neurograph, computation_graph):
"""
edge_skeletal_features = dict()
for edge in neurograph.edges:
edge_skeletal_features[frozenset(edge)] = np.concatenate(
(
edge_skeletal_features[frozenset(edge)] = np.array(
[
np.mean(neurograph.edges[edge]["radius"]),
neurograph.edge_length(edge) / 1000,
),
axis=None,
neurograph.edges[edge]["length"] / 1000,
],
)
return edge_skeletal_features

Expand All @@ -226,7 +225,6 @@ def proposal_skeletal(neurograph, proposals, radius):
"""
proposal_skeletal_features = dict()
for proposal in proposals:
i, j = tuple(proposal)
proposal_skeletal_features[proposal] = np.concatenate(
(
neurograph.proposal_length(proposal),
Expand Down
8 changes: 4 additions & 4 deletions src/deep_neurographs/machine_learning/heterograph_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def check_missing_edge_type(self):
# Update edge_index
n = self.data["branch"]["x"].size(0)
edge_index = [[n - 1, n - 2], [n - 2, n - 1]]
self.data[edge_type].edge_index = gnn_util.to_tensor(edge_index)
self.data[edge_type].edge_index = gnn_util.toTensor(edge_index)
self.idxs_branches["idx_to_edge"][n - 1] = frozenset({-1, -2})
self.idxs_branches["idx_to_edge"][n - 2] = frozenset({-2, -3})

Expand Down Expand Up @@ -282,7 +282,7 @@ def proposal_to_proposal(self):
v1 = self.idxs_proposals["edge_to_idx"][frozenset(e1)]
v2 = self.idxs_proposals["edge_to_idx"][frozenset(e2)]
edge_index.extend([[v1, v2], [v2, v1]])
return gnn_util.to_tensor(edge_index)
return gnn_util.toTensor(edge_index)

def branch_to_branch(self):
"""
Expand All @@ -308,7 +308,7 @@ def branch_to_branch(self):
v1 = self.idxs_branches["edge_to_idx"][frozenset(e1)]
v2 = self.idxs_branches["edge_to_idx"][frozenset(e2)]
edge_index.extend([[v1, v2], [v2, v1]])
return gnn_util.to_tensor(edge_index)
return gnn_util.toTensor(edge_index)

def branch_to_proposal(self):
"""
Expand Down Expand Up @@ -338,7 +338,7 @@ def branch_to_proposal(self):
if frozenset((j, k)) not in self.proposals:
v2 = self.idxs_branches["edge_to_idx"][frozenset((j, k))]
edge_index.extend([[v2, v1]])
return gnn_util.to_tensor(edge_index)
return gnn_util.toTensor(edge_index)

# Set Edge Attributes
def set_edge_attrs(self, x_nodes, edge_type, idx_map):
Expand Down
7 changes: 4 additions & 3 deletions src/deep_neurographs/machine_learning/heterograph_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class HeteroGNN(torch.nn.Module):

def __init__(
self,
device=None,
scale_hidden_dim=2,
dropout=DROPOUT,
heads_1=HEADS_1,
Expand All @@ -50,12 +51,12 @@ def __init__(
# Linear layers
output_dim = heads_1 * heads_2 * hidden_dim
self.input_nodes = nn.ModuleDict(
{key: nn.Linear(d, hidden_dim) for key, d in node_dict.items()}
{key: nn.Linear(d, hidden_dim, device=device) for key, d in node_dict.items()}
)
self.input_edges = {
key: nn.Linear(d, hidden_dim) for key, d in edge_dict.items()
key: nn.Linear(d, hidden_dim, device=device) for key, d in edge_dict.items()
}
self.output = Linear(output_dim, 1)
self.output = Linear(output_dim, 1, device=device)

# Convolutional layers
self.conv1 = HeteroConv(
Expand Down
Loading

0 comments on commit 4f04abc

Please sign in to comment.