Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions keras_squeezenet/squeezenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def fire_module(x, fire_id, squeeze=16, expand=64):
channel_axis = 1
else:
channel_axis = 3

x = Convolution2D(squeeze, (1, 1), padding='valid', name=s_id + sq1x1)(x)
x = Activation('relu', name=s_id + relu + sq1x1)(x)

Expand All @@ -47,13 +47,13 @@ def SqueezeNet(include_top=True, weights='imagenet',
classes=1000):
"""Instantiates the SqueezeNet architecture.
"""

if weights not in {'imagenet', None}:
raise ValueError('The `weights` argument should be either '
'`None` (random initialization) or `imagenet` '
'(pre-training on ImageNet).')

if weights == 'imagenet' and classes != 1000:
if weights == 'imagenet' and classes != 1000 and include_top==True:
Copy link

@jaiprasadreddy jaiprasadreddy Apr 5, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rcmalli @JesperChristensen89
if weights == 'imagenet' and classes != 1000 and include_top:

raise ValueError('If using `weights` as imagenet with `include_top`'
' as true, `classes` should be 1000')

Expand Down Expand Up @@ -89,11 +89,11 @@ def SqueezeNet(include_top=True, weights='imagenet',
x = fire_module(x, fire_id=7, squeeze=48, expand=192)
x = fire_module(x, fire_id=8, squeeze=64, expand=256)
x = fire_module(x, fire_id=9, squeeze=64, expand=256)

if include_top:
# It's not obvious where to cut the network...
# It's not obvious where to cut the network...
# Could do the 8th or 9th layer... some work recommends cutting earlier layers.

x = Dropout(0.5, name='drop9')(x)

x = Convolution2D(classes, (1, 1), padding='valid', name='conv10')(x)
Expand Down Expand Up @@ -129,7 +129,7 @@ def SqueezeNet(include_top=True, weights='imagenet',
weights_path = get_file('squeezenet_weights_tf_dim_ordering_tf_kernels_notop.h5',
WEIGHTS_PATH_NO_TOP,
cache_subdir='models')

model.load_weights(weights_path)
if K.backend() == 'theano':
layer_utils.convert_all_kernels_in_model(model)
Expand Down