From 21cad5bfc7a00515a5371126f79c57d4bc303fe2 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Tue, 12 Dec 2023 22:48:55 +0000 Subject: [PATCH] add feature : optimize edge alignment --- src/deep_neurographs/evaluation.py | 2 +- src/deep_neurographs/feature_extraction.py | 11 +- src/deep_neurographs/geometry_utils.py | 249 +++++++++++++++++++-- src/deep_neurographs/intake.py | 16 +- src/deep_neurographs/neurograph.py | 93 +++----- src/deep_neurographs/swc_utils.py | 13 +- src/deep_neurographs/utils.py | 14 +- src/deep_neurographs/visualization.py | 65 ++++-- 8 files changed, 341 insertions(+), 122 deletions(-) diff --git a/src/deep_neurographs/evaluation.py b/src/deep_neurographs/evaluation.py index 44c78a6..04b49e5 100644 --- a/src/deep_neurographs/evaluation.py +++ b/src/deep_neurographs/evaluation.py @@ -56,7 +56,7 @@ def run_evaluation(neurographs, blocks, pred_edges): overall_stats_i = get_stats( neurographs[block_id], neurographs[block_id].mutable_edges, - pred_edges[block_id] + pred_edges[block_id], ) simple_stats_i = get_stats( diff --git a/src/deep_neurographs/feature_extraction.py b/src/deep_neurographs/feature_extraction.py index 4ceffd4..d2b5834 100644 --- a/src/deep_neurographs/feature_extraction.py +++ b/src/deep_neurographs/feature_extraction.py @@ -17,7 +17,7 @@ CHUNK_SIZE = [64, 64, 64] HALF_CHUNK_SIZE = [CHUNK_SIZE[i] // 2 for i in range(3)] -WINDOW_SIZE = [5, 5, 5] +WINDOW = [5, 5, 5] NUM_POINTS = 10 NUM_IMG_FEATURES = NUM_POINTS @@ -85,7 +85,7 @@ def generate_img_chunk_features( labels_chunk = utils.get_chunk(labels, midpoint, CHUNK_SIZE) # Mark path - if neurograph.optimize_proposals: + if neurograph.optimize_alignment: xyz_list = to_patch_coords(neurograph, edge, midpoint) path = geometry_utils.sample_path(xyz_list, NUM_POINTS) else: @@ -120,9 +120,8 @@ def generate_img_profile_features( path, "zarr", origin, neurograph.shape, from_center=False ) img = utils.normalize_img(img) - simple_edges = neurograph.get_simple_proposals() for edge in neurograph.mutable_edges: - if neurograph.optimize_proposals and edge in simple_edges: + if neurograph.optimize_alignment: xyz = to_img_coords(neurograph, edge) path = geometry_utils.sample_path(xyz, NUM_POINTS) else: @@ -130,9 +129,7 @@ def generate_img_profile_features( xyz_i = utils.world_to_img(neurograph, i) xyz_j = utils.world_to_img(neurograph, j) path = geometry_utils.make_line(xyz_i, xyz_j, NUM_POINTS) - features[edge] = geometry_utils.get_profile( - img, path, window_size=WINDOW_SIZE - ) + features[edge] = geometry_utils.get_profile(img, path, window=WINDOW) return features diff --git a/src/deep_neurographs/geometry_utils.py b/src/deep_neurographs/geometry_utils.py index a2dd7d9..c0a4547 100644 --- a/src/deep_neurographs/geometry_utils.py +++ b/src/deep_neurographs/geometry_utils.py @@ -10,7 +10,6 @@ # Directional Vectors def get_directional(neurograph, i, proposal_tangent, window=5): - # Compute principle axes directionals = [] d = neurograph.optimize_depth for branch in neurograph.get_branches(i): @@ -20,7 +19,6 @@ def get_directional(neurograph, i, proposal_tangent, window=5): xyz = deepcopy(branch) else: xyz = deepcopy(branch[d : window + d, :]) - # print(xyz) directionals.append(compute_tangent(xyz)) # Determine best @@ -72,9 +70,9 @@ def smooth_branch(xyz, s=None): def fit_spline(xyz, s=None): s = xyz.shape[0] / 5 if not s else xyz.shape[0] / s t = np.linspace(0, 1, xyz.shape[0]) - spline_x = UnivariateSpline(t, xyz[:, 0], s=s, k=1) - spline_y = UnivariateSpline(t, xyz[:, 1], s=s, k=1) - spline_z = UnivariateSpline(t, xyz[:, 2], s=s, k=1) + spline_x = UnivariateSpline(t, xyz[:, 0], s=s, k=3) + spline_y = UnivariateSpline(t, xyz[:, 1], s=s, k=3) + spline_z = UnivariateSpline(t, xyz[:, 2], s=s, k=3) return spline_x, spline_y, spline_z @@ -86,8 +84,8 @@ def sample_path(path, num_points): # Image feature extraction -def get_profile(img, xyz_arr, window_size=[5, 5, 5]): - return [np.max(utils.get_chunk(img, xyz, window_size)) for xyz in xyz_arr] +def get_profile(img, xyz_arr, window=[5, 5, 5]): + return [np.max(utils.get_chunk(img, xyz, window)) for xyz in xyz_arr] def fill_path(img, path, val=-1): @@ -98,9 +96,224 @@ def fill_path(img, path, val=-1): return img -# Miscellaneous +# Proposal optimization +def optimize_alignment(neurograph, img, edge, depth=15): + """ + Optimizes alignment of edge proposal between two branches by finding + straight path with the brightest averaged image profile. + + Parameters + ---------- + neurograph : NeuroGraph + Predicted neuron reconstruction to be corrected. + img : numpy.ndarray + Image chunk that the reconstruction is contained in. + edge : frozenset + Edge proposal to be aligned. + depth : int, optional + Maximum depth checked during alignment optimization. The default value + is 15. + + Returns + ------- + numpy.ndarray, numpy.ndarray + xyz coordinates of aligned edge proposal. + + """ + if neurograph.is_simple(edge): + return optimize_simple_alignment(neurograph, img, edge, depth=depth) + else: + return optimize_complex_alignment(neurograph, img, edge, depth=depth) + + +def optimize_simple_alignment(neurograph, img, edge, depth=15): + """ + Optimizes alignment of edge proposal for simple edges. + + Parameters + ---------- + neurograph : NeuroGraph + Predicted neuron reconstruction to be corrected. + img : numpy.ndarray + Image chunk that the reconstruction is contained in. + edge : frozenset + Edge proposal to be aligned. + depth : int, optional + Maximum depth checked during alignment optimization. The default value + is 15. + + Returns + ------- + numpy.ndarray, numpy.ndarray + xyz coordinates of aligned edge proposal. + + """ + i, j = tuple(edge) + branch_i = neurograph.get_branch(i) + branch_j = neurograph.get_branch(j) + xyz_i, xyz_j, _ = align(neurograph, img, branch_i, branch_j, depth) + return xyz_i, xyz_j + + +def optimize_complex_alignment(neurograph, img, edge, depth=15): + """ + Optimizes alignment of edge proposal for complex edges. + + Parameters + ---------- + neurograph : NeuroGraph + Predicted neuron reconstruction to be corrected. + img : numpy.ndarray + Image chunk that the reconstruction is contained in. + edge : frozenset + Edge proposal to be aligned. + depth : int, optional + Maximum depth checked during alignment optimization. The default value + is 15. + + Returns + ------- + numpy.ndarray, numpy.ndarray + xyz coordinates of aligned edge proposal. + + """ + i, j = tuple(edge) + branch = neurograph.get_branch(i if neurograph.is_leaf(i) else j) + branches = neurograph.get_branches(j if neurograph.is_leaf(i) else i) + xyz_1, leaf_1, val_1 = align(neurograph, img, branch, branches[0], depth) + xyz_2, leaf_2, val_2 = align(neurograph, img, branch, branches[1], depth) + return (xyz_1, leaf_1) if val_1 > val_2 else (xyz_2, leaf_2) + + +def align(neurograph, img, branch_1, branch_2, depth): + """ + Finds straight line path between end points of "branch_1" and "branch_2" + that best captures the image signal. This path is determined by checking + the average image intensity of the line drawn from "branch_1[d_1]" and + "branch_2[d_2]" with d_1, d_2 in [0, depth]. + + Parameters + ---------- + neurograph : NeuroGraph + Predicted neuron reconstruction to be corrected. + img : numpy.ndarray + Image chunk that the reconstruction is contained in. + branch_1 : np.ndarray + Branch corresponding to some predicted neuron. This branch must be + oriented so that the end points being considered are the coordinates + in rows 0 through "depth". + branch_2 : np.ndarray + Branch corresponding to some predicted neuron. This branch must be + oriented so that the end points being considered are the coordinates + in rows 0 through "depth". + depth : int + Maximum depth of branch that is optimized over. + + Returns + ------- + best_xyz_1 : np.ndarray + Optimal xyz coordinate from "branch_1". + best_xyz_2 : np.ndarray + Optimal xyz coordinate from "branch_2". + best_score : float + Average brightness of voxels sampled along line between "best_xyz_1" + and "best_xyz_2". + + """ + best_xyz_1 = None + best_xyz_2 = None + best_score = 0 + for d_1 in range(min(depth, len(branch_1) - 1)): + xyz_1 = neurograph.to_img(branch_1[d_1]) + for d_2 in range(min(depth, len(branch_2) - 1)): + xyz_2 = neurograph.to_img(branch_2[d_2]) + line = make_line(xyz_1, xyz_2, 10) + score = np.mean(get_profile(img, line, window=[3, 3, 3])) + if score > best_score: + best_score = score + best_xyz_1 = deepcopy(xyz_1) + best_xyz_2 = deepcopy(xyz_2) + return best_xyz_1, best_xyz_2, best_score + + +def optimize_path(img, origin, xyz_1, xyz_2): + """ + Finds optimal path between "xyz_1" and "xyz_2" that best captures the + image signal. The path is determined by finding the shortest path these + points with respect the cost function f(xyz) = 1 / img[xyz]. + + Parameters + ---------- + img : np.ndarray + Image chunk that contains "start" and "end". The centroid of this img + is "origin". + origin : np.ndarray + The xyz-coordinate (in world coordinates) of "img". + xyz_1 : np.ndarray + The xyz-coordinate (in image coordinates) of the start point of the + path. + xyz_2 : np.ndarray + The xyz-coordinate (in image coordinates) of the end point of the + path. + + Returns + ------- + list[tuple[float]] + Optimal path between "xyz_1" and "xyz_2". + + """ + patch_dims = get_optimal_patch(xyz_1, xyz_2, buffer=5) + center = get_midpoint(xyz_1, xyz_2).astype(int) + img_chunk = utils.get_chunk(img, center, patch_dims) + path = shortest_path( + img_chunk, + utils.img_to_patch(xyz_1, center, patch_dims), + utils.img_to_patch(xyz_2, center, patch_dims), + ) + return transform_path(path, origin, center, patch_dims) + + def shortest_path(img, start, end): + """ + Finds shortest path between "start" and "end" with respect to the image + intensity values. + + Parameters + ---------- + img : np.ndarray + Image chunk that "start" and "end" are contained within and domain of + the shortest path. + start : np.ndarray + Start point of path. + end : np.ndarray + End point of path. + + Returns + ------- + list[tuple] + Shortest path between "start" and "end". + + """ + def is_valid_move(x, y, z): + """ + Determines whether (x, y, z) coordinate is contained in image. + + Parameters + ---------- + x : int + X-coordinate. + y : int + Y-coordinate. + z : int + Z-coordinate. + + Returns + ------- + bool + Indication of whether coordinate is contained in image. + + """ return ( 0 <= x < shape[0] and 0 <= y < shape[1] @@ -170,28 +383,28 @@ def get_optimal_patch(xyz_1, xyz_2, buffer=8): return [int(abs(xyz_1[i] - xyz_2[i])) + buffer for i in range(3)] -def compare_edges(xyx_i, xyz_j, xyz_k): - dist_ij = dist(xyx_i, xyz_j) - dist_ik = dist(xyx_i, xyz_k) - return dist_ij < dist_ik - - -def dist(x, y, metric="l2"): +# Miscellaneous +def dist(v_1, v_2, metric="l2"): """ - Computes distance between "x" and "y". + Computes distance between "v_1" and "v_2". Parameters ---------- + v_1 : np.ndarray + Vector. + v_2 : np.ndarray + Vector. Returns ------- float + Distance between "v_1" and "v_2". """ if metric == "l1": - return np.linalg.norm(np.subtract(x, y), ord=1) + return np.linalg.norm(np.subtract(v_1, v_2), ord=1) else: - return np.linalg.norm(np.subtract(x, y), ord=2) + return np.linalg.norm(np.subtract(v_1, v_2), ord=2) def make_line(xyz_1, xyz_2, num_steps): diff --git a/src/deep_neurographs/intake.py b/src/deep_neurographs/intake.py index 0ca86ea..b49e627 100644 --- a/src/deep_neurographs/intake.py +++ b/src/deep_neurographs/intake.py @@ -24,7 +24,9 @@ def build_neurograph( search_radius=25.0, prune=True, prune_depth=16, - optimize_proposals=False, + optimize_depth=15, + optimize_alignment=True, + optimize_path=False, origin=None, shape=None, smooth=True, @@ -38,7 +40,9 @@ def build_neurograph( neurograph = NeuroGraph( swc_dir, img_path=img_path, - optimize_proposals=optimize_proposals, + optimize_depth=optimize_depth, + optimize_alignment=optimize_alignment, + optimize_path=optimize_path, origin=origin, shape=shape, ) @@ -71,8 +75,12 @@ def init_immutables( for path in get_paths(neurograph.path): swc_id = get_id(path) - raw_swc = swc_utils.read_swc(path) - swc_dict = swc_utils.parse(raw_swc, anisotropy=anisotropy) + swc_dict = swc_utils.parse( + swc_utils.read_swc(path), + anisotropy=anisotropy, + bbox=neurograph.bbox, + img_shape=neurograph.shape, + ) if len(swc_dict["xyz"]) < size_threshold: continue if smooth: diff --git a/src/deep_neurographs/neurograph.py b/src/deep_neurographs/neurograph.py index 8c8b6db..7866636 100644 --- a/src/deep_neurographs/neurograph.py +++ b/src/deep_neurographs/neurograph.py @@ -9,11 +9,9 @@ """ from copy import deepcopy -from time import time import networkx as nx import numpy as np -import plotly.graph_objects as go import tensorstore as ts from scipy.spatial import KDTree @@ -39,8 +37,9 @@ def __init__( swc_path, img_path=None, label_mask=None, - optimize_depth=8, - optimize_proposals=False, + optimize_depth=10, + optimize_alignment=False, + optimize_path=False, origin=None, shape=None, ): @@ -67,19 +66,19 @@ def __init__( self.img_path = img_path self.optimize_depth = optimize_depth - self.optimize_proposals = optimize_proposals + self.optimize_alignment = optimize_alignment + self.optimize_path = optimize_path self.simple_proposals = set() self.complex_proposals = set() + self.bbox = None + self.shape = shape if origin and shape: self.bbox = { "min": np.array(origin), "max": np.array([origin[i] + shape[i] for i in range(3)]), } self.origin = np.array(origin) - self.shape = shape - else: - self.bbox = None def init_immutable_graph(self, add_attrs=False): immutable_graph = nx.Graph() @@ -94,7 +93,7 @@ def init_immutable_graph(self, add_attrs=False): def init_predicted_graph(self): self.predicted_graph = self.init_immutable_graph() - + def init_densegraph(self): self.densegraph = DenseGraph(self.path) @@ -203,7 +202,7 @@ def generate_proposals(self, num_proposals=3, search_radius=25.0): self.add_edge(leaf, node, xyz=np.array([xyz_leaf, xyz])) self.mutable_edges.add(frozenset((leaf, node))) - if self.optimize_proposals: + if self.optimize_alignment or self.optimize_path: self.run_optimization() def _get_proposals( @@ -324,62 +323,28 @@ def run_optimization(self): img = utils.get_superchunk( self.img_path, "zarr", origin, self.shape, from_center=False ) - simple_edges = self.get_simple_proposals() for edge in self.mutable_edges: - if edge in simple_edges: - self.optimize_simple_edge(img, edge) - else: - self.optimize_complex_edge(img, edge) - - def optimize_simple_edge(self, img, edge): - # Extract Branches - i, j = tuple(edge) - branch_i = self.get_branch(self.nodes[i]["xyz"]) - branch_j = self.get_branch(self.nodes[j]["xyz"]) - depth = self.optimize_depth - - # Get image patch - idx_i = min(depth, branch_i.shape[0] - 1) - idx_j = min(depth, branch_j.shape[0] - 1) - hat_xyz_i = self.to_img(branch_i[idx_i]) - hat_xyz_j = self.to_img(branch_j[idx_j]) - patch_dims = geometry_utils.get_optimal_patch(hat_xyz_i, hat_xyz_j) - center = geometry_utils.get_midpoint(hat_xyz_i, hat_xyz_j).astype(int) - img_chunk = utils.get_chunk(img, center, patch_dims) - - # Optimize - if (np.array(hat_xyz_i) < 0).any() or (np.array(hat_xyz_j) < 0).any(): - return False - path = geometry_utils.shortest_path( - img_chunk, - utils.img_to_patch(hat_xyz_i, center, patch_dims), - utils.img_to_patch(hat_xyz_j, center, patch_dims), - ) - origin = utils.apply_anisotropy(self.origin, return_int=True) - path = geometry_utils.transform_path(path, origin, center, patch_dims) - self.edges[edge]["xyz"] = np.vstack( - [branch_i[idx_i], path, branch_j[idx_j]] - ) + xyz_1, xyz_2 = geometry_utils.optimize_alignment(self, img, edge) + proposal = [self.to_world(xyz_1)] + if self.optimize_path: + path = geometry_utils.optimize_path( + img, self.origin, xyz_1, xyz_2 + ) + proposal.append(path) + proposal.append(self.to_world(xyz_2)) + self.edges[edge]["xyz"] = np.vstack(proposal) - def optimize_complex_edge(self, img, edge): - # Extract Branches - i, j = tuple(edge) - leaf = i if self.immutable_degree(i) == 1 else j - i = j if leaf == i else i - branches = self.get_branches(i) - depth = self.optimize_depth - - # Search for best anchor - #if len(branches) == 2: - - def get_branch(self, xyz_or_node): - if type(xyz_or_node) is int: + if type(xyz_or_node) == int: nb = self.get_immutable_nbs(xyz_or_node)[0] return self.orient_edge((xyz_or_node, nb), xyz_or_node) else: edge = self.xyz_to_edge[tuple(xyz_or_node)] - return deepcopy(self.edges[edge]["xyz"]) + branch = deepcopy(self.edges[edge]["xyz"]) + if not (branch[0] == xyz_or_node).all(): + return np.flip(branch, axis=0) + else: + return branch def get_branches(self, i): branches = [] @@ -490,10 +455,7 @@ def get_reconstruction(self, proposals, upd_self=False): r_i = self.nodes[i]["radius"] r_j = self.nodes[j]["radius"] reconstruction.add_edge( - i, - j, - xyz=self.edges[i, j]["xyz"], - radius=[r_i, r_j], + i, j, xyz=self.edges[i, j]["xyz"], radius=[r_i, r_j] ) return reconstruction @@ -647,6 +609,11 @@ def to_img(self, node_or_xyz): node_or_xyz = deepcopy(self.nodes[node_or_xyz]["xyz"]) return utils.to_img(node_or_xyz, shift=self.origin) + def to_world(self, node_or_xyz, shift=[0, 0, 0]): + if type(node_or_xyz) == int: + node_or_xyz = deepcopy(self.nodes[node_or_xyz]["xyz"]) + return utils.to_world(node_or_xyz, shift=-self.origin) + def to_line_graph(self): """ Converts graph to a line graph. diff --git a/src/deep_neurographs/swc_utils.py b/src/deep_neurographs/swc_utils.py index 59be19c..917f60a 100644 --- a/src/deep_neurographs/swc_utils.py +++ b/src/deep_neurographs/swc_utils.py @@ -28,7 +28,9 @@ def read_swc(path): return contents -def parse(raw_swc, anisotropy=[1.0, 1.0, 1.0], idx=False): +def parse( + raw_swc, anisotropy=[1.0, 1.0, 1.0], bbox=None, img_shape=None, idx=False +): """ Parses a raw swc file to extract the (x,y,z) coordinates and radii. Note that node_ids from swc are refactored to index from 0 to n-1 where n is @@ -60,12 +62,15 @@ def parse(raw_swc, anisotropy=[1.0, 1.0, 1.0], idx=False): offset = read_xyz(parts[2:5]) if not line.startswith("#") and len(line) > 0: parts = line.split() + xyz = read_xyz(parts[2:5], anisotropy=anisotropy, offset=offset) + if bbox: + if not utils.is_contained(bbox, img_shape, xyz): + break + swc_dict["id"].append(int(parts[0])) swc_dict["radius"].append(float(parts[-2])) swc_dict["pid"].append(int(parts[-1])) - swc_dict["xyz"].append( - read_xyz(parts[2:5], anisotropy=anisotropy, offset=offset) - ) + swc_dict["xyz"].append(xyz) if swc_dict["id"][-1] < min_id: min_id = swc_dict["id"][-1] diff --git a/src/deep_neurographs/utils.py b/src/deep_neurographs/utils.py index 98bc701..8b7ea9c 100644 --- a/src/deep_neurographs/utils.py +++ b/src/deep_neurographs/utils.py @@ -15,10 +15,8 @@ from copy import deepcopy import numpy as np -import plotly.graph_objects as go import tensorstore as ts import zarr -from plotly.subplots import make_subplots ANISOTROPY = [0.748, 0.748, 1.0] SUPPORTED_DRIVERS = ["neuroglancer_precomputed", "zarr"] @@ -303,7 +301,7 @@ def patch_to_img(xyz, patch_centroid, patch_dims): def to_world(xyz, shift=[0, 0, 0]): - return tuple([(xyz[i] - shift[i]) * ANISOTROPY[i] for i in range(3)]) + return tuple([xyz[i] * ANISOTROPY[i] - shift[i] for i in range(3)]) def to_img(xyz, shift=[0, 0, 0]): @@ -317,6 +315,16 @@ def apply_anisotropy(xyz, return_int=False): return [xyz[i] / ANISOTROPY[i] for i in range(3)] +def is_contained(bbox, img_shape, xyz): + xyz = apply_anisotropy(xyz - bbox["min"]) + for i in range(3): + lower_bool = xyz[i] < 0 + upper_bool = xyz[i] >= img_shape[i] + if lower_bool or upper_bool: + return False + return True + + # --- miscellaneous --- def get_img_mip(img, axis=0): return np.max(img, axis=axis) diff --git a/src/deep_neurographs/visualization.py b/src/deep_neurographs/visualization.py index a04fa80..dcd7575 100644 --- a/src/deep_neurographs/visualization.py +++ b/src/deep_neurographs/visualization.py @@ -9,14 +9,38 @@ import networkx as nx import numpy as np +import plotly.colors as plc import plotly.graph_objects as go - - -def visualize_connected_components(graph): - pass - - -def visualize_immutables(graph, title="Immutable Graph"): +from plotly import tools + + +def visualize_connected_components( + graph, return_data=False, title="", vertex_threshold=50 +): + # Make plot + data = [] + colors = plc.qualitative.Bold + connected_components = nx.connected_components(graph) + cnt = 0 + while True: + try: + component = next(connected_components) + subgraph = graph.subgraph(component) + if len(subgraph.nodes) > vertex_threshold: + color = colors[cnt % len(colors)] + data.extend(plot_edges(graph, subgraph.edges, color=color)) + cnt += 1 + except StopIteration: + break + + # Output + if return_data: + return data + else: + plot(data, title) + + +def visualize_immutables(graph, title="Initial Segmentation"): data = plot_edges(graph, graph.immutable_edges) data.append(plot_nodes(graph)) plot(data, title) @@ -28,17 +52,14 @@ def visualize_proposals(graph, title="Edge Proposals"): def visualize_targets(graph, target_graph=None, title="Target Edges"): visualize_subset( - graph, - graph.target_edges, - target_graph=target_graph, - title=title, + graph, graph.target_edges, target_graph=target_graph, title=title ) -def visualize_subset(graph, edges, target_graph=None, title=""): +def visualize_subset(graph, edges, line_width=5, target_graph=None, title=""): data = plot_edges(graph, graph.immutable_edges, color="black") - data.extend(plot_edges(graph, edges)) - data.append(plot_nodes(graph)) + data.extend(plot_edges(graph, edges, line_width=line_width)) + data.append(plot_nodes(graph)) if target_graph: edges = target_graph.immutable_edges data.extend(plot_edges(target_graph, edges, color="blue")) @@ -59,9 +80,11 @@ def plot_nodes(graph): ) -def plot_edges(graph, edges, color=None): +def plot_edges(graph, edges, color=None, line_width=3.5): traces = [] - line = dict(width=4) if color is None else dict(color=color, width=3) + line = ( + dict(width=5) if color is None else dict(color=color, width=line_width) + ) for i, j in edges: trace = go.Scatter3d( x=graph.edges[(i, j)]["xyz"][:, 0], @@ -78,20 +101,18 @@ def plot_edges(graph, edges, color=None): def plot(data, title): fig = go.Figure(data=data) fig.update_layout( - plot_bgcolor="white", title=title, - scene=dict(xaxis_title="X", yaxis_title="Y", zaxis_title="Z"), - ) - fig.update_layout( + template="plotly_white", + plot_bgcolor="rgba(0, 0, 0, 0)", scene=dict(aspectmode="manual", aspectratio=dict(x=1, y=1, z=1)), width=1200, - height=600, + height=800, ) fig.show() def subplot(data1, data2, title): - fig = make_subplots( + fig = tools.make_subplots( rows=1, cols=2, specs=[[{"type": "scene"}, {"type": "scene"}]] ) fig.add_trace(data1, row=1, col=1)