Skip to content

Commit

Permalink
Merge pull request #123 from sovrasov/vs/upd_linear
Browse files Browse the repository at this point in the history
Fix imprecise bias flops calculation in linear hook
  • Loading branch information
sovrasov authored Nov 25, 2023
2 parents 36c1be8 + 2c85b93 commit 316cda9
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
6 changes: 4 additions & 2 deletions ptflops/pytorch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,11 @@ def linear_flops_counter_hook(module, input, output):
input = input[0]
# pytorch checks dimensions, so here we don't care much
output_last_dim = output.shape[-1]
input_last_dim = input.shape[-1]
pre_last_dims_prod = np.prod(input.shape[0:-1], dtype=np.int64)
bias_flops = output_last_dim if module.bias is not None else 0
module.__flops__ += int(np.prod(input.shape, dtype=np.int64) *
output_last_dim + bias_flops)
module.__flops__ += int((input_last_dim * output_last_dim + bias_flops)
* pre_last_dims_prod)


def pool_flops_counter_hook(module, input, output):
Expand Down
9 changes: 9 additions & 0 deletions tests/common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@ def test_fc(self):
assert params == 3 * 2 + 2
assert int(macs) == 8

def test_fc_multidim(self):
net = nn.Sequential(nn.Linear(3, 2, bias=True))
macs, params = get_model_complexity_info(net, (4, 5, 3),
as_strings=False,
print_per_layer_stat=False)

assert params == (3 * 2 + 2)
assert int(macs) == (3 * 2 + 2) * 4 * 5

def test_input_constructor_tensor(self):
net = nn.Sequential(nn.Linear(3, 2, bias=True))

Expand Down

0 comments on commit 316cda9

Please sign in to comment.