Skip to content

Commit 4ed48f6

Browse files
committed
memory annotation
1 parent a218e1d commit 4ed48f6

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

thunder/examine/memory_calculation.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def clear_mutable_collection_argument_memory(
147147
return memory_size
148148

149149

150-
def get_alloc_memory(trc: TraceCtx) -> tuple[int, dict[str, int]]:
150+
def get_alloc_memory(trc: TraceCtx, *, annotate=False) -> tuple[int, dict[str, int]]:
151151
"""
152152
Calculate the memory usage based on the executable trace.
153153
The memory calculation is based only on the compile-time trace, i.e. the input and output shape
@@ -189,6 +189,10 @@ def get_alloc_memory(trc: TraceCtx) -> tuple[int, dict[str, int]]:
189189
impl = partial(impl, is_argument=is_argument)
190190

191191
allocated += impl(bsym, tensor_to_memory_data, name_to_alloc_memory)
192+
if annotate:
193+
if bsym.header:
194+
bsym.header += " "
195+
bsym.header += f"mem after next op: ~{allocated/(2**30):2f}GB"
192196
max_allocated = max(max_allocated, allocated)
193197

194198
return max_allocated, name_to_alloc_memory

0 commit comments

Comments
 (0)