-
Notifications
You must be signed in to change notification settings - Fork 6
/
subpixel.py
103 lines (90 loc) · 3.89 KB
/
subpixel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from __future__ import division
from keras import backend as K
from keras.layers import Conv2D
import numpy as np
import tensorflow as tf
def icnr_weights(init = tf.glorot_normal_initializer(), scale=2, shape=[3,3,32,4], dtype = tf.float32):
sess = tf.Session()
return sess.run(ICNR(init, scale=scale)(shape=shape, dtype=dtype))
class ICNR:
"""ICNR initializer for checkerboard artifact free sub pixel convolution
Ref:
[1] Andrew Aitken et al. Checkerboard artifact free sub-pixel convolution
https://arxiv.org/pdf/1707.02937.pdf)
Args:
initializer: initializer used for sub kernels (orthogonal, glorot uniform, etc.)
scale: scale factor of sub pixel convolution
"""
def __init__(self, initializer, scale=1):
self.scale = scale
self.initializer = initializer
def __call__(self, shape, dtype, partition_info=None):
shape = list(shape)
if self.scale == 1:
return self.initializer(shape)
new_shape = shape[:3] + [shape[3] // (self.scale ** 2)]
x = self.initializer(new_shape, dtype, partition_info)
x = tf.transpose(x, perm=[2, 0, 1, 3])
x = tf.image.resize_nearest_neighbor(x, size=(shape[0] * self.scale, shape[1] * self.scale))
x = tf.space_to_depth(x, block_size=self.scale)
x = tf.transpose(x, perm=[1, 2, 0, 3])
return x
class Subpixel(Conv2D):
def __init__(self,
filters,
kernel_size,
r,
padding='valid',
data_format=None,
strides=(1,1),
activation=None,
use_bias=True,
kernel_initializer='glorot_uniform',
bias_initializer='zeros',
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
**kwargs):
super(Subpixel, self).__init__(
filters=r*r*filters,
kernel_size=kernel_size,
strides=strides,
padding=padding,
data_format=data_format,
activation=activation,
use_bias=use_bias,
kernel_initializer=kernel_initializer,
bias_initializer=bias_initializer,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
activity_regularizer=activity_regularizer,
kernel_constraint=kernel_constraint,
bias_constraint=bias_constraint,
**kwargs)
self.r = r
def _phase_shift(self, I):
r = self.r
bsize, a, b, c = I.get_shape().as_list()
bsize = K.shape(I)[0] # Handling Dimension(None) type for undefined batch dim
X = K.reshape(I, [bsize, a, b, int(c/(r*r)),r, r]) # bsize, a, b, c/(r*r), r, r
X = K.permute_dimensions(X, (0, 1, 2, 5, 4, 3)) # bsize, a, b, r, r, c/(r*r)
#Keras backend does not support tf.split, so in future versions this could be nicer
X = [X[:,i,:,:,:,:] for i in range(a)] # a, [bsize, b, r, r, c/(r*r)
X = K.concatenate(X, 2) # bsize, b, a*r, r, c/(r*r)
X = [X[:,i,:,:,:] for i in range(b)] # b, [bsize, r, r, c/(r*r)
X = K.concatenate(X, 2) # bsize, a*r, b*r, c/(r*r)
return X
def call(self, inputs):
return self._phase_shift(super(Subpixel, self).call(inputs))
def compute_output_shape(self, input_shape):
unshifted = super(Subpixel, self).compute_output_shape(input_shape)
return (unshifted[0], self.r*unshifted[1], self.r*unshifted[2], int(unshifted[3]/(self.r*self.r)))
def get_config(self):
config = super(Conv2D, self).get_config()
config.pop('rank')
config.pop('dilation_rate')
config['filters']= int(config['filters'] / self.r*self.r)
config['r'] = self.r
return config