Skip to content

Commit

Permalink
feat: log runtimes
Browse files Browse the repository at this point in the history
  • Loading branch information
anna-grim committed Oct 14, 2024
1 parent 5d16c74 commit 30fe182
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 32 deletions.
2 changes: 1 addition & 1 deletion src/deep_neurographs/graph_artifact_removal.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def remove_doubles(neurograph, max_size, node_spacing, output_dir=None):
)
neurograph = delete(neurograph, components[idx], swc_id)
deleted.add(swc_id)
print("# Doubles detected:", util.reformat_number(len(deleted)))
return len(deleted)


def compute_projections(neurograph, kdtree, edge):
Expand Down
86 changes: 55 additions & 31 deletions src/deep_neurographs/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(
device=None,
is_multimodal=False,
label_path=None,
log_runtimes=True,
):
"""
Initializes an object that executes the full GraphTrace inference
Expand All @@ -93,10 +94,13 @@ def __init__(
for the inference pipeline.
device : str, optional
...
label_path : str, optional
Path to the segmentation assumed to be stored on a GCS bucket.
is_multimodal : bool, optional
...
label_path : str, optional
Path to the segmentation assumed to be stored on a GCS bucket. The
default is None.
log_runtimes : bool, optional
Indication of whether to log runtimes. The default is True.
Returns
-------
Expand All @@ -105,9 +109,10 @@ def __init__(
"""
# Class attributes
self.accepted_proposals = list()
self.log_runtimes = log_runtimes
self.model_path = model_path
self.sample_id = sample_id
self.segmentation_id = segmentation_id
self.model_path = model_path

