diff --git a/thunder/core/rematerialization.py b/thunder/core/rematerialization.py index 3817770ce5..1a5241f237 100644 --- a/thunder/core/rematerialization.py +++ b/thunder/core/rematerialization.py @@ -192,7 +192,9 @@ def apply_rematerialization_for_consumer( _, leaves = bsym_list_to_dag(list(new_subsymbols)) new_subsymbols = toposort_bsym_dag(leaves, TOPOSORT_ORDER.BOTTOM_UP) proxy_order = order_proxies(new_subsymbols) - new_consumer_args = tuple(sorted((a for a in new_consumer_args if a.name in proxy_order), key=lambda x: proxy_order[x.name])) + new_consumer_args = tuple( + sorted((a for a in new_consumer_args if a.name in proxy_order), key=lambda x: proxy_order[x.name]) + ) new_consumer = replace(consumer, args=new_consumer_args, subsymbols=new_subsymbols) return new_consumer diff --git a/thunder/core/transform_common.py b/thunder/core/transform_common.py index 3eb4eb6d52..bfe4123dc6 100644 --- a/thunder/core/transform_common.py +++ b/thunder/core/transform_common.py @@ -441,6 +441,7 @@ def order_proxies(bsyms: Sequence[BoundSymbol]) -> dict[str, int]: """ counter = 0 proxy_order: dict[str, int] = {} # names to order + def process_bound_symbols(bound_symbols): nonlocal counter for bsym in bound_symbols: