Skip to content

Commit

Permalink
Add type annotations to get_complexity funciton
Browse files Browse the repository at this point in the history
  • Loading branch information
sovrasov committed Nov 27, 2023
1 parent f6e4ee3 commit 4245bae
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions ptflops/flops_counter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,25 @@
'''

import sys
from typing import Any, Callable, Dict, TextIO, Tuple, Union

import torch.nn as nn

from .pytorch_engine import get_flops_pytorch
from .utils import flops_to_string, params_to_string


def get_model_complexity_info(model, input_res,
print_per_layer_stat=True,
as_strings=True,
input_constructor=None, ost=sys.stdout,
verbose=False, ignore_modules=[],
custom_modules_hooks={}, backend='pytorch',
flops_units=None, param_units=None,
output_precision=2):
def get_model_complexity_info(model: nn.Module, input_res: Tuple[int, ...],
print_per_layer_stat: bool = True,
as_strings: bool = True,
input_constructor: Union[Callable, None] = None,
ost: TextIO = sys.stdout,
verbose: bool = False, ignore_modules=[],
custom_modules_hooks: Dict[nn.Module, Any] = {},
backend: str = 'pytorch',
flops_units: Union[str, None] = None,
param_units: Union[str, None] = None,
output_precision: int = 2):
assert type(input_res) is tuple
assert len(input_res) >= 1
assert isinstance(model, nn.Module)
Expand Down

0 comments on commit 4245bae

Please sign in to comment.