From eb43de88c40ac7e40c163004256b800c234fdda6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 18 Dec 2024 08:48:00 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- thunder/core/rematerialization.py | 4 +++- thunder/core/transform_common.py | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/thunder/core/rematerialization.py b/thunder/core/rematerialization.py index 3817770ce5..1a5241f237 100644 --- a/thunder/core/rematerialization.py +++ b/thunder/core/rematerialization.py @@ -192,7 +192,9 @@ def apply_rematerialization_for_consumer( _, leaves = bsym_list_to_dag(list(new_subsymbols)) new_subsymbols = toposort_bsym_dag(leaves, TOPOSORT_ORDER.BOTTOM_UP) proxy_order = order_proxies(new_subsymbols) - new_consumer_args = tuple(sorted((a for a in new_consumer_args if a.name in proxy_order), key=lambda x: proxy_order[x.name])) + new_consumer_args = tuple( + sorted((a for a in new_consumer_args if a.name in proxy_order), key=lambda x: proxy_order[x.name]) + ) new_consumer = replace(consumer, args=new_consumer_args, subsymbols=new_subsymbols) return new_consumer diff --git a/thunder/core/transform_common.py b/thunder/core/transform_common.py index 3eb4eb6d52..bfe4123dc6 100644 --- a/thunder/core/transform_common.py +++ b/thunder/core/transform_common.py @@ -441,6 +441,7 @@ def order_proxies(bsyms: Sequence[BoundSymbol]) -> dict[str, int]: """ counter = 0 proxy_order: dict[str, int] = {} # names to order + def process_bound_symbols(bound_symbols): nonlocal counter for bsym in bound_symbols: