diff --git a/mnist/main.py b/mnist/main.py index 1bee55c4..8f0bc49f 100644 --- a/mnist/main.py +++ b/mnist/main.py @@ -20,8 +20,9 @@ def __init__(self): def forward(self, x): x = self.conv1(x) - x = F.relu(x) + x = F.relu(x) x = self.conv2(x) + x = F.relu(x) x = F.max_pool2d(x, 2) x = self.dropout1(x) x = torch.flatten(x, 1)