Skip to content

Commit

Permalink
Merge branch 'main' into update-changelog
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Jan 10, 2024
2 parents 4717d6e + f6cb681 commit 7c6ec66
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 22 deletions.
7 changes: 0 additions & 7 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,3 @@ jobs:
if: contains('refs/heads/main refs/heads/development refs/heads/release', github.ref) != 1
run: |
make test-light
- name: Test coveralls - python ${{ matrix.python-version }}
run: coveralls --service=github
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
flag-name: run-${{ matrix.python-version }}
parallel: true
8 changes: 4 additions & 4 deletions makefile
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@ uninstall:
.PHONY: install-dev

install-dev:
@pip install -e .[test]
@pip install -e .[lint]
@pip install -e ."[test]"
@pip install -e ."[lint]"

.PHONY: install-test

install-test:
@pip install -e .[test]
@pip install -e ."[test]"

.PHONY: test test-light

Expand All @@ -66,7 +66,7 @@ test-light:
.PHONY: install-lint

install-lint:
@pip install -e .[lint]
@pip install -e ."[lint]"

.PHONY: isort isort-check

Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ ignore =
W291, # trailing whitespace
W503, # line break before binary operator
W504, # line break after binary operator
B905, # `zip()` without an explicit `strict=` parameter
exclude = docs, docs_src, build, .git, .eggs

[isort]
Expand Down
10 changes: 10 additions & 0 deletions test/fold_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,16 @@
"kernel_size": (2, 2),
},
},
{
"seed": 0,
"input_fn": lambda: torch.rand(2, 3 * 2 * 2, 5 * 9),
"fold_kwargs": {
"output_size": (4, 8),
"kernel_size": 2,
"padding": 1,
},
"id": "bug30-fold-with-padding",
},
]
PROBLEMS_2D_IDS = [make_id(problem) for problem in PROBLEMS_2D]

Expand Down
36 changes: 25 additions & 11 deletions unfoldNd/fold.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,34 +56,48 @@ def foldNd(input, output_size, kernel_size, dilation=1, padding=0, stride=1):
"""
device = input.device

if isinstance(output_size, tuple):
N = len(output_size)
output_size_numel = _get_numel_from_shape(output_size)
else:
if not isinstance(output_size, tuple):
raise ValueError(f"'output_size' must be tuple. Got {type(output_size)}.")

N = len(output_size)
kernel_size = _tuple(kernel_size, N)
kernel_size_numel = _get_kernel_size_numel(kernel_size)

batch_size = input.shape[0]
in_channels_kernel_size_numel = input.shape[1]
in_channels = in_channels_kernel_size_numel // kernel_size_numel

_check_output_size(output_size_numel)
idx = torch.arange(output_size_numel, dtype=torch.float32, device=device).reshape(
1, 1, *output_size
)
idx = unfoldNd(idx, kernel_size, dilation=dilation, padding=padding, stride=stride)
# Set up an array containing the locations on the padded image
padding = _tuple(padding, N)
padded_output_size = tuple(o + 2 * p for o, p in zip(output_size, padding))
padded_output_size_numel = _get_numel_from_shape(padded_output_size)
_check_output_size(padded_output_size_numel)

idx = torch.arange(
padded_output_size_numel, dtype=torch.float32, device=device
).reshape(1, 1, *padded_output_size)
idx = unfoldNd(idx, kernel_size, dilation=dilation, padding=0, stride=stride)

# Replicate indices over batch and channels, then scatter the patch values
# back to the padded image
input = input.reshape(batch_size, in_channels, -1)
idx = idx.reshape(1, 1, -1).long().expand(batch_size, in_channels, -1)

output = torch.zeros(
batch_size, in_channels, output_size_numel, device=device, dtype=input.dtype
batch_size,
in_channels,
padded_output_size_numel,
device=device,
dtype=input.dtype,
)
output.scatter_add_(2, idx, input)
output = output.reshape(batch_size, in_channels, *padded_output_size)

# Remove the pixels that correspond to padding
for n, (out_n, padding_n) in enumerate(zip(output_size, padding), start=2):
output = output.narrow(n, padding_n, out_n)

return output.reshape(batch_size, in_channels, *output_size)
return output


def _check_output_size(output_size_numel):
Expand Down

0 comments on commit 7c6ec66

Please sign in to comment.