Skip to content

Commit

Permalink
feat: added additional skeleton features (#36)
Browse files Browse the repository at this point in the history
Co-authored-by: anna-grim <anna.grim@alleninstitute.org>
  • Loading branch information
anna-grim and anna-grim authored Jan 18, 2024
1 parent d4dabd5 commit a3a513a
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 17 deletions.
45 changes: 37 additions & 8 deletions src/deep_neurographs/feature_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,21 +177,35 @@ def generate_mutable_skel_features(neurograph):
for edge in neurograph.mutable_edges:
i, j = tuple(edge)
radius_i, radius_j = get_radii(neurograph, edge)
dot1, dot2, dot3 = get_directionals(neurograph, edge, 8)
ddot1, ddot2, ddot3 = get_directionals(neurograph, edge, 16)
avg_radius_i, avg_radius_j = get_avg_radii(neurograph, edge)
avg_len_i, avg_len_j = get_avg_branch_len(neurograph, edge)
dot4_1, dot4_2, dot4_3 = get_directionals(neurograph, edge, 4)
dot8_1, dot8_2, dot8_3 = get_directionals(neurograph, edge, 8)
dot16_1, dot16_2, dot16_3 = get_directionals(neurograph, edge, 16)
dot32_1, dot32_2, dot_32_3 = get_directionals(neurograph, edge, 32)
features[edge] = np.concatenate(
(
neurograph.compute_length(edge),
neurograph.immutable_degree(i),
neurograph.immutable_degree(j),
radius_i,
radius_j,
dot1,
dot2,
dot3,
ddot1,
ddot2,
ddot3,
#avg_radius_i,
#avg_radius_j,
#avg_len_i,
#avg_len_j,
#dot4_1,
#dot4_2,
#dot4_3,
dot8_1,
dot8_2,
dot8_3,
dot16_1,
dot16_2,
dot16_3,
dot32_1,
dot32_2,
dot32_3,
),
axis=None,
)
Expand All @@ -218,6 +232,21 @@ def get_directionals(neurograph, edge, window):
return inner_product_1, inner_product_2, inner_product_3


def get_avg_radii(neurograph, edge):
i, j = tuple(edge)
radii_i = neurograph.get_branches(i, key="radius")
radii_j = neurograph.get_branches(j, key="radius")
return get_avg_radii(radii_i), get_avg_radii(radii_j)


def get_avg_radius(radii_list):
avg = 0
for radii in radii_list:
end = min(16, len(radii) - 1)
avg += np.mean(radii[0:end]) / len(radii_list)
return avg


def get_radii(neurograph, edge):
i, j = tuple(edge)
radius_i = neurograph.nodes[i]["radius"]
Expand Down
17 changes: 8 additions & 9 deletions src/deep_neurographs/neurograph.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,10 +308,10 @@ def run_optimization(self):
proposal = [self.to_world(xyz_1), self.to_world(xyz_2)]
self.edges[edge]["xyz"] = np.vstack(proposal)

def get_branch(self, xyz_or_node):
def get_branch(self, xyz_or_node, key="xyz"):
if type(xyz_or_node) == int:
nb = self.get_immutable_nbs(xyz_or_node)[0]
return self.orient_edge((xyz_or_node, nb), xyz_or_node)
return self.orient_edge((xyz_or_node, nb), xyz_or_node, key=key)
else:
edge = self.xyz_to_edge[tuple(xyz_or_node)]
branch = deepcopy(self.edges[edge]["xyz"])
Expand All @@ -320,18 +320,17 @@ def get_branch(self, xyz_or_node):
else:
return branch

def get_branches(self, i):
def get_branches(self, i, key="xyz"):
branches = []
for j in self.neighbors(i):
if frozenset((i, j)) in self.immutable_edges:
branches.append(self.orient_edge((i, j), i))
for j in self.get_immutable_nbs(i):
branches.append(self.orient_edge((i, j), i, key=key))
return branches

def orient_edge(self, edge, i):
def orient_edge(self, edge, i, key="xyz"):
if (self.edges[edge]["xyz"][0, :] == self.nodes[i]["xyz"]).all():
return self.edges[edge]["xyz"]
return self.edges[edge][key]
else:
return np.flip(self.edges[edge]["xyz"], axis=0)
return np.flip(self.edges[edge][key], axis=0)

# --- Ground Truth Generation ---
def init_targets(self, target_neurograph):
Expand Down

0 comments on commit a3a513a

Please sign in to comment.