-
Notifications
You must be signed in to change notification settings - Fork 746
Open
Description
When using Flax’s Conv
and ConvTranspose
layers in pair, the ConvTranspose
does not seem to correctly restore the original input shape, even when the parameters are set in a way that should theoretically allow this. This behavior differs from PyTorch, where ConvXd
and ConvTransposeXd
used together reliably restore the input shape.
ConvTranspose
function appears to produce incorrect output shapes, sometimes resulting in dimensions collapsing to zero. This behavior is not just a mismatch with PyTorch, but makes the function effectively unusable in certain cases.
Reproduction Example
from jax import random
from flax import nnx
import torch
from torch import nn
key = random.PRNGKey(42)
batch_size = 4
in_channels = 128
out_channels = 32
i = 4
k = 3
s = 1
p = 0
# ============= Flax ===========================
x = random.uniform(key, shape=(batch_size, i, i, in_channels))
conv = nnx.Conv(in_features=in_channels,
out_features=out_channels,
kernel_size=(k, k),
strides=(s, s),
padding=p,
rngs=nnx.Rngs(0))
y = conv(x)
print(y.shape) # (4, 2, 2, 32)
assert y.shape[2] == 2
tconv = nnx.ConvTranspose(in_features=out_channels,
out_features=in_channels,
kernel_size=(k, k),
strides=(s, s),
padding=p,
rngs=nnx.Rngs(0))
z = tconv(y)
print(z.shape) # (4, 0, 0, 128)
if z.shape[2] != i:
print(f"Flax transConv failed to restore original input shape.")
# ============= PyTorch ========================
x = torch.rand(batch_size, in_channels, i, i)
conv = nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=k,
stride=s,
padding=p)
y = conv(x)
print(y.shape) # torch.Size([4, 32, 2, 2])
assert y.shape == (batch_size, out_channels, 2, 2)
kp = k
sp = s
pp = k - 1
ip = 2
op = ip + (k-1)
tconv = nn.ConvTranspose2d(in_channels=out_channels,
out_channels=in_channels,
kernel_size=k,
stride=s,
padding=p)
z = tconv(y)
print(z.shape) # torch.Size([4, 128, 4, 4])
assert z.shape[2] == i
Metadata
Metadata
Assignees
Labels
No labels