Skip to content

Commit 4f0ed9f

Browse files
authored
Fix a issue (Swall0w#6)
1 parent b52a3b0 commit 4f0ed9f

File tree

1 file changed

+31
-27
lines changed

1 file changed

+31
-27
lines changed

torchstat/compute_memory.py

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def compute_memory(module, inp, out):
1818
return compute_Pool2d_memory(module, inp, out)
1919
else:
2020
print(f"[Memory]: {type(module).__name__} is not supported!")
21-
return (0, 0)
21+
return 0, 0
2222
pass
2323

2424

@@ -28,20 +28,21 @@ def num_params(module):
2828

2929
def compute_ReLU_memory(module, inp, out):
3030
assert isinstance(module, (nn.ReLU, nn.ReLU6, nn.ELU, nn.LeakyReLU))
31-
batch_size = inp.size()[0]
32-
mread = batch_size * inp.size()[1:].numel()
33-
mwrite = batch_size * inp.size()[1:].numel()
34-
35-
return (mread, mwrite)
31+
32+
mread = inp.numel()
33+
mwrite = out.numel()
34+
35+
return mread, mwrite
3636

3737

3838
def compute_PReLU_memory(module, inp, out):
39-
assert isinstance(module, (nn.PReLU))
39+
assert isinstance(module, nn.PReLU)
40+
4041
batch_size = inp.size()[0]
41-
mread = batch_size * (inp.size()[1:].numel() + num_params(module))
42-
mwrite = batch_size * inp.size()[1:].numel()
42+
mread = batch_size * (inp[0].numel() + num_params(module))
43+
mwrite = out.numel()
4344

44-
return (mread, mwrite)
45+
return mread, mwrite
4546

4647

4748
def compute_Conv2d_memory(module, inp, out):
@@ -50,39 +51,42 @@ def compute_Conv2d_memory(module, inp, out):
5051
assert len(inp.size()) == 4 and len(inp.size()) == len(out.size())
5152

5253
batch_size = inp.size()[0]
53-
in_c = inp.size()[1]
54-
out_c, out_h, out_w = out.size()[1:]
5554

56-
# This includes weighs with bias if the module contains it.
57-
mread = batch_size * (inp.size()[1:].numel() + num_params(module))
58-
mwrite = batch_size * out_c * out_h * out_w
59-
return (mread, mwrite)
55+
# This includes weights with bias if the module contains it.
56+
mread = batch_size * (inp[0].numel() + num_params(module))
57+
mwrite = out.numel()
58+
return mread, mwrite
6059

6160

6261
def compute_BatchNorm2d_memory(module, inp, out):
6362
assert isinstance(module, nn.BatchNorm2d)
6463
assert len(inp.size()) == 4 and len(inp.size()) == len(out.size())
64+
6565
batch_size, in_c, in_h, in_w = inp.size()
66-
67-
mread = batch_size * (inp.size()[1:].numel() + 2 * in_c)
68-
mwrite = inp.size().numel()
69-
return (mread, mwrite)
66+
mread = batch_size * (inp[0].numel() + 2 * in_c)
67+
mwrite = out.numel()
68+
69+
return mread, mwrite
7070

7171

7272
def compute_Linear_memory(module, inp, out):
7373
assert isinstance(module, nn.Linear)
7474
assert len(inp.size()) == 2 and len(out.size()) == 2
75+
7576
batch_size = inp.size()[0]
76-
mread = batch_size * (inp.size()[1:].numel() + num_params(module))
77-
mwrite = out.size().numel()
7877

79-
return (mread, mwrite)
78+
# This includes weights with bias if the module contains it.
79+
mread = batch_size * (inp[0].numel() + num_params(module))
80+
mwrite = out.numel()
81+
82+
return mread, mwrite
8083

8184

8285
def compute_Pool2d_memory(module, inp, out):
8386
assert isinstance(module, (nn.MaxPool2d, nn.AvgPool2d))
8487
assert len(inp.size()) == 4 and len(inp.size()) == len(out.size())
85-
batch_size = inp.size()[0]
86-
mread = batch_size * inp.size()[1:].numel()
87-
mwrite = batch_size * out.size()[1:].numel()
88-
return (mread, mwrite)
88+
89+
mread = inp.numel()
90+
mwrite = out.numel()
91+
92+
return mread, mwrite

0 commit comments

Comments
 (0)