diff --git a/models/ecbsr1d/demo.ipynb b/models/ecbsr1d/demo.ipynb new file mode 100644 index 0000000..42d471f --- /dev/null +++ b/models/ecbsr1d/demo.ipynb @@ -0,0 +1,182 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 一、F.conv2d() 与 forward的等价性\n", + "\n", + "常用的卷积参数为 inp,oup,kernel_size,stride,padding\n", + "F.conv2d(inp, weight, bias, stride)\n", + "padding部分需要自己手动填充\n", + "2" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([64, 64, 3, 1])\n", + "torch.Size([64, 28, 28])\n" + ] + } + ], + "source": [ + "conv0 = torch.nn.Conv2d(64, 64, kernel_size=(3,1), stride=(1,0), padding=(1,0))\n", + "k0 = conv0.weight\n", + "b0 = conv.bias\n", + "inp = torch.randn(64, 28, 28)\n", + "out1 = conv1(inp)\n", + "# out2 = F.conv2d(input=inp, weight=k0, bias=b0, stride=1)\n", + "print(k0.shape)\n", + "print(out1.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'y0' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m/home/user3/code/SimpleIR/models/ecbsr copy/demo.ipynb Cell 4'\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0m y0 \u001b[39m=\u001b[39m F\u001b[39m.\u001b[39mpad(y0, (\u001b[39m1\u001b[39m, \u001b[39m1\u001b[39m, \u001b[39m1\u001b[39m, \u001b[39m1\u001b[39m), \u001b[39m'\u001b[39m\u001b[39mconstant\u001b[39m\u001b[39m'\u001b[39m, \u001b[39m0\u001b[39m)\n", + "\u001b[0;31mNameError\u001b[0m: name 'y0' is not defined" + ] + } + ], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([64, 28, 28])\n", + "torch.Size([64, 64, 3, 3])\n", + "torch.Size([64, 28, 30])\n" + ] + } + ], + "source": [ + "conv0= torch.nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)\n", + "k0 = conv0.weight\n", + "b0 = conv0.bias\n", + "inp = torch.randn(64, 28, 28)\n", + "out = conv0(inp)\n", + "# out2 = F.conv2d(input=inp, weight=k0, bias=b0, stride=1)\n", + "print(out.shape)\n", + "print(k0.shape)\n", + "out = F.pad(out, (1, 1, 0, 0), 'constant', 0)\n", + "print(out.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "conv0 = torch.nn.Conv2d(self.inp_planes, self.out_planes, kernel_size=1, padding=0)\n", + "self.k0 = conv0.weight\n", + "self.b0 = conv0.bias\n", + "\n", + "# init scale & bias\n", + "scale = torch.randn(size=(self.out_planes, 1, 1, 1)) * 1e-3\n", + "self.scale = nn.Parameter(scale)\n", + "# bias = 0.0\n", + "# bias = [bias for c in range(self.out_planes)]\n", + "# bias = torch.FloatTensor(bias)\n", + "bias = torch.randn(self.out_planes) * 1e-3\n", + "bias = torch.reshape(bias, (self.out_planes,))\n", + "self.bias = nn.Parameter(bias)\n", + "# init mask\n", + "self.mask = torch.zeros((self.out_planes, 1, 3, 3), dtype=torch.float32)\n", + "for i in range(self.out_planes):\n", + " self.mask[i, 0, 0, 0] = 1.0\n", + " self.mask[i, 0, 1, 0] = 2.0\n", + " self.mask[i, 0, 2, 0] = 1.0\n", + " self.mask[i, 0, 0, 2] = -1.0\n", + " self.mask[i, 0, 1, 2] = -2.0\n", + " self.mask[i, 0, 2, 2] = -1.0\n", + "self.mask = nn.Parameter(data=self.mask, requires_grad=False)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "y0 = F.conv2d(input=x, weight=k0, bias=b0, stride=1)\n", + "# explicitly padding with bias\n", + "y0 = F.pad(y0, (1, 1, 1, 1), 'constant', 0)\n", + "b0_pad = self.b0.view(1, -1, 1, 1)\n", + "y0[:, :, 0:1, :] = b0_pad\n", + "y0[:, :, -1:, :] = b0_pad\n", + "y0[:, :, :, 0:1] = b0_pad\n", + "y0[:, :, :, -1:] = b0_pad\n", + "# conv-3x3\n", + "y1 = F.conv2d(input=y0, weight=self.scale * self.mask, bias=self.bias, stride=1, groups=self.out_planes)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.8.13 ('py38': conda)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.13" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "798585ee0e69c52f6919ecef47b9e35918308944029e7c22636f87b86ab713f5" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/models/ecbsr1d/ecb1d_block.py b/models/ecbsr1d/ecb1d_block.py new file mode 100644 index 0000000..3403823 --- /dev/null +++ b/models/ecbsr1d/ecb1d_block.py @@ -0,0 +1,278 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + + +class SeqConv1d(nn.Module): + def __init__(self, seq_type, inp_planes, out_planes, depth_multiplier=1, with_bn=False): + super(SeqConv1d, self).__init__() + + self.type = seq_type + self.inp_planes = inp_planes + self.out_planes = out_planes + self.with_bn = with_bn + + if self.with_bn: + self.bn = nn.BatchNorm2d(num_features=out_planes) + + if self.type == 'conv1x1-conv3x1': + self.mid_planes = int(out_planes * depth_multiplier) + conv0 = torch.nn.Conv2d(self.inp_planes, self.mid_planes, kernel_size=1, padding=0) + self.k0 = conv0.weight + self.b0 = conv0.bias + + conv1 = torch.nn.Conv2d(self.mid_planes, self.out_planes, kernel_size=(3,1)) + self.k1 = conv1.weight + self.b1 = conv1.bias + + if self.type == 'conv1x1-conv1x3': + self.mid_planes = int(out_planes * depth_multiplier) + conv0 = torch.nn.Conv2d(self.inp_planes, self.mid_planes, kernel_size=1, padding=0) + self.k0 = conv0.weight + self.b0 = conv0.bias + + conv1 = torch.nn.Conv2d(self.mid_planes, self.out_planes, kernel_size=(1,3)) + self.k1 = conv1.weight + self.b1 = conv1.bias + + elif self.type == 'conv1x1-sobelx': + conv0 = torch.nn.Conv2d(self.inp_planes, self.out_planes, kernel_size=(1,1), padding=0) + self.k0 = conv0.weight + self.b0 = conv0.bias + + # init scale & bias + scale = torch.randn(size=(self.out_planes, 1, 1, 1)) * 1e-3 + self.scale = nn.Parameter(scale) + # bias = 0.0 + # bias = [bias for c in range(self.out_planes)] + # bias = torch.FloatTensor(bias) + bias = torch.randn(self.out_planes) * 1e-3 + bias = torch.reshape(bias, (self.out_planes,)) + self.bias = nn.Parameter(bias) + # init mask + self.mask = torch.zeros((self.out_planes, 1, 1, 3), dtype=torch.float32) + for i in range(self.out_planes): + self.mask[i, 0, 0, 0] = 2.0 + self.mask[i, 0, 0, 2] = -2.0 + + self.mask = nn.Parameter(data=self.mask, requires_grad=False) + + elif self.type == 'conv1x1-sobely': + conv0 = torch.nn.Conv2d(self.inp_planes, self.out_planes, kernel_size=(1,1), padding=0) + self.k0 = conv0.weight + self.b0 = conv0.bias + + # init scale & bias + scale = torch.randn(size=(self.out_planes, 1, 1, 1)) * 1e-3 + self.scale = nn.Parameter(scale) + # bias = 0.0 + # bias = [bias for c in range(self.out_planes)] + # bias = torch.FloatTensor(bias) + bias = torch.randn(self.out_planes) * 1e-3 + bias = torch.reshape(bias, (self.out_planes,)) + self.bias = nn.Parameter(bias) + # init mask + self.mask = torch.zeros((self.out_planes, 1, 3, 1), dtype=torch.float32) + for i in range(self.out_planes): + self.mask[i, 0, 0, 0] = 2.0 + self.mask[i, 0, 2, 0] = -2.0 + + self.mask = nn.Parameter(data=self.mask, requires_grad=False) + + + def forward(self, x): + if self.type == 'conv1x1-conv1x3': + # conv-1x1 + y0 = F.conv2d(input=x, weight=self.k0, bias=self.b0, stride=1) + # explicitly padding with bias + y0 = F.pad(y0, (1, 1, 0, 0), 'constant', 0) + b0_pad = self.b0.view(1, -1, 1, 1) + y0[:, :, :, 0:1] = b0_pad + y0[:, :, :, -1:] = b0_pad + # conv-3x3 + y1 = F.conv2d(input=y0, weight=self.k1, bias=self.b1, stride=1) + elif self.type == 'conv1x1-conv3x1': + # conv-1x1 + y0 = F.conv2d(input=x, weight=self.k0, bias=self.b0, stride=1) + # explicitly padding with bias + y0 = F.pad(y0, (0, 0, 1, 1), 'constant', 0) + b0_pad = self.b0.view(1, -1, 1, 1) + y0[:, :, 0:1, :] = b0_pad + y0[:, :, -1:, :] = b0_pad + y1 = F.conv2d(input=y0, weight=self.k1, bias=self.b1, stride=1) + elif self.type == 'conv1x1-sobelx': + y0 = F.conv2d(input=x, weight=self.k0, bias=self.b0, stride=1) + # explicitly padding with bias + y0 = F.pad(y0, (1, 1, 0, 0), 'constant', 0) + b0_pad = self.b0.view(1, -1, 1, 1) + y0[:, :, :, 0:1] = b0_pad + y0[:, :, :, -1:] = b0_pad + # conv-3x3 + y1 = F.conv2d(input=y0, weight=self.scale * self.mask, bias=self.bias, stride=1, groups=self.out_planes) + + elif self.type == 'conv1x1-sobely': + y0 = F.conv2d(input=x, weight=self.k0, bias=self.b0, stride=1) + # explicitly padding with bias + y0 = F.pad(y0, (0, 0, 1, 1), 'constant', 0) + b0_pad = self.b0.view(1, -1, 1, 1) + y0[:, :, 0:1, :] = b0_pad + y0[:, :, -1:, :] = b0_pad + # conv-3x3 + y1 = F.conv2d(input=y0, weight=self.scale * self.mask, bias=self.bias, stride=1, groups=self.out_planes) + + if self.with_bn: + y1 = self.bn(y1) + return y1 + +class ECB1d(nn.Module): + def __init__(self, inp_planes, out_planes, type='x_axis', depth_multiplier=2,act_type='prelu', with_idt = False, with_bn = False): + super(ECB1d, self).__init__() + + self.depth_multiplier = depth_multiplier + self.inp_planes = inp_planes + self.out_planes = out_planes + self.act_type = act_type + self.with_bn = with_bn + self.type = type + + if with_idt and (self.inp_planes == self.out_planes): + self.with_idt = True + else: + self.with_idt = False + + if with_bn: + self.conv3x3 = nn.Sequential( + nn.Conv2d(self.inp_planes, self.out_planes, kernel_size=3, padding=1), + nn.BatchNorm2d(self.out_planes) + ) + if self.type == 'x_axis': + self.conv1x1_1x3 = SeqConv1d('conv1x1-conv1x3', self.inp_planes, self.out_planes, self.depth_multiplier, self.with_bn) + self.conv1x1_sbx = SeqConv1d('conv1x1-sobelx', self.inp_planes, self.out_planes, -1, self.with_bn) + else: + self.conv1x1_3x1 = SeqConv1d('conv1x1-conv3x1', self.inp_planes, self.out_planes, self.depth_multiplier, self.with_bn) + self.conv1x1_sby = SeqConv1d('conv1x1-sobely', self.inp_planes, self.out_planes, -1, self.with_bn) + # self.conv1x1_lpl = SeqConv1d('conv1x1-laplacian', self.inp_planes, self.out_planes, -1, self.with_bn) + + self.act = nn.LeakyReLU(0.1) + + + def forward(self, x): + + if self.type == 'x_axis': + y = self.conv1x1_1x3(x) + \ + self.conv1x1_sbx(x) + + else: + y = self.conv1x1_3x1(x) + \ + self.conv1x1_sby(x) + + return y + + +class ECB1d_conv(nn.Module): + def __init__(self, inp_planes, out_planes, depth_multiplier=2,act_type='prelu', with_idt = False, with_bn = False): + super(ECB1d_conv, self).__init__() + self.depth_multiplier = depth_multiplier + self.inp_planes = inp_planes + self.out_planes = out_planes + self.act_type = act_type + self.with_bn = with_bn + + self.conv_x = ECB1d(self.inp_planes, self.out_planes, type='x_axis') + self.conv_y = ECB1d(self.inp_planes, self.out_planes, type='y_axis') + + def forward(self, x): + oup = self.conv_x(x) + oup = self.conv_y(oup) + + return oup + +class LKA_noatt(nn.Module): + def __init__(self, dim): + super().__init__() + self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim) + self.conv_spatial = nn.Conv2d(dim, dim, 7, stride=1, padding=9, groups=dim, dilation=3) + self.conv1 = nn.Conv2d(dim, dim, 1) + + + def forward(self, x): + u = x.clone() + attn = self.conv0(x) + attn = self.conv_spatial(attn) + attn = self.conv1(attn) + + return attn + +class ECB_all(nn.Module): + def __init__(self, inp_planes, out_planes, depth_multiplier=2,act_type='prelu', with_idt = False, with_bn = False): + super(ECB_all, self).__init__() + self.depth_multiplier = depth_multiplier + self.inp_planes = inp_planes + self.out_planes = out_planes + self.act_type = act_type + self.with_bn = with_bn + + if self.inp_planes == self.out_planes: + self.lka = LKA_noatt(self.inp_planes) + + self.ecb1d = ECB1d_conv(self.inp_planes, self.out_planes) + + + + def forward(self, x): + oup = self.ecb1d(x) + if self.inp_planes == self.out_planes: + oup = self.lka(oup) + return oup + +class ECB_all_test(nn.Module): + def __init__(self, inp_planes, out_planes, depth_multiplier=2,act_type='prelu', with_idt = False, with_bn = False): + super(ECB_all_test, self).__init__() + self.depth_multiplier = depth_multiplier + self.inp_planes = inp_planes + self.out_planes = out_planes + self.act_type = act_type + self.with_bn = with_bn + + if self.inp_planes == self.out_planes: + self.lka = LKA_noatt(self.inp_planes) + + self.ecb1d_x = nn.Conv2d(self.inp_planes, self.out_planes, kernel_size=(1,3), padding=(0,1)) + self.ecb1d_y = nn.Conv2d(self.inp_planes, self.out_planes, kernel_size=(3,1), padding=(1,0)) + + + + def forward(self, x): + oup = self.ecb1d_x(x) + oup = self.ecb1d_y(oup) + if self.inp_planes == self.out_planes: + oup = self.lka(oup) + return oup + + + +if __name__ == '__main__': + x0 = torch.randn(1, 64, 28, 28) + con1d_x = SeqConv1d('conv1x1-sobelx', 64, 64) + con1d_y = SeqConv1d('conv1x1-sobely', 64, 64) + conv1d = ECB1d_conv(64, 64) + conv_atten = ECB_all(64, 64) + conv_test = ECB_all_test(64, 64) + + y0 = con1d_x(x0) + print(y0.shape) + + y0 = con1d_y(x0) + print(y0.shape) + + y0 = conv1d(x0) + print(y0.shape) + + y0 = conv_atten(x0) + print(y0.shape) + + y0 = conv_test(x0) + print(y0.shape) + + \ No newline at end of file diff --git a/models/ecbsr1d/ecbsr1d_network.py b/models/ecbsr1d/ecbsr1d_network.py new file mode 100644 index 0000000..4696bbc --- /dev/null +++ b/models/ecbsr1d/ecbsr1d_network.py @@ -0,0 +1,128 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +try: + from models.ecbsr.ecbsr_block import ECB, Conv3X3 +except ModuleNotFoundError: + from ecbsr_block import ECB, Conv3X3 + +try: + from models.ecbsr1d.ecb1d_block import ECB1d_conv, ECB_all, ECB_all_test +except ModuleNotFoundError: + from ecb1d_block import ECB1d_conv, ECB_all, ECB_all_test + +from torch.quantization import QuantStub, DeQuantStub +from torch.nn.quantized import FloatFunctional +from torchsummaryX import summary + +def create_model(args): + return ECBSR1d(args) + +class ECBSR1d(nn.Module): + def __init__(self, args): + super(ECBSR1d, self).__init__() + self.m_ecbsr = args.m_ecbsr + self.c_ecbsr = args.c_ecbsr + self.scale = args.scale + self.colors = args.colors + self.chns_exp = 2.0 + self.with_idt = args.with_idt + self.with_bn = args.with_bn + self.act_type = args.act_type + self.quant = QuantStub() + self.dequant = DeQuantStub() + + self.backbone = None + self.upsampler = None + + backbone = [] + backbone += [ECB(self.colors, self.c_ecbsr, depth_multiplier=self.chns_exp, act_type=self.act_type, with_idt = self.with_idt, with_bn = self.with_bn)] + for i in range(self.m_ecbsr): + backbone += [ECB_all(self.c_ecbsr, self.c_ecbsr, depth_multiplier=self.chns_exp, act_type=self.act_type, with_idt = self.with_idt, with_bn = self.with_bn)] + backbone += [ECB(self.c_ecbsr, self.colors*self.scale*self.scale, depth_multiplier=2.0, act_type='linear', with_idt = self.with_idt, with_bn = self.with_bn)] + self.backbone = nn.Sequential(*backbone) + self.upsampler = nn.PixelShuffle(self.scale) + self.shortcut = FloatFunctional() + + def fuse_model(self): + ## reparam as plainsr + for idx, blk in enumerate(self.backbone): + if type(blk) == ECB: + RK, RB = blk.rep_params() + conv3x3 = Conv3X3(blk.inp_planes, blk.out_planes, act_type=blk.act_type, with_bn=False) + ## update weights & bias for conv3x3 + conv3x3.block[0].weight.data = RK + conv3x3.block[0].bias.data = RB + ## update weights & bias for activation + if blk.act_type == 'prelu': + conv3x3.block[1].weight = blk.act.weight + ## update block for backbone + self.backbone[idx] = conv3x3.to(RK.device) + ## fused modules + for m in self.modules(): + if type(m) == Conv3X3: + if m.act_type == 'relu': + torch.quantization.fuse_modules(m.block, ['0', '1'], inplace=True) + def forward(self, x): + x = self.quant(x) + y = self.shortcut.add(self.backbone(x), x.repeat(1, self.colors*self.scale*self.scale, 1, 1).contiguous()) + y = self.upsampler(y) + y = torch.clamp(y, min=0.0, max=255.0) + y = self.dequant(y) + return y + + + +class ECBSR1d_test(nn.Module): + def __init__(self, args): + super(ECBSR1d_test, self).__init__() + self.m_ecbsr = args.m_ecbsr + self.c_ecbsr = args.c_ecbsr + self.scale = args.scale + self.colors = args.colors + self.chns_exp = 2.0 + self.with_idt = args.with_idt + self.with_bn = args.with_bn + self.act_type = args.act_type + self.quant = QuantStub() + self.dequant = DeQuantStub() + + self.backbone = None + self.upsampler = None + + backbone = [] + backbone += [nn.Conv2d(self.colors, self.c_ecbsr, kernel_size=3, padding=1)] + for i in range(self.m_ecbsr): + backbone += [ECB_all_test(self.c_ecbsr, self.c_ecbsr, depth_multiplier=self.chns_exp, act_type=self.act_type, with_idt = self.with_idt, with_bn = self.with_bn)] + backbone += [nn.Conv2d(self.c_ecbsr, self.colors*self.scale*self.scale, kernel_size=3, padding=1)] + self.backbone = nn.Sequential(*backbone) + self.upsampler = nn.PixelShuffle(self.scale) + self.shortcut = FloatFunctional() + + def forward(self, x): + x = self.quant(x) + y = self.shortcut.add(self.backbone(x), x.repeat(1, self.colors*self.scale*self.scale, 1, 1).contiguous()) + y = self.upsampler(y) + y = torch.clamp(y, min=0.0, max=255.0) + y = self.dequant(y) + return y + +if __name__ == '__main__': + import argparse + args = argparse.ArgumentParser(description='') + args.m_ecbsr = 4 + args.c_ecbsr = 16 + args.with_idt = 1 + args.with_bn = 1 + args.act_type = 'relu' + args.model = 'ecbsr' + args.scale = 4 + args.colors = 1 + + + # model = ECBSR1d(args).eval().to('cuda') + model = ECBSR1d_test(args).eval().to('cuda') + + + in_ = torch.randn(1, 1, round(720/args.scale), round(1280/args.scale)).to('cuda') + summary(model, in_) \ No newline at end of file diff --git a/models/ecbsr1d/ecbsr_block.py b/models/ecbsr1d/ecbsr_block.py new file mode 100644 index 0000000..8f9b410 --- /dev/null +++ b/models/ecbsr1d/ecbsr_block.py @@ -0,0 +1,387 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Conv3X3(nn.Module): + def __init__(self, inp_planes, out_planes, act_type='prelu', with_bn=False): + super(Conv3X3, self).__init__() + + self.inp_planes = inp_planes + self.out_planes = out_planes + self.act_type = act_type + self.with_bn = with_bn + + self.block = [nn.Conv2d(self.inp_planes, self.out_planes, kernel_size=3, padding=1)] + if self.with_bn: + self.block += [nn.BatchNorm2d(self.out_planes)] + ## activation selection + if self.act_type == 'prelu': + self.block += [nn.PReLU(num_parameters=self.out_planes)] + elif self.act_type == 'relu': + self.block += [nn.ReLU(inplace=True)] + elif self.act_type == 'rrelu': + self.block += [nn.RReLU(lower=-0.05, upper=0.05)] + elif self.act_type == 'softplus': + self.block += [nn.Softplus()] + elif self.act_type == 'linear': + pass + else: + raise ValueError('The type of activation if not support!') + ## initialize block + self.block = nn.Sequential(*self.block) + + def forward(self, x): + x = self.block(x) + return x + +class SeqConv3x3(nn.Module): + def __init__(self, seq_type, inp_planes, out_planes, depth_multiplier, with_bn=False): + super(SeqConv3x3, self).__init__() + + self.type = seq_type + self.inp_planes = inp_planes + self.out_planes = out_planes + self.with_bn = with_bn + + if self.with_bn: + self.bn = nn.BatchNorm2d(num_features=out_planes) + + if self.type == 'conv1x1-conv3x3': + self.mid_planes = int(out_planes * depth_multiplier) + conv0 = torch.nn.Conv2d(self.inp_planes, self.mid_planes, kernel_size=1, padding=0) + self.k0 = conv0.weight + self.b0 = conv0.bias + + conv1 = torch.nn.Conv2d(self.mid_planes, self.out_planes, kernel_size=3) + self.k1 = conv1.weight + self.b1 = conv1.bias + + elif self.type == 'conv1x1-sobelx': + conv0 = torch.nn.Conv2d(self.inp_planes, self.out_planes, kernel_size=1, padding=0) + self.k0 = conv0.weight + self.b0 = conv0.bias + + # init scale & bias + scale = torch.randn(size=(self.out_planes, 1, 1, 1)) * 1e-3 + self.scale = nn.Parameter(scale) + # bias = 0.0 + # bias = [bias for c in range(self.out_planes)] + # bias = torch.FloatTensor(bias) + bias = torch.randn(self.out_planes) * 1e-3 + bias = torch.reshape(bias, (self.out_planes,)) + self.bias = nn.Parameter(bias) + # init mask + self.mask = torch.zeros((self.out_planes, 1, 3, 3), dtype=torch.float32) + for i in range(self.out_planes): + self.mask[i, 0, 0, 0] = 1.0 + self.mask[i, 0, 1, 0] = 2.0 + self.mask[i, 0, 2, 0] = 1.0 + self.mask[i, 0, 0, 2] = -1.0 + self.mask[i, 0, 1, 2] = -2.0 + self.mask[i, 0, 2, 2] = -1.0 + self.mask = nn.Parameter(data=self.mask, requires_grad=False) + + elif self.type == 'conv1x1-sobely': + conv0 = torch.nn.Conv2d(self.inp_planes, self.out_planes, kernel_size=1, padding=0) + self.k0 = conv0.weight + self.b0 = conv0.bias + + # init scale & bias + scale = torch.randn(size=(self.out_planes, 1, 1, 1)) * 1e-3 + self.scale = nn.Parameter(torch.FloatTensor(scale)) + # bias = 0.0 + # bias = [bias for c in range(self.out_planes)] + # bias = torch.FloatTensor(bias) + bias = torch.randn(self.out_planes) * 1e-3 + bias = torch.reshape(bias, (self.out_planes,)) + self.bias = nn.Parameter(torch.FloatTensor(bias)) + # init mask + self.mask = torch.zeros((self.out_planes, 1, 3, 3), dtype=torch.float32) + for i in range(self.out_planes): + self.mask[i, 0, 0, 0] = 1.0 + self.mask[i, 0, 0, 1] = 2.0 + self.mask[i, 0, 0, 2] = 1.0 + self.mask[i, 0, 2, 0] = -1.0 + self.mask[i, 0, 2, 1] = -2.0 + self.mask[i, 0, 2, 2] = -1.0 + self.mask = nn.Parameter(data=self.mask, requires_grad=False) + + elif self.type == 'conv1x1-laplacian': + conv0 = torch.nn.Conv2d(self.inp_planes, self.out_planes, kernel_size=1, padding=0) + self.k0 = conv0.weight + self.b0 = conv0.bias + + # init scale & bias + scale = torch.randn(size=(self.out_planes, 1, 1, 1)) * 1e-3 + self.scale = nn.Parameter(torch.FloatTensor(scale)) + # bias = 0.0 + # bias = [bias for c in range(self.out_planes)] + # bias = torch.FloatTensor(bias) + bias = torch.randn(self.out_planes) * 1e-3 + bias = torch.reshape(bias, (self.out_planes,)) + self.bias = nn.Parameter(torch.FloatTensor(bias)) + # init mask + self.mask = torch.zeros((self.out_planes, 1, 3, 3), dtype=torch.float32) + for i in range(self.out_planes): + self.mask[i, 0, 0, 1] = 1.0 + self.mask[i, 0, 1, 0] = 1.0 + self.mask[i, 0, 1, 2] = 1.0 + self.mask[i, 0, 2, 1] = 1.0 + self.mask[i, 0, 1, 1] = -4.0 + self.mask = nn.Parameter(data=self.mask, requires_grad=False) + else: + raise ValueError('the type of seqconv is not supported!') + + def forward(self, x): + if self.type == 'conv1x1-conv3x3': + # conv-1x1 + y0 = F.conv2d(input=x, weight=self.k0, bias=self.b0, stride=1) + # explicitly padding with bias + y0 = F.pad(y0, (1, 1, 1, 1), 'constant', 0) + b0_pad = self.b0.view(1, -1, 1, 1) + y0[:, :, 0:1, :] = b0_pad + y0[:, :, -1:, :] = b0_pad + y0[:, :, :, 0:1] = b0_pad + y0[:, :, :, -1:] = b0_pad + # conv-3x3 + y1 = F.conv2d(input=y0, weight=self.k1, bias=self.b1, stride=1) + else: + y0 = F.conv2d(input=x, weight=self.k0, bias=self.b0, stride=1) + # explicitly padding with bias + y0 = F.pad(y0, (1, 1, 1, 1), 'constant', 0) + b0_pad = self.b0.view(1, -1, 1, 1) + y0[:, :, 0:1, :] = b0_pad + y0[:, :, -1:, :] = b0_pad + y0[:, :, :, 0:1] = b0_pad + y0[:, :, :, -1:] = b0_pad + # conv-3x3 + y1 = F.conv2d(input=y0, weight=self.scale * self.mask, bias=self.bias, stride=1, groups=self.out_planes) + if self.with_bn: + y1 = self.bn(y1) + return y1 + + def rep_params(self): + device = self.k0.get_device() + if device < 0: + device = None + + if self.type == 'conv1x1-conv3x3': + # re-param conv kernel + RK = F.conv2d(input=self.k1, weight=self.k0.permute(1, 0, 2, 3)) + # re-param conv bias + RB = torch.ones(1, self.mid_planes, 3, 3, device=device) * self.b0.view(1, -1, 1, 1) + RB = F.conv2d(input=RB, weight=self.k1).view(-1,) + self.b1 + else: + tmp = self.scale * self.mask + k1 = torch.zeros((self.out_planes, self.out_planes, 3, 3), device=device) + for i in range(self.out_planes): + k1[i, i, :, :] = tmp[i, 0, :, :] + b1 = self.bias + # re-param conv kernel + RK = F.conv2d(input=k1, weight=self.k0.permute(1, 0, 2, 3)) + # re-param conv bias + RB = torch.ones(1, self.out_planes, 3, 3, device=device) * self.b0.view(1, -1, 1, 1) + RB = F.conv2d(input=RB, weight=k1).view(-1,) + b1 + + if self.with_bn: + v = torch.sqrt(self.bn.running_var + self.bn.eps) + m = self.bn.running_mean + s = self.bn.weight + b = self.bn.bias + RK = (s/v).reshape(self.out_planes, 1, 1, 1) * RK + RB = (s/v) * (RB - m) + b + return RK, RB + +def pad_tensor(t, pattern): + pattern = pattern.view(1, -1, 1, 1) + t = F.pad(t, (1, 1, 1, 1), 'constant', 0) + t[:, :, 0:1, :] = pattern + t[:, :, -1:, :] = pattern + t[:, :, :, 0:1] = pattern + t[:, :, :, -1:] = pattern + + return t + +class RRRB(nn.Module): + """ Residual in residual reparameterizable block. + Using reparameterizable block to replace single 3x3 convolution. + Diagram: + ---Conv1x1--Conv3x3-+-Conv1x1--+-- + |________| + |_____________________________| + Args: + n_feats (int): The number of feature maps. + ratio (int): Expand ratio. + """ + + def __init__(self, n_feats, ratio=2): + super(RRRB, self).__init__() + self.expand_conv = nn.Conv2d(n_feats, ratio*n_feats, 1, 1, 0) + self.fea_conv = nn.Conv2d(ratio*n_feats, ratio*n_feats, 3, 1, 0) + self.reduce_conv = nn.Conv2d(ratio*n_feats, n_feats, 1, 1, 0) + + def forward(self, x): + + out = self.expand_conv(x) + out_identity = out + + # explicitly padding with bias for reparameterizing in the test phase + b0 = self.expand_conv.bias + out = pad_tensor(out, b0) + + out = self.fea_conv(out) + out_identity + out = self.reduce_conv(out) + out += x + + return out + + +class SESR(nn.Module): + def __init__(self, n_feats, ratio=2): + super(SESR, self).__init__() + self.expand_conv = nn.Conv2d(n_feats, ratio*n_feats, 3, 1, 1) + self.squeeze_conv = nn.Conv2d(ratio*n_feats, n_feats, 1, 1, 0) + + def forward(self, x): + + identity = x + out = self.expand_conv(x) + out = self.squeeze_conv(out) + + out += identity + + return out + + +class ECB(nn.Module): + def __init__(self, inp_planes, out_planes, depth_multiplier, conv3_type='conv3', act_type='prelu', with_idt = False, with_bn = False): + super(ECB, self).__init__() + + self.depth_multiplier = depth_multiplier + self.inp_planes = inp_planes + self.out_planes = out_planes + self.act_type = act_type + self.with_bn = with_bn + + if with_idt and (self.inp_planes == self.out_planes): + self.with_idt = True + else: + self.with_idt = False + + if with_bn: + self.conv3x3 = nn.Sequential( + nn.Conv2d(self.inp_planes, self.out_planes, kernel_size=3, padding=1), + nn.BatchNorm2d(self.out_planes) + ) + elif conv3_type == 'conv3': + self.conv3x3 = torch.nn.Conv2d(self.inp_planes, self.out_planes, kernel_size=3, padding=1) + elif conv3_type == 'RRRB': + self.conv3x3 = RRRB(self.inp_planes) + else: + self.conv3x3 = SESR(self.inp_planes, ratio=2) + + self.conv1x1_3x3 = SeqConv3x3('conv1x1-conv3x3', self.inp_planes, self.out_planes, self.depth_multiplier, self.with_bn) + self.conv1x1_sbx = SeqConv3x3('conv1x1-sobelx', self.inp_planes, self.out_planes, -1, self.with_bn) + self.conv1x1_sby = SeqConv3x3('conv1x1-sobely', self.inp_planes, self.out_planes, -1, self.with_bn) + self.conv1x1_lpl = SeqConv3x3('conv1x1-laplacian', self.inp_planes, self.out_planes, -1, self.with_bn) + + if self.act_type == 'prelu': + self.act = nn.PReLU(num_parameters=self.out_planes) + elif self.act_type == 'relu': + self.act = nn.ReLU(inplace=True) + elif self.act_type == 'lrelu': + self.act = nn.LeakyReLU(0.1) + elif self.act_type == 'rrelu': + self.act = nn.RReLU(lower=-0.05, upper=0.05) + elif self.act_type == 'softplus': + self.act = nn.Softplus() + elif self.act_type == 'linear': + pass + else: + raise ValueError('The type of activation if not support!') + + + + def forward(self, x): + if self.training: + + y = self.conv3x3(x) + \ + self.conv1x1_sbx(x) + \ + self.conv1x1_sby(x) + \ + self.conv1x1_lpl(x) + \ + self.conv1x1_3x3(x) + + + # y = self.conv3x3(x) + \ + # self.rrrb(x) + \ + # self.conv1x1_sbx(x) + \ + # self.conv1x1_sby(x) + \ + # self.conv1x1_lpl(x) + if self.with_idt: + y += x + else: + RK, RB = self.rep_params() + y = F.conv2d(input=x, weight=RK, bias=RB, stride=1, padding=1) + if self.act_type != 'linear': + y = self.act(y) + return y + + # def forward(self, x): + + # y = self.conv3x3(x) + \ + # self.conv1x1_sbx(x) + \ + # self.conv1x1_sby(x) + \ + # self.conv1x1_lpl(x) + + # if self.with_idt: + # y += x + + # return y + + + def rep_params(self): + if self.with_bn: + K0, B0 = self.conv3x3[0].weight, self.conv3x3[0].bias + v = torch.sqrt(self.conv3x3[1].running_var + self.conv3x3[1].eps) + m = self.conv3x3[1].running_mean + s = self.conv3x3[1].weight + b = self.conv3x3[1].bias + K0 = (s/v).reshape(self.out_planes, 1, 1, 1) * K0 + B0 = (s/v) * (B0 - m) + b + else: + K0, B0 = self.conv3x3.weight, self.conv3x3.bias + K1, B1 = self.conv1x1_3x3.rep_params() + K2, B2 = self.conv1x1_sbx.rep_params() + K3, B3 = self.conv1x1_sby.rep_params() + K4, B4 = self.conv1x1_lpl.rep_params() + RK, RB = (K0+K1+K2+K3+K4), (B0+B1+B2+B3+B4) + + if self.with_idt: + device = RK.get_device() + if device < 0: + device = None + K_idt = torch.zeros(self.out_planes, self.out_planes, 3, 3, device=device) + for i in range(self.out_planes): + K_idt[i, i, 1, 1] = 1.0 + B_idt = 0.0 + RK, RB = RK + K_idt, RB + B_idt + return RK, RB + +if __name__ == '__main__': + + # # # test seq-conv + # x = torch.randn(1, 3, 5, 5).cuda() + # conv = SeqConv3x3('conv1x1-conv3x3', 3, 3, 2, with_bn=True).cuda().eval() + # y0 = conv(x) + # RK, RB = conv.rep_params() + # y1 = F.conv2d(input=x, weight=RK, bias=RB, stride=1, padding=1) + # print(y0-y1) + + # test ecb + x = torch.randn(1, 3, 5, 5).cuda() * 200 + ecb = ECB(3, 3, 2, act_type='linear', with_idt=True, with_bn=True).cuda().eval() + y0 = ecb(x) + + RK, RB = ecb.rep_params() + y1 = F.conv2d(input=x, weight=RK, bias=RB, stride=1, padding=1) + print(y0-y1) \ No newline at end of file diff --git a/models/ecbsr1d/tf/__init__.py b/models/ecbsr1d/tf/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/models/ecbsr1d/tf/plainsr.py b/models/ecbsr1d/tf/plainsr.py new file mode 100644 index 0000000..b392217 --- /dev/null +++ b/models/ecbsr1d/tf/plainsr.py @@ -0,0 +1,79 @@ +import sys +import tensorflow as tf +import h5py +import math +import numpy as np +import cv2 +from tensorflow.keras.layers import Conv2D, Conv2DTranspose, Input, PReLU, ReLU, UpSampling2D, Lambda +from tensorflow.keras.models import Model + +import os +os.environ["CUDA_VISIBLE_DEVICES"]="-1" + +def tf_conv3x3(inp, out_channels, act_type): + y = Conv2D(filters=out_channels, kernel_size=3, strides=1, padding='same')(inp) + if act_type == 'relu': + y = ReLU()(y) + elif act_type == 'prelu': + y = PReLU(shared_axes=[1,2])(y) + elif act_type == 'linear': + pass + else: + raise ValueError('invalid act-type for tensorflow!') + return y + +def plainsr_tf(module_nums, channel_nums, act_type, scale, colors, input_h, input_w): + inp = Input(shape=(input_h, input_w, colors)) + ## head + y = tf_conv3x3(inp, channel_nums, act_type) + ## body + for i in range(module_nums): + y = tf_conv3x3(y, channel_nums, act_type) + if colors == 1: + ## tail + y = tf_conv3x3(y, colors*scale*scale, 'linear') + y = y + inp + # y = tf.clip_by_value(y, 0.0, 255.0) + # y = tf.clip_by_value(y, 0.0, 1.0) + ## upscaling + out = tf.nn.depth_to_space(y, scale, data_format='NHWC') + # if colors == 1: + # y = tf_conv3x3(y, colors*scale*scale, 'linear') + # out = tf.nn.depth_to_space(y, scale, data_format='NHWC') + tf.keras.layers.UpSampling2D(size=(scale, scale), data_format=None, interpolation='nearest')(inp) + # out = tf.clip_by_value(out, 0.0, 1.0) + elif colors == 3: + ## since internal data layout bwtween pytorch and tensorflow are quite different, e.g. NCHW for pytorch, NHWC for tensorflow + ## input data layout of pixel-shuffle needs to be carefully handled + + ## tail + y = tf_conv3x3(y, colors*scale*scale, 'linear') + + ## rgb layout + r,g,b = tf.split(y, num_or_size_splits=colors, axis=3) + ## upsaling + tf_r, tf_g, tf_b = tf.split(y, num_or_size_splits=colors, axis=3) + + tf_r += r + tf_g += g + tf_b += b + + tf_r = tf.nn.depth_to_space(tf_r, scale, data_format='NHWC') + tf_r = tf.clip_by_value(tf_r, 0.0, 255.0) + tf_g = tf.nn.depth_to_space(tf_g, scale, data_format='NHWC') + tf_g = tf.clip_by_value(tf_g, 0.0, 255.0) + tf_b = tf.nn.depth_to_space(tf_b, scale, data_format='NHWC') + tf_b = tf.clip_by_value(tf_b, 0.0, 255.0) + out = tf.concat(values=[tf_r, tf_g, tf_b], axis=3) + else: + raise ValueError('invalid colors!') + return Model(inputs=inp, outputs=out) + +if __name__ == '__main__': + model_tf = plainsr_tf(module_nums=4, channel_nums=8, act_type='relu', scale=2, colors=3) + for idx, layer in enumerate(model_tf.layers): + wgt = layer.get_weights() + nums = len(wgt) + print(layer, idx) + if nums == 1: + pass + # print(layer, isinstance(layer, PReLU)) \ No newline at end of file diff --git a/models/fmen/fmen_network.py b/models/fmen/fmen_network.py new file mode 100644 index 0000000..6574ec2 --- /dev/null +++ b/models/fmen/fmen_network.py @@ -0,0 +1,243 @@ +import sys +sys.path.append('..') +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchsummaryX import summary +try: + from ecbsr.ecbsr_block import ECB +except: + from models.ecbsr.ecbsr_block import ECB + +lrelu_value = 0.1 +act = nn.LeakyReLU(lrelu_value) + + +def create_model(args, parent=False): + return TEST_FMEN(args) + + +# class RRRB(nn.Module): +# def __init__(self, n_feats): +# super(RRRB, self).__init__() +# self.rep_conv = nn.Conv2d(n_feats, n_feats, 3, 1, 1) + +# def forward(self, x): +# out = self.rep_conv(x) + +# return out + + + +# class ERB(nn.Module): +# def __init__(self, n_feats): +# super(ERB, self).__init__() +# self.conv1 = RRRB(n_feats) +# self.conv2 = RRRB(n_feats) + +# def forward(self, x): +# res = self.conv1(x) +# res = act(res) +# res = self.conv2(res) + +# return res + +class ERB(nn.Module): + def __init__(self, n_feats): + super(ERB, self).__init__() + self.conv1 = ECB(n_feats, n_feats, depth_multiplier=2, act_type='lrelu', with_idt = True) + self.conv2 = ECB(n_feats, n_feats, depth_multiplier=2, act_type='linear', with_idt = True) + + def forward(self, x): + res = self.conv1(x) + res = self.conv2(res) + + return res + +class HFAB(nn.Module): + def __init__(self, n_feats, up_blocks, mid_feats): + super(HFAB, self).__init__() + + self.squeeze = nn.Conv2d(n_feats, mid_feats, 3, 1, 1) + convs = [ERB(mid_feats) for _ in range(up_blocks)] + self.convs = nn.Sequential(*convs) + self.excitate = nn.Conv2d(mid_feats, n_feats, 3, 1, 1) + + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + out = act(self.squeeze(x)) + out = act(self.convs(out)) + out = self.excitate(out) + out = self.sigmoid(out) + out *= x + + return out + + +def pixel_unshuffle(input, downscale_factor): + ''' + input: batchSize * c * k*w * k*h + kdownscale_factor: k + batchSize * c * k*w * k*h -> batchSize * k*k*c * w * h + ''' + c = input.shape[1] + + kernel = torch.zeros(size=[downscale_factor * downscale_factor * c, + 1, downscale_factor, downscale_factor], + device=input.device) + for y in range(downscale_factor): + for x in range(downscale_factor): + kernel[x + y * downscale_factor::downscale_factor*downscale_factor, 0, y, x] = 1 + return F.conv2d(input, kernel, stride=downscale_factor, groups=c) + +class PixelUnshuffle(nn.Module): + def __init__(self, downscale_factor): + super(PixelUnshuffle, self).__init__() + self.downscale_factor = downscale_factor + def forward(self, input): + ''' + input: batchSize * c * k*w * k*h + kdownscale_factor: k + batchSize * c * k*w * k*h -> batchSize * k*k*c * w * h + ''' + + return pixel_unshuffle(input, self.downscale_factor) + + +class ESA(nn.Module): + def __init__(self, n_feats, conv): + super(ESA, self).__init__() + + self.bn1 = nn.BatchNorm2d(n_feats) + + f = n_feats // 4 + self.conv1 = conv(n_feats, f, kernel_size=1) + self.conv_f = conv(f, f, kernel_size=1) + self.conv_max = conv(f, f, kernel_size=3, padding=1) + self.conv_un = PixelUnshuffle(2) + self.con_ = conv(4 * f, f, kernel_size=1, padding=1) + self.conv_sh = nn.PixelShuffle(2) + self.conv4 = conv(f, n_feats, kernel_size=1) + self.sigmoid = nn.Sigmoid() + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + + + c1_ = (self.conv1(x)) + c1 = self.conv_un(c1_) + c1_p = F.max_pool2d(c1, kernel_size=7, stride=3) + c1_p = self.relu(c1_p) + c2 = self.con_(c1_p) + c2 = self.relu(c2) + c3 = F.interpolate(c2, (x.size(2), x.size(3)), mode='bilinear', align_corners=False) + cf = self.conv_f(c1_) + c4 = self.conv4(c3 + cf) + m = self.sigmoid(c4) + + return x * m + +class TEST_FMEN(nn.Module): + def __init__(self, args): + super(TEST_FMEN, self).__init__() + + self.down_blocks = args.down_blocks + + # up_blocks = args.up_blocks + + n_feats = args.n_feats + n_colors = args.colors + scale = args.scale + + # define head module + # self.head = nn.Conv2d(n_colors, n_feats, 3, 1, 1) + + # warm up + # self.warmup = nn.Sequential( + # nn.Conv2d(n_feats, n_feats, 3, 1, 1), + # # HFAB(n_feats, up_blocks[0], mid_feats-4) + # ESA(n_feats, nn.Conv2d) + # ) + + self.head = ECB(n_colors, n_feats, depth_multiplier=2, act_type='lrelu', with_idt = True) + self.warmup = nn.Sequential( + ECB(n_feats, n_feats, depth_multiplier=2, act_type='lrelu', with_idt = True), + # HFAB(n_feats, up_blocks[0], mid_feats-4) + ESA(n_feats, nn.Conv2d) + ) + + # define body module + ERBs = [ERB(n_feats) for _ in range(self.down_blocks)] + # HFABs = [HFAB(n_feats, up_blocks[i+1], mid_feats) for i in range(self.down_blocks)] + HFABs = [ESA(n_feats, nn.Conv2d) for i in range(self.down_blocks)] + + self.ERBs = nn.ModuleList(ERBs) + self.HFABs = nn.ModuleList(HFABs) + + + + self.lr_conv = nn.Conv2d(n_feats, n_feats, 3, 1, 1) + + # define tail module + self.tail = nn.Sequential( + nn.Conv2d(n_feats, n_colors*(scale**2), 3, 1, 1), + nn.PixelShuffle(scale) + ) + + + def forward(self, x): + x = self.head(x) + + h = self.warmup(x) + for i in range(self.down_blocks): + h = self.ERBs[i](h) + h = self.HFABs[i](h) + + h = self.lr_conv(h) + + h += x + x = self.tail(h) + + return x + + + def load_state_dict(self, state_dict, strict=True): + own_state = self.state_dict() + for name, param in state_dict.items(): + if name in own_state: + if isinstance(param, nn.Parameter): + param = param.data + try: + own_state[name].copy_(param) + except Exception: + if name.find('tail') == -1: + raise RuntimeError('While copying the parameter named {}, ' + 'whose dimensions in the model are {} and ' + 'whose dimensions in the checkpoint are {}.' + .format(name, own_state[name].size(), param.size())) + elif strict: + if name.find('tail') == -1: + raise KeyError('unexpected key "{}" in state_dict' + .format(name)) + + + + + +class Args: + def __init__(self): + self.down_blocks = 4 + self.up_blocks = [2, 1, 1, 1, 1] + self.n_feats = 50 + self.mid_feats = 16 + + self.scale = [4] + self.rgb_range = 255 + self.n_colors = 3 + +if __name__ == '__main__': + args = Args() + model = TEST_FMEN(args).to('cuda') + in_ = torch.randn(1, 3, round(720/args.scale[0]), round(1280/args.scale[0])).to('cuda') + summary(model, in_) \ No newline at end of file diff --git a/models/fmen/fmen_train.py b/models/fmen/fmen_train.py new file mode 100644 index 0000000..3a9c483 --- /dev/null +++ b/models/fmen/fmen_train.py @@ -0,0 +1,245 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchsummaryX import summary + + +lrelu_value = 0.1 +act = nn.LeakyReLU(lrelu_value) + + +def make_model(args, parent=False): + return TRAIN_FMEN(args) + + +def pad_tensor(t, pattern): + pattern = pattern.view(1, -1, 1, 1) + t = F.pad(t, (1, 1, 1, 1), 'constant', 0) + t[:, :, 0:1, :] = pattern + t[:, :, -1:, :] = pattern + t[:, :, :, 0:1] = pattern + t[:, :, :, -1:] = pattern + + return t + + +def get_bn_bias(bn_layer): + gamma, beta, mean, var, eps = bn_layer.weight, bn_layer.bias, bn_layer.running_mean, bn_layer.running_var, bn_layer.eps + std = (var + eps).sqrt() + bn_bias = beta - mean * gamma / std + + return bn_bias + + +class RRRB(nn.Module): + """ Residual in residual reparameterizable block. + Using reparameterizable block to replace single 3x3 convolution. + Diagram: + ---Conv1x1--Conv3x3-+-Conv1x1--+-- + |________| + |_____________________________| + Args: + n_feats (int): The number of feature maps. + ratio (int): Expand ratio. + """ + + def __init__(self, n_feats, ratio=2): + super(RRRB, self).__init__() + self.expand_conv = nn.Conv2d(n_feats, ratio*n_feats, 1, 1, 0) + self.fea_conv = nn.Conv2d(ratio*n_feats, ratio*n_feats, 3, 1, 0) + self.reduce_conv = nn.Conv2d(ratio*n_feats, n_feats, 1, 1, 0) + + def forward(self, x): + out = self.expand_conv(x) + out_identity = out + + # explicitly padding with bias for reparameterizing in the test phase + b0 = self.expand_conv.bias + out = pad_tensor(out, b0) + + out = self.fea_conv(out) + out_identity + out = self.reduce_conv(out) + out += x + + return out + + +class ERB(nn.Module): + """ Enhanced residual block for building FEMN. + Diagram: + --RRRB--LeakyReLU--RRRB-- + + Args: + n_feats (int): Number of feature maps. + ratio (int): Expand ratio in RRRB. + """ + + def __init__(self, n_feats, ratio=2): + super(ERB, self).__init__() + self.conv1 = RRRB(n_feats, ratio) + self.conv2 = RRRB(n_feats, ratio) + + def forward(self, x): + out = self.conv1(x) + out = act(out) + out = self.conv2(out) + + return out + + +class HFAB(nn.Module): + """ High-Frequency Attention Block. + Diagram: + ---BN--Conv--[ERB]*up_blocks--BN--Conv--BN--Sigmoid--*-- + |___________________________________________________| + Args: + n_feats (int): Number of HFAB input feature maps. + up_blocks (int): Number of ERBs for feature extraction in this HFAB. + mid_feats (int): Number of feature maps in ERB. + Note: + Batch Normalization (BN) is adopted to introduce global contexts and achieve sigmoid unsaturated area. + """ + + def __init__(self, n_feats, up_blocks, mid_feats, ratio): + super(HFAB, self).__init__() + self.bn1 = nn.BatchNorm2d(n_feats) + self.bn2 = nn.BatchNorm2d(mid_feats) + self.bn3 = nn.BatchNorm2d(n_feats) + + self.squeeze = nn.Conv2d(n_feats, mid_feats, 3, 1, 0) + + convs = [ERB(mid_feats, ratio) for _ in range(up_blocks)] + self.convs = nn.Sequential(*convs) + + self.excitate = nn.Conv2d(mid_feats, n_feats, 3, 1, 0) + + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + # explicitly padding with bn bias + out = self.bn1(x) + bn1_bias = get_bn_bias(self.bn1) + out = pad_tensor(out, bn1_bias) + + out = act(self.squeeze(out)) + out = act(self.convs(out)) + + # explicitly padding with bn bias + out = self.bn2(out) + bn2_bias = get_bn_bias(self.bn2) + out = pad_tensor(out, bn2_bias) + + out = self.excitate(out) + + out = self.sigmoid(self.bn3(out)) + + return out * x + + +class TRAIN_FMEN(nn.Module): + """ Fast and Memory-Efficient Network + Diagram: + --Conv--Conv-HFAB-[ERB-HFAB]*down_blocks-Conv-+-Upsample-- + |______________________________________| + Args: + down_blocks (int): Number of [ERB-HFAB] pairs. + up_blocks (list): Number of ERBs in each HFAB. + mid_feats (int): Number of feature maps in branch ERB. + n_feats (int): Number of feature maps in trunk ERB. + n_colors (int): Number of image channels. + scale (list): upscale factor. + backbone_expand_ratio (int): Expand ratio of RRRB in trunk ERB. + attention_expand_ratio (int): Expand ratio of RRRB in branch ERB. + """ + + def __init__(self, args): + super(TRAIN_FMEN, self).__init__() + + self.down_blocks = args.down_blocks + + up_blocks = args.up_blocks + mid_feats = args.mid_feats + n_feats = args.n_feats + n_colors = args.n_colors + scale = args.scale[0] + backbone_expand_ratio = args.backbone_expand_ratio + attention_expand_ratio = args.attention_expand_ratio + + # define head module + self.head = nn.Conv2d(n_colors, n_feats, 3, 1, 1) + + # warm up + self.warmup = nn.Sequential( + nn.Conv2d(n_feats, n_feats, 3, 1, 1), + HFAB(n_feats, up_blocks[0], mid_feats-4, attention_expand_ratio) + ) + + # define body module + ERBs = [ERB(n_feats, backbone_expand_ratio) for _ in range(self.down_blocks)] + HFABs = [HFAB(n_feats, up_blocks[i+1], mid_feats, attention_expand_ratio) for i in range(self.down_blocks)] + + self.ERBs = nn.ModuleList(ERBs) + self.HFABs = nn.ModuleList(HFABs) + + self.lr_conv = nn.Conv2d(n_feats, n_feats, 3, 1, 1) + + # define tail module + self.tail = nn.Sequential( + nn.Conv2d(n_feats, n_colors*(scale**2), 3, 1, 1), + nn.PixelShuffle(scale) + ) + + + def forward(self, x): + x = self.head(x) + + h = self.warmup(x) + for i in range(self.down_blocks): + h = self.ERBs[i](h) + h = self.HFABs[i](h) + h = self.lr_conv(h) + + h += x + x = self.tail(h) + + return x + + + def load_state_dict(self, state_dict, strict=True): + own_state = self.state_dict() + for name, param in state_dict.items(): + if name in own_state: + if isinstance(param, nn.Parameter): + param = param.data + try: + own_state[name].copy_(param) + except Exception: + if name.find('tail') == -1: + raise RuntimeError('While copying the parameter named {}, ' + 'whose dimensions in the model are {} and ' + 'whose dimensions in the checkpoint are {}.' + .format(name, own_state[name].size(), param.size())) + elif strict: + if name.find('tail') == -1: + raise KeyError('unexpected key "{}" in state_dict' + .format(name)) + + +class Args: + def __init__(self): + self.down_blocks = 4 + self.up_blocks = [2, 1, 1, 1, 1] + self.n_feats = 50 + self.mid_feats = 16 + self.backbone_expand_ratio = 2 + self.attention_expand_ratio = 2 + + self.scale = [4] + self.rgb_range = 255 + self.n_colors = 3 + +if __name__ == '__main__': + args = Args() + model = TRAIN_FMEN(args).to('cuda') + in_ = torch.randn(1, 3, round(720/args.scale[0]), round(1280/args.scale[0])).to('cuda') + summary(model, in_) \ No newline at end of file diff --git a/models/fmen/repameterize.py b/models/fmen/repameterize.py new file mode 100644 index 0000000..398e8cb --- /dev/null +++ b/models/fmen/repameterize.py @@ -0,0 +1,140 @@ +import torch +import torch.nn.functional as F +import code.SimpleIR.models.fmen.fmen_train as fmen_train +from tqdm import tqdm +from argparse import ArgumentParser + +class Args: + def __init__(self): + self.n_feats = 50 + self.mid_feats = 16 + self.down_blocks = 4 + self.up_blocks = [2, 1, 1, 1, 1] + self.backbone_expand_ratio = 2 + self.attention_expand_ratio = 2 + self.n_colors = 3 + self.scale = [4] + + +def merge_bn(w, b, gamma, beta, mean, var, eps, before_conv=True): + """Merge BN layer into convolution layer. + Args: + w (torch.tensor): Convolution kernel weight. (C_out, C_in, K, K) + b (torch.tensor): Convolution kernel bias. (C_out) + """ + + out_feats = w.shape[0] + std = (var + eps).sqrt() + scale = gamma / std + bn_bias = beta - mean * gamma / std + + # Reparameterizing kernel + if before_conv: + rep_w = w * scale.reshape(1, -1, 1, 1) + else: + rep_w = torch.mm(torch.diag(scale), w.view(out_feats, -1)).view(w.shape) + + # Reparameterizing bias + if before_conv: + rep_b = torch.mm(torch.sum(w, dim=(2,3)), bn_bias.unsqueeze(1)).squeeze() + b + else: + rep_b = b.mul(scale) + bn_bias + + return rep_w, rep_b + + +def bn_parameter(pretrain_state_dict, k, dst='bn1'): + src = k.split('.')[-2] + gamma = pretrain_state_dict[k.replace(src, dst)] + beta = pretrain_state_dict[k.replace(f'{src}.weight', f'{dst}.bias')] + mean = pretrain_state_dict[k.replace(f'{src}.weight', f'{dst}.running_mean')] + var = pretrain_state_dict[k.replace(f'{src}.weight', f'{dst}.running_var')] + eps = 1e-05 + + return gamma, beta, mean, var, eps + + +if __name__ == '__main__': + parser = ArgumentParser() + parser.add_argument('--pretrained_path', type=str, required=True) + args = parser.parse_args() + model_args = Args() + model = test_fmen.make_model(model_args).cuda() + rep_state_dict = model.state_dict() + pretrain_state_dict = torch.load(args.pretrained_path, map_location='cuda') + + for k, v in tqdm(rep_state_dict.items()): + # merge conv1x1-conv3x3-conv1x1 + if 'rep_conv.weight' in k: + k0 = pretrain_state_dict[k.replace('rep', 'expand')] + k1 = pretrain_state_dict[k.replace('rep', 'fea')] + k2 = pretrain_state_dict[k.replace('rep', 'reduce')] + + bias_str = k.replace('weight', 'bias') + b0 = pretrain_state_dict[bias_str.replace('rep', 'expand')] + b1 = pretrain_state_dict[bias_str.replace('rep', 'fea')] + b2 = pretrain_state_dict[bias_str.replace('rep', 'reduce')] + + mid_feats, n_feats = k0.shape[:2] + + # first step: remove the middle identity + for i in range(mid_feats): + k1[i, i, 1, 1] += 1.0 + + # second step: merge the first 1x1 convolution and the next 3x3 convolution + merge_k0k1 = F.conv2d(input=k1, weight=k0.permute(1, 0, 2, 3)) + merge_b0b1 = b0.view(1, -1, 1, 1) * torch.ones(1, mid_feats, 3, 3).cuda() + merge_b0b1 = F.conv2d(input=merge_b0b1, weight=k1, bias=b1) + + # third step: merge the remain 1x1 convolution + merge_k0k1k2 = F.conv2d(input=merge_k0k1.permute(1, 0, 2, 3), weight=k2).permute(1, 0, 2, 3) + merge_b0b1b2 = F.conv2d(input=merge_b0b1, weight=k2, bias=b2).view(-1) + + # last step: remove the global identity + for i in range(n_feats): + merge_k0k1k2[i, i, 1, 1] += 1.0 + + rep_state_dict[k] = merge_k0k1k2.float() + rep_state_dict[bias_str] = merge_b0b1b2.float() + + elif 'rep_conv.bias' in k: + pass + + # merge BN + elif 'squeeze.weight' in k: + bias_str = k.replace('weight', 'bias') + w = pretrain_state_dict[k] + b = pretrain_state_dict[bias_str] + gamma, beta, mean, var, eps = bn_parameter(pretrain_state_dict, k, dst='bn1') + + rep_w, rep_b = merge_bn(w, b, gamma, beta, mean, var, eps, before_conv=True) + + rep_state_dict[k] = rep_w + rep_state_dict[bias_str] = rep_b + + elif 'squeeze.bias' in k: + pass + + elif 'excitate.weight' in k: + bias_str = k.replace('weight', 'bias') + w = pretrain_state_dict[k] + b = pretrain_state_dict[bias_str] + gamma1, beta1, mean1, var1, eps1 = bn_parameter(pretrain_state_dict, k, dst='bn2') + gamma2, beta2, mean2, var2, eps2 = bn_parameter(pretrain_state_dict, k, dst='bn3') + rep_w, rep_b = merge_bn(w, b, gamma1, beta1, mean1, var1, eps1, before_conv=True) + rep_w, rep_b = merge_bn(rep_w, rep_b, gamma2, beta2, mean2, var2, eps2, before_conv=False) + + rep_state_dict[k] = rep_w + rep_state_dict[bias_str] = rep_b + + elif 'excitate.bias' in k: + pass + + elif k in pretrain_state_dict.keys(): + rep_state_dict[k] = pretrain_state_dict[k] + + else: + raise NotImplementedError('{} is not found in pretrain_state_dict.'.format(k)) + + torch.save(rep_state_dict, 'test.pt') + print('Reparameterize successfully!') \ No newline at end of file