-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathexample_fold.py
37 lines (28 loc) · 914 Bytes
/
example_fold.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
"""How to use ``foldNd``. A comparison with ``torch.nn.Fold``."""
# imports, make this example deterministic
import torch
import unfoldNd
torch.manual_seed(0)
# random output of an im2col operation
inputs = torch.randn(64, 3 * 2 * 2, 12)
output_size = (4, 5)
# other module hyperparameters
kernel_size = 2
dilation = 1
padding = 0
stride = 1
# both modules accept the same arguments and perform the same operation
torch_module = torch.nn.Fold(
output_size, kernel_size, dilation=dilation, padding=padding, stride=stride
)
lib_module = unfoldNd.FoldNd(
output_size, kernel_size, dilation=dilation, padding=padding, stride=stride
)
# forward pass
torch_outputs = torch_module(inputs)
lib_outputs = lib_module(inputs)
# check
if torch.allclose(torch_outputs, lib_outputs):
print("✔ Outputs of torch.nn.Fold and unfoldNd.FoldNd match.")
else:
raise AssertionError("❌ Outputs don't match")