Skip to content

Commit

Permalink
Merge pull request #120 from sovrasov/fix_linear_args
Browse files Browse the repository at this point in the history
Fix parsing of args in interpolate hook
  • Loading branch information
sovrasov authored Oct 30, 2023
2 parents ea08d72 + b077c64 commit 36c1be8
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 1 deletion.
5 changes: 4 additions & 1 deletion ptflops/pytorch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,10 @@ def _numel_functional_flops_hook(input, *args, **kwargs):


def _interpolate_functional_flops_hook(*args, **kwargs):
input = args[0]
input = kwargs.get('input', None)
if input is None and len(args) > 0:
input = args[0]

size = kwargs.get('size', None)
if size is None and len(args) > 1:
size = args[1]
Expand Down
26 changes: 26 additions & 0 deletions tests/common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,29 @@ def input_constructor(input_res):
print_per_layer_stat=False)

assert (macs, params) == (8, 8)

def test_func_interpolate_args(self):
class CustomModel(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return nn.functional.interpolate(input=x, size=(20, 20),
mode='bilinear', align_corners=False)

macs, params = \
get_model_complexity_info(CustomModel(), (3, 10, 10),
as_strings=False,
print_per_layer_stat=False)
assert params == 0
assert macs > 0

CustomModel.forward = lambda self, x: nn.functional.interpolate(x, size=(20, 20),
mode='bilinear')

macs, params = \
get_model_complexity_info(CustomModel(), (3, 10, 10),
as_strings=False,
print_per_layer_stat=False)
assert params == 0
assert macs > 0

0 comments on commit 36c1be8

Please sign in to comment.