Skip to content

Commit

Permalink
Add support for timm module in MODULES_MAPPING
Browse files Browse the repository at this point in the history
  • Loading branch information
OliverXUZY committed Nov 26, 2023
1 parent 316cda9 commit b6c6da4
Showing 1 changed file with 29 additions and 0 deletions.
29 changes: 29 additions & 0 deletions ptflops/pytorch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,29 @@ def multihead_attention_counter_hook(multihead_attention_module, input, output):
multihead_attention_module.__flops__ += int(flops)


def timm_attention_counter_hook(attention_module, input, output):
flops = 0
B, N, C = input[0].shape # [Batch_size, Seq_len, Dimension]

# QKV projection is already covered in MODULES_MAPPING

# Q scaling
flops += N * attention_module.head_dim * attention_module.num_heads

# head flops
head_flops = (
(N * N * attention_module.head_dim) # QK^T
+ (N * N) # softmax
+ (N * N * attention_module.head_dim) # AV
)
flops += head_flops * attention_module.num_heads

# Final projection is already covered in MODULES_MAPPING

flops *= B
attention_module.__flops__ += int(flops)


CUSTOM_MODULES_MAPPING = {}

MODULES_MAPPING = {
Expand Down Expand Up @@ -300,6 +323,12 @@ def multihead_attention_counter_hook(multihead_attention_module, input, output):
except ImportError:
pass

try:
import timm
MODULES_MAPPING[timm.models.vision_transformer.Attention] = timm_attention_counter_hook
except ImportError:
pass


def _linear_functional_flops_hook(input, weight, bias=None):
out_features = weight.shape[0]
Expand Down

0 comments on commit b6c6da4

Please sign in to comment.