From a049f03b4f62a7dd9c1fe46d14302978f88caece Mon Sep 17 00:00:00 2001 From: Michal Januszewski Date: Mon, 22 Jan 2024 04:51:30 -0800 Subject: [PATCH] Cover more targets with pytype annotations and checks. PiperOrigin-RevId: 600421971 --- ffn/training/augmentation.py | 2 +- ffn/training/inputs.py | 6 +-- ffn/training/model.py | 75 +++++++++++++++++++++++------------- 3 files changed, 52 insertions(+), 31 deletions(-) diff --git a/ffn/training/augmentation.py b/ffn/training/augmentation.py index 0e9f518..bb77146 100644 --- a/ffn/training/augmentation.py +++ b/ffn/training/augmentation.py @@ -48,7 +48,7 @@ def xy_transpose(data, decision): """ with tf.name_scope('augment_xy_transpose'): rank = data.get_shape().ndims - perm = range(rank) + perm = list(range(rank)) perm[rank - 3], perm[rank - 2] = perm[rank - 2], perm[rank - 3] return tf.cond(decision, lambda: tf.transpose(data, perm), diff --git a/ffn/training/inputs.py b/ffn/training/inputs.py index 5c359fe..4e30bde 100644 --- a/ffn/training/inputs.py +++ b/ffn/training/inputs.py @@ -15,11 +15,11 @@ """Tensorflow Python ops and utilities for generating network inputs.""" import re + +from connectomics.common import bounding_box import numpy as np import tensorflow.compat.v1 as tf - from tensorflow.io import gfile -from ..utils import bounding_box def create_filename_queue(coordinates_file_pattern, shuffle=True): @@ -139,7 +139,7 @@ def _load_from_numpylike(coord, volname): volume = volume_map[volname.decode('ascii')] # Get data, including all channels if volume is 4d. starts = np.array(coord) - start_offset - slc = bounding_box.BoundingBox(start=starts, size=shape).to_slice() + slc = bounding_box.BoundingBox(start=starts, size=shape).to_slice3d() if volume.ndim == 4: slc = np.index_exp[:] + slc data = volume[slc] diff --git a/ffn/training/model.py b/ffn/training/model.py index d3b6449..6411d5c 100644 --- a/ffn/training/model.py +++ b/ffn/training/model.py @@ -14,7 +14,10 @@ # ============================================================================== """Classes for FFN model definition.""" +from typing import Optional + import tensorflow.compat.v1 as tf + from . import optimizer @@ -22,7 +25,7 @@ class FFNModel(object): """Base class for FFN models.""" # Dimensionality of the model (2 or 3). - dim = None + dim: int = None ############################################################################ # (x, y, z) tuples defining various properties of the network. @@ -30,16 +33,16 @@ class FFNModel(object): # the third (z) value is ignored. # How far to move the field of view in the respective directions. - deltas = None + deltas: tuple[int, int, int] = None # Size of the input image and seed subvolumes to be used during inference. # This is enough information to execute a single prediction step, without # moving the field of view. - input_image_size = None - input_seed_size = None + input_image_size: tuple[int, int, int] = None + input_seed_size: tuple[int, int, int] = None # Size of the predicted patch as returned by the model. - pred_mask_size = None + pred_mask_size: tuple[int, int, int] = None ########################################################################### # TF op to compute loss optimized during training. This should include all @@ -49,7 +52,12 @@ class FFNModel(object): # TF op to call to perform loss optimization on the model. train_op = None - def __init__(self, deltas, batch_size=None, define_global_step=True): + def __init__( + self, + deltas: tuple[int, int, int], + batch_size: Optional[int] = None, + define_global_step: bool = True, + ): assert self.dim is not None self.deltas = deltas @@ -111,18 +119,21 @@ def set_input_shapes(self): Assumes input_seed_size and input_image_size are already set. """ - self.input_seed.set_shape([self.batch_size] + - list(self.input_seed_size[::-1]) + [1]) - self.input_patches.set_shape([self.batch_size] + - list(self.input_image_size[::-1]) + [1]) + self.input_seed.set_shape( + [self.batch_size] + list(self.input_seed_size[::-1]) + [1] + ) + self.input_patches.set_shape( + [self.batch_size] + list(self.input_image_size[::-1]) + [1] + ) def set_up_sigmoid_pixelwise_loss(self, logits): """Sets up the loss function of the model.""" assert self.labels is not None assert self.loss_weights is not None - pixel_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, - labels=self.labels) + pixel_loss = tf.nn.sigmoid_cross_entropy_with_logits( + logits=logits, labels=self.labels + ) pixel_loss *= self.loss_weights self.loss = tf.reduce_mean(pixel_loss) tf.summary.scalar('pixel_loss', self.loss) @@ -142,24 +153,28 @@ def set_up_optimizer(self, loss=None, max_gradient_entry_mag=0.7): tf.logging.error('Gradient is None: %s', v.op.name) if max_gradient_entry_mag > 0.0: - grads_and_vars = [(tf.clip_by_value(g, - -max_gradient_entry_mag, - +max_gradient_entry_mag), v) - for g, v, in grads_and_vars] + grads_and_vars = [ + ( + tf.clip_by_value( + g, -max_gradient_entry_mag, +max_gradient_entry_mag + ), + v, + ) + for g, v, in grads_and_vars + ] trainables = tf.trainable_variables() if trainables: for var in trainables: tf.summary.histogram(var.name.replace(':0', ''), var) for grad, var in grads_and_vars: - tf.summary.histogram( - 'gradients/%s' % var.name.replace(':0', ''), grad) + tf.summary.histogram('gradients/%s' % var.name.replace(':0', ''), grad) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): - self.train_op = opt.apply_gradients(grads_and_vars, - global_step=self.global_step, - name='train') + self.train_op = opt.apply_gradients( + grads_and_vars, global_step=self.global_step, name='train' + ) def show_center_slice(self, image, sigmoid=True): image = image[:, image.get_shape().dims[1] // 2, :, :, :] @@ -179,11 +194,16 @@ def update_seed(self, seed, update): if dx == 0 and dy == 0 and dz == 0: seed += update else: - seed += tf.pad(update, [[0, 0], - [dz // 2, dz - dz // 2], - [dy // 2, dy - dy // 2], - [dx // 2, dx - dx // 2], - [0, 0]]) + seed += tf.pad( + update, + [ + [0, 0], + [dz // 2, dz - dz // 2], + [dy // 2, dy - dy // 2], + [dx // 2, dx - dx // 2], + [0, 0], + ], + ) return seed def define_tf_graph(self): @@ -193,4 +213,5 @@ def define_tf_graph(self): computing and optimizing the loss. """ raise NotImplementedError( - 'DefineTFGraph needs to be defined by a subclass.') + 'DefineTFGraph needs to be defined by a subclass.' + )