diff --git a/subpixel.py b/subpixel.py index 3b9a1cf..bd12057 100644 --- a/subpixel.py +++ b/subpixel.py @@ -8,17 +8,17 @@ def _phase_shift(I, r): bsize = tf.shape(I)[0] # Handling Dimension(None) type for undefined batch dim X = tf.reshape(I, (bsize, a, b, r, r)) X = tf.transpose(X, (0, 1, 2, 4, 3)) # bsize, a, b, 1, 1 - X = tf.split(1, a, X) # a, [bsize, b, r, r] - X = tf.concat(2, [tf.squeeze(x, axis=1) for x in X]) # bsize, b, a*r, r - X = tf.split(1, b, X) # b, [bsize, a*r, r] - X = tf.concat(2, [tf.squeeze(x, axis=1) for x in X]) # bsize, a*r, b*r + X = tf.split(X, a, 1) # a, [bsize, b, r, r] + X = tf.concat([tf.squeeze(x, axis=1) for x in X], 2) # bsize, b, a*r, r + X = tf.split(X, b, 1) # b, [bsize, a*r, r] + X = tf.concat([tf.squeeze(x, axis=1) for x in X], 2) # bsize, a*r, b*r return tf.reshape(X, (bsize, a*r, b*r, 1)) def PS(X, r, color=False): if color: - Xc = tf.split(3, 3, X) - X = tf.concat(3, [_phase_shift(x, r) for x in Xc]) + Xc = tf.split(X, 3, 3) + X = tf.concat([_phase_shift(x, r) for x in Xc], 3) else: X = _phase_shift(X, r) return X