From 87adf1cc01ec53854a872a1b62a5fde42ef76f00 Mon Sep 17 00:00:00 2001 From: jokerz0624 <2412711011@qq.com> Date: Mon, 5 Jan 2026 18:32:24 +0800 Subject: [PATCH] [fix]: resolve performance issue in multi-GPU inference --- wan/utils/multitalk_utils.py | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/wan/utils/multitalk_utils.py b/wan/utils/multitalk_utils.py index d33be09..22b9563 100644 --- a/wan/utils/multitalk_utils.py +++ b/wan/utils/multitalk_utils.py @@ -49,21 +49,25 @@ def torch_gc(): def split_token_counts_and_frame_ids(T, token_frame, world_size, rank): S = T * token_frame - split_sizes = [S // world_size + (1 if i < S % world_size else 0) for i in range(world_size)] - start = sum(split_sizes[:rank]) + + # compute split sizes per rank + base = S // world_size + rem = S % world_size + split_sizes = torch.full((world_size,), base, dtype=torch.long) + split_sizes[:rem] += 1 + + start = split_sizes[:rank].sum() end = start + split_sizes[rank] - counts = [0] * T - for idx in range(start, end): - t = idx // token_frame - counts[t] += 1 - - counts_filtered = [] - frame_ids = [] - for t, c in enumerate(counts): - if c > 0: - counts_filtered.append(c) - frame_ids.append(t) - return counts_filtered, frame_ids + + # vectorized mapping: global index -> frame id + idx = torch.arange(start, end, dtype=torch.long) + frame_ids = idx // token_frame + + # unique counts + unique_frames, counts = torch.unique(frame_ids, return_counts=True) + + # return as Python list (optional) + return counts.tolist(), unique_frames.tolist() def normalize_and_scale(column, source_range, target_range, epsilon=1e-8):