From 4245bae7bf2d4f74590a037026074bd908066bfb Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Tue, 28 Nov 2023 05:59:52 +0900 Subject: [PATCH] Add type annotations to get_complexity funciton --- ptflops/flops_counter.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/ptflops/flops_counter.py b/ptflops/flops_counter.py index a4def46..296ea49 100644 --- a/ptflops/flops_counter.py +++ b/ptflops/flops_counter.py @@ -7,6 +7,7 @@ ''' import sys +from typing import Any, Callable, Dict, TextIO, Tuple, Union import torch.nn as nn @@ -14,14 +15,17 @@ from .utils import flops_to_string, params_to_string -def get_model_complexity_info(model, input_res, - print_per_layer_stat=True, - as_strings=True, - input_constructor=None, ost=sys.stdout, - verbose=False, ignore_modules=[], - custom_modules_hooks={}, backend='pytorch', - flops_units=None, param_units=None, - output_precision=2): +def get_model_complexity_info(model: nn.Module, input_res: Tuple[int, ...], + print_per_layer_stat: bool = True, + as_strings: bool = True, + input_constructor: Union[Callable, None] = None, + ost: TextIO = sys.stdout, + verbose: bool = False, ignore_modules=[], + custom_modules_hooks: Dict[nn.Module, Any] = {}, + backend: str = 'pytorch', + flops_units: Union[str, None] = None, + param_units: Union[str, None] = None, + output_precision: int = 2): assert type(input_res) is tuple assert len(input_res) >= 1 assert isinstance(model, nn.Module)