Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions aiutils/examples/resnet_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Demonstrates use of pretrained ResNet model available here:
# https://github.com/ry/tensorflow-resnet

from aiutils.tftools import placeholder_management
from aiutils.vis import image, image_io
import aiutils.vis.pretrained_models.resnet.model as model
import os
import numpy as np

import tensorflow as tf


if __name__=='__main__':
im_h, im_w = (224, 224)

# Path to the image to apply resnet on
image_path = './aiutils/examples/telephone.jpg'

# Checkpoint file to restore parameters from
model_dir = '/home/nfs/tgupta6/data/Resnet'
ckpt_filename = os.path.join(model_dir, 'ResNet-L50.ckpt')

# Graph construction
graph = tf.Graph()
with graph.as_default():
plh = placeholder_management.PlaceholderManager()
plh.add_placeholder(
'images',
tf.float32,
shape=[None,im_h,im_w,3])

resnet_model = model.ResnetInference(
plh['images'],
num_blocks = [3, 4, 6, 3],
)

# Create feed dict
im = image_io.imread(image_path)
im = image.imresize(im, output_size=(im_h, im_w)).astype(np.float32)
inputs = {
'images': im.reshape(1, im_h, im_w, 3)
}
feed_dict = plh.get_feed_dict(inputs)

# Restore model and get top-5 class predictions
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.gpu_options.per_process_gpu_memory_fraction = 0.8
sess = tf.Session(config=config, graph=graph)

resnet_model.restore_pretrained_model(sess, ckpt_filename)
logits = resnet_model.get_logits().eval(feed_dict,sess)
resnet_model.imagenet_class_prediction(logits[0,:])

sess.close()
58 changes: 58 additions & 0 deletions aiutils/examples/resnet_fully_conv_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Demonstrates use of pretrained ResNet model available here:
# https://github.com/ry/tensorflow-resnet

from aiutils.tftools import placeholder_management
from aiutils.vis import image, image_io
import aiutils.vis.pretrained_models.resnet.model_fully_conv as model
import os
import numpy as np

from skimage.color import label2rgb
import tensorflow as tf


if __name__=='__main__':
im_h, im_w = (512, 512)

# Path to the image to apply resnet on
image_path = './aiutils/examples/telephone.jpg'

# Checkpoint file to restore parameters from
model_dir = '/home/nfs/tgupta6/data/Resnet'
ckpt_filename = os.path.join(model_dir, 'ResNet-L50.ckpt')

# Graph construction
graph = tf.Graph()
with graph.as_default():
plh = placeholder_management.PlaceholderManager()
plh.add_placeholder(
'images',
tf.float32,
shape=[None,im_h,im_w,3])

resnet_model = model.ResnetFullyConvInference(
plh['images'],
num_blocks = [3, 4, 6, 3]
)

# Create feed dict
im = image_io.imread(image_path)
im = image.imresize(im, output_size=(im_h, im_w)).astype(np.float32)
inputs = {
'images': im.reshape(1, im_h, im_w, 3)
}
feed_dict = plh.get_feed_dict(inputs)

# Restore model and get top-5 class predictions
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.gpu_options.per_process_gpu_memory_fraction = 1.0
sess = tf.Session(config=config, graph=graph)

resnet_model.restore_pretrained_model(sess, ckpt_filename)
logits = resnet_model.logits.eval(feed_dict,sess)
pred = np.argmax(logits, 3)
colored_pred = label2rgb(pred[0,:,:])
image_io.imwrite(np.uint8(colored_pred*255), './aiutils/examples/telephone_pred.jpg')

sess.close()
Binary file added aiutils/examples/telephone.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added aiutils/examples/telephone_pred.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion aiutils/tftools/var_collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


def print_var_list(var_list, name='Variables'):
print name + ': \n' + '[' + ', '.join([var.name for var in var_list]) + ']'
print name + ': \n' + '[' + ',\n '.join([var.name for var in var_list]) + ']'


def collect_name(var_name, graph=None, var_type=tf.GraphKeys.VARIABLES):
Expand Down
4 changes: 2 additions & 2 deletions aiutils/vis/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ def imresize(np_im, method='bilinear', **kwargs):
if 'output_size' in kwargs:
h, w = kwargs['output_size']
elif 'scale' in kwargs:
h = scale * im_h
w = scale * im_w
h = kwargs['scale'] * im_h
w = kwargs['scale'] * im_w
else:
assert_str = "Variable argument must be one of {'output_size','scale'}"
assert (False), assert_str
Expand Down
Empty file.
Empty file.
82 changes: 82 additions & 0 deletions aiutils/vis/pretrained_models/resnet/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# This is a variable scope aware configuation object for TensorFlow

import tensorflow as tf

FLAGS = tf.app.flags.FLAGS

class Config:
def __init__(self):
root = self.Scope('')
for k, v in FLAGS.__dict__['__flags'].iteritems():
root[k] = v
self.stack = [ root ]

def _pop_stale(self):
var_scope_name = tf.get_variable_scope().name
top = self.stack[0]
while not top.contains(var_scope_name):
# We aren't in this scope anymore
self.stack.pop(0)
top = self.stack[0]

def __getitem__(self, name):
self._pop_stale()
# Recursively extract value
for i in range(len(self.stack)):
cs = self.stack[i]
if name in cs:
return cs[name]

raise KeyError(name)

def __setitem__(self, name, value):
self._pop_stale()
top = self.stack[0]
var_scope_name = tf.get_variable_scope().name
assert top.contains(var_scope_name)

if top.name != var_scope_name:
top = self.Scope(var_scope_name)
self.stack.insert(0, top)

top[name] = value

class Scope(dict):
def __init__(self, name):
self.name = name

def contains(self, var_scope_name):
return var_scope_name.startswith(self.name)



# Test
if __name__ == '__main__':

def assert_raises(exception, fn):
try:
fn()
except exception:
pass
else:
assert False, "Expected exception"

c = Config()

c['hello'] = 1
assert c['hello'] == 1

with tf.variable_scope('foo'):
c['bar'] = 2
assert c['bar'] == 2
assert c['hello'] == 1

with tf.variable_scope('meow'):
c['dog'] = 3
assert c['dog'] == 3
assert c['bar'] == 2
assert c['hello'] == 1

assert_raises(KeyError, lambda: c['dog'])
assert c['bar'] == 2
assert c['hello'] == 1
Loading