Skip to content

Commit

Permalink
[update]: embed print function to show report.
Browse files Browse the repository at this point in the history
  • Loading branch information
Swall0w committed Oct 28, 2018
1 parent dcadc0f commit 97021bb
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 19 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,11 @@ from torchstat import stat
from torchvision.models as models

model = models.resnet18()
report = stat(model, (3, 224, 224))
print(report)
stat(model, (3, 224, 224))
```

## Features
**Note**: These features work only nn.Module. Modules in torch.nn.functional are not supported yet.
- [x] FLOPs
- [x] Number of Parameters
- [x] Total memory
Expand Down
3 changes: 1 addition & 2 deletions example.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,4 @@ def forward(self, x):

if __name__ == '__main__':
model = Net()
report = stat(model, (3, 224, 224))
print(report)
stat(model, (3, 224, 224))
6 changes: 3 additions & 3 deletions torchstat/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
__copyright__ = 'Copyright (C) 2018 Swall0w'
__version__ = '0.0.3'
__version__ = '0.0.4'
__author__ = 'Swall0w'
__url__ = 'https://github.com/Swall0w/torchstat'

Expand All @@ -8,7 +8,7 @@
from torchstat.stat_tree import StatTree, StatNode
from torchstat.model_hook import ModelHook
from torchstat.reporter import report_format
from torchstat.statistics import stat
from torchstat.statistics import stat, ModelStat

__all__ = ['report_format', 'StatTree', 'StatNode', 'compute_madd',
'compute_flops', 'ModelHook', 'stat', '__main__']
'compute_flops', 'ModelHook', 'stat','ModelStat', '__main__']
4 changes: 1 addition & 3 deletions torchstat/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,4 @@ def main():
import sys; sys.exit()

input_size = tuple(int(x) for x in args.size.split('x'))
report = stat(model, input_size, query_granularity=1)

print(report)
stat(model, input_size, query_granularity=1)
29 changes: 20 additions & 9 deletions torchstat/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,24 @@ def convert_leaf_modules_to_stat_tree(leaf_modules):
return StatTree(root_node)


class ModelStat(object):
def __init__(self, model, input_size, query_granularity=1):
assert isinstance(model, nn.Module)
assert isinstance(input_size, (tuple, list)) and len(input_size) == 3
self._model = model
self._input_size = input_size
self._query_granularity = query_granularity


def show_report(self):
model_hook = ModelHook(self._model, self._input_size)
leaf_modules = model_hook.retrieve_leaf_modules()
stat_tree = convert_leaf_modules_to_stat_tree(leaf_modules)
collected_nodes = stat_tree.get_collected_stat_nodes(self._query_granularity)
report = report_format(collected_nodes)
print(report)


def stat(model, input_size, query_granularity=1):
assert isinstance(model, nn.Module)
assert isinstance(input_size, (tuple, list)) and len(input_size) == 3

model_hook = ModelHook(model, input_size)
leaf_modules = model_hook.retrieve_leaf_modules()
stat_tree = convert_leaf_modules_to_stat_tree(leaf_modules)
collected_nodes = stat_tree.get_collected_stat_nodes(query_granularity)
report = report_format(collected_nodes)
return report
ms = ModelStat(model, input_size, query_granularity)
ms.show_report()

0 comments on commit 97021bb

Please sign in to comment.