From b6c6da48815c1ccbff4d59ae45d43029ebef6f8a Mon Sep 17 00:00:00 2001 From: Zhuoyan Xu Date: Sun, 26 Nov 2023 11:00:40 -0600 Subject: [PATCH] Add support for timm module in MODULES_MAPPING --- ptflops/pytorch_ops.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/ptflops/pytorch_ops.py b/ptflops/pytorch_ops.py index 3801070..1bd2214 100644 --- a/ptflops/pytorch_ops.py +++ b/ptflops/pytorch_ops.py @@ -237,6 +237,29 @@ def multihead_attention_counter_hook(multihead_attention_module, input, output): multihead_attention_module.__flops__ += int(flops) +def timm_attention_counter_hook(attention_module, input, output): + flops = 0 + B, N, C = input[0].shape # [Batch_size, Seq_len, Dimension] + + # QKV projection is already covered in MODULES_MAPPING + + # Q scaling + flops += N * attention_module.head_dim * attention_module.num_heads + + # head flops + head_flops = ( + (N * N * attention_module.head_dim) # QK^T + + (N * N) # softmax + + (N * N * attention_module.head_dim) # AV + ) + flops += head_flops * attention_module.num_heads + + # Final projection is already covered in MODULES_MAPPING + + flops *= B + attention_module.__flops__ += int(flops) + + CUSTOM_MODULES_MAPPING = {} MODULES_MAPPING = { @@ -300,6 +323,12 @@ def multihead_attention_counter_hook(multihead_attention_module, input, output): except ImportError: pass +try: + import timm + MODULES_MAPPING[timm.models.vision_transformer.Attention] = timm_attention_counter_hook +except ImportError: + pass + def _linear_functional_flops_hook(input, weight, bias=None): out_features = weight.shape[0]