-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathexample_unfold.py
35 lines (27 loc) · 901 Bytes
/
example_unfold.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
"""How to use ``unfoldNd``. A comparison with ``torch.nn.Unfold``."""
# imports, make this example deterministic
import torch
import unfoldNd
torch.manual_seed(0)
# random batched RGB 32x32 image-shaped input tensor of batch size 64
inputs = torch.randn((64, 3, 32, 32))
# module hyperparameters
kernel_size = 3
dilation = 1
padding = 1
stride = 2
# both modules accept the same arguments and perform the same operation
torch_module = torch.nn.Unfold(
kernel_size, dilation=dilation, padding=padding, stride=stride
)
lib_module = unfoldNd.UnfoldNd(
kernel_size, dilation=dilation, padding=padding, stride=stride
)
# forward pass
torch_outputs = torch_module(inputs)
lib_outputs = lib_module(inputs)
# check
if torch.allclose(torch_outputs, lib_outputs):
print("✔ Outputs of torch.nn.Unfold and unfoldNd.UnfoldNd match.")
else:
raise AssertionError("❌ Outputs don't match")