@@ -18,7 +18,7 @@ def compute_memory(module, inp, out):
18
18
return compute_Pool2d_memory (module , inp , out )
19
19
else :
20
20
print (f"[Memory]: { type (module ).__name__ } is not supported!" )
21
- return ( 0 , 0 )
21
+ return 0 , 0
22
22
pass
23
23
24
24
@@ -28,20 +28,21 @@ def num_params(module):
28
28
29
29
def compute_ReLU_memory (module , inp , out ):
30
30
assert isinstance (module , (nn .ReLU , nn .ReLU6 , nn .ELU , nn .LeakyReLU ))
31
- batch_size = inp . size ()[ 0 ]
32
- mread = batch_size * inp . size ()[ 1 :] .numel ()
33
- mwrite = batch_size * inp . size ()[ 1 :] .numel ()
34
-
35
- return ( mread , mwrite )
31
+
32
+ mread = inp .numel ()
33
+ mwrite = out .numel ()
34
+
35
+ return mread , mwrite
36
36
37
37
38
38
def compute_PReLU_memory (module , inp , out ):
39
- assert isinstance (module , (nn .PReLU ))
39
+ assert isinstance (module , nn .PReLU )
40
+
40
41
batch_size = inp .size ()[0 ]
41
- mread = batch_size * (inp . size ()[ 1 : ].numel () + num_params (module ))
42
- mwrite = batch_size * inp . size ()[ 1 :] .numel ()
42
+ mread = batch_size * (inp [ 0 ].numel () + num_params (module ))
43
+ mwrite = out .numel ()
43
44
44
- return ( mread , mwrite )
45
+ return mread , mwrite
45
46
46
47
47
48
def compute_Conv2d_memory (module , inp , out ):
@@ -50,39 +51,42 @@ def compute_Conv2d_memory(module, inp, out):
50
51
assert len (inp .size ()) == 4 and len (inp .size ()) == len (out .size ())
51
52
52
53
batch_size = inp .size ()[0 ]
53
- in_c = inp .size ()[1 ]
54
- out_c , out_h , out_w = out .size ()[1 :]
55
54
56
- # This includes weighs with bias if the module contains it.
57
- mread = batch_size * (inp . size ()[ 1 : ].numel () + num_params (module ))
58
- mwrite = batch_size * out_c * out_h * out_w
59
- return ( mread , mwrite )
55
+ # This includes weights with bias if the module contains it.
56
+ mread = batch_size * (inp [ 0 ].numel () + num_params (module ))
57
+ mwrite = out . numel ()
58
+ return mread , mwrite
60
59
61
60
62
61
def compute_BatchNorm2d_memory (module , inp , out ):
63
62
assert isinstance (module , nn .BatchNorm2d )
64
63
assert len (inp .size ()) == 4 and len (inp .size ()) == len (out .size ())
64
+
65
65
batch_size , in_c , in_h , in_w = inp .size ()
66
-
67
- mread = batch_size * ( inp . size ()[ 1 :]. numel () + 2 * in_c )
68
- mwrite = inp . size (). numel ()
69
- return ( mread , mwrite )
66
+ mread = batch_size * ( inp [ 0 ]. numel () + 2 * in_c )
67
+ mwrite = out . numel ()
68
+
69
+ return mread , mwrite
70
70
71
71
72
72
def compute_Linear_memory (module , inp , out ):
73
73
assert isinstance (module , nn .Linear )
74
74
assert len (inp .size ()) == 2 and len (out .size ()) == 2
75
+
75
76
batch_size = inp .size ()[0 ]
76
- mread = batch_size * (inp .size ()[1 :].numel () + num_params (module ))
77
- mwrite = out .size ().numel ()
78
77
79
- return (mread , mwrite )
78
+ # This includes weights with bias if the module contains it.
79
+ mread = batch_size * (inp [0 ].numel () + num_params (module ))
80
+ mwrite = out .numel ()
81
+
82
+ return mread , mwrite
80
83
81
84
82
85
def compute_Pool2d_memory (module , inp , out ):
83
86
assert isinstance (module , (nn .MaxPool2d , nn .AvgPool2d ))
84
87
assert len (inp .size ()) == 4 and len (inp .size ()) == len (out .size ())
85
- batch_size = inp .size ()[0 ]
86
- mread = batch_size * inp .size ()[1 :].numel ()
87
- mwrite = batch_size * out .size ()[1 :].numel ()
88
- return (mread , mwrite )
88
+
89
+ mread = inp .numel ()
90
+ mwrite = out .numel ()
91
+
92
+ return mread , mwrite
0 commit comments