Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

upds #112

Merged
merged 1 commit into from
Apr 10, 2024
Merged

upds #112

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 70 additions & 32 deletions src/deep_neurographs/machine_learning/groundtruth_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,36 @@
MIN_INTERSECTION = 10


def init_targets(target_neurograph, pred_neurograph):
def init_targets(target_neurograph, pred_neurograph, strict=True):
"""
Initializes ground truth for edge proposals.

Parameters
----------
target_neurograph : NeuroGraph
Graph built from ground truth swc files.
pred_neurograph : NeuroGraph
Graph build from predicted swc files.
strict : bool, optional
Indication if whether target edges should be determined by using
stricter criteria that checks if proposals are reasonably well
aligned. The default is True.

Returns
-------
target_edges : set
Edge proposals that machine learning model learns to accept.

"""
# Initializations
target_edges = set()
valid_proposals = get_valid_proposals(target_neurograph, pred_neurograph)
lengths = [pred_neurograph.proposal_length(e) for e in valid_proposals]

# Add best simple edges
dists = [pred_neurograph.proposal_length(e) for e in valid_proposals]
target_edges = set()
graph = pred_neurograph.copy_graph()
for idx in np.argsort(dists):
edge = valid_proposals[idx]
for i in np.argsort(lengths):
edge = valid_proposals[i]
created_cycle, _ = gutils.creates_cycle(graph, tuple(edge))
if not created_cycle:
graph.add_edges_from([edge])
Expand All @@ -39,26 +59,18 @@ def init_targets(target_neurograph, pred_neurograph):


def get_valid_proposals(target_neurograph, pred_neurograph):
# Detect components unaligned to ground truth
invalid_proposals = set()
node_to_target = dict()
for component in nx.connected_components(pred_neurograph):
aligned, target_id = is_component_aligned(
target_neurograph, pred_neurograph, component
)
if not aligned:
i = utils.sample_singleton(component)
invalid_proposals.add(pred_neurograph.nodes[i]["swc_id"])
else:
node_to_target = upd_dict(node_to_target, component, target_id)
# Initializations
valid_proposals = list()
invalid_ids, node_to_target = unaligned_components(
target_neurograph, pred_neurograph
)

# Check whether aligned to same/adjacent target edges (i.e. valid)
valid_proposals = list()
for edge in pred_neurograph.proposals:
# Filter invalid and proposals btw different components
i, j = tuple(edge)
invalid_i = pred_neurograph.nodes[i]["swc_id"] in invalid_proposals
invalid_j = pred_neurograph.nodes[j]["swc_id"] in invalid_proposals
invalid_i = pred_neurograph.nodes[i]["swc_id"] in invalid_ids
invalid_j = pred_neurograph.nodes[j]["swc_id"] in invalid_ids
if invalid_i or invalid_j:
continue
elif node_to_target[i] != node_to_target[j]:
Expand All @@ -71,6 +83,41 @@ def get_valid_proposals(target_neurograph, pred_neurograph):
return valid_proposals


def unaligned_components(target_neurograph, pred_neurograph):
"""
Detects connected components in "pred_neurograph" that are unaligned to a
connected component in "target_neurograph".

Parameters
----------
target_neurograph : NeuroGraph
Graph built from ground truth swc files.
pred_neurograph : NeuroGraph
Graph build from predicted swc files.

Returns
-------
invalid_ids : set
IDs in ""pred_neurograph" that correspond to connected components that
are unaligned to a connected component in "target_neurograph".
node_to_target : dict
Mapping between nodes and target ids.

"""
invalid_ids = set()
node_to_target = dict()
for component in nx.connected_components(pred_neurograph):
aligned, target_id = is_component_aligned(
target_neurograph, pred_neurograph, component
)
if not aligned:
i = utils.sample_singleton(component)
invalid_ids.add(pred_neurograph.nodes[i]["swc_id"])
else:
node_to_target = upd_dict(node_to_target, component, target_id)
return invalid_ids, node_to_target


