From 30fe182a44531f8e69a2c199c0240cbdc9cef414 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Mon, 14 Oct 2024 20:18:14 +0000 Subject: [PATCH] feat: log runtimes --- .../graph_artifact_removal.py | 2 +- src/deep_neurographs/inference.py | 86 ++++++++++++------- src/deep_neurographs/utils/swc_util.py | 1 + 3 files changed, 57 insertions(+), 32 deletions(-) diff --git a/src/deep_neurographs/graph_artifact_removal.py b/src/deep_neurographs/graph_artifact_removal.py index 1b433b5..30dd784 100644 --- a/src/deep_neurographs/graph_artifact_removal.py +++ b/src/deep_neurographs/graph_artifact_removal.py @@ -68,7 +68,7 @@ def remove_doubles(neurograph, max_size, node_spacing, output_dir=None): ) neurograph = delete(neurograph, components[idx], swc_id) deleted.add(swc_id) - print("# Doubles detected:", util.reformat_number(len(deleted))) + return len(deleted) def compute_projections(neurograph, kdtree, edge): diff --git a/src/deep_neurographs/inference.py b/src/deep_neurographs/inference.py index c14bddf..900ccd6 100644 --- a/src/deep_neurographs/inference.py +++ b/src/deep_neurographs/inference.py @@ -69,6 +69,7 @@ def __init__( device=None, is_multimodal=False, label_path=None, + log_runtimes=True, ): """ Initializes an object that executes the full GraphTrace inference @@ -93,10 +94,13 @@ def __init__( for the inference pipeline. device : str, optional ... - label_path : str, optional - Path to the segmentation assumed to be stored on a GCS bucket. is_multimodal : bool, optional ... + label_path : str, optional + Path to the segmentation assumed to be stored on a GCS bucket. The + default is None. + log_runtimes : bool, optional + Indication of whether to log runtimes. The default is True. Returns ------- @@ -105,9 +109,10 @@ def __init__( """ # Class attributes self.accepted_proposals = list() + self.log_runtimes = log_runtimes + self.model_path = model_path self.sample_id = sample_id self.segmentation_id = segmentation_id - self.model_path = model_path # Extract config settings self.graph_config = config.graph_config @@ -130,6 +135,9 @@ def __init__( date = datetime.today().strftime("%Y-%m-%d") self.output_dir = f"{output_dir}/{segmentation_id}-{date}" util.mkdir(self.output_dir, delete=True) + if self.log_runtimes: + log_path = os.path.join(self.output_dir, "runtimes.txt") + self.log_handle = open(log_path, 'a') # --- Core --- def run(self, fragments_pointer): @@ -158,8 +166,13 @@ def run(self, fragments_pointer): self.run_inference() self.save_results() + # Finish + self.report("Final Graph...") + self.report_graph() + self.report("\n") + t, unit = util.time_writer(time() - t0) - print(f"Total Runtime: {round(t, 4)} {unit}\n") + self.report(f"Total Runtime: {round(t, 4)} {unit}\n") def run_schedule( self, fragments_pointer, radius_schedule, save_all_rounds=False @@ -168,7 +181,7 @@ def run_schedule( self.report_experiment() self.build_graph(fragments_pointer) for round_id, radius in enumerate(radius_schedule): - print(f"--- Round {round_id + 1}: Radius = {radius} ---") + self.report(f"--- Round {round_id + 1}: Radius = {radius} ---") round_id += 1 self.generate_proposals(radius) self.run_inference() @@ -179,7 +192,7 @@ def run_schedule( self.save_results(round_id=round_id) t, unit = util.time_writer(time() - t0) - print(f"Total Runtime: {round(t, 4)} {unit}\n") + self.report(f"Total Runtime: {round(t, 4)} {unit}\n") def build_graph(self, fragments_pointer): """ @@ -197,7 +210,7 @@ def build_graph(self, fragments_pointer): None """ - print("(1) Building FragmentGraph") + self.report("(1) Building FragmentGraph") t0 = time() # Initialize Graph @@ -210,7 +223,10 @@ def build_graph(self, fragments_pointer): # Remove doubles (if applicable) if self.graph_config.remove_doubles_bool: - remove_doubles(self.graph, 200, self.graph_config.node_spacing) + n_doubles = remove_doubles( + self.graph, 200, self.graph_config.node_spacing + ) + self.report(f"# Doubles Detected: {n_doubles}") # Save valid labels and current graph swcs_path = os.path.join(self.output_dir, "processed-swcs.zip") @@ -219,8 +235,8 @@ def build_graph(self, fragments_pointer): self.graph.save_labels(labels_path) t, unit = util.time_writer(time() - t0) - print(f"Module Runtime: {round(t, 4)} {unit}\n") - self.print_graph_overview() + self.report_graph() + self.report(f"Module Runtime: {round(t, 4)} {unit}\n") def generate_proposals(self, radius=None): """ @@ -237,7 +253,7 @@ def generate_proposals(self, radius=None): """ # Initializations - print("(2) Generate Proposals") + self.report("(2) Generate Proposals") if radius is None: radius = self.graph_config.search_radius @@ -254,8 +270,8 @@ def generate_proposals(self, radius=None): # Report results t, unit = util.time_writer(time() - t0) - print("# Proposals:", n_proposals) - print(f"Module Runtime: {round(t, 4)} {unit}\n") + self.report(f"# Proposals: {n_proposals}") + self.report(f"Module Runtime: {round(t, 4)} {unit}\n") def run_inference(self): """ @@ -271,18 +287,18 @@ def run_inference(self): None """ - print("(3) Run Inference") + self.report("(3) Run Inference") t0 = time() n_proposals = max(self.graph.n_proposals(), 1) self.graph, accepts = self.inference_engine.run( self.graph, self.graph.list_proposals() ) self.accepted_proposals.extend(accepts) - print("# Accepted:", util.reformat_number(len(accepts))) - print("% Accepted:", round(len(accepts) / n_proposals, 4)) + self.report(f"# Accepted: {util.reformat_number(len(accepts))}") + self.report(f"% Accepted: {round(len(accepts) / n_proposals, 4)}") t, unit = util.time_writer(time() - t0) - print(f"Module Runtime: {round(t, 4)} {unit}\n") + self.report(f"Module Runtime: {round(t, 4)} {unit}\n") def save_results(self, round_id=None): """ @@ -305,13 +321,6 @@ def save_results(self, round_id=None): self.save_connections(round_id=round_id) self.write_metadata() - def report_experiment(self): - print("\nExperiment Overview") - print("-----------------------------------------------") - print("Sample_ID:", self.sample_id) - print("Segmentation_ID:", self.segmentation_id) - print("") - # --- io --- def save_connections(self, round_id=None): """ @@ -364,7 +373,20 @@ def write_metadata(self): util.write_json(path, metadata) # --- Summaries --- - def print_graph_overview(self): + def report(self, txt): + print(txt) + if self.log_runtimes: + self.log_handle.write(txt) + self.log_handle.write("\n") + + def report_experiment(self): + self.report("\nExperiment Overview") + self.report("-------------------------------------------------------") + self.report(f"Sample_ID: {self.sample_id}") + self.report(f"Segmentation_ID: {self.segmentation_id}") + self.report("\n") + + def report_graph(self): """ Prints an overview of the graph's structure and memory usage. @@ -379,14 +401,16 @@ def print_graph_overview(self): """ # Compute values n_components = nx.number_connected_components(self.graph) + n_components = util.reformat_number(n_components) + n_nodes = util.reformat_number(self.graph.number_of_nodes()) + n_edges = util.reformat_number(self.graph.number_of_edges()) usage = round(util.get_memory_usage(), 2) - # Print overview - print("Graph Overview...") - print("# Connected Components:", util.reformat_number(n_components)) - print("# Nodes:", util.reformat_number(self.graph.number_of_nodes())) - print("# Edges:", util.reformat_number(self.graph.number_of_edges())) - print(f"Memory Consumption: {usage} GBs\n") + # Report + self.report(f"# Connected Components: {n_components}") + self.report(f"# Nodes: {n_nodes}") + self.report(f"# Edges: {n_edges}") + self.report(f"Memory Consumption: {usage} GBs") class InferenceEngine: diff --git a/src/deep_neurographs/utils/swc_util.py b/src/deep_neurographs/utils/swc_util.py index 8c54492..6f545e7 100644 --- a/src/deep_neurographs/utils/swc_util.py +++ b/src/deep_neurographs/utils/swc_util.py @@ -177,6 +177,7 @@ def load_from_gcs(self, gcs_dict): processes.append( executor.submit(self.load_from_cloud_zip, zip_content) ) + break # Store results swc_dicts = list()