Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update docs #128

Merged
merged 2 commits into from
Dec 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 27 additions & 83 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,18 @@ Supported layers:
Experimental support:
- RNN, LSTM, GRU (NLH layout is assumed)
- RNNCell, LSTMCell, GRUCell
- MultiheadAttention
- torch.nn.MultiheadAttention
- torchvision.ops.DeformConv2d
- visual transformers from [timm](https://github.com/huggingface/pytorch-image-models)

Requirements: Pytorch >= 1.1, torchvision >= 0.3

Thanks to @warmspringwinds for the initial version of script.

## Usage tips

- This tool doesn't take into account some of the `torch.nn.functional.*` and `tensor.*` operations. See `ptflops/pytorch_ops.py:FUNCTIONAL_MAPPING,TENSOR_OPS_MAPPING` to check supported ops.
- This tool doesn't take into account some of the `torch.nn.functional.*` and `tensor.*` operations. Therefore unsupported operations are
not contributing to the final complexity estimation. See `ptflops/pytorch_ops.py:FUNCTIONAL_MAPPING,TENSOR_OPS_MAPPING` to check supported ops.
- `ptflops` launches a given model on a random tensor and estimates amount of computations during inference. Complicated models can have several inputs, some of them could be optional. To construct non-trivial input one can use the `input_constructor` argument of the `get_model_complexity_info`. `input_constructor` is a function that takes the input spatial resolution as a tuple and returns a dict with named input arguments of the model. Next this dict would be passed to the model as a keyword arguments.
- `verbose` parameter allows to get information about modules that don't contribute to the final numbers.
- `ignore_modules` option forces `ptflops` to ignore the listed modules. This can be useful
Expand Down Expand Up @@ -72,84 +74,26 @@ If ptflops was useful for your paper or tech report, please cite me:

## Benchmark

### [torchvision](https://pytorch.org/docs/1.0.0/torchvision/models.html)

Model | Input Resolution | Params(M) | MACs(G) | Top-1 error | Top-5 error
--- |--- |--- |--- |--- |---
alexnet |224x224 | 61.1 | 0.72 | 43.45 | 20.91
vgg11 |224x224 | 132.86 | 7.63 | 30.98 | 11.37
vgg13 |224x224 | 133.05 | 11.34 | 30.07 | 10.75
vgg16 |224x224 | 138.36 | 15.5 | 28.41 | 9.62
vgg19 |224x224 | 143.67 | 19.67 | 27.62 | 9.12
vgg11_bn |224x224 | 132.87 | 7.64 | 29.62 | 10.19
vgg13_bn |224x224 | 133.05 | 11.36 | 28.45 | 9.63
vgg16_bn |224x224 | 138.37 | 15.53 | 26.63 | 8.50
vgg19_bn |224x224 | 143.68 | 19.7 | 25.76 | 8.15
resnet18 |224x224 | 11.69 | 1.82 | 30.24 | 10.92
resnet34 |224x224 | 21.8 | 3.68 | 26.70 | 8.58
resnet50 |224x224 | 25.56 | 4.12 | 23.85 | 7.13
resnet101 |224x224 | 44.55 | 7.85 | 22.63 | 6.44
resnet152 |224x224 | 60.19 | 11.58 | 21.69 | 5.94
squeezenet1_0 |224x224 | 1.25 | 0.83 | 41.90 | 19.58
squeezenet1_1 |224x224 | 1.24 | 0.36 | 41.81 | 19.38
densenet121 |224x224 | 7.98 | 2.88 | 25.35 | 7.83
densenet169 |224x224 | 14.15 | 3.42 | 24.00 | 7.00
densenet201 |224x224 | 20.01 | 4.37 | 22.80 | 6.43
densenet161 |224x224 | 28.68 | 7.82 | 22.35 | 6.20
inception_v3 |224x224 | 27.16 | 2.85 | 22.55 | 6.44

* Top-1 error - ImageNet single-crop top-1 error (224x224)
* Top-5 error - ImageNet single-crop top-5 error (224x224)

### [Cadene/pretrained-models.pytorch](https://github.com/Cadene/pretrained-models.pytorch)

Model | Input Resolution | Params(M) | MACs(G) | Acc@1 | Acc@5
--- |--- |--- |--- |--- |---
alexnet | 224x224 | 61.1 | 0.72 | 56.432 | 79.194
bninception | 224x224 | 11.3 | 2.05 | 73.524 | 91.562
cafferesnet101 | 224x224 | 44.55 | 7.62 | 76.2 | 92.766
densenet121 | 224x224 | 7.98 | 2.88 | 74.646 | 92.136
densenet161 | 224x224 | 28.68 | 7.82 | 77.56 | 93.798
densenet169 | 224x224 | 14.15 | 3.42 | 76.026 | 92.992
densenet201 | 224x224 | 20.01 | 4.37 | 77.152 | 93.548
dpn107 | 224x224 | 86.92 | 18.42 | 79.746 | 94.684
dpn131 | 224x224 | 79.25 | 16.13 | 79.432 | 94.574
dpn68 | 224x224 | 12.61 | 2.36 | 75.868 | 92.774
dpn68b | 224x224 | 12.61 | 2.36 | 77.034 | 93.59
dpn92 | 224x224 | 37.67 | 6.56 | 79.4 | 94.62
dpn98 | 224x224 | 61.57 | 11.76 | 79.224 | 94.488
fbresnet152 | 224x224 | 60.27 | 11.6 | 77.386 | 93.594
inceptionresnetv2 | 299x299 | 55.84 | 13.22 | 80.17 | 95.234
inceptionv3 | 299x299 | 27.16 | 5.73 | 77.294 | 93.454
inceptionv4 | 299x299 | 42.68 | 12.31 | 80.062 | 94.926
nasnetalarge | 331x331 | 88.75 | 24.04 | 82.566 | 96.086
nasnetamobile | 224x224 | 5.29 | 0.59 | 74.08 | 91.74
pnasnet5large | 331x331 | 86.06 | 25.21 | 82.736 | 95.992
polynet | 331x331 | 95.37 | 34.9 | 81.002 | 95.624
resnet101 | 224x224 | 44.55 | 7.85 | 77.438 | 93.672
resnet152 | 224x224 | 60.19 | 11.58 | 78.428 | 94.11
resnet18 | 224x224 | 11.69 | 1.82 | 70.142 | 89.274
resnet34 | 224x224 | 21.8 | 3.68 | 73.554 | 91.456
resnet50 | 224x224 | 25.56 | 4.12 | 76.002 | 92.98
resnext101_32x4d | 224x224 | 44.18 | 8.03 | 78.188 | 93.886
resnext101_64x4d | 224x224 | 83.46 | 15.55 | 78.956 | 94.252
se_resnet101 | 224x224 | 49.33 | 7.63 | 78.396 | 94.258
se_resnet152 | 224x224 | 66.82 | 11.37 | 78.658 | 94.374
se_resnet50 | 224x224 | 28.09 | 3.9 | 77.636 | 93.752
se_resnext101_32x4d | 224x224 | 48.96 | 8.05 | 80.236 | 95.028
se_resnext50_32x4d | 224x224 | 27.56 | 4.28 | 79.076 | 94.434
senet154 | 224x224 | 115.09 | 20.82 | 81.304 | 95.498
squeezenet1_0 | 224x224 | 1.25 | 0.83 | 58.108 | 80.428
squeezenet1_1 | 224x224 | 1.24 | 0.36 | 58.25 | 80.8
vgg11 | 224x224 | 132.86 | 7.63 | 68.97 | 88.746
vgg11_bn | 224x224 | 132.87 | 7.64 | 70.452 | 89.818
vgg13 | 224x224 | 133.05 | 11.34 | 69.662 | 89.264
vgg13_bn | 224x224 | 133.05 | 11.36 | 71.508 | 90.494
vgg16 | 224x224 | 138.36 | 15.5 | 71.636 | 90.354
vgg16_bn | 224x224 | 138.37 | 15.53 | 73.518 | 91.608
vgg19 | 224x224 | 143.67 | 19.67 | 72.08 | 90.822
vgg19_bn | 224x224 | 143.68 | 19.7 | 74.266 | 92.066
xception | 299x299 | 22.86 | 8.42 | 78.888 | 94.292

* Acc@1 - ImageNet single-crop top-1 accuracy on validation images of the same size used during the training process.
* Acc@5 - ImageNet single-crop top-5 accuracy on validation images of the same size used during the training process.
### [torchvision](https://pytorch.org/vision/0.16/models.html)

Model | Input Resolution | Params(M) | MACs(G)
--- |--- |--- |---
alexnet | 224x224 | 61.10 | 0.72
convnext_base | 224x224 | 88.59 | 15.43
densenet121 | 224x224 | 7.98 | 2.90
efficientnet_b0 | 224x224 | 5.29 | 0.41
efficientnet_v2_m | 224x224 | 54.14 | 5.43
googlenet | 224x224 | 13.00 | 1.51
inception_v3 | 224x224 | 27.16 | 2.86
maxvit_t | 224x224 | 30.92 | 5.48
mnasnet1_0 | 224x224 | 4.38 | 0.33
mobilenet_v2 | 224x224 | 3.50 | 0.32
mobilenet_v3_large | 224x224 | 5.48 | 0.23
regnet_y_1_6gf | 224x224 | 11.20 | 1.65
resnet18 | 224x224 | 11.69 | 1.83
resnet50 | 224x224 | 25.56 | 4.13
resnext50_32x4d | 224x224 | 25.03 | 4.29
shufflenet_v2_x1_0 | 224x224 | 2.28 | 0.15
squeezenet1_0 | 224x224 | 1.25 | 0.84
vgg16 | 224x224 | 138.36 | 15.52
wide_resnet50_2 | 224x224 | 68.88 | 11.45
30 changes: 18 additions & 12 deletions ptflops/pytorch_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,35 +415,41 @@ def unpatch_functional():
F.interpolate = F.interpolate.op


def wrap_tensor_op(op, collector):
tensor_op_handler = torch_function_wrapper(
op, TENSOR_OPS_MAPPING[op], collector)

def wrapper(*args, **kwargs):
return tensor_op_handler(*args, **kwargs)

wrapper.op = tensor_op_handler.op

return wrapper


def patch_tensor_ops(collector):
torch.matmul = torch_function_wrapper(
torch.matmul, TENSOR_OPS_MAPPING[torch.matmul], collector)
torch.Tensor.matmul = torch_function_wrapper(
torch.Tensor.matmul, TENSOR_OPS_MAPPING[torch.Tensor.matmul], collector)
torch.Tensor.matmul = wrap_tensor_op(torch.Tensor.matmul, collector)
torch.mm = torch_function_wrapper(
torch.mm, TENSOR_OPS_MAPPING[torch.mm], collector)
torch.Tensor.mm = torch_function_wrapper(
torch.Tensor.mm, TENSOR_OPS_MAPPING[torch.Tensor.mm], collector)
torch.Tensor.mm = wrap_tensor_op(torch.Tensor.mm, collector)
torch.bmm = torch_function_wrapper(
torch.bmm, TENSOR_OPS_MAPPING[torch.bmm], collector)
torch.Tensor.bmm = torch_function_wrapper(
torch.Tensor.bmm, TENSOR_OPS_MAPPING[torch.Tensor.bmm], collector)
torch.Tensor.bmm = wrap_tensor_op(torch.Tensor.bmm, collector)

torch.addmm = torch_function_wrapper(
torch.addmm, TENSOR_OPS_MAPPING[torch.addmm], collector)
torch.Tensor.addmm = torch_function_wrapper(
torch.Tensor.addmm, TENSOR_OPS_MAPPING[torch.Tensor.addmm], collector)
torch.Tensor.addmm = wrap_tensor_op(torch.Tensor.addmm, collector)
torch.baddbmm = torch_function_wrapper(
torch.baddbmm, TENSOR_OPS_MAPPING[torch.baddbmm], collector)

torch.mul = torch_function_wrapper(
torch.mul, TENSOR_OPS_MAPPING[torch.mul], collector)
torch.Tensor.mul = torch_function_wrapper(
torch.Tensor.mul, TENSOR_OPS_MAPPING[torch.Tensor.mul], collector)
torch.Tensor.mul = wrap_tensor_op(torch.Tensor.mul, collector)
torch.add = torch_function_wrapper(
torch.add, TENSOR_OPS_MAPPING[torch.add], collector)
torch.Tensor.add = torch_function_wrapper(
torch.Tensor.add, TENSOR_OPS_MAPPING[torch.Tensor.add], collector)
torch.Tensor.add = wrap_tensor_op(torch.Tensor.add, collector)


def unpatch_tensor_ops():
Expand Down
16 changes: 16 additions & 0 deletions tests/common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,19 @@ def forward(self, x):
print_per_layer_stat=False)
assert params == 0
assert macs > 0

def test_ten_matmul(self):
class CustomModel(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return x.matmul(x.t())

macs, params = \
get_model_complexity_info(CustomModel(), (10, ),
as_strings=False,
print_per_layer_stat=False)

assert params == 0
assert macs > 0
Loading