Skip to content

Commit

Permalink
remove batch.id
Browse files Browse the repository at this point in the history
replaced in tensorflow predict node debug statments with the request. This better indicates
the roi being predicted on.
replaced in snapshot node with an internal counter
  • Loading branch information
pattonw committed Jan 3, 2024
1 parent 446331e commit ed445f1
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 15 deletions.
10 changes: 0 additions & 10 deletions gunpowder/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,7 @@ class Batch(Freezable):
Contains all graphs that have been requested for this batch.
"""

__next_id = multiprocessing.Value("L")

@staticmethod
def get_next_id():
with Batch.__next_id.get_lock():
next_id = Batch.__next_id.value
Batch.__next_id.value += 1
return next_id

def __init__(self):
self.id = Batch.get_next_id()
self.profiling_stats = ProfilingStats()
self.arrays = {}
self.graphs = {}
Expand Down
6 changes: 4 additions & 2 deletions gunpowder/nodes/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(
self,
dataset_names,
output_dir="snapshots",
output_filename="{id}.zarr",
output_filename="{iteration}.zarr",
every=1,
additional_request=None,
compression_type=None,
Expand All @@ -97,6 +97,7 @@ def __init__(
self.dataset_dtypes = dataset_dtypes

self.mode = "w"
self.id = 0

def write_if(self, batch):
"""To be implemented in subclasses.
Expand Down Expand Up @@ -157,6 +158,7 @@ def prepare(self, request):
return deps

def process(self, batch, request):
self.id += 1
if self.record_snapshot and self.write_if(batch):
try:
os.makedirs(self.output_dir)
Expand All @@ -166,7 +168,7 @@ def process(self, batch, request):
snapshot_name = os.path.join(
self.output_dir,
self.output_filename.format(
id=str(batch.id).zfill(8), iteration=int(batch.iteration or 0)
id=str(self.id).zfill(8), iteration=int(batch.iteration or self.id)
),
)
logger.info("saving to %s" % snapshot_name)
Expand Down
6 changes: 3 additions & 3 deletions gunpowder/tensorflow/nodes/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def predict(self, batch, request):
break

if can_skip:
logger.info("Skipping batch %i (all inputs are 0)" % batch.id)
logger.info(f"Skipping batch for request: {request} (all inputs are 0)")

for name, array_key in self.outputs.items():
shape = self.shared_output_arrays[name].shape
Expand All @@ -124,7 +124,7 @@ def predict(self, batch, request):

return

logger.debug("predicting in batch %i", batch.id)
logger.debug(f"predicting for request: {request}")

output_tensors = self.__collect_outputs(request)
input_data = self.__collect_provided_inputs(batch)
Expand Down Expand Up @@ -160,7 +160,7 @@ def predict(self, batch, request):
spec.roi = request[array_key].roi
batch.arrays[array_key] = Array(output_data[array_key], spec)

logger.debug("predicted in batch %i", batch.id)
logger.debug("predicted")

def __predict(self):
"""The background predict process."""
Expand Down

0 comments on commit ed445f1

Please sign in to comment.