diff --git a/implementations/bicyclegan/models.py b/implementations/bicyclegan/models.py index a5701321..91f7f947 100644 --- a/implementations/bicyclegan/models.py +++ b/implementations/bicyclegan/models.py @@ -150,7 +150,7 @@ def discriminator_block(in_filters, out_filters, normalize=True): ), ) - self.downsample = nn.AvgPool2d(in_channels, stride=2, padding=[1, 1], count_include_pad=False) + self.downsample = nn.AvgPool2d(channels, stride=2, padding=[1, 1], count_include_pad=False) def compute_loss(self, x, gt): """Computes the MSE between model output and scalar gt"""