Skip to content

Commit

Permalink
Cover more targets with pytype annotations and checks.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 600421971
  • Loading branch information
mjanusz authored and copybara-github committed Jan 23, 2024
1 parent 76f00a6 commit a049f03
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 31 deletions.
2 changes: 1 addition & 1 deletion ffn/training/augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def xy_transpose(data, decision):
"""
with tf.name_scope('augment_xy_transpose'):
rank = data.get_shape().ndims
perm = range(rank)
perm = list(range(rank))
perm[rank - 3], perm[rank - 2] = perm[rank - 2], perm[rank - 3]
return tf.cond(decision,
lambda: tf.transpose(data, perm),
Expand Down
6 changes: 3 additions & 3 deletions ffn/training/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
"""Tensorflow Python ops and utilities for generating network inputs."""

import re

from connectomics.common import bounding_box
import numpy as np
import tensorflow.compat.v1 as tf

from tensorflow.io import gfile
from ..utils import bounding_box


def create_filename_queue(coordinates_file_pattern, shuffle=True):
Expand Down Expand Up @@ -139,7 +139,7 @@ def _load_from_numpylike(coord, volname):
volume = volume_map[volname.decode('ascii')]
# Get data, including all channels if volume is 4d.
starts = np.array(coord) - start_offset
slc = bounding_box.BoundingBox(start=starts, size=shape).to_slice()
slc = bounding_box.BoundingBox(start=starts, size=shape).to_slice3d()
if volume.ndim == 4:
slc = np.index_exp[:] + slc
data = volume[slc]
Expand Down
75 changes: 48 additions & 27 deletions ffn/training/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,32 +14,35 @@
# ==============================================================================
"""Classes for FFN model definition."""

from typing import Optional

import tensorflow.compat.v1 as tf

from . import optimizer


class FFNModel(object):
"""Base class for FFN models."""

# Dimensionality of the model (2 or 3).
dim = None
dim: int = None

############################################################################
# (x, y, z) tuples defining various properties of the network.
# Note that 3-tuples should be used even for 2D networks, in which case
# the third (z) value is ignored.

# How far to move the field of view in the respective directions.
deltas = None
deltas: tuple[int, int, int] = None

# Size of the input image and seed subvolumes to be used during inference.
# This is enough information to execute a single prediction step, without
# moving the field of view.
input_image_size = None
input_seed_size = None
input_image_size: tuple[int, int, int] = None
input_seed_size: tuple[int, int, int] = None

# Size of the predicted patch as returned by the model.
pred_mask_size = None
pred_mask_size: tuple[int, int, int] = None
###########################################################################

# TF op to compute loss optimized during training. This should include all
Expand All @@ -49,7 +52,12 @@ class FFNModel(object):
# TF op to call to perform loss optimization on the model.
train_op = None

def __init__(self, deltas, batch_size=None, define_global_step=True):
def __init__(
self,
deltas: tuple[int, int, int],
batch_size: Optional[int] = None,
define_global_step: bool = True,
):
assert self.dim is not None

self.deltas = deltas
Expand Down Expand Up @@ -111,18 +119,21 @@ def set_input_shapes(self):
Assumes input_seed_size and input_image_size are already set.
"""
self.input_seed.set_shape([self.batch_size] +
list(self.input_seed_size[::-1]) + [1])
self.input_patches.set_shape([self.batch_size] +
list(self.input_image_size[::-1]) + [1])
self.input_seed.set_shape(
[self.batch_size] + list(self.input_seed_size[::-1]) + [1]
)
self.input_patches.set_shape(
[self.batch_size] + list(self.input_image_size[::-1]) + [1]
)

def set_up_sigmoid_pixelwise_loss(self, logits):
"""Sets up the loss function of the model."""
assert self.labels is not None
assert self.loss_weights is not None

pixel_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits,
labels=self.labels)
pixel_loss = tf.nn.sigmoid_cross_entropy_with_logits(
logits=logits, labels=self.labels
)
pixel_loss *= self.loss_weights
self.loss = tf.reduce_mean(pixel_loss)
tf.summary.scalar('pixel_loss', self.loss)
Expand All @@ -142,24 +153,28 @@ def set_up_optimizer(self, loss=None, max_gradient_entry_mag=0.7):
tf.logging.error('Gradient is None: %s', v.op.name)

if max_gradient_entry_mag > 0.0:
grads_and_vars = [(tf.clip_by_value(g,
-max_gradient_entry_mag,
+max_gradient_entry_mag), v)
for g, v, in grads_and_vars]
grads_and_vars = [
(
tf.clip_by_value(
g, -max_gradient_entry_mag, +max_gradient_entry_mag
),
v,
)
for g, v, in grads_and_vars
]

trainables = tf.trainable_variables()
if trainables:
for var in trainables:
tf.summary.histogram(var.name.replace(':0', ''), var)
for grad, var in grads_and_vars:
tf.summary.histogram(
'gradients/%s' % var.name.replace(':0', ''), grad)
tf.summary.histogram('gradients/%s' % var.name.replace(':0', ''), grad)

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
self.train_op = opt.apply_gradients(grads_and_vars,
global_step=self.global_step,
name='train')
self.train_op = opt.apply_gradients(
grads_and_vars, global_step=self.global_step, name='train'
)

def show_center_slice(self, image, sigmoid=True):
image = image[:, image.get_shape().dims[1] // 2, :, :, :]
Expand All @@ -179,11 +194,16 @@ def update_seed(self, seed, update):
if dx == 0 and dy == 0 and dz == 0:
seed += update
else:
seed += tf.pad(update, [[0, 0],
[dz // 2, dz - dz // 2],
[dy // 2, dy - dy // 2],
[dx // 2, dx - dx // 2],
[0, 0]])
seed += tf.pad(
update,
[
[0, 0],
[dz // 2, dz - dz // 2],
[dy // 2, dy - dy // 2],
[dx // 2, dx - dx // 2],
[0, 0],
],
)
return seed

def define_tf_graph(self):
Expand All @@ -193,4 +213,5 @@ def define_tf_graph(self):
computing and optimizing the loss.
"""
raise NotImplementedError(
'DefineTFGraph needs to be defined by a subclass.')
'DefineTFGraph needs to be defined by a subclass.'
)

0 comments on commit a049f03

Please sign in to comment.