Skip to content

Commit

Permalink
Feat img features (#134)
Browse files Browse the repository at this point in the history
* feat: generates img profiles for node nbhd

* bug: flake8 bugs

* refactor: black

---------

Co-authored-by: anna-grim <anna.grim@alleninstitute.org>
  • Loading branch information
anna-grim and anna-grim authored May 9, 2024
1 parent ced5d75 commit db397e8
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 62 deletions.
2 changes: 1 addition & 1 deletion src/deep_neurographs/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def optimize_simple_alignment(neurograph, img, edge, depth=15):
"""
i, j = tuple(edge)
branch_i = neurograph.get_branches(i, ignore_reducibles=True)[0]
branch_j = neurograph.get_branches(j, ignore_reducibles=True,)[0]
branch_j = neurograph.get_branches(j, ignore_reducibles=True)[0]
d_i, d_j, _ = align(neurograph, img, branch_i, branch_j, depth)
return branch_i[d_i], branch_j[d_j]

Expand Down
114 changes: 56 additions & 58 deletions src/deep_neurographs/machine_learning/feature_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def generate_proposal_profiles(neurograph, proposals, img):
# Generate coordinates
coords = dict()
for i, proposal in enumerate(proposals):
coords[proposal] = get_profile_coords(neurograph, proposal)
coords[proposal] = get_proposal_profile_coords(neurograph, proposal)

# Generate profiles
with ThreadPoolExecutor() as executor:
Expand Down Expand Up @@ -287,7 +287,7 @@ def generate_node_profiles(neurograph, img):

coords[i] = {
"bbox": get_node_bbox(neurograph, path),
"path": geometry.sample_curve(path, N_PROFILE_PTS)
"path": geometry.sample_curve(path, N_PROFILE_PTS),
}

# Generate profiles
Expand All @@ -304,39 +304,29 @@ def generate_node_profiles(neurograph, img):
return profiles


def get_leaf_profile_coords(neurograph, i):
j = list(neurograph.neighbors(i))[0]
return get_profile_path(neurograph.orient_edge((i, j), i, key="xyz"))


def get_junction_profile_coords(neurograph, i):
# Get branches
nbs = list(neurograph.neighbors(i))
xyz_list_1 = neurograph.orient_edge((i, nbs[0]), i, key="xyz")
xyz_list_2 = neurograph.orient_edge((i, nbs[1]), i, key="xyz")

# Get profile paths
path_1 = get_profile_path(xyz_list_1)
path_2 = get_profile_path(xyz_list_2)
return np.vstack([np.flip(path_1, axis=0), path_2])

def get_profile(img, coords, thread_id):
"""
Gets the image intensity profile for a given proposal.
def get_profile_path(xyz_list):
path_length = 0
for i in range(1, len(xyz_list)):
if i > 0:
path_length += geometry.dist(xyz_list[i - 1], xyz_list[i])
if path_length >= NODE_PROFILE_DEPTH:
break
return xyz_list[0:i, :]
Parameters
----------
img : tensorstore.TensorStore
Image to be queried.
coords : dict
...
thread_id : hashable
...
Returns
-------
thread_id : hashable
...
list[int]
Image intensity profile.
def get_node_bbox(neurograph, coords):
bbox = {
"start": np.floor(np.min(coords, axis=0)).astype(int),
"end": np.ceil(np.max(coords, axis=0)).astype(int),
}
return bbox
"""
chunk = utils.read_tensorstore_bbox(img, coords["bbox"])
return thread_id, [chunk[tuple(xyz)] for xyz in coords["path"]]


def get_proposal_profile_coords(neurograph, proposal):
Expand Down Expand Up @@ -372,30 +362,40 @@ def get_proposal_profile_coords(neurograph, proposal):
}
return coords


def get_leaf_profile_coords(neurograph, i):
j = list(neurograph.neighbors(i))[0]
return get_profile_path(neurograph.orient_edge((i, j), i, key="xyz"))


def get_profile(img, coords, thread_id):
"""
Gets the image intensity profile for a given proposal.
def get_junction_profile_coords(neurograph, i):
# Get branches
nbs = list(neurograph.neighbors(i))
xyz_list_1 = neurograph.orient_edge((i, nbs[0]), i, key="xyz")
xyz_list_2 = neurograph.orient_edge((i, nbs[1]), i, key="xyz")

Parameters
----------
img : tensorstore.TensorStore
Image to be queried.
coords : dict
...
thread_id : hashable
...
# Get profile paths
path_1 = get_profile_path(xyz_list_1)
path_2 = get_profile_path(xyz_list_2)
return np.vstack([np.flip(path_1, axis=0), path_2])

Returns
-------
thread_id : hashable
...
list[int]
Image intensity profile.

"""
chunk = utils.read_tensorstore_bbox(img, coords["bbox"])
return thread_id, [chunk[tuple(xyz)] for xyz in coords["path"]]
def get_profile_path(xyz_list):
path_length = 0
for i in range(1, len(xyz_list)):
if i > 0:
path_length += geometry.dist(xyz_list[i - 1], xyz_list[i])
if path_length >= NODE_PROFILE_DEPTH:
break
return xyz_list[0:i, :]


def get_node_bbox(neurograph, coords):
bbox = {
"start": np.floor(np.min(coords, axis=0)).astype(int),
"end": np.ceil(np.max(coords, axis=0)).astype(int),
}
return bbox


def generate_skel_features(neurograph, proposals, search_radius):
Expand All @@ -404,7 +404,7 @@ def generate_skel_features(neurograph, proposals, search_radius):
i, j = tuple(proposal)
features[proposal] = np.concatenate(
(
1, # edge type
1, # edge type
neurograph.proposal_length(proposal),
neurograph.degree[i],
neurograph.degree[j],
Expand Down Expand Up @@ -512,14 +512,11 @@ def generate_branch_features(neurograph):
for edge in neurograph.edges:
i, j = tuple(edge)
features[frozenset(edge)] = np.concatenate(
(
-1, # edge type
np.zeros((32))
),
axis=None,
(-1, np.zeros((32))), axis=None # edge type
)
return features


"""
0,
neurograph.degree[i],
Expand All @@ -532,6 +529,7 @@ def generate_branch_features(neurograph):
np.zeros((N_PROFILE_PTS + 2)),
"""


def compute_curvature(neurograph, edge):
kappa = curvature(neurograph.edges[edge]["xyz"])
n_pts = len(kappa)
Expand Down
11 changes: 8 additions & 3 deletions src/deep_neurographs/machine_learning/graph_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,9 @@ def run_on_graphs(self, datasets, augment=False):
y, hat_y = [], []
self.model.train()
for graph_id in train_ids:
y_i, hat_y_i = self.train(datasets[graph_id], epoch, augment=augment)
y_i, hat_y_i = self.train(
datasets[graph_id], epoch, augment=augment
)
y.extend(toCPU(y_i))
hat_y.extend(toCPU(hat_y_i))
self.compute_metrics(y, hat_y, "train", epoch)
Expand Down Expand Up @@ -196,7 +198,7 @@ def train(self, dataset, epoch, augment=False):

def augment(self, dataset):
augmented_dataset = rescale_data(dataset, self.scaling_factor)
#augmented_data = proposal_dropout(dataset, self.max_proposal_dropout)
# augmented_data = proposal_dropout(dataset, self.max_proposal_dropout)
return augmented_dataset

def forward(self, data):
Expand Down Expand Up @@ -441,9 +443,12 @@ def proposal_dropout(data, max_proposal_dropout):
for edge in remove_edges:
reversed_edge = [edge[1], edge[0]]
edges_to_remove = torch.tensor([edge, reversed_edge], dtype=torch.long)
edges_mask = torch.all(data.data.edge_index.T == edges_to_remove[:, None], dim=2).any(dim=0)
edges_mask = torch.all(
data.data.edge_index.T == edges_to_remove[:, None], dim=2
).any(dim=0)
data.data.edge_index = data.data.edge_index[:, ~edges_mask]
return data


def count_proposals(dataset):
return dataset.data.y.size(0)

0 comments on commit db397e8

Please sign in to comment.