From 4f04abcb4b883071799c7e3c680e3b33468cc065 Mon Sep 17 00:00:00 2001 From: Anna Grim <108307071+anna-grim@users.noreply.github.com> Date: Sun, 29 Sep 2024 17:14:14 -0700 Subject: [PATCH] Refactor feature generation (#255) * refactor removed old gnn option * refactor: improved heterognn feature generation --------- Co-authored-by: anna-grim --- src/deep_neurographs/inference.py | 98 ++++----- .../machine_learning/datasets.py | 4 +- .../machine_learning/feature_generation.py | 54 +---- .../feature_generation_graphs.py | 10 +- .../machine_learning/heterograph_datasets.py | 8 +- .../machine_learning/heterograph_models.py | 7 +- src/deep_neurographs/neurograph.py | 30 ++- src/deep_neurographs/utils/gnn_util.py | 192 ++++++++---------- src/deep_neurographs/utils/graph_util.py | 26 ++- src/deep_neurographs/utils/ml_util.py | 2 +- 10 files changed, 199 insertions(+), 232 deletions(-) diff --git a/src/deep_neurographs/inference.py b/src/deep_neurographs/inference.py index 6cd3035..de94666 100644 --- a/src/deep_neurographs/inference.py +++ b/src/deep_neurographs/inference.py @@ -12,7 +12,6 @@ from datetime import datetime from time import time from torch.nn.functional import sigmoid -from torch.utils.data import DataLoader from tqdm import tqdm import networkx as nx @@ -66,6 +65,7 @@ def __init__( model_path, output_dir, config, + device=None, ): """ Initializes an object that executes the full GraphTrace inference @@ -88,6 +88,8 @@ def __init__( config : Config Configuration object containing parameters and settings required for the inference pipeline. + device : str, optional + ... Returns ------- @@ -105,6 +107,17 @@ def __init__( self.graph_config = config.graph_config self.ml_config = config.ml_config + # Inference engine + self.inference_engine = InferenceEngine( + self.img_path, + self.model_path, + self.ml_config.model_type, + self.graph_config.search_radius, + confidence_threshold=self.ml_config.threshold, + device=device, + downsample_factor=self.ml_config.downsample_factor, + ) + # Set output directory date = datetime.today().strftime("%Y-%m-%d") self.output_dir = f"{output_dir}/{segmentation_id}-{date}" @@ -127,11 +140,7 @@ def run(self, fragments_pointer): """ # Initializations - print("\nExperiment Details") - print("-----------------------------------------------") - print("Sample_ID:", self.sample_id) - print("Segmentation_ID:", self.segmentation_id) - print("") + self.report_experiment() self.write_metadata() t0 = time() @@ -145,15 +154,8 @@ def run(self, fragments_pointer): print(f"Total Runtime: {round(t, 4)} {unit}\n") def run_schedule(self, fragments_pointer, search_radius_schedule): - # Initializations - print("\nExperiment Details") - print("-----------------------------------------------") - print("Sample_ID:", self.sample_id) - print("Segmentation_ID:", self.segmentation_id) - print("") t0 = time() - - # Main + self.report_experiment() self.build_graph(fragments_pointer) for round_id, search_radius in enumerate(search_radius_schedule): print(f"--- Round {round_id + 1}: Radius = {search_radius} ---") @@ -258,15 +260,7 @@ def run_inference(self): print("(3) Run Inference") t0 = time() n_proposals = self.graph.n_proposals() - inference_engine = InferenceEngine( - self.img_path, - self.model_path, - self.ml_config.model_type, - self.graph_config.search_radius, - confidence_threshold=self.ml_config.threshold, - downsample_factor=self.ml_config.downsample_factor, - ) - self.graph, accepts = inference_engine.run( + self.graph, accepts = self.inference_engine.run( self.graph, self.graph.list_proposals() ) self.accepted_proposals.extend(accepts) @@ -297,6 +291,13 @@ def save_results(self, round_id=None): self.save_connections(round_id=round_id) self.write_metadata() + def report_experiment(self): + print("\nExperiment Overview") + print("-----------------------------------------------") + print("Sample_ID:", self.sample_id) + print("Segmentation_ID:", self.segmentation_id) + print("") + # --- io --- def save_connections(self, round_id=None): """ @@ -390,6 +391,7 @@ def __init__( search_radius, batch_size=BATCH_SIZE, confidence_threshold=CONFIDENCE_THRESHOLD, + device=None, downsample_factor=1, ): """ @@ -424,6 +426,7 @@ def __init__( # Set class attributes self.batch_size = batch_size self.downsample_factor = downsample_factor + self.device = "cpu" if device is None else device self.is_gnn = True if "Graph" in model_type else False self.model_type = model_type self.search_radius = search_radius @@ -433,6 +436,9 @@ def __init__( driver = "n5" if ".n5" in img_path else "zarr" self.img = img_util.open_tensorstore(img_path, driver=driver) self.model = ml_util.load_model(model_path) + if self.is_gnn: + self.model = self.model.to(self.device) + self.model.eval() def run(self, neurograph, proposals): """ @@ -470,7 +476,7 @@ def run(self, neurograph, proposals): # Predict batch = self.get_batch(neurograph, proposals) dataset = self.get_batch_dataset(neurograph, batch) - preds = self.run_model(dataset) + preds = self.predict(dataset) # Update graph batch_accepts = get_accepted_proposals( @@ -547,7 +553,7 @@ def get_batch_dataset(self, neurograph, batch): ) return dataset - def run_model(self, dataset): + def predict(self, dataset): """ Runs the model on the given dataset to generate and filter predictions. @@ -561,47 +567,29 @@ def run_model(self, dataset): ------- dict A dictionary that maps a proposal to the model's prediction (i.e. - probability). Note that this dictionary only contains proposals - whose predicted probability is greater the threshold. + probability). """ # Get predictions if self.model_type == "GraphNeuralNet": - preds = run_gnn_model(dataset.data, self.model) - elif "Net" in self.model_type: - preds = run_nn_model(dataset.data, self.model) + with torch.no_grad(): + # Get inputs + n = len(dataset.data["proposal"]["y"]) + x, edge_index, edge_attr = gnn_util.get_inputs( + dataset.data, device=self.device + ) + + # Run model + preds = sigmoid(self.model(x, edge_index, edge_attr)) + preds = toCPU(preds[0:n, 0]) else: preds = np.array(self.model.predict_proba(dataset.data.x)[:, 1]) - # Filter preds + # Reformat prediction idxs = dataset.idxs_proposals["idx_to_edge"] return {idxs[i]: p for i, p in enumerate(preds)} -# --- run machine learning model --- -def run_nn_model(data, model): - hat_y = list() - model.eval() - with torch.no_grad(): - for batch in DataLoader(data, batch_size=32): - # Run model - hat_y_i = sigmoid(model(batch["inputs"])) - - # Postprocess - hat_y_i = np.array(hat_y_i) - hat_y.extend(hat_y_i[:, 0].tolist()) - return np.array(hat_y) - - -def run_gnn_model(data, model): - model.eval() - with torch.no_grad(): - x, edge_index, edge_attr = gnn_util.get_inputs(data) - hat_y = sigmoid(model(x, edge_index, edge_attr)) - idx = len(data["proposal"]["y"]) - return toCPU(hat_y[0:idx, 0]) - - # --- Accepting Proposals --- def get_accepted_proposals(neurograph, preds, threshold, high_threshold=0.9): """ diff --git a/src/deep_neurographs/machine_learning/datasets.py b/src/deep_neurographs/machine_learning/datasets.py index 9e05a1a..5af5349 100644 --- a/src/deep_neurographs/machine_learning/datasets.py +++ b/src/deep_neurographs/machine_learning/datasets.py @@ -16,7 +16,7 @@ # Wrapper -def init(neurograph, features, model_type, sample_ids=None): +def init(neurograph, features, sample_ids=None): """ Initializes a dataset that can be used to train a machine learning model. @@ -41,7 +41,7 @@ def init(neurograph, features, model_type, sample_ids=None): """ # Extract features x_proposals, y_proposals, idxs_proposals = feature_generation.get_matrix( - neurograph, features["proposals"], model_type, sample_ids=sample_ids + neurograph, features["proposals"], sample_ids=sample_ids ) # Initialize dataset diff --git a/src/deep_neurographs/machine_learning/feature_generation.py b/src/deep_neurographs/machine_learning/feature_generation.py index 48cd297..689188a 100644 --- a/src/deep_neurographs/machine_learning/feature_generation.py +++ b/src/deep_neurographs/machine_learning/feature_generation.py @@ -14,6 +14,7 @@ """ +from collections import defaultdict from concurrent.futures import ThreadPoolExecutor, as_completed from copy import deepcopy from random import sample @@ -29,7 +30,7 @@ CHUNK_SIZE = [64, 64, 64] N_BRANCH_PTS = 50 -N_PROFILE_PTS = 16 +N_PROFILE_PTS = 16 # 10 N_SKEL_FEATURES = 22 @@ -115,55 +116,18 @@ def generate_features( Feature vectors. """ - features = { - "proposals": run_on_proposals( - neurograph, - img, - proposals_dict["proposals"], - radius, - downsample_factor, - ) - } - return features - - -# -- feature generation by graphical structure type -- -def run_on_proposals(neurograph, img, proposals, radius, downsample_factor): - """ - Generates feature vectors for a set of proposals in a neurograph. - - Parameters - ---------- - neurograph : NeuroGraph - Graph that "proposals" belong to. - img : tensorstore.Tensorstore - Image stored in a GCS bucket. - proposals : list[frozenset] - List of proposals for which features will be generated. - radius : float - Search radius used to generate proposals. - downsample_factor : int - Downsampling factor that accounts for which level in the image pyramid - the voxel coordinates must index into. - - Returns - ------- - dict - Dictionary whose keys are feature types (i.e. skeletal and profiles) - and values are a dictionary that maps a proposal id to the - corresponding feature vector. - - """ - proposal_features = { - "skel": proposal_skeletal(neurograph, proposals, radius), + features = defaultdict(bool) + features["proposals"] = { + "skel": proposal_skeletal( + neurograph, proposals_dict["proposals"], radius + ), "profiles": proposal_profiles( - neurograph, img, proposals, downsample_factor + neurograph, img, proposals_dict["proposals"], downsample_factor ), } - return proposal_features + return features -# -- part 1: proposal feature generation -- def proposal_profiles(neurograph, img, proposals, downsample_factor): """ Generates an image intensity profile along each proposal by reading from diff --git a/src/deep_neurographs/machine_learning/feature_generation_graphs.py b/src/deep_neurographs/machine_learning/feature_generation_graphs.py index bb3bf86..da3f3b5 100644 --- a/src/deep_neurographs/machine_learning/feature_generation_graphs.py +++ b/src/deep_neurographs/machine_learning/feature_generation_graphs.py @@ -195,12 +195,11 @@ def edge_skeletal(neurograph, computation_graph): """ edge_skeletal_features = dict() for edge in neurograph.edges: - edge_skeletal_features[frozenset(edge)] = np.concatenate( - ( + edge_skeletal_features[frozenset(edge)] = np.array( + [ np.mean(neurograph.edges[edge]["radius"]), - neurograph.edge_length(edge) / 1000, - ), - axis=None, + neurograph.edges[edge]["length"] / 1000, + ], ) return edge_skeletal_features @@ -226,7 +225,6 @@ def proposal_skeletal(neurograph, proposals, radius): """ proposal_skeletal_features = dict() for proposal in proposals: - i, j = tuple(proposal) proposal_skeletal_features[proposal] = np.concatenate( ( neurograph.proposal_length(proposal), diff --git a/src/deep_neurographs/machine_learning/heterograph_datasets.py b/src/deep_neurographs/machine_learning/heterograph_datasets.py index 8d99c31..a18994b 100644 --- a/src/deep_neurographs/machine_learning/heterograph_datasets.py +++ b/src/deep_neurographs/machine_learning/heterograph_datasets.py @@ -206,7 +206,7 @@ def check_missing_edge_type(self): # Update edge_index n = self.data["branch"]["x"].size(0) edge_index = [[n - 1, n - 2], [n - 2, n - 1]] - self.data[edge_type].edge_index = gnn_util.to_tensor(edge_index) + self.data[edge_type].edge_index = gnn_util.toTensor(edge_index) self.idxs_branches["idx_to_edge"][n - 1] = frozenset({-1, -2}) self.idxs_branches["idx_to_edge"][n - 2] = frozenset({-2, -3}) @@ -282,7 +282,7 @@ def proposal_to_proposal(self): v1 = self.idxs_proposals["edge_to_idx"][frozenset(e1)] v2 = self.idxs_proposals["edge_to_idx"][frozenset(e2)] edge_index.extend([[v1, v2], [v2, v1]]) - return gnn_util.to_tensor(edge_index) + return gnn_util.toTensor(edge_index) def branch_to_branch(self): """ @@ -308,7 +308,7 @@ def branch_to_branch(self): v1 = self.idxs_branches["edge_to_idx"][frozenset(e1)] v2 = self.idxs_branches["edge_to_idx"][frozenset(e2)] edge_index.extend([[v1, v2], [v2, v1]]) - return gnn_util.to_tensor(edge_index) + return gnn_util.toTensor(edge_index) def branch_to_proposal(self): """ @@ -338,7 +338,7 @@ def branch_to_proposal(self): if frozenset((j, k)) not in self.proposals: v2 = self.idxs_branches["edge_to_idx"][frozenset((j, k))] edge_index.extend([[v2, v1]]) - return gnn_util.to_tensor(edge_index) + return gnn_util.toTensor(edge_index) # Set Edge Attributes def set_edge_attrs(self, x_nodes, edge_type, idx_map): diff --git a/src/deep_neurographs/machine_learning/heterograph_models.py b/src/deep_neurographs/machine_learning/heterograph_models.py index fd9794e..7047af7 100644 --- a/src/deep_neurographs/machine_learning/heterograph_models.py +++ b/src/deep_neurographs/machine_learning/heterograph_models.py @@ -32,6 +32,7 @@ class HeteroGNN(torch.nn.Module): def __init__( self, + device=None, scale_hidden_dim=2, dropout=DROPOUT, heads_1=HEADS_1, @@ -50,12 +51,12 @@ def __init__( # Linear layers output_dim = heads_1 * heads_2 * hidden_dim self.input_nodes = nn.ModuleDict( - {key: nn.Linear(d, hidden_dim) for key, d in node_dict.items()} + {key: nn.Linear(d, hidden_dim, device=device) for key, d in node_dict.items()} ) self.input_edges = { - key: nn.Linear(d, hidden_dim) for key, d in edge_dict.items() + key: nn.Linear(d, hidden_dim, device=device) for key, d in edge_dict.items() } - self.output = Linear(output_dim, 1) + self.output = Linear(output_dim, 1, device=device) # Convolutional layers self.conv1 = HeteroConv( diff --git a/src/deep_neurographs/neurograph.py b/src/deep_neurographs/neurograph.py index d01cc3c..7050e67 100644 --- a/src/deep_neurographs/neurograph.py +++ b/src/deep_neurographs/neurograph.py @@ -136,7 +136,8 @@ def add_component(self, irreducibles): for (i, j), attrs in irreducibles["edges"].items(): edge = (ids[i], ids[j]) idxs = util.spaced_idxs(attrs["radius"], self.node_spacing) - attrs = {key: value[idxs] for key, value in attrs.items()} + for key in ["radius", "xyz"]: + attrs[key] = attrs[key][idxs] self.__add_edge(edge, attrs, swc_id) def __add_nodes(self, irreducibles, node_type, node_ids): @@ -198,7 +199,12 @@ def __add_edge(self, edge, attrs, swc_id): """ i, j = tuple(edge) self.add_edge( - i, j, radius=attrs["radius"], xyz=attrs["xyz"], swc_id=swc_id, + i, + j, + length=attrs["length"], + radius=attrs["radius"], + xyz=attrs["xyz"], + swc_id=swc_id, ) self.xyz_to_edge.update({tuple(xyz): edge for xyz in attrs["xyz"]}) @@ -642,12 +648,28 @@ def proposal_directionals(self, proposal, window): def merge_proposal(self, proposal): i, j = tuple(proposal) - if not (self.is_soma(i) and self.is_soma(j)): - # Attributes + somas_check = not (self.is_soma(i) and self.is_soma(j)) + degrees_check = self.degree[i] == 2 and self.degree[j] == 2 + if somas_check and degrees_check: + # Dense attributes attrs = dict() for k in ["xyz", "radius"]: combine = np.vstack if k == "xyz" else np.array attrs[k] = combine([self.nodes[i][k], self.nodes[j][k]]) + + # Sparse attributes + if self.degree[i] == 1 and self.degree[j] == 1: + e_i = (i, self.leaf_neighbor(i)) + e_j = (j, self.leaf_neighbor(j)) + len_ij = self.edges[e_i]["length"] + self.edges[e_j]["length"] + attrs["length"] = len_ij + elif self.degree[i] == 2: + e_j = (j, self.leaf_neighbor(j)) + attrs["length"] = self.edges[e_i]["length"] + else: + e_i = (i, self.leaf_neighbor(i)) + attrs["length"] = self.edges[e_j]["length"] + swc_id_i = self.nodes[i]["swc_id"] swc_id_j = self.nodes[j]["swc_id"] swc_id = swc_id_i if self.is_soma(i) else swc_id_j diff --git a/src/deep_neurographs/utils/gnn_util.py b/src/deep_neurographs/utils/gnn_util.py index c89d006..7e845c5 100644 --- a/src/deep_neurographs/utils/gnn_util.py +++ b/src/deep_neurographs/utils/gnn_util.py @@ -17,74 +17,83 @@ from deep_neurographs.utils import graph_util as gutil from deep_neurographs.utils import util +GNN_DEPTH = 2 -def get_inputs(data): + +# --- Tensor Operations --- +def get_inputs(data, device=None): + # Extract inputs x = data.x_dict edge_index = data.edge_index_dict - edge_attr_dict = data.edge_attr_dict - return x, edge_index, edge_attr_dict + edge_attr = data.edge_attr_dict + # Move to gpu (if applicable) + if "cuda" in device and torch.cuda.is_available(): + x = toGPU(x, device) + edge_index = toGPU(edge_index, device) + edge_attr = toGPU(edge_attr, device) + return x, edge_index, edge_attr -def toCPU(tensor): + +def toGPU(tensor_dict, device): """ - Moves tensor from GPU to CPU. + Moves dictionary of tensors from CPU to GPU. Parameters ---------- - tensor : torch.Tensor - Tensor. + tensor_dict : dict + Tensor to be moved to GPU. Returns ------- None """ - return tensor.detach().cpu().tolist() + return {k: tensor.to("cuda") for k, tensor in tensor_dict.items()} -def to_tensor(my_list): +def toCPU(tensor): """ - Converts a list to a tensor with contiguous memory. + Moves tensor from GPU to CPU. Parameters ---------- - my_list : list - List to be converted into a tensor. + tensor : torch.Tensor + Tensor to be moved to CPU. Returns ------- - torch.Tensor - Tensor. + None """ - arr = np.array(my_list, dtype=np.int64).tolist() - return torch.Tensor(arr).t().contiguous().long() + return tensor.detach().cpu().tolist() -def init_line_graph(edges): +def toTensor(my_list): """ - Initializes a line graph from a list of edges. + Converts a list to a tensor with contiguous memory. Parameters ---------- - edges : list - List of edges. + my_list : list + List to be converted into a tensor. Returns ------- - networkx.Graph - Line graph generated from a list of edges. + torch.Tensor + Tensor. """ - graph = nx.Graph() - graph.add_edges_from(edges) - return nx.line_graph(graph) + arr = np.array(my_list, dtype=np.int64).tolist() + return torch.Tensor(arr).t().contiguous().long() +# --- Batch Generation --- def get_batch(graph, proposals, batch_size): """ Gets a batch for training or inference that consist of a computation graph - and list of proposals. + and list of proposals. Note: queue contains tuples that consist of a node + id and distance from proposal. Parameters ---------- @@ -93,7 +102,7 @@ def get_batch(graph, proposals, batch_size): proposals : list Proposals to be classified as accept or reject. batch_size : int - Maximum number of nodes in the computation graph. + Maximum number of proposals in the computation graph. Returns ------- @@ -106,10 +115,10 @@ def get_batch(graph, proposals, batch_size): visited = set() while len(proposals) > 0 and len(batch["proposals"]) < batch_size: root = tuple(util.sample_once(proposals)) - queue = [root[0], root[1]] + queue = [(root[0], 0), (root[1], 0)] while len(queue) > 0: # Visit node - i = queue.pop() + i, d = queue.pop() for j in graph.neighbors(i): if (i, j) not in batch["graph"].edges: batch["graph"].add_edge(i, j) @@ -119,61 +128,18 @@ def get_batch(graph, proposals, batch_size): batch["graph"].add_edge(i, p) batch["proposals"].add(frozenset({i, p})) proposals.remove(frozenset({i, p})) - queue.append(p) + queue.append((p, 0)) visited.add(i) # Update queue if len(batch["proposals"]) < batch_size: - for j in graph.neighbors(i): - if validate_node_for_queue(graph, proposals, visited, j): - queue.append(j) + for j in [j for j in graph.neighbors(i) if j not in visited]: + d_j = min(d + 1, -len(graph.nodes[j]["proposals"])) + if d_j <= GNN_DEPTH: + queue.append((j, d + 1)) return batch -def validate_node_for_queue(graph, proposals, visited, j): - """ - Check whether node is within 3 hops from a proposal to be classified. If - so, then the node should be added to the queue in the routine "get_batch". - - Parameters - ---------- - graph : NeuroGraph - Graph that contains proposals - proposals : list - Proposals to be classified as accept or reject. - visited : set - Nodes that have been visited and added the current batch's computation - graph. - j : int - Node to be validated for being added to queue. - - Returns - ------- - bool - Indication of whether node should be added to queue. - - """ - hit_proposal = False - validate_queue = [(j, 0)] # node id, distance from j - validate_visited = set() - while len(validate_queue) > 0: - # Check if node has proposals to be classified - k, d = validate_queue.pop() - for p in graph.nodes[k]["proposals"]: - if frozenset({p, k}) in proposals: - hit_proposal = True - validate_queue = list() - break - validate_visited.add(k) - - # Update queue - if d < 3 and not hit_proposal: - for kk in graph.neighbors(k): - if kk not in visited and kk not in validate_visited: - validate_queue.append((kk, d + 1)) - return True if hit_proposal else False - - def get_train_batch(graph, proposals, batch_size): """ Gets a batch for training or inference that consist of a computation graph @@ -239,31 +205,6 @@ def get_train_batch(graph, proposals, batch_size): yield batch -def reset_batch(): - """ - Resets the current batch. - - Parameters - ---------- - None - - Returns - ------- - dict - Batch that consists of a graph and list of proposals. - - """ - return {"graph": nx.Graph(), "proposals": set()} - - -def get_node_proposal_cnt(proposals): - node_proposal_cnt = defaultdict(int) - for i, j in proposals: - node_proposal_cnt[i] += 1 - node_proposal_cnt[j] += 1 - return node_proposal_cnt - - def extract_subgraph_batch( graph, proposals, batch, batch_size, node_proposal_cnt ): @@ -293,12 +234,29 @@ def extract_subgraph_batch( if len(batch["proposals"]) >= batch_size: break - print("# proposals added in extract_subgraph:", n_proposals_added) # Yield batch graph.remove_nodes_from(remove_nodes) return batch +def reset_batch(): + """ + Resets the current batch. + + Parameters + ---------- + None + + Returns + ------- + dict + Batch that consists of a graph and list of proposals. + + """ + return {"graph": nx.Graph(), "proposals": set()} + + +# --- Miscellaneous --- def proposals_in_graph(graph, proposals): """ Lists the proposals that are edges in "graph". @@ -316,3 +274,31 @@ def proposals_in_graph(graph, proposals): Proposals that are edges in "graph". """ return set([e for e in map(frozenset, graph.edges) if e in proposals]) + + +def get_node_proposal_cnt(proposals): + node_proposal_cnt = defaultdict(int) + for i, j in proposals: + node_proposal_cnt[i] += 1 + node_proposal_cnt[j] += 1 + return node_proposal_cnt + + +def init_line_graph(edges): + """ + Initializes a line graph from a list of edges. + + Parameters + ---------- + edges : list + List of edges. + + Returns + ------- + networkx.Graph + Line graph generated from a list of edges. + + """ + graph = nx.Graph() + graph.add_edges_from(edges) + return nx.line_graph(graph) diff --git a/src/deep_neurographs/utils/graph_util.py b/src/deep_neurographs/utils/graph_util.py index fa2b8ce..70cf1b9 100644 --- a/src/deep_neurographs/utils/graph_util.py +++ b/src/deep_neurographs/utils/graph_util.py @@ -365,25 +365,32 @@ def get_subcomponent_irreducibles(graph, swc_dict, smooth_bool): """ # Extract nodes leafs, junctions = get_irreducible_nodes(graph) - assert len(leafs), "No leaf nodes!" + assert len(leafs) > 0, "No leaf nodes!" if len(leafs) > 0: - source = sample(leafs, 1)[0] + source = util.sample_once(leafs) else: - source = sample(junctions, 1)[0] + source = util.sample_once(junctions) # Extract edges edges = dict() nbs = defaultdict(list) root = None for (i, j) in nx.dfs_edges(graph, source=source): - # Check if start of path is valid + # Check if starting new or continuing current path if root is None: root = i + cur_length = 0 attrs = init_edge_attrs(swc_dict, root) + else: + xyz_i = swc_dict["xyz"][swc_dict["idx"][i]] + xyz_j = swc_dict["xyz"][swc_dict["idx"][j]] + cur_length += geometry.dist(xyz_i, xyz_j) # Visit j attrs = upd_edge_attrs(swc_dict, attrs, j) if j in leafs or j in junctions: + # Check whether to smooth + attrs["length"] = cur_length attrs = to_numpy(attrs) if smooth_bool: swc_dict, edges = __smooth_branch( @@ -391,8 +398,10 @@ def get_subcomponent_irreducibles(graph, swc_dict, smooth_bool): ) else: edges[(root, j)] = attrs - nbs[root].append(j) # = util.append_dict_value(nbs, root, j) - nbs[j].append(root) # = util.append_dict_value(nbs, j, root) + + # Finish + nbs[root].append(j) + nbs[j].append(root) root = None # Output @@ -638,7 +647,8 @@ def upd_endpoint_xyz(edges, key, old_xyz, new_xyz): def init_edge_attrs(swc_dict, i): """ Initializes edge attribute dictionary with attributes from node "i" which - is an end point of the edge. + is an end point of the edge. Note: the following assertion error may be + useful: assert i in swc_dict["idx"].keys(), f"{swc_dict["swc_id"]} - {i}" Parameters ---------- @@ -654,8 +664,6 @@ def init_edge_attrs(swc_dict, i): Edge attribute dictionary. """ - swc_id = swc_dict["swc_id"] - assert i in swc_dict["idx"].keys(), f"{swc_id} - {i}" j = swc_dict["idx"][i] return {"radius": [swc_dict["radius"][j]], "xyz": [swc_dict["xyz"][j]]} diff --git a/src/deep_neurographs/utils/ml_util.py b/src/deep_neurographs/utils/ml_util.py index e0aedd0..90ebb71 100644 --- a/src/deep_neurographs/utils/ml_util.py +++ b/src/deep_neurographs/utils/ml_util.py @@ -104,7 +104,7 @@ def init_dataset( ) else: dataset = datasets.init( - neurograph, features, model_type, sample_ids=sample_ids + neurograph, features, sample_ids=sample_ids ) return dataset