Skip to content

Commit

Permalink
Merge branch 'main' into refactor-img-windows
Browse files Browse the repository at this point in the history
  • Loading branch information
anna-grim authored Oct 11, 2024
2 parents 827967d + 8389bab commit a745898
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 193 deletions.
3 changes: 2 additions & 1 deletion src/deep_neurographs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,11 @@ class MLConfig:
batch_size: int = 2000
downsample_factor: int = 1
high_threshold: float = 0.9
lr: float = 1e-3
lr: float = 1e-4
threshold: float = 0.6
model_type: str = "GraphNeuralNet"
n_epochs: int = 1000
use_img_embedding: bool = False
validation_split: float = 0.15
weight_decay: float = 1e-3

Expand Down
75 changes: 41 additions & 34 deletions src/deep_neurographs/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@
from tqdm import tqdm

from deep_neurographs.graph_artifact_removal import remove_doubles
from deep_neurographs.machine_learning import feature_generation
from deep_neurographs.machine_learning.feature_generation import (
FeatureGenerator,
)
from deep_neurographs.utils import gnn_util
from deep_neurographs.utils import graph_util as gutil
from deep_neurographs.utils import img_util, ml_util, util
from deep_neurographs.utils import ml_util, util
from deep_neurographs.utils.gnn_util import toCPU
from deep_neurographs.utils.graph_util import GraphLoader

Expand Down Expand Up @@ -65,6 +67,8 @@ def __init__(
output_dir,
config,
device=None,
label_path=None,
use_img_embedding=False,
):
"""
Initializes an object that executes the full GraphTrace inference
Expand All @@ -79,7 +83,7 @@ def __init__(
Identifier for the predicted segmentation to be processed by the
inference pipeline.
img_path : str
Path to the raw image of whole brain stored on a GCS bucket.
Path to the raw image assumed to be stored in a GCS bucket.
model_path : str
Path to machine learning model parameters.
output_dir : str
Expand All @@ -89,6 +93,10 @@ def __init__(
for the inference pipeline.
device : str, optional
...
label_path : str, optional
Path to the segmentation assumed to be stored on a GCS bucket.
use_img_embedding : bool, optional
...
Returns
-------
Expand All @@ -99,7 +107,6 @@ def __init__(
self.accepted_proposals = list()
self.sample_id = sample_id
self.segmentation_id = segmentation_id
self.img_path = img_path
self.model_path = model_path

# Extract config settings
Expand All @@ -108,13 +115,15 @@ def __init__(

# Inference engine
self.inference_engine = InferenceEngine(
self.img_path,
img_path,
self.model_path,
self.ml_config.model_type,
self.graph_config.search_radius,
confidence_threshold=self.ml_config.threshold,
device=device,
downsample_factor=self.ml_config.downsample_factor,
label_path=label_path,
use_img_embedding=use_img_embedding,
)

# Set output directory
Expand Down Expand Up @@ -158,10 +167,10 @@ def run_schedule(
t0 = time()
self.report_experiment()
self.build_graph(fragments_pointer)
for round_id, search_radius in enumerate(search_radius_schedule):
print(f"--- Round {round_id + 1}: Radius = {search_radius} ---")
for round_id, radius in enumerate(radius_schedule):
print(f"--- Round {round_id + 1}: Radius = {radius} ---")
round_id += 1
self.generate_proposals(search_radius)
self.generate_proposals(radius)
self.run_inference()
if save_all_rounds:
self.save_results(round_id=round_id)
Expand Down Expand Up @@ -213,7 +222,7 @@ def build_graph(self, fragments_pointer):
print(f"Module Runtime: {round(t, 4)} {unit}\n")
self.print_graph_overview()

def generate_proposals(self, search_radius=None):
def generate_proposals(self, radius=None):
"""
Generates proposals for the fragment graph based on the specified
configuration.
Expand All @@ -229,13 +238,13 @@ def generate_proposals(self, search_radius=None):
"""
# Initializations
print("(2) Generate Proposals")
if search_radius is None:
search_radius = self.graph_config.search_radius
if radius is None:
radius = self.graph_config.radius

# Main
t0 = time()
self.graph.generate_proposals(
search_radius,
radius,
complex_bool=self.graph_config.complex_bool,
long_range_bool=self.graph_config.long_range_bool,
proposals_per_leaf=self.graph_config.proposals_per_leaf,
Expand Down Expand Up @@ -392,11 +401,13 @@ def __init__(
img_path,
model_path,
model_type,
search_radius,
radius,
batch_size=BATCH_SIZE,
confidence_threshold=CONFIDENCE_THRESHOLD,
device=None,
downsample_factor=1,
label_path=None,
use_img_embedding=False
):
"""
Initializes an inference engine by loading images and setting class
Expand All @@ -410,7 +421,7 @@ def __init__(
Path to machine learning model parameters.
model_type : str
Type of machine learning model used to perform inference.
search_radius : float
radius : float
Search radius used to generate proposals.
batch_size : int, optional
Number of proposals to generate features and classify per batch.
Expand All @@ -429,16 +440,20 @@ def __init__(
"""
# Set class attributes
self.batch_size = batch_size
self.downsample_factor = downsample_factor
self.device = "cpu" if device is None else device
self.is_gnn = True if "Graph" in model_type else False
self.model_type = model_type
self.search_radius = search_radius
self.radius = radius
self.threshold = confidence_threshold

# Load image and model
driver = "n5" if ".n5" in img_path else "zarr"
self.img = img_util.open_tensorstore(img_path, driver=driver)
# Features
self.feature_generator = FeatureGenerator(
img_path,
downsample_factor,
label_path=label_path,
use_img_embedding=use_img_embedding
)

# Model
self.model = ml_util.load_model(model_path)
if self.is_gnn:
self.model = self.model.to(self.device)
Expand Down Expand Up @@ -532,22 +547,14 @@ def get_batch_dataset(self, neurograph, batch):
...
"""
# Generate features
features = feature_generation.run(
neurograph,
self.img,
self.model_type,
batch,
self.search_radius,
downsample_factor=self.downsample_factor,
)

# Initialize dataset
t0 = time()
features = self.feature_generator.run(neurograph, batch, self.radius)
print("Feature Generation:", time() - t0)
computation_graph = batch["graph"] if type(batch) is dict else None
dataset = ml_util.init_dataset(
neurograph,
features,
self.model_type,
self.is_gnn,
computation_graph=computation_graph,
)
return dataset
Expand All @@ -570,7 +577,7 @@ def predict(self, dataset):
"""
# Get predictions
if self.model_type == "GraphNeuralNet":
if self.is_gnn:
with torch.no_grad():
# Get inputs
n = len(dataset.data["proposal"]["y"])
Expand All @@ -585,7 +592,7 @@ def predict(self, dataset):
preds = np.array(self.model.predict_proba(dataset.data.x)[:, 1])

# Reformat prediction
idxs = dataset.idxs_proposals["idx_to_edge"]
idxs = dataset.idxs_proposals["idx_to_id"]
return {idxs[i]: p for i, p in enumerate(preds)}


Expand Down
29 changes: 29 additions & 0 deletions src/deep_neurographs/machine_learning/archived/features.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""
Created on Sat May 9 11:00:00 2024

@author: Anna Grim
@email: anna.grim@alleninstitute.org

Archived routines for feature generation.

"""

def compute_curvature(neurograph, edge):
kappa = curvature(neurograph.edges[edge]["xyz"])
n_pts = len(kappa)
if n_pts <= N_BRANCH_PTS:
sampled_kappa = np.zeros((N_BRANCH_PTS))
sampled_kappa[0:n_pts] = kappa
else:
idxs = np.linspace(0, n_pts - 1, N_BRANCH_PTS).astype(int)
sampled_kappa = kappa[idxs]
return np.array(sampled_kappa)


def curvature(xyz_list):
a = np.linalg.norm(xyz_list[1:-1] - xyz_list[:-2], axis=1)
b = np.linalg.norm(xyz_list[2:] - xyz_list[1:-1], axis=1)
c = np.linalg.norm(xyz_list[2:] - xyz_list[:-2], axis=1)
s = 0.5 * (a + b + c)
delta = np.sqrt(s * (s - a) * (s - b) * (s - c))
return 4 * delta / (a * b * c)
Loading

0 comments on commit a745898

Please sign in to comment.