Skip to content

Commit

Permalink
[FIX] Bug (#30) in fold with non-zero padding (#32)
Browse files Browse the repository at this point in the history
* [BUG] Reproduce bug for fold with padding=1 (#30)

* [FIX] Treat padding explicitly when scattering, then unpad

This should fix #30.

The problem was in the creation of the tensor that maps the patch values
back to the positions in the original image. The code was using `unfold`
with non-zero padding, which would lead to a lot of patch values being
assigned to location 0. The fix is to deal with padding explicitly so
that the `unfold` operation uses `padding=0`. Then the index tensor is
correct, and after scattering we have to un-pad the image to obtain the
output.

* [CI] Ignore B905

* [DEL] Remove upload to coveralls

Github actions keep failing for unknown reasonds.
The error is:
```
coveralls.exception.CoverallsException: Could not submit coverage: 422 Client Error: Unprocessable Entity for url: https://coveralls.io/api/v1/jobs
```
Deciding to remove coverage uploads to coveralls.
  • Loading branch information
f-dangel authored Jan 10, 2024
1 parent 426dbf2 commit f6cb681
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 18 deletions.
7 changes: 0 additions & 7 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,3 @@ jobs:
if: contains('refs/heads/main refs/heads/development refs/heads/release', github.ref) != 1
run: |
make test-light
- name: Test coveralls - python ${{ matrix.python-version }}
run: coveralls --service=github
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
flag-name: run-${{ matrix.python-version }}
parallel: true
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ ignore =
W291, # trailing whitespace
W503, # line break before binary operator
W504, # line break after binary operator
B905, # `zip()` without an explicit `strict=` parameter
exclude = docs, docs_src, build, .git, .eggs

[isort]
Expand Down
10 changes: 10 additions & 0 deletions test/fold_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,16 @@
"kernel_size": (2, 2),
},
},
{
"seed": 0,
"input_fn": lambda: torch.rand(2, 3 * 2 * 2, 5 * 9),
"fold_kwargs": {
"output_size": (4, 8),
"kernel_size": 2,
"padding": 1,
},
"id": "bug30-fold-with-padding",
},
]
PROBLEMS_2D_IDS = [make_id(problem) for problem in PROBLEMS_2D]

Expand Down
36 changes: 25 additions & 11 deletions unfoldNd/fold.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,34 +56,48 @@ def foldNd(input, output_size, kernel_size, dilation=1, padding=0, stride=1):
"""
device = input.device

if isinstance(output_size, tuple):
N = len(output_size)
output_size_numel = _get_numel_from_shape(output_size)
else:
if not isinstance(output_size, tuple):
raise ValueError(f"'output_size' must be tuple. Got {type(output_size)}.")

N = len(output_size)
kernel_size = _tuple(kernel_size, N)
kernel_size_numel = _get_kernel_size_numel(kernel_size)

batch_size = input.shape[0]
in_channels_kernel_size_numel = input.shape[1]
in_channels = in_channels_kernel_size_numel // kernel_size_numel

_check_output_size(output_size_numel)
idx = torch.arange(output_size_numel, dtype=torch.float32, device=device).reshape(
1, 1, *output_size
)
idx = unfoldNd(idx, kernel_size, dilation=dilation, padding=padding, stride=stride)
# Set up an array containing the locations on the padded image
padding = _tuple(padding, N)
padded_output_size = tuple(o + 2 * p for o, p in zip(output_size, padding))
padded_output_size_numel = _get_numel_from_shape(padded_output_size)
_check_output_size(padded_output_size_numel)

idx = torch.arange(
padded_output_size_numel, dtype=torch.float32, device=device
).reshape(1, 1, *padded_output_size)
idx = unfoldNd(idx, kernel_size, dilation=dilation, padding=0, stride=stride)

# Replicate indices over batch and channels, then scatter the patch values
# back to the padded image
input = input.reshape(batch_size, in_channels, -1)
idx = idx.reshape(1, 1, -1).long().expand(batch_size, in_channels, -1)

output = torch.zeros(
batch_size, in_channels, output_size_numel, device=device, dtype=input.dtype
batch_size,
in_channels,
padded_output_size_numel,
device=device,
dtype=input.dtype,
)
output.scatter_add_(2, idx, input)
output = output.reshape(batch_size, in_channels, *padded_output_size)

# Remove the pixels that correspond to padding
for n, (out_n, padding_n) in enumerate(zip(output_size, padding), start=2):
output = output.narrow(n, padding_n, out_n)

return output.reshape(batch_size, in_channels, *output_size)
return output


def _check_output_size(output_size_numel):
Expand Down

0 comments on commit f6cb681

Please sign in to comment.