Skip to content

Flax nnx ConvTranspose Does Not Restore Input Shape When Used with Conv (Unexpected Behavior) #4593

@Stella-S-Yan

Description

@Stella-S-Yan

When using Flax’s Conv and ConvTranspose layers in pair, the ConvTranspose does not seem to correctly restore the original input shape, even when the parameters are set in a way that should theoretically allow this. This behavior differs from PyTorch, where ConvXd and ConvTransposeXd used together reliably restore the input shape.

ConvTranspose function appears to produce incorrect output shapes, sometimes resulting in dimensions collapsing to zero. This behavior is not just a mismatch with PyTorch, but makes the function effectively unusable in certain cases.

Reproduction Example

from jax import random
from flax import nnx

import torch
from torch import nn

key = random.PRNGKey(42) 

batch_size = 4
in_channels = 128
out_channels = 32
i = 4
k = 3
s = 1
p = 0

# ============= Flax ===========================
x = random.uniform(key, shape=(batch_size, i, i, in_channels))
conv = nnx.Conv(in_features=in_channels,
                out_features=out_channels,
                kernel_size=(k, k), 
                strides=(s, s), 
                padding=p,
                rngs=nnx.Rngs(0))
y = conv(x)
print(y.shape) # (4, 2, 2, 32)
assert y.shape[2] == 2

tconv = nnx.ConvTranspose(in_features=out_channels,
                          out_features=in_channels,
                          kernel_size=(k, k), 
                            strides=(s, s), 
                            padding=p,
                            rngs=nnx.Rngs(0))
z = tconv(y)
print(z.shape) # (4, 0, 0, 128)
if z.shape[2] != i:
    print(f"Flax transConv failed to restore original input shape.")

# ============= PyTorch ========================
x = torch.rand(batch_size, in_channels, i, i)
conv = nn.Conv2d(in_channels=in_channels,
                 out_channels=out_channels,
                 kernel_size=k,
                 stride=s,
                 padding=p)
y = conv(x)
print(y.shape) # torch.Size([4, 32, 2, 2])
assert y.shape == (batch_size, out_channels, 2, 2)

kp = k
sp = s
pp = k - 1
ip = 2
op = ip + (k-1)
tconv = nn.ConvTranspose2d(in_channels=out_channels,
                           out_channels=in_channels,
                           kernel_size=k,
                           stride=s, 
                           padding=p)
z = tconv(y)
print(z.shape) # torch.Size([4, 128, 4, 4])
assert z.shape[2] == i

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions