Skip to content

Commit

Permalink
Auto commit: 2024-12-12 22:00:01
Browse files Browse the repository at this point in the history
  • Loading branch information
blackwood168 committed Dec 12, 2024
1 parent 453e4e3 commit f87504b
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 8 deletions.
30 changes: 22 additions & 8 deletions models/superformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,12 +735,12 @@ class SuperFormer(nn.Module):
"""

def __init__(self, img_size=12, patch_size=3, in_chans=1,
embed_dim=48, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
window_size=6, mlp_ratio=4., qkv_bias=True, qk_scale=None,
embed_dim=72, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
window_size=2, mlp_ratio=4., qkv_bias=True, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, rpb=True ,patch_norm=True,
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
output_type = "residual",num_feat=64,**kwargs):
output_type = "",num_feat=64,**kwargs):
super(SuperFormer, self).__init__()
num_in_ch = in_chans
num_out_ch = in_chans
Expand Down Expand Up @@ -951,12 +951,26 @@ def forward(self, x):
res_deep = (res_deep_feat + res_deep_vol)/2
# Modify the upsampling to actually scale up the features
res = self.upsample_feat(res_deep) # Make sure this actually upsamples
res = res + F.interpolate(x_first, scale_factor=self.upscale, mode='trilinear', align_corners=False)
x_first_upsampled = F.interpolate(x_first,
size=(H*self.upscale, W*self.upscale, D*self.upscale),
mode='trilinear',
align_corners=False)
res_upsampled = F.interpolate(res,
size=(H*self.upscale, W*self.upscale, D*self.upscale),
mode='trilinear',
align_corners=False)
res = res_upsampled + x_first_upsampled
else:
res = self.conv_after_body(self.forward_features(x_first))
# Add upsampling here
res = F.interpolate(res, scale_factor=self.upscale, mode='trilinear', align_corners=False)
res = res + F.interpolate(x_first, scale_factor=self.upscale, mode='trilinear', align_corners=False)
res = self.conv_after_body(self.forward_features(x, x_first))
res = F.interpolate(res,
size=(H*self.upscale, W*self.upscale, D*self.upscale),
mode='trilinear',
align_corners=False)
x_first_upsampled = F.interpolate(x_first,
size=(H*self.upscale, W*self.upscale, D*self.upscale),
mode='trilinear',
align_corners=False)
res = res + x_first_upsampled
if self.output_type == 'residual':
x = x + self.conv_last(res)
else:
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ seaborn==0.13.2
torch==2.3.0
torchmetrics
tqdm==4.66.4
timm

0 comments on commit f87504b

Please sign in to comment.