diff --git a/README.md b/README.md index c7d905f..0cbbe20 100644 --- a/README.md +++ b/README.md @@ -18,8 +18,9 @@ Supported layers: Experimental support: - RNN, LSTM, GRU (NLH layout is assumed) - RNNCell, LSTMCell, GRUCell -- MultiheadAttention +- torch.nn.MultiheadAttention - torchvision.ops.DeformConv2d +- visual transformers from [timm](https://github.com/huggingface/pytorch-image-models) Requirements: Pytorch >= 1.1, torchvision >= 0.3 @@ -27,7 +28,8 @@ Thanks to @warmspringwinds for the initial version of script. ## Usage tips -- This tool doesn't take into account some of the `torch.nn.functional.*` and `tensor.*` operations. See `ptflops/pytorch_ops.py:FUNCTIONAL_MAPPING,TENSOR_OPS_MAPPING` to check supported ops. +- This tool doesn't take into account some of the `torch.nn.functional.*` and `tensor.*` operations. Therefore unsupported operations are +not contributing to the final complexity estimation. See `ptflops/pytorch_ops.py:FUNCTIONAL_MAPPING,TENSOR_OPS_MAPPING` to check supported ops. - `ptflops` launches a given model on a random tensor and estimates amount of computations during inference. Complicated models can have several inputs, some of them could be optional. To construct non-trivial input one can use the `input_constructor` argument of the `get_model_complexity_info`. `input_constructor` is a function that takes the input spatial resolution as a tuple and returns a dict with named input arguments of the model. Next this dict would be passed to the model as a keyword arguments. - `verbose` parameter allows to get information about modules that don't contribute to the final numbers. - `ignore_modules` option forces `ptflops` to ignore the listed modules. This can be useful @@ -72,84 +74,26 @@ If ptflops was useful for your paper or tech report, please cite me: ## Benchmark -### [torchvision](https://pytorch.org/docs/1.0.0/torchvision/models.html) - -Model | Input Resolution | Params(M) | MACs(G) | Top-1 error | Top-5 error ---- |--- |--- |--- |--- |--- -alexnet |224x224 | 61.1 | 0.72 | 43.45 | 20.91 -vgg11 |224x224 | 132.86 | 7.63 | 30.98 | 11.37 -vgg13 |224x224 | 133.05 | 11.34 | 30.07 | 10.75 -vgg16 |224x224 | 138.36 | 15.5 | 28.41 | 9.62 -vgg19 |224x224 | 143.67 | 19.67 | 27.62 | 9.12 -vgg11_bn |224x224 | 132.87 | 7.64 | 29.62 | 10.19 -vgg13_bn |224x224 | 133.05 | 11.36 | 28.45 | 9.63 -vgg16_bn |224x224 | 138.37 | 15.53 | 26.63 | 8.50 -vgg19_bn |224x224 | 143.68 | 19.7 | 25.76 | 8.15 -resnet18 |224x224 | 11.69 | 1.82 | 30.24 | 10.92 -resnet34 |224x224 | 21.8 | 3.68 | 26.70 | 8.58 -resnet50 |224x224 | 25.56 | 4.12 | 23.85 | 7.13 -resnet101 |224x224 | 44.55 | 7.85 | 22.63 | 6.44 -resnet152 |224x224 | 60.19 | 11.58 | 21.69 | 5.94 -squeezenet1_0 |224x224 | 1.25 | 0.83 | 41.90 | 19.58 -squeezenet1_1 |224x224 | 1.24 | 0.36 | 41.81 | 19.38 -densenet121 |224x224 | 7.98 | 2.88 | 25.35 | 7.83 -densenet169 |224x224 | 14.15 | 3.42 | 24.00 | 7.00 -densenet201 |224x224 | 20.01 | 4.37 | 22.80 | 6.43 -densenet161 |224x224 | 28.68 | 7.82 | 22.35 | 6.20 -inception_v3 |224x224 | 27.16 | 2.85 | 22.55 | 6.44 - -* Top-1 error - ImageNet single-crop top-1 error (224x224) -* Top-5 error - ImageNet single-crop top-5 error (224x224) - -### [Cadene/pretrained-models.pytorch](https://github.com/Cadene/pretrained-models.pytorch) - -Model | Input Resolution | Params(M) | MACs(G) | Acc@1 | Acc@5 ---- |--- |--- |--- |--- |--- -alexnet | 224x224 | 61.1 | 0.72 | 56.432 | 79.194 -bninception | 224x224 | 11.3 | 2.05 | 73.524 | 91.562 -cafferesnet101 | 224x224 | 44.55 | 7.62 | 76.2 | 92.766 -densenet121 | 224x224 | 7.98 | 2.88 | 74.646 | 92.136 -densenet161 | 224x224 | 28.68 | 7.82 | 77.56 | 93.798 -densenet169 | 224x224 | 14.15 | 3.42 | 76.026 | 92.992 -densenet201 | 224x224 | 20.01 | 4.37 | 77.152 | 93.548 -dpn107 | 224x224 | 86.92 | 18.42 | 79.746 | 94.684 -dpn131 | 224x224 | 79.25 | 16.13 | 79.432 | 94.574 -dpn68 | 224x224 | 12.61 | 2.36 | 75.868 | 92.774 -dpn68b | 224x224 | 12.61 | 2.36 | 77.034 | 93.59 -dpn92 | 224x224 | 37.67 | 6.56 | 79.4 | 94.62 -dpn98 | 224x224 | 61.57 | 11.76 | 79.224 | 94.488 -fbresnet152 | 224x224 | 60.27 | 11.6 | 77.386 | 93.594 -inceptionresnetv2 | 299x299 | 55.84 | 13.22 | 80.17 | 95.234 -inceptionv3 | 299x299 | 27.16 | 5.73 | 77.294 | 93.454 -inceptionv4 | 299x299 | 42.68 | 12.31 | 80.062 | 94.926 -nasnetalarge | 331x331 | 88.75 | 24.04 | 82.566 | 96.086 -nasnetamobile | 224x224 | 5.29 | 0.59 | 74.08 | 91.74 -pnasnet5large | 331x331 | 86.06 | 25.21 | 82.736 | 95.992 -polynet | 331x331 | 95.37 | 34.9 | 81.002 | 95.624 -resnet101 | 224x224 | 44.55 | 7.85 | 77.438 | 93.672 -resnet152 | 224x224 | 60.19 | 11.58 | 78.428 | 94.11 -resnet18 | 224x224 | 11.69 | 1.82 | 70.142 | 89.274 -resnet34 | 224x224 | 21.8 | 3.68 | 73.554 | 91.456 -resnet50 | 224x224 | 25.56 | 4.12 | 76.002 | 92.98 -resnext101_32x4d | 224x224 | 44.18 | 8.03 | 78.188 | 93.886 -resnext101_64x4d | 224x224 | 83.46 | 15.55 | 78.956 | 94.252 -se_resnet101 | 224x224 | 49.33 | 7.63 | 78.396 | 94.258 -se_resnet152 | 224x224 | 66.82 | 11.37 | 78.658 | 94.374 -se_resnet50 | 224x224 | 28.09 | 3.9 | 77.636 | 93.752 -se_resnext101_32x4d | 224x224 | 48.96 | 8.05 | 80.236 | 95.028 -se_resnext50_32x4d | 224x224 | 27.56 | 4.28 | 79.076 | 94.434 -senet154 | 224x224 | 115.09 | 20.82 | 81.304 | 95.498 -squeezenet1_0 | 224x224 | 1.25 | 0.83 | 58.108 | 80.428 -squeezenet1_1 | 224x224 | 1.24 | 0.36 | 58.25 | 80.8 -vgg11 | 224x224 | 132.86 | 7.63 | 68.97 | 88.746 -vgg11_bn | 224x224 | 132.87 | 7.64 | 70.452 | 89.818 -vgg13 | 224x224 | 133.05 | 11.34 | 69.662 | 89.264 -vgg13_bn | 224x224 | 133.05 | 11.36 | 71.508 | 90.494 -vgg16 | 224x224 | 138.36 | 15.5 | 71.636 | 90.354 -vgg16_bn | 224x224 | 138.37 | 15.53 | 73.518 | 91.608 -vgg19 | 224x224 | 143.67 | 19.67 | 72.08 | 90.822 -vgg19_bn | 224x224 | 143.68 | 19.7 | 74.266 | 92.066 -xception | 299x299 | 22.86 | 8.42 | 78.888 | 94.292 - -* Acc@1 - ImageNet single-crop top-1 accuracy on validation images of the same size used during the training process. -* Acc@5 - ImageNet single-crop top-5 accuracy on validation images of the same size used during the training process. +### [torchvision](https://pytorch.org/vision/0.16/models.html) + +Model | Input Resolution | Params(M) | MACs(G) +--- |--- |--- |--- +alexnet | 224x224 | 61.10 | 0.72 +convnext_base | 224x224 | 88.59 | 15.43 +densenet121 | 224x224 | 7.98 | 2.90 +efficientnet_b0 | 224x224 | 5.29 | 0.41 +efficientnet_v2_m | 224x224 | 54.14 | 5.43 +googlenet | 224x224 | 13.00 | 1.51 +inception_v3 | 224x224 | 27.16 | 2.86 +maxvit_t | 224x224 | 30.92 | 5.48 +mnasnet1_0 | 224x224 | 4.38 | 0.33 +mobilenet_v2 | 224x224 | 3.50 | 0.32 +mobilenet_v3_large | 224x224 | 5.48 | 0.23 +regnet_y_1_6gf | 224x224 | 11.20 | 1.65 +resnet18 | 224x224 | 11.69 | 1.83 +resnet50 | 224x224 | 25.56 | 4.13 +resnext50_32x4d | 224x224 | 25.03 | 4.29 +shufflenet_v2_x1_0 | 224x224 | 2.28 | 0.15 +squeezenet1_0 | 224x224 | 1.25 | 0.84 +vgg16 | 224x224 | 138.36 | 15.52 +wide_resnet50_2 | 224x224 | 68.88 | 11.45 diff --git a/ptflops/pytorch_engine.py b/ptflops/pytorch_engine.py index b7ed8bd..9f6d125 100644 --- a/ptflops/pytorch_engine.py +++ b/ptflops/pytorch_engine.py @@ -415,35 +415,41 @@ def unpatch_functional(): F.interpolate = F.interpolate.op +def wrap_tensor_op(op, collector): + tensor_op_handler = torch_function_wrapper( + op, TENSOR_OPS_MAPPING[op], collector) + + def wrapper(*args, **kwargs): + return tensor_op_handler(*args, **kwargs) + + wrapper.op = tensor_op_handler.op + + return wrapper + + def patch_tensor_ops(collector): torch.matmul = torch_function_wrapper( torch.matmul, TENSOR_OPS_MAPPING[torch.matmul], collector) - torch.Tensor.matmul = torch_function_wrapper( - torch.Tensor.matmul, TENSOR_OPS_MAPPING[torch.Tensor.matmul], collector) + torch.Tensor.matmul = wrap_tensor_op(torch.Tensor.matmul, collector) torch.mm = torch_function_wrapper( torch.mm, TENSOR_OPS_MAPPING[torch.mm], collector) - torch.Tensor.mm = torch_function_wrapper( - torch.Tensor.mm, TENSOR_OPS_MAPPING[torch.Tensor.mm], collector) + torch.Tensor.mm = wrap_tensor_op(torch.Tensor.mm, collector) torch.bmm = torch_function_wrapper( torch.bmm, TENSOR_OPS_MAPPING[torch.bmm], collector) - torch.Tensor.bmm = torch_function_wrapper( - torch.Tensor.bmm, TENSOR_OPS_MAPPING[torch.Tensor.bmm], collector) + torch.Tensor.bmm = wrap_tensor_op(torch.Tensor.bmm, collector) torch.addmm = torch_function_wrapper( torch.addmm, TENSOR_OPS_MAPPING[torch.addmm], collector) - torch.Tensor.addmm = torch_function_wrapper( - torch.Tensor.addmm, TENSOR_OPS_MAPPING[torch.Tensor.addmm], collector) + torch.Tensor.addmm = wrap_tensor_op(torch.Tensor.addmm, collector) torch.baddbmm = torch_function_wrapper( torch.baddbmm, TENSOR_OPS_MAPPING[torch.baddbmm], collector) torch.mul = torch_function_wrapper( torch.mul, TENSOR_OPS_MAPPING[torch.mul], collector) - torch.Tensor.mul = torch_function_wrapper( - torch.Tensor.mul, TENSOR_OPS_MAPPING[torch.Tensor.mul], collector) + torch.Tensor.mul = wrap_tensor_op(torch.Tensor.mul, collector) torch.add = torch_function_wrapper( torch.add, TENSOR_OPS_MAPPING[torch.add], collector) - torch.Tensor.add = torch_function_wrapper( - torch.Tensor.add, TENSOR_OPS_MAPPING[torch.Tensor.add], collector) + torch.Tensor.add = wrap_tensor_op(torch.Tensor.add, collector) def unpatch_tensor_ops(): diff --git a/tests/common_test.py b/tests/common_test.py index 9a89def..510babf 100644 --- a/tests/common_test.py +++ b/tests/common_test.py @@ -95,3 +95,19 @@ def forward(self, x): print_per_layer_stat=False) assert params == 0 assert macs > 0 + + def test_ten_matmul(self): + class CustomModel(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x.matmul(x.t()) + + macs, params = \ + get_model_complexity_info(CustomModel(), (10, ), + as_strings=False, + print_per_layer_stat=False) + + assert params == 0 + assert macs > 0