-
Notifications
You must be signed in to change notification settings - Fork 0
/
prof.py
140 lines (103 loc) · 3.21 KB
/
prof.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import ast
import ctypes
import os
from dataclasses import dataclass
from time import perf_counter
from typing import Dict, Optional
import torch
_lib = ctypes.cdll.LoadLibrary('build/libltprof.so')
_enabled = False
def metrics_enabled():
return os.getenv('ENABLE_METRICS') is not None
def initialize_metrics():
_lib.initializeMetrics()
def finalize_metrics():
_lib.finalizeMetrics()
def begin_profiler_pass():
_lib.beginProfilerPass()
def end_profiler_pass():
_lib.endProfilerPass()
def all_passes_submitted():
return _lib.allPassesSubmitted()
def enable_profiling():
global _enabled
_lib.enableProfiling()
_enabled = True
def disable_profiling():
global _enabled
_lib.enableProfiling()
_enabled = False
@dataclass
class Record:
total: float = 0
min: float = float('inf')
max: float = 0
count: int = 0
begin: Optional[float] = None
_records: Dict[str, Record] = {}
def prof_begin(label: str):
if not _enabled:
return
torch.cuda.synchronize()
if label not in _records:
_records[label] = Record()
_records[label].begin = perf_counter()
def prof_end(label: str):
if not _enabled:
return
torch.cuda.synchronize()
record = _records[label]
assert record.begin is not None
dur = perf_counter() - record.begin
record.begin = None
record.count += 1
record.total += dur
record.min = min(record.min, dur)
record.max = max(record.max, dur)
def fmt_duration(dur: float):
units = ['s', 'ms', 'us', 'ns']
idx = 0
while idx < len(units) - 1 and dur < 1:
dur *= 1e3
idx += 1
return '{:.4}{}'.format(dur, units[idx])
_record_fmt = '{:<16}{:>10}{:>10}{:>10}{:>10}{:>10}'
def print_profiling_results(count: int):
_lib.printProfilingResults(ctypes.c_size_t(count))
if len(_records) == 0:
return
print('\nRanges:')
print(_record_fmt.format(
'Label', 'Count', 'Total', 'Mean', 'Min', 'Max'))
for label, record in _records.items():
print(_record_fmt.format(
label, record.count, fmt_duration(
record.total), fmt_duration(record.total / record.count),
fmt_duration(record.min), fmt_duration(record.max))
)
class ProfileRewriter(ast.NodeTransformer):
def __init__(self) -> None:
super().__init__()
def visit_Call(self, node: ast.Call):
# Filter profiling prints
func = node.func
if not isinstance(func, ast.Name):
return node
if func.id != 'print':
return node
if len(node.args) != 2:
return node
label = node.args[0]
if not isinstance(label, ast.Constant) or not isinstance(label.value, str):
return node
begin = node.args[1]
if not isinstance(begin, ast.Constant) or not isinstance(begin.value, bool):
return node
# Replace with profiling function call
if begin.value:
prof_func_id = prof_begin.__name__
else:
prof_func_id = prof_end.__name__
new_node = ast.Call(func=ast.Name(
id=prof_func_id, ctx=ast.Load()), args=[label], keywords=[])
return ast.fix_missing_locations(new_node)