Hello authors, I noticed in your code the line
attn_mean = attn_mean.view(attn_mean.shape[0] // 4, -1).mean(dim=-1)
This operation seems to merge every 4 consecutive rows into one token. However, in Qwen the token merging is usually performed as a 2×2 spatial merge of neighboring tokens. Could you clarify if this implementation is intentional, or should it be a 2×2 merge instead of row-wise grouping?