-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathD2S&S2D.py
39 lines (30 loc) · 1.41 KB
/
D2S&S2D.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch
from torch import nn
from torch.nn.modules.pixelshuffle import PixelShuffle, PixelUnshuffle
'''
Paper: `SmaAt-UNet: Precipitation Nowcasting using a Small Attention-UNet Architecture`
'''
# Taken from https://discuss.pytorch.org/t/is-there-any-layer-like-tensorflows-space-to-depth-function/3487/14
class DepthToSpace(nn.Module):
def __init__(self, block_size):
super().__init__()
self.bs = block_size
def forward(self, x):
N, C, H, W = x.size()
x = x.view(N, self.bs, self.bs, C // (self.bs ** 2), H, W) # (N, bs, bs, C//bs^2, H, W)
x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # (N, C//bs^2, H, bs, W, bs)
x = x.view(N, C // (self.bs ** 2), H * self.bs, W * self.bs) # (N, C//bs^2, H * bs, W * bs)
return x
class SpaceToDepth(nn.Module):
# Expects the following shape: Batch, Channel, Height, Width
def __init__(self, block_size):
super().__init__()
self.bs = block_size
def forward(self, x):
N, C, H, W = x.size()
x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs) # (N, C, H//bs, bs, W//bs, bs)
x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs)
x = x.view(N, C * (self.bs ** 2), H // self.bs, W // self.bs) # (N, C*bs^2, H//bs, W//bs)
return x
space2depth = PixelUnshuffle(downscale_factor=2)
depth2space = PixelShuffle(upscale_factor=2)