Skip to content

Commit

Permalink
black format docs and examples
Browse files Browse the repository at this point in the history
  • Loading branch information
pattonw committed Dec 19, 2023
1 parent 625cb03 commit 35fcd43
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 145 deletions.
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
]

# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
templates_path = ["_templates"]

# -- Options for HTML output -------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
Expand Down
54 changes: 22 additions & 32 deletions examples/cremi/mknet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,26 @@
import tensorflow as tf
import json

def create_network(input_shape, name):

def create_network(input_shape, name):
tf.reset_default_graph()

# create a placeholder for the 3D raw input tensor
raw = tf.placeholder(tf.float32, shape=input_shape)

# create a U-Net
raw_batched = tf.reshape(raw, (1, 1) + input_shape)
unet_output = unet(raw_batched, 6, 4, [[1,3,3],[1,3,3],[1,3,3]])
unet_output = unet(raw_batched, 6, 4, [[1, 3, 3], [1, 3, 3], [1, 3, 3]])

# add a convolution layer to create 3 output maps representing affinities
# in z, y, and x
pred_affs_batched = conv_pass(
unet_output,
kernel_size=1,
num_fmaps=3,
num_repetitions=1,
activation='sigmoid')
unet_output, kernel_size=1, num_fmaps=3, num_repetitions=1, activation="sigmoid"
)

# get the shape of the output
output_shape_batched = pred_affs_batched.get_shape().as_list()
output_shape = output_shape_batched[1:] # strip the batch dimension
output_shape = output_shape_batched[1:] # strip the batch dimension

# the 4D output tensor (3, depth, height, width)
pred_affs = tf.reshape(pred_affs_batched, output_shape)
Expand All @@ -33,46 +30,39 @@ def create_network(input_shape, name):
gt_affs = tf.placeholder(tf.float32, shape=output_shape)

# create a placeholder for per-voxel loss weights
loss_weights = tf.placeholder(
tf.float32,
shape=output_shape)
loss_weights = tf.placeholder(tf.float32, shape=output_shape)

# compute the loss as the weighted mean squared error between the
# predicted and the ground-truth affinities
loss = tf.losses.mean_squared_error(
gt_affs,
pred_affs,
loss_weights)
loss = tf.losses.mean_squared_error(gt_affs, pred_affs, loss_weights)

# use the Adam optimizer to minimize the loss
opt = tf.train.AdamOptimizer(
learning_rate=0.5e-4,
beta1=0.95,
beta2=0.999,
epsilon=1e-8)
learning_rate=0.5e-4, beta1=0.95, beta2=0.999, epsilon=1e-8
)
optimizer = opt.minimize(loss)

# store the network in a meta-graph file
tf.train.export_meta_graph(filename=name + '.meta')
tf.train.export_meta_graph(filename=name + ".meta")

# store network configuration for use in train and predict scripts
config = {
'raw': raw.name,
'pred_affs': pred_affs.name,
'gt_affs': gt_affs.name,
'loss_weights': loss_weights.name,
'loss': loss.name,
'optimizer': optimizer.name,
'input_shape': input_shape,
'output_shape': output_shape[1:]
"raw": raw.name,
"pred_affs": pred_affs.name,
"gt_affs": gt_affs.name,
"loss_weights": loss_weights.name,
"loss": loss.name,
"optimizer": optimizer.name,
"input_shape": input_shape,
"output_shape": output_shape[1:],
}
with open(name + '_config.json', 'w') as f:
with open(name + "_config.json", "w") as f:
json.dump(config, f)

if __name__ == "__main__":

if __name__ == "__main__":
# create a network for training
create_network((84, 268, 268), 'train_net')
create_network((84, 268, 268), "train_net")

# create a larger network for faster prediction
create_network((120, 322, 322), 'test_net')
create_network((120, 322, 322), "test_net")
63 changes: 28 additions & 35 deletions examples/cremi/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,29 @@
import gunpowder as gp
import json

def predict(iteration):

def predict(iteration):
##################
# DECLARE ARRAYS #
##################

# raw intensities
raw = gp.ArrayKey('RAW')
raw = gp.ArrayKey("RAW")

# the predicted affinities
pred_affs = gp.ArrayKey('PRED_AFFS')
pred_affs = gp.ArrayKey("PRED_AFFS")

####################
# DECLARE REQUESTS #
####################

with open('test_net_config.json', 'r') as f:
with open("test_net_config.json", "r") as f:
net_config = json.load(f)

# get the input and output size in world units (nm, in this case)
voxel_size = gp.Coordinate((40, 4, 4))
input_size = gp.Coordinate(net_config['input_shape'])*voxel_size
output_size = gp.Coordinate(net_config['output_shape'])*voxel_size
input_size = gp.Coordinate(net_config["input_shape"]) * voxel_size
output_size = gp.Coordinate(net_config["output_shape"]) * voxel_size
context = input_size - output_size

# formulate the request for what a batch should contain
Expand All @@ -37,52 +37,44 @@ def predict(iteration):
#############################

source = gp.Hdf5Source(
'sample_A_padded_20160501.hdf',
datasets = {
raw: 'volumes/raw'
})
"sample_A_padded_20160501.hdf", datasets={raw: "volumes/raw"}
)

# get the ROI provided for raw (we need it later to calculate the ROI in
# which we can make predictions)
with gp.build(source):
raw_roi = source.spec[raw].roi

pipeline = (

# read from HDF5 file
source +

source
+
# convert raw to float in [0, 1]
gp.Normalize(raw) +

gp.Normalize(raw)
+
# perform one training iteration for each passing batch (here we use
# the tensor names earlier stored in train_net.config)
gp.tensorflow.Predict(
graph='test_net.meta',
checkpoint='train_net_checkpoint_%d'%iteration,
inputs={
net_config['raw']: raw
},
outputs={
net_config['pred_affs']: pred_affs
},
array_specs={
pred_affs: gp.ArraySpec(roi=raw_roi.grow(-context, -context))
}) +

graph="test_net.meta",
checkpoint="train_net_checkpoint_%d" % iteration,
inputs={net_config["raw"]: raw},
outputs={net_config["pred_affs"]: pred_affs},
array_specs={pred_affs: gp.ArraySpec(roi=raw_roi.grow(-context, -context))},
)
+
# store all passing batches in the same HDF5 file
gp.Hdf5Write(
{
raw: '/volumes/raw',
pred_affs: '/volumes/pred_affs',
raw: "/volumes/raw",
pred_affs: "/volumes/pred_affs",
},
output_filename='predictions_sample_A.hdf',
compression_type='gzip'
) +

output_filename="predictions_sample_A.hdf",
compression_type="gzip",
)
+
# show a summary of time spend in each node every 10 iterations
gp.PrintProfilingStats(every=10) +

gp.PrintProfilingStats(every=10)
+
# iterate over the whole dataset in a scanning fashion, emitting
# requests that match the size of the network
gp.Scan(reference=request)
Expand All @@ -93,5 +85,6 @@ def predict(iteration):
# without keeping the complete dataset in memory
pipeline.request_batch(gp.BatchRequest())


if __name__ == "__main__":
predict(200000)
Loading

0 comments on commit 35fcd43

Please sign in to comment.