diff --git a/implementations/srgan/models.py b/implementations/srgan/models.py index b8efd3c0..7722a21d 100644 --- a/implementations/srgan/models.py +++ b/implementations/srgan/models.py @@ -9,7 +9,7 @@ class FeatureExtractor(nn.Module): def __init__(self): super(FeatureExtractor, self).__init__() vgg19_model = vgg19(pretrained=True) - self.feature_extractor = nn.Sequential(*list(vgg19_model.features.children())[:18]) + self.feature_extractor = nn.Sequential(*list(vgg19_model.features.children())[:35]) def forward(self, img): return self.feature_extractor(img) @@ -59,7 +59,7 @@ def __init__(self, in_channels=3, out_channels=3, n_residual_blocks=16): self.upsampling = nn.Sequential(*upsampling) # Final output layer - self.conv3 = nn.Sequential(nn.Conv2d(64, out_channels, kernel_size=9, stride=1, padding=4), nn.Tanh()) + self.conv3 = nn.Conv2d(64, out_channels, kernel_size=9, stride=1, padding=4) def forward(self, x): out1 = self.conv1(x)