Skip to content

Commit

Permalink
lint: black (#113)
Browse files Browse the repository at this point in the history
Co-authored-by: anna-grim <anna.grim@alleninstitute.org>
  • Loading branch information
anna-grim and anna-grim authored Apr 10, 2024
1 parent f3f77bf commit c0d4a0e
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 17 deletions.
5 changes: 5 additions & 0 deletions src/deep_neurographs/machine_learning/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,11 @@ def __getitem__(self, idx):
return {"inputs": inputs, "targets": self.targets[idx]}


class ProposalGraphDataset(Dataset):
def __init__(self, neurograph, inputs, labels):
pass


# Augmentation
class AugmentImages:
"""
Expand Down
22 changes: 13 additions & 9 deletions src/deep_neurographs/reconstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,10 @@ def get_structure_aware_accepts(
good_preds.append(edge)
good_probs.append(prob)

more_accepts, graph = check_cycles_sequential(graph, good_preds, good_probs)
accepts.extend(more_accepts)
more_accepts, graph = check_cycles_sequential(
graph, good_preds, good_probs
)
accepts.extend(more_accepts)
return accepts, graph


Expand Down Expand Up @@ -216,8 +218,6 @@ def get_best_preds(neurograph, preds, threshold):


def fuse_branches(neurograph, edges):
simple_cnt = 0
complex_cnt = 0
for edge in edges:
neurograph.merge_proposal(edge)
return neurograph
Expand All @@ -229,10 +229,14 @@ def save_prediction(neurograph, accepted_proposals, output_dir):
corrections_dir = os.path.join(output_dir, "corrections")
utils.mkdir(output_dir, delete=True)
utils.mkdir(corrections_dir, delete=True)

connections_path = os.path.join(output_dir, "connections.txt")
reconstruction.save_prediction(output_neurograph, accepted_proposals, output_dir)
utils.save_connection(pred_neurograph, accepted_proposals, connections_path)
save_prediction(
neurograph, accepted_proposals, output_dir
)
utils.save_connection(
neurograph, accepted_proposals, connections_path
)

# Write Result
neurograph.to_swc(output_dir)
Expand Down Expand Up @@ -271,9 +275,9 @@ def save_connections(neurograph, accepted_proposals, path):
None
"""
with open(path, 'w') as f:
with open(path, "w") as f:
for edge in accepted_proposals:
i, j = tuple(edge)
swc_id_i = neurograph.nodes[i]["swc_id"]
swc_id_j = neurograph.nodes[j]["swc_id"]
f.write(f"{swc_id_i}, {swc_id_j}" + '\n')
f.write(f"{swc_id_i}, {swc_id_j}" + "\n")
20 changes: 13 additions & 7 deletions src/deep_neurographs/swc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,9 +238,7 @@ def write_list(path, entry_list, color=None):
else:
f.write("# id, type, z, y, x, r, pid")
for i, entry in enumerate(entry_list):
f.write("\n")
for item in entry:
f.write(str(item) + " ")
f.write("\n" + entry)


def write_dict(path, swc_dict, color=None):
Expand All @@ -266,10 +264,10 @@ def write_graph(path, graph, color=None):
List of swc file entries to be written.
"""
node_to_idx = dict()
node_to_idx = {-1: -1}
for i, j in nx.dfs_edges(graph):
# Initialize entry list
if len(node_to_idx) < 1:
if len(node_to_idx) == 1:
entry, node_to_idx = make_entry(graph, i, -1, node_to_idx)
entry_list = [entry]

Expand Down Expand Up @@ -361,13 +359,21 @@ def make_entry(graph, i, parent, node_to_idx):
...
"""
r = graph[i]["radius"]
r = set_radius(graph, i)
x, y, z = tuple(graph.nodes[i]["xyz"])
node_to_idx[i] = len(node_to_idx) + 1
node_to_idx[i] = len(node_to_idx)
entry = f"{node_to_idx[i]} 2 {x} {y} {z} {r} {node_to_idx[parent]}"
return entry, node_to_idx


def set_radius(graph, i):
try:
radius = graph[i]["radius"]
return radius
except:
return 1


def make_simple_entry(node, parent, xyz, radius=8):
"""
Makes an entry to be written in an swc file.
Expand Down
2 changes: 1 addition & 1 deletion src/deep_neurographs/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def plot(data, title):
fig.update_layout(
title=title,
template="plotly_white",
#plot_bgcolor="rgba(0, 0, 0, 0)",
# plot_bgcolor="rgba(0, 0, 0, 0)",
scene=dict(aspectmode="manual", aspectratio=dict(x=1, y=1, z=1)),
width=1200,
height=700,
Expand Down

0 comments on commit c0d4a0e

Please sign in to comment.