# Extract config settings
self.graph_config = config.graph_config
Expand All @@ -130,6 +135,9 @@ def __init__(
date = datetime.today().strftime("%Y-%m-%d")
self.output_dir = f"{output_dir}/{segmentation_id}-{date}"
util.mkdir(self.output_dir, delete=True)
if self.log_runtimes:
log_path = os.path.join(self.output_dir, "runtimes.txt")
self.log_handle = open(log_path, 'a')

# --- Core ---
def run(self, fragments_pointer):
Expand Down Expand Up @@ -158,8 +166,13 @@ def run(self, fragments_pointer):
self.run_inference()
self.save_results()

# Finish
self.report("Final Graph...")
self.report_graph()
self.report("\n")

t, unit = util.time_writer(time() - t0)
print(f"Total Runtime: {round(t, 4)} {unit}\n")
self.report(f"Total Runtime: {round(t, 4)} {unit}\n")

def run_schedule(
self, fragments_pointer, radius_schedule, save_all_rounds=False
Expand All @@ -168,7 +181,7 @@ def run_schedule(
self.report_experiment()
self.build_graph(fragments_pointer)
for round_id, radius in enumerate(radius_schedule):
print(f"--- Round {round_id + 1}: Radius = {radius} ---")
self.report(f"--- Round {round_id + 1}: Radius = {radius} ---")
round_id += 1
self.generate_proposals(radius)
self.run_inference()
Expand All @@ -179,7 +192,7 @@ def run_schedule(
self.save_results(round_id=round_id)

t, unit = util.time_writer(time() - t0)
print(f"Total Runtime: {round(t, 4)} {unit}\n")
self.report(f"Total Runtime: {round(t, 4)} {unit}\n")

def build_graph(self, fragments_pointer):
"""
Expand All @@ -197,7 +210,7 @@ def build_graph(self, fragments_pointer):
None
"""
print("(1) Building FragmentGraph")
self.report("(1) Building FragmentGraph")
t0 = time()

# Initialize Graph
Expand All @@ -210,7 +223,10 @@ def build_graph(self, fragments_pointer):

# Remove doubles (if applicable)
if self.graph_config.remove_doubles_bool:
remove_doubles(self.graph, 200, self.graph_config.node_spacing)
n_doubles = remove_doubles(
self.graph, 200, self.graph_config.node_spacing
)
self.report(f"# Doubles Detected: {n_doubles}")

# Save valid labels and current graph
swcs_path = os.path.join(self.output_dir, "processed-swcs.zip")
Expand All @@ -219,8 +235,8 @@ def build_graph(self, fragments_pointer):
self.graph.save_labels(labels_path)

t, unit = util.time_writer(time() - t0)
print(f"Module Runtime: {round(t, 4)} {unit}\n")
self.print_graph_overview()
self.report_graph()
self.report(f"Module Runtime: {round(t, 4)} {unit}\n")

def generate_proposals(self, radius=None):
"""
Expand All @@ -237,7 +253,7 @@ def generate_proposals(self, radius=None):
"""
# Initializations
print("(2) Generate Proposals")
self.report("(2) Generate Proposals")
if radius is None:
radius = self.graph_config.search_radius

Expand All @@ -254,8 +270,8 @@ def generate_proposals(self, radius=None):

# Report results
t, unit = util.time_writer(time() - t0)
print("# Proposals:", n_proposals)
print(f"Module Runtime: {round(t, 4)} {unit}\n")
self.report(f"# Proposals: {n_proposals}")
self.report(f"Module Runtime: {round(t, 4)} {unit}\n")

def run_inference(self):
"""
Expand All @@ -271,18 +287,18 @@ def run_inference(self):
None
"""
print("(3) Run Inference")
self.report("(3) Run Inference")
t0 = time()
n_proposals = max(self.graph.n_proposals(), 1)
self.graph, accepts = self.inference_engine.run(
self.graph, self.graph.list_proposals()
)
self.accepted_proposals.extend(accepts)
print("# Accepted:", util.reformat_number(len(accepts)))
print("% Accepted:", round(len(accepts) / n_proposals, 4))
self.report(f"# Accepted: {util.reformat_number(len(accepts))}")
self.report(f"% Accepted: {round(len(accepts) / n_proposals, 4)}")

t, unit = util.time_writer(time() - t0)
print(f"Module Runtime: {round(t, 4)} {unit}\n")
self.report(f"Module Runtime: {round(t, 4)} {unit}\n")

def save_results(self, round_id=None):
"""
Expand All @@ -305,13 +321,6 @@ def save_results(self, round_id=None):
self.save_connections(round_id=round_id)
self.write_metadata()

def report_experiment(self):
print("\nExperiment Overview")
print("-----------------------------------------------")
print("Sample_ID:", self.sample_id)
print("Segmentation_ID:", self.segmentation_id)
print("")

# --- io ---
def save_connections(self, round_id=None):
"""
Expand Down Expand Up @@ -364,7 +373,20 @@ def write_metadata(self):
util.write_json(path, metadata)

# --- Summaries ---
def print_graph_overview(self):
def report(self, txt):
print(txt)
if self.log_runtimes:
self.log_handle.write(txt)
self.log_handle.write("\n")

def report_experiment(self):
self.report("\nExperiment Overview")
self.report("-------------------------------------------------------")
self.report(f"Sample_ID: {self.sample_id}")
self.report(f"Segmentation_ID: {self.segmentation_id}")
self.report("\n")

def report_graph(self):
"""
Prints an overview of the graph's structure and memory usage.
Expand All @@ -379,14 +401,16 @@ def print_graph_overview(self):
"""
# Compute values
n_components = nx.number_connected_components(self.graph)
n_components = util.reformat_number(n_components)
n_nodes = util.reformat_number(self.graph.number_of_nodes())
n_edges = util.reformat_number(self.graph.number_of_edges())
usage = round(util.get_memory_usage(), 2)

# Print overview
print("Graph Overview...")
print("# Connected Components:", util.reformat_number(n_components))
print("# Nodes:", util.reformat_number(self.graph.number_of_nodes()))
print("# Edges:", util.reformat_number(self.graph.number_of_edges()))
print(f"Memory Consumption: {usage} GBs\n")
# Report
self.report(f"# Connected Components: {n_components}")
self.report(f"# Nodes: {n_nodes}")
self.report(f"# Edges: {n_edges}")
self.report(f"Memory Consumption: {usage} GBs")


class InferenceEngine:
Expand Down
1 change: 1 addition & 0 deletions src/deep_neurographs/utils/swc_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def load_from_gcs(self, gcs_dict):
processes.append(
executor.submit(self.load_from_cloud_zip, zip_content)
)
break

# Store results
swc_dicts = list()
Expand Down

0 comments on commit 30fe182

Please sign in to comment.