From f87504bf454a1beb9042a421e47b40418d3ff180 Mon Sep 17 00:00:00 2001 From: Arthur Khaibrakhmanov Date: Thu, 12 Dec 2024 22:00:01 +0300 Subject: [PATCH] Auto commit: 2024-12-12 22:00:01 --- models/superformer.py | 30 ++++++++++++++++++++++-------- requirements.txt | 1 + 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/models/superformer.py b/models/superformer.py index d76afb636..e343bb337 100644 --- a/models/superformer.py +++ b/models/superformer.py @@ -735,12 +735,12 @@ class SuperFormer(nn.Module): """ def __init__(self, img_size=12, patch_size=3, in_chans=1, - embed_dim=48, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6], - window_size=6, mlp_ratio=4., qkv_bias=True, qk_scale=None, + embed_dim=72, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6], + window_size=2, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm, ape=False, rpb=True ,patch_norm=True, use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', - output_type = "residual",num_feat=64,**kwargs): + output_type = "",num_feat=64,**kwargs): super(SuperFormer, self).__init__() num_in_ch = in_chans num_out_ch = in_chans @@ -951,12 +951,26 @@ def forward(self, x): res_deep = (res_deep_feat + res_deep_vol)/2 # Modify the upsampling to actually scale up the features res = self.upsample_feat(res_deep) # Make sure this actually upsamples - res = res + F.interpolate(x_first, scale_factor=self.upscale, mode='trilinear', align_corners=False) + x_first_upsampled = F.interpolate(x_first, + size=(H*self.upscale, W*self.upscale, D*self.upscale), + mode='trilinear', + align_corners=False) + res_upsampled = F.interpolate(res, + size=(H*self.upscale, W*self.upscale, D*self.upscale), + mode='trilinear', + align_corners=False) + res = res_upsampled + x_first_upsampled else: - res = self.conv_after_body(self.forward_features(x_first)) - # Add upsampling here - res = F.interpolate(res, scale_factor=self.upscale, mode='trilinear', align_corners=False) - res = res + F.interpolate(x_first, scale_factor=self.upscale, mode='trilinear', align_corners=False) + res = self.conv_after_body(self.forward_features(x, x_first)) + res = F.interpolate(res, + size=(H*self.upscale, W*self.upscale, D*self.upscale), + mode='trilinear', + align_corners=False) + x_first_upsampled = F.interpolate(x_first, + size=(H*self.upscale, W*self.upscale, D*self.upscale), + mode='trilinear', + align_corners=False) + res = res + x_first_upsampled if self.output_type == 'residual': x = x + self.conv_last(res) else: diff --git a/requirements.txt b/requirements.txt index 70d704fe7..e3ca1f006 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,3 +9,4 @@ seaborn==0.13.2 torch==2.3.0 torchmetrics tqdm==4.66.4 +timm