def is_component_aligned(target_neurograph, pred_neurograph, component):
"""
Determines whether the connected component defined by "node_subset" is
Expand Down Expand Up @@ -137,7 +184,6 @@ def is_valid(target_neurograph, pred_neurograph, target_id, edge):
bool
Indication of whether proposal is consistent
"""
# aligned = is_proposal_aligned(target_neurograph, pred_neurograph, edge)
consistent = is_consistent(
target_neurograph, pred_neurograph, target_id, edge
)
Expand All @@ -164,7 +210,7 @@ def is_consistent(target_neurograph, pred_neurograph, target_id, edge):
Returns
-------
bool
Indication of whether proposal is consistent
Indication of whether proposal is consistent.

"""
# Find closest edges from target_neurograph
Expand All @@ -186,16 +232,8 @@ def is_consistent(target_neurograph, pred_neurograph, target_id, edge):
xyz_j = pred_neurograph.nodes[j]["xyz"]
if is_adjacent_aligned(hat_branch_i, hat_branch_j, xyz_i, xyz_j):
return True
return False


def is_proposal_aligned(target_neurograph, pred_neurograph, edge):
xyz_0, xyz_1 = pred_neurograph.proposal_xyz(edge)
proj_dists = []
for xyz in geometry.make_line(xyz_0, xyz_1, 10):
hat_xyz = target_neurograph.get_projection(tuple(xyz))
proj_dists.append(get_dist(hat_xyz, xyz))
return True if np.mean(proj_dists) < ALIGNED_THRESHOLD else False
else:
return False


def proj_branch(target_neurograph, pred_neurograph, target_id, i):
Expand Down
30 changes: 28 additions & 2 deletions src/deep_neurographs/reconstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,6 @@ def threshold_preds(preds, idx_to_edge, threshold, valid_idxs=[]):
predicted probability.

"""
print(preds)
thresholded_preds = dict()
for i, pred_i in enumerate(preds):
contained_bool = True if len(valid_idxs) == 0 else i in valid_idxs
Expand All @@ -111,7 +110,7 @@ def threshold_preds(preds, idx_to_edge, threshold, valid_idxs=[]):


def get_structure_aware_accepts(
neurograph, graph, preds, high_threshold=0.8, low_threshold=0.6
neurograph, graph, preds, high_threshold=0.9, low_threshold=0.6
):
# Add best preds
best_preds, best_probs = get_best_preds(neurograph, preds, high_threshold)
Expand Down Expand Up @@ -254,3 +253,30 @@ def save_corrections(neurograph, proposal_preds, output_dir):
xyz_i = neurograph.nodes[i]["xyz"]
xyz_j = neurograph.nodes[j]["xyz"]
swc_utils.save_edge(path, xyz_i, xyz_j, color=color, radius=3)


def save_connections(neurograph, accepted_proposals, path):
"""
Saves predicted connections between connected components in a txt file.

Parameters
----------
neurograph : NeuroGraph
Graph built from predicted swc files.
accepted_proposals : list[frozenset]
List of accepted edge proposals where each entry is a frozenset that
consists of the nodes corresponding to a predicted connection.
path : str
Path that output is written to.

Returns
-------
None

"""
with open(path, 'w') as f:
for edge in accepted_proposals:
i, j = tuple(edge)
swc_id_i = neurograph.nodes[i]["swc_id"]
swc_id_j = neurograph.nodes[j]["swc_id"]
f.write(f"{swc_id_i}, {swc_id_j}" + '\n')
2 changes: 1 addition & 1 deletion src/deep_neurographs/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def plot(data, title):
fig.update_layout(
title=title,
template="plotly_white",
plot_bgcolor="rgba(0, 0, 0, 0)",
#plot_bgcolor="rgba(0, 0, 0, 0)",
scene=dict(aspectmode="manual", aspectratio=dict(x=1, y=1, z=1)),
width=1200,
height=700,
Expand Down
Loading