diff --git a/thunder/transforms/tensor_subclasses.py b/thunder/transforms/tensor_subclasses.py index f32417af34..50c0c9197f 100644 --- a/thunder/transforms/tensor_subclasses.py +++ b/thunder/transforms/tensor_subclasses.py @@ -646,7 +646,7 @@ def tensor_subclass_dce(trace: TraceCtx) -> TraceCtx: """ start_time_ns = time.perf_counter_ns() swap_map: dict[Variable, TensorProxy] = {} - producer_map, consumer_map = utils.producers_and_consumers(trace) + producer_map = utils.producers(trace) bsym_to_exclude: set[BoundSymbol] = set() subclass_flatten_bsym: BoundSymbol