From 4916a65dcbd64ca3af3558a8b6a13e5cf345188e Mon Sep 17 00:00:00 2001 From: anna-grim Date: Fri, 12 Apr 2024 01:38:37 +0000 Subject: [PATCH] feat: graph data sets --- .../machine_learning/datasets.py | 14 +- .../machine_learning/feature_generation.py | 227 +++++++++++------- .../machine_learning/graph_datasets.py | 195 +++++++++++++++ .../machine_learning/ml_utils.py | 14 +- src/deep_neurographs/reconstruction.py | 8 +- 5 files changed, 354 insertions(+), 104 deletions(-) create mode 100644 src/deep_neurographs/machine_learning/graph_datasets.py diff --git a/src/deep_neurographs/machine_learning/datasets.py b/src/deep_neurographs/machine_learning/datasets.py index 9755697..c7607e7 100644 --- a/src/deep_neurographs/machine_learning/datasets.py +++ b/src/deep_neurographs/machine_learning/datasets.py @@ -4,7 +4,7 @@ @author: Anna Grim @email: anna.grim@alleninstitute.org -Custom datasets for deep learning models. +Custom datasets for training deep learning models. """ @@ -230,11 +230,6 @@ def __getitem__(self, idx): return {"inputs": inputs, "targets": self.targets[idx]} -class ProposalGraphDataset(Dataset): - def __init__(self, neurograph, inputs, labels): - pass - - # Augmentation class AugmentImages: """ @@ -284,6 +279,13 @@ def run(self, arr): return self.transform(arr) +def get_lengths(neurograph): + lengths = [] + for edge in neurograph.proposals.keys(): + lengths.append(neurograph.proposal_length(edge)) + return lengths + + # -- utils -- def reformat(arr): """ diff --git a/src/deep_neurographs/machine_learning/feature_generation.py b/src/deep_neurographs/machine_learning/feature_generation.py index d78637f..168fa8e 100644 --- a/src/deep_neurographs/machine_learning/feature_generation.py +++ b/src/deep_neurographs/machine_learning/feature_generation.py @@ -26,6 +26,7 @@ CHUNK_SIZE = [64, 64, 64] WINDOW = [5, 5, 5] +N_BRANCH_PTS = 50 N_PROFILE_PTS = 10 N_SKEL_FEATURES = 19 SUPPORTED_MODELS = [ @@ -34,13 +35,16 @@ "FeedForwardNet", "ConvNet", "MultiModalNet", + "GraphNeuralNet", ] # -- Wrappers -- -def run(neurograph, model_type, img_path, labels_path=None, proposals=None): +def run_on_proposals( + neurograph, model_type, img_path, labels_path=None, proposals=None +): """ - Generates feature vectors for every edge proposal in a neurograph. + Generates feature vectors for every proposal in a neurograph. Parameters ---------- @@ -56,7 +60,7 @@ def run(neurograph, model_type, img_path, labels_path=None, proposals=None): Path to predicted segmentation stored in a GCS bucket. The default is None. proposals : list[frozenset], optional - List of edge proposals for which features will be generated. The + List of proposals for which features will be generated. The default is None. Returns @@ -88,12 +92,32 @@ def run(neurograph, model_type, img_path, labels_path=None, proposals=None): return features -# -- Edge feature extraction -- +def run_on_branches(neurograph, branches): + """ + Generates feature vectors for every edge in a neurograph. + + Parameters + ---------- + neurograph : NeuroGraph + NeuroGraph generated from a directory of swcs generated from a + predicted segmentation. + + Returns + ------- + features : dict + Dictionary where each key-value pair corresponds to a type of feature + vector and the numerical vector. + + """ + return {"skel": generate_branch_features(neurograph, branches)} + + +# -- Proposal Feature Extraction -- def generate_img_chunks(neurograph, proposals, img, labels): """ - Generates an image chunk for each edge proposal such that the centroid of - the image chunk is the midpoint of the edge proposal. Image chunks contain - two channels: raw image and predicted segmentation. + Generates an image chunk for each proposal such that the centroid of the + image chunk is the midpoint of the proposal. Image chunks contain two + channels: raw image and predicted segmentation. Parameters ---------- @@ -105,33 +129,33 @@ def generate_img_chunks(neurograph, proposals, img, labels): labels : tensorstore.TensorStore Predicted segmentation mask stored in a GCS bucket. proposals : list[frozenset], optional - List of edge proposals for which features will be generated. The + List of proposals for which features will be generated. The default is None. Returns ------- features : dict - Dictonary such that each pair is the edge id and image chunk. + Dictonary such that each pair is the proposal id and image chunk. """ with ThreadPoolExecutor() as executor: # Assign Threads threads = [None] * len(proposals) - for t, edge in enumerate(proposals): - xyz_0, xyz_1 = neurograph.proposal_xyz(edge) + for t, proposal in enumerate(proposals): + xyz_0, xyz_1 = neurograph.proposal_xyz(proposal) coord_0 = utils.to_img(xyz_0) coord_1 = utils.to_img(xyz_1) threads[t] = executor.submit( - get_img_chunk, img, labels, coord_0, coord_1, edge + get_img_chunk, img, labels, coord_0, coord_1, proposal ) # Save result chunks = dict() profiles = dict() for thread in as_completed(threads): - edge, chunk, profile = thread.result() - chunks[edge] = chunk - profiles[edge] = profile + proposal, chunk, profile = thread.result() + chunks[proposal] = chunk + profiles[proposal] = profile return chunks, profiles @@ -166,8 +190,8 @@ def get_img_chunk(img, labels, coord_0, coord_1, thread_id=None): def generate_img_profiles(neurograph, proposals, img): """ - Generates an image intensity profile along each edge proposal by reading - from an image on the cloud. + Generates an image intensity profile along each proposal by reading from + an image on the cloud. Parameters ---------- @@ -175,21 +199,21 @@ def generate_img_profiles(neurograph, proposals, img): NeuroGraph generated from a directory of swcs generated from a predicted segmentation. proposals : list[frozenset] - List of edge proposals for which features will be generated. + List of proposals for which features will be generated. img : tensorstore.TensorStore Image stored in a GCS bucket. Returns ------- features : dict - Dictonary such that each pair is the edge id and image intensity + Dictonary such that each pair is the proposal id and image intensity profile. """ # Generate coordinates coords = dict() - for i, edge in enumerate(proposals): - coords[edge] = get_profile_coords(neurograph, edge) + for i, proposal in enumerate(proposals): + coords[proposal] = get_profile_coords(neurograph, proposal) # Generate profiles img_profiles = dict() @@ -204,7 +228,7 @@ def generate_img_profiles(neurograph, proposals, img): return img_profiles -def get_profile_coords(neurograph, edge): +def get_profile_coords(neurograph, proposal): """ Gets coordinates needed to compute an image intensity profile. @@ -213,8 +237,8 @@ def get_profile_coords(neurograph, edge): neurograph : NeuroGarph NeuroGraph generated from a directory of swcs generated from a predicted segmentation. - edge : frozenset - Edge proposal that image intensity profile will be generated for. + proposal : frozenset + Proposal that image intensity profile will be generated for. Returns ------- @@ -223,7 +247,7 @@ def get_profile_coords(neurograph, edge): """ # Compute coordinates - xyz_0, xyz_1 = neurograph.proposal_xyz(edge) + xyz_0, xyz_1 = neurograph.proposal_xyz(proposal) coord_0 = utils.to_img(xyz_0) coord_1 = utils.to_img(xyz_1) @@ -237,70 +261,70 @@ def get_profile_coords(neurograph, edge): return coords -def get_profile(img, edge, coords): +def get_profile(img, proposal, coords): """ - Gets the image intensity profile for a given edge proposal. + Gets the image intensity profile for a given proposal. Parameters ---------- img : tensorstore.TensorStore Image to be queried. - edge : frozenset - Edge proposal that image profile corresponds to. + proposal : frozenset + Proposal that image profile corresponds to. Returns ------- - edge : frozenset - Edge proposal that image profile corresponds to. + proposal : frozenset + Proposal that image profile corresponds to. list[int] Image intensity profile. """ chunk = utils.read_tensorstore_bbox(img, coords["bbox"]) line = geometry.make_line(coords["start"], coords["end"], N_PROFILE_PTS) - return edge, [chunk[tuple(xyz)] for xyz in line] + return proposal, [chunk[tuple(xyz)] for xyz in line] def generate_skel_features(neurograph, proposals): features = dict() - for edge in proposals: - i, j = tuple(edge) - features[edge] = np.concatenate( + for proposal in proposals: + i, j = tuple(proposal) + features[proposal] = np.concatenate( ( - neurograph.proposal_length(edge), + neurograph.proposal_length(proposal), neurograph.degree[i], neurograph.degree[j], - get_radii(neurograph, edge), - get_avg_radii(neurograph, edge), - get_directionals(neurograph, edge, 8), - get_directionals(neurograph, edge, 16), - get_directionals(neurograph, edge, 32), - get_directionals(neurograph, edge, 64), + get_radii(neurograph, proposal), + get_avg_radii(neurograph, proposal), + get_directionals(neurograph, proposal, 8), + get_directionals(neurograph, proposal, 16), + get_directionals(neurograph, proposal, 32), + get_directionals(neurograph, proposal, 64), ), axis=None, ) return features -def get_directionals(neurograph, edge, window_size): +def get_directionals(neurograph, proposal, window_size): # Compute tangent vectors - i, j = tuple(edge) - edge_direction = geometry.compute_tangent( - neurograph.proposals[edge]["xyz"] + i, j = tuple(proposal) + proposal_direction = geometry.compute_tangent( + neurograph.proposals[proposal]["xyz"] ) - origin = neurograph.proposal_midpoint(edge) + origin = neurograph.proposal_midpoint(proposal) direction_i = geometry.get_directional(neurograph, i, origin, window_size) direction_j = geometry.get_directional(neurograph, j, origin, window_size) # Compute features - inner_product_1 = abs(np.dot(edge_direction, direction_i)) - inner_product_2 = abs(np.dot(edge_direction, direction_j)) + inner_product_1 = abs(np.dot(proposal_direction, direction_i)) + inner_product_2 = abs(np.dot(proposal_direction, direction_j)) inner_product_3 = np.dot(direction_i, direction_j) return np.array([inner_product_1, inner_product_2, inner_product_3]) -def get_avg_radii(neurograph, edge): - i, j = tuple(edge) +def get_avg_radii(neurograph, proposal): + i, j = tuple(proposal) radii_i = neurograph.get_branches(i, key="radius") radii_j = neurograph.get_branches(j, key="radius") return np.array([get_avg_radius(radii_i), get_avg_radius(radii_j)]) @@ -314,8 +338,8 @@ def get_avg_radius(radii_list): return avg -def get_avg_branch_lens(neurograph, edge): - i, j = tuple(edge) +def get_avg_branch_lens(neurograph, proposal): + i, j = tuple(proposal) branches_i = neurograph.get_branches(i, key="xyz") branches_j = neurograph.get_branches(j, key="xyz") return np.array([get_branch_len(branches_i), get_branch_len(branches_j)]) @@ -328,13 +352,55 @@ def get_branch_len(branch_list): return branch_len -def get_radii(neurograph, edge): - i, j = tuple(edge) +def get_radii(neurograph, proposal): + i, j = tuple(proposal) radius_i = neurograph.nodes[i]["radius"] radius_j = neurograph.nodes[j]["radius"] return np.array([radius_i, radius_j]) +def avg_branch_radii(neurograph, edge): + return np.array([np.mean(neurograph.edges[edge]["radius"])]) + + +# --- Edge Feature Generation -- +def generate_branch_features(neurograph, edges): + features = dict() + for (i, j) in edges: + edge = frozenset((i, j)) + features[edge] = np.zeros((31)) + + temp = np.concatenate( + ( + np.array([len(neurograph.edges[i, j]["xyz"])]), + avg_branch_radii(neurograph, edge), + compute_curvature(neurograph, edge), + ) + ) + return features + + +def compute_curvature(neurograph, edge): + kappa = curvature(neurograph.edges[edge]["xyz"]) + n_pts = len(kappa) + if n_pts <= N_BRANCH_PTS: + sampled_kappa = np.zeros((N_BRANCH_PTS)) + sampled_kappa[0:n_pts] = kappa + else: + idxs = np.linspace(0, n_pts - 1, N_BRANCH_PTS).astype(int) + sampled_kappa = kappa[idxs] + return np.array(sampled_kappa) + + +def curvature(xyz_list): + a = np.linalg.norm(xyz_list[1:-1] - xyz_list[:-2], axis=1) + b = np.linalg.norm(xyz_list[2:] - xyz_list[1:-1], axis=1) + c = np.linalg.norm(xyz_list[2:] - xyz_list[:-2], axis=1) + s = 0.5 * (a + b + c) + delta = np.sqrt(s * (s - a) * (s - b) * (s - c)) + return 4 * delta / (a * b * c) + + # -- Build feature matrix def get_feature_matrix(neurographs, features, model_type, block_ids=None): assert model_type in SUPPORTED_MODELS, "Error! model_type not supported" @@ -350,27 +416,25 @@ def __multiblock_feature_matrix(neurographs, features, blocks, model_type): # Initialize X = None y = None - - block_to_idxs = dict() - idx_to_edge = dict() + idx_transforms = {"block_to_idxs": dict(), "idx_to_edge": dict()} # Feature extraction for block_id in blocks: if neurographs[block_id].n_proposals() == 0: - block_to_idxs[block_id] = set() + idx_transforms["block_to_idxs"][block_id] = set() continue idx_shift = 0 if X is None else X.shape[0] if model_type == "MultiModalNet": - X_i, x_i, y_i, idxs_i, idx_to_edge_i = get_multimodal_features( + X_i, x_i, y_i, idx_transforms_i = get_multimodal_features( neurographs[block_id], features[block_id], shift=idx_shift ) elif model_type == "ConvNet": - X_i, y_i, idxs_i, idx_to_edge_i = stack_img_chunks( + X_i, y_i, idx_transforms_i = stack_img_chunks( neurographs[block_id], features[block_id], shift=idx_shift ) else: - X_i, y_i, idxs_i, idx_to_edge_i = get_feature_vectors( + X_i, y_i, idx_transforms_i = get_feature_vectors( neurographs[block_id], features[block_id], shift=idx_shift ) @@ -387,13 +451,15 @@ def __multiblock_feature_matrix(neurographs, features, blocks, model_type): x = np.concatenate((x, x_i), axis=0) # Update dicts - block_to_idxs[block_id] = idxs_i - idx_to_edge.update(idx_to_edge_i) + idx_transforms["block_to_idxs"][block_id] = idx_transforms_i[ + "block_to_idxs" + ] + idx_transforms["idx_to_edge"].update(idx_transforms_i["idx_to_edge"]) if model_type == "MultiModalNet": X = {"imgs": X, "features": x} - return X, y, block_to_idxs, idx_to_edge + return X, y, idx_transforms def __feature_matrix(neurographs, features, model_type): @@ -409,18 +475,17 @@ def get_feature_vectors(neurograph, features, shift=0): # Initialize features = combine_features(features) key = sample(list(features.keys()), 1)[0] - X = np.zeros((neurograph.n_proposals(), len(features[key]))) - y = np.zeros((neurograph.n_proposals())) + X = np.zeros((len(features.keys()), len(features[key]))) + y = np.zeros((len(features.keys()))) + idx_transforms = {"block_to_idxs": set(), "idx_to_edge": dict()} # Build - idxs = set() - idx_to_edge = dict() for i, edge in enumerate(features.keys()): X[i, :] = features[edge] y[i] = 1 if edge in neurograph.target_edges else 0 - idxs.add(i + shift) - idx_to_edge[i + shift] = edge - return X, y, idxs, idx_to_edge + idx_transforms["block_to_idxs"].add(i + shift) + idx_transforms["idx_to_edge"][i + shift] = edge + return X, y, idx_transforms def get_multimodal_features(neurograph, features, shift=0): @@ -429,35 +494,33 @@ def get_multimodal_features(neurograph, features, shift=0): X = np.zeros(((n_edges, 2) + tuple(CHUNK_SIZE))) x = np.zeros((n_edges, N_SKEL_FEATURES + N_PROFILE_PTS)) y = np.zeros((n_edges)) + idx_transforms = {"block_to_idxs": set(), "idx_to_edge": dict()} # Build - idxs = set() - idx_to_edge = dict() for i, edge in enumerate(features["img_chunks"].keys()): X[i, :] = features["img_chunks"][edge] x[i, :] = np.concatenate( (features["skel"][edge], features["img_profile"][edge]) ) y[i] = 1 if edge in neurograph.target_edges else 0 - idxs.add(i + shift) - idx_to_edge[i + shift] = edge - return X, x, y, idxs, idx_to_edge + idx_transforms["block_to_idxs"].add(i + shift) + idx_transforms["idx_to_edge"][i + shift] = edge + return X, x, y, idx_transforms def stack_img_chunks(neurograph, features, shift=0): # Initialize X = np.zeros(((neurograph.n_proposals(), 2) + tuple(CHUNK_SIZE))) y = np.zeros((neurograph.n_proposals())) + idx_transforms = {"block_to_idxs": set(), "idx_to_edge": dict()} # Build - idxs = set() - idx_to_edge = dict() for i, edge in enumerate(features["img_chunks"].keys()): X[i, :] = features["img_chunks"][edge] y[i] = 1 if edge in neurograph.target_edges else 0 - idxs.add(i + shift) - idx_to_edge[i + shift] = edge - return X, y, idxs, idx_to_edge + idx_transforms["block_to_idxs"].add(i + shift) + idx_transforms["idx_to_edge"][i + shift] = edge + return X, y, idx_transforms # -- Utils -- diff --git a/src/deep_neurographs/machine_learning/graph_datasets.py b/src/deep_neurographs/machine_learning/graph_datasets.py new file mode 100644 index 0000000..521fc14 --- /dev/null +++ b/src/deep_neurographs/machine_learning/graph_datasets.py @@ -0,0 +1,195 @@ +""" +Created on Sat April 11 15:30:00 2023 + +@author: Anna Grim +@email: anna.grim@alleninstitute.org + +Custom datasets for training graph neural networks. + +""" + +import networkx as nx +import numpy as np +import torch +from torch.utils.data import Dataset +from torch_geometric.data import Data as GraphData +from torch_geometric.data import HeteroData as HeteroGraphData + +from deep_neurographs.machine_learning import feature_generation + + +# Wrapper +def init(neurograph, branch_features, proposal_features, heterogeneous=False): + """ + Initializes a dataset that can be used to train a graph neural network. + + Parameters + ---------- + + """ + # Extract features + x_branches, _, idxs_branches = feature_generation.get_feature_matrix( + neurograph, branch_features, "GraphNeuralNet" + ) + x_proposals, y_proposals, idxs_proposals = feature_generation.get_feature_matrix( + neurograph, proposal_features, "GraphNeuralNet" + ) + + # Initialize data + if heterogeneous: + data, idxs_branches, idxs_proposals = HeteroGraphDataset( + neurograph, x_branches, x_proposals, idxs_branches, idxs_proposals + ) + else: + graph_dataset = GraphDataset( + neurograph, x_branches, x_proposals, idxs_branches, idxs_proposals + ) + + # Store dataset + dataset = { + "dataset": graph_dataset, + "idxs_branches": idxs_branches, + "idxs_proposals": idxs_proposals, + } + return dataset + + +# Datasets +class GraphDataset: + def __init__( + self, + neurograph, + x_branches, + x_proposals, + idxs_branches, + idxs_proposals, + ): + # Combine feature matrices + x = torch.tensor(np.vstack([x_proposals, x_branches])) + idxs_branches = upd_idxs(idxs_branches, x_proposals.shape[0]) + self.idxs_branches = add_edge_to_idx(idxs_branches) + self.idxs_proposals = add_edge_to_idx(idxs_proposals) + + # Initialize data + edge_index = init_edge_index(neurograph, idxs_branches, idxs_proposals) + self.data = GraphData(x=x, edge_index=edge_index) + + +class HeteroGraphDataset: + def __init__( + self, + neurograph, + x_branches, + x_proposals, + y_proposals, + idxs_branches, + idxs_proposals, + ): + # Update idxs + idxs_branches = add_edge_to_idx(idxs_branches) + idxs_proposals = add_edge_to_idx(idxs_proposals) + + # Init dataset + data = HeteroGraphData() + data["branch"].x = x_branches + data["proposal"].x = x_proposals + data["proposal", "to", "proposal"] = None + data["proposal", "to", "branch"] = None + data["branch", "to", "branch"] = None + + +# -- utils -- +def upd_idxs(idxs, shift): + """ + Updates index transform dictionary "idxs" by shifting each index by + "shift". + + idxs : dict + ... + shift : int + ... + + Returns + ------- + idxs : dict + Updated index transform dictinoary. + + """ + idxs["block_to_idxs"] = upd_set(idxs["block_to_idxs"], shift) + idxs["idx_to_edge"] = upd_dict(idxs["idx_to_edge"], shift) + return idxs + + +def upd_set(my_set, shift): + shifted_set = set() + for element in my_set: + shifted_set.add(element + shift) + return shifted_set + + +def upd_dict(my_dict, shift): + shifted_dict = dict() + for key, value in my_dict.items(): + shifted_dict[key + shift] = value + return shifted_dict + + +def add_edge_to_idx(idxs): + idxs["edge_to_idx"] = dict() + for idx, edge in idxs["idx_to_edge"].items(): + idxs["edge_to_idx"][edge] = idx + return idxs + + +def init_edge_index(neurograph, idxs_branches, idxs_proposals): + # Initializations + branches_line_graph = nx.line_graph(neurograph) + proposals_line_graph = init_proposals_line_graph(neurograph) + + # Compute edges + edge_index = branch_to_branch(branches_line_graph, idxs_branches) + edge_index.extend( + proposal_to_proposal(proposals_line_graph, idxs_proposals) + ) + edge_index.extend( + branch_to_proposal(neurograph, idxs_branches, idxs_proposals) + ) + return edge_index + + +def init_proposals_line_graph(neurograph): + proposals_graph = nx.Graph() + proposals_graph.add_edges_from(list(neurograph.proposals.keys())) + return nx.line_graph(proposals_graph) + + +def branch_to_branch(branches_line_graph, idxs_branches): + edge_index = [] + for e1, e2 in branches_line_graph.edges: + v1 = idxs_branches["edge_to_idx"][frozenset(e1)] + v2 = idxs_branches["edge_to_idx"][frozenset(e2)] + edge_index.extend([[v1, v2], [v2, v1]]) + return edge_index + + +def proposal_to_proposal(proposals_line_graph, idxs_proposals): + edge_index = [] + for e1, e2 in proposals_line_graph.edges: + v1 = idxs_proposals["edge_to_idx"][frozenset(e1)] + v2 = idxs_proposals["edge_to_idx"][frozenset(e2)] + edge_index.extend([[v1, v2], [v2, v1]]) + return edge_index + + +def branch_to_proposal(neurograph, idxs_branches, idxs_proposals): + edge_index = [] + for e in neurograph.proposals.keys(): + i, j = tuple(e) + v1 = idxs_proposals["edge_to_idx"][frozenset(e)] + for k in neurograph.neighbors(i): + v2 = idxs_branches["edge_to_idx"][frozenset((i, k))] + edge_index.extend([[v1, v2], [v2, v1]]) + for k in neurograph.neighbors(j): + v2 = idxs_branches["edge_to_idx"][frozenset((j, k))] + edge_index.extend([[v1, v2], [v2, v1]]) + return edge_index diff --git a/src/deep_neurographs/machine_learning/ml_utils.py b/src/deep_neurographs/machine_learning/ml_utils.py index 2f13326..f76d5e5 100644 --- a/src/deep_neurographs/machine_learning/ml_utils.py +++ b/src/deep_neurographs/machine_learning/ml_utils.py @@ -33,6 +33,7 @@ "FeedForwardNet", "ConvNet", "MultiModalNet", + "GraphNeuralNet", ] @@ -153,7 +154,7 @@ def init_dataset( neurographs, features, model_type, block_ids=None, transform=False ): # Extract features - inputs, targets, block_to_idx, idx_to_edge = feature_generation.get_feature_matrix( + inputs, targets, idx_transforms = feature_generation.get_feature_matrix( neurographs, features, model_type, block_ids=block_ids ) lens = [] @@ -163,14 +164,7 @@ def init_dataset( dataset = { "dataset": get_dataset(inputs, targets, model_type, transform, lens), - "block_to_idxs": block_to_idx, - "idx_to_edge": idx_to_edge, + "block_to_idxs": idx_transforms["block_to_idxs"], + "idx_to_edge": idx_transforms["idx_to_edge"], } return dataset - - -def get_lengths(neurograph): - lengths = [] - for edge in neurograph.proposals.keys(): - lengths.append(neurograph.proposal_length(edge)) - return lengths diff --git a/src/deep_neurographs/reconstruction.py b/src/deep_neurographs/reconstruction.py index a2c7c11..c9ff1e3 100644 --- a/src/deep_neurographs/reconstruction.py +++ b/src/deep_neurographs/reconstruction.py @@ -231,12 +231,8 @@ def save_prediction(neurograph, accepted_proposals, output_dir): utils.mkdir(corrections_dir, delete=True) connections_path = os.path.join(output_dir, "connections.txt") - save_prediction( - neurograph, accepted_proposals, output_dir - ) - utils.save_connection( - neurograph, accepted_proposals, connections_path - ) + save_prediction(neurograph, accepted_proposals, output_dir) + utils.save_connection(neurograph, accepted_proposals, connections_path) # Write Result neurograph.to_swc(output_dir)