diff --git a/prednet.py b/prednet.py index b5a0208..ab1f2b6 100755 --- a/prednet.py +++ b/prednet.py @@ -165,7 +165,7 @@ def get_initial_state(self, x): initial_state = K.reshape(initial_state, output_shp) initial_states += [initial_state] - if K._BACKEND == 'theano': + if K.backend() == 'theano': from theano import tensor as T # There is a known issue in the Theano scan op when dealing with inputs whose shape is 1 along a dimension. # In our case, this is a problem when training on grayscale images, and the below line fixes it.