From 8cddc2ef86094cb745a85887f7657da43f92e761 Mon Sep 17 00:00:00 2001 From: "wangmingfa@bytedance.com" Date: Tue, 12 Aug 2025 15:38:31 +0800 Subject: [PATCH 1/2] feat(adapt-ascend): adapt on ascend 910b --- native_sparse_attention/ops/parallel.py | 384 ++++++++++++++---------- tests/test_nsa.py | 222 +++++++------- 2 files changed, 337 insertions(+), 269 deletions(-) diff --git a/native_sparse_attention/ops/parallel.py b/native_sparse_attention/ops/parallel.py index e3bfff9..694db43 100644 --- a/native_sparse_attention/ops/parallel.py +++ b/native_sparse_attention/ops/parallel.py @@ -138,93 +138,112 @@ def parallel_nsa_compression_fwd_kernel( ) @triton.jit(do_not_specialize=['T']) def parallel_nsa_compression_bwd_kernel_dq( - q, - k, - v, - lse, - delta, - do, - dq, - scale, - offsets, - token_indices, + # 输入参数 + q, # Query矩阵 + k, # Key矩阵 + v, # Value矩阵 + lse, # Log Sum Exp结果 + delta, # softmax导数修正项 + do, # 输出o的梯度 + dq, # Query梯度(输出) + scale, # 缩放因子 + offsets, # 序列偏移量 + token_indices, # token索引 chunk_offsets, - T, - B: tl.constexpr, - H: tl.constexpr, - HQ: tl.constexpr, - G: tl.constexpr, - K: tl.constexpr, - V: tl.constexpr, - BC: tl.constexpr, - BS: tl.constexpr, - BK: tl.constexpr, - BV: tl.constexpr, - USE_OFFSETS: tl.constexpr + T, # 序列长度 + # 常量参数 + B: tl.constexpr, # 块大小 + H: tl.constexpr, # 注意力头数 + HQ: tl.constexpr, # 每个头的维度 + G: tl.constexpr, # 组数 + K: tl.constexpr, # Key维度 + V: tl.constexpr, # Value维度 + BC: tl.constexpr, # 压缩块大小 + BS: tl.constexpr, # 序列块大小 + BK: tl.constexpr, # Key块大小 + BV: tl.constexpr, # Value块大小 + USE_OFFSETS: tl.constexpr, # 是否使用偏移量的标志 ): + # 获取程序ID,用于并行计算 i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) - i_b, i_h = i_bh // H, i_bh % H + i_b, i_h = i_bh // H, i_bh % H # 计算批次和头索引 + # 根据是否使用偏移量设置序列范围 if USE_OFFSETS: + # 加载token信息和偏移量 i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) - T = eos - bos + T = eos - bos # 计算实际序列长度 boc = tl.load(chunk_offsets + i_n).to(tl.int32) else: + # 使用固定长度 bos, eos = i_b * T, i_b * T + T boc = i_b * tl.cdiv(T, BS) + # 调整所有指针的基地址 q += (bos + i_t) * HQ*K do += (bos + i_t) * HQ*V lse += (bos + i_t) * HQ delta += (bos + i_t) * HQ dq += (i_v * B * T + bos + i_t) * HQ*K + # 创建Query和其梯度的块指针 p_q = tl.make_block_ptr(q, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) p_dq = tl.make_block_ptr(dq, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) - # [G, BK] - b_q = tl.load(p_q, boundary_check=(0, 1)) + # 加载并缩放Query块 + b_q = tl.load(p_q, boundary_check=(0, 1)) # [G, BK] b_q = (b_q * scale).to(b_q.dtype) + # 创建输出梯度的块指针 p_do = tl.make_block_ptr(do, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) + # 创建lse和delta的指针 p_lse = lse + i_h * G + tl.arange(0, G) p_delta = delta + i_h * G + tl.arange(0, G) + # 计算压缩表示相关参数 # the number of compression representations in total - TC = tl.cdiv(T, BS) + TC = tl.cdiv(T, BS) # 总压缩块数 # the number of compression representations required to iterate over # incomplete compression blocks are not included - NC = (i_t + 1) // BS + NC = (i_t + 1) // BS # 需要处理的压缩块数 - # [G, BV] - b_do = tl.load(p_do, boundary_check=(0, 1)) - # [G] - b_lse = tl.load(p_lse) - b_delta = tl.load(p_delta) + # 加载必要的数据 + b_do = tl.load(p_do, boundary_check=(0, 1)) # [G, BV]输出梯度 + b_lse = tl.load(p_lse) # [G] log-sum-exp值 + b_delta = tl.load(p_delta) # [G] delta值 - # [G, BK] - b_dq = tl.zeros([G, BK], dtype=tl.float32) + # 初始化Query梯度 + b_dq = tl.zeros([G, BK], dtype=tl.float32) # [G, BK] + + # 主循环:处理所有压缩块 for i_c in range(0, NC, BC): o_c = i_c + tl.arange(0, BC) + + # 创建Key和Value的块指针 p_k = tl.make_block_ptr(k + (boc * H + i_h) * K, (K, TC), (1, H*K), (0, i_c), (BK, BC), (0, 1)) p_v = tl.make_block_ptr(v + (boc * H + i_h) * V, (V, TC), (1, H*V), (i_v * BV, i_c), (BV, BC), (0, 1)) - # [BK, BC] - b_k = tl.load(p_k, boundary_check=(0, 1)) - # [BV, BC] - b_v = tl.load(p_v, boundary_check=(0, 1)) - - # [G, BC] - b_s = tl.dot(b_q, b_k) - b_p = tl.exp(b_s - b_lse[:, None]) - b_p = tl.where((o_c < NC)[None, :], b_p, 0) - - # [G, BV] @ [BV, BC] -> [G, BC] - b_dp = tl.dot(b_do, b_v) - b_ds = b_p * (b_dp.to(tl.float32) - b_delta[:, None]) - # [G, BC] @ [BC, BK] -> [G, BK] + # 加载Key和Value块 + b_k = tl.load(p_k, boundary_check=(0, 1)) # [BK, BC] + b_v = tl.load(p_v, boundary_check=(0, 1)) # [BV, BC] + + # 计算attention scores和probabilities + b_s = tl.dot(b_q, b_k) # [G, BC] = [G, BK] @ [BK, BC] + b_p = tl.exp(b_s - b_lse[:, None]) # 计算softmax概率 + b_p = tl.where((o_c < NC)[None, :], b_p, 0) # 处理边界情况 + + # 计算梯度 + b_dp = tl.dot(b_do, b_v) # [G, BV] @ [BV, BC] -> [G, BC] + # 计算attention score的梯度:p * (dp - delta) + b_ds = b_p * (b_dp.to(tl.float32) - b_delta[:, None]) # [G, BC] + + # 计算Query的梯度 + # dQ = dS @ K.T b_dq += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_k)) + + # 最终处理:应用scale因子 b_dq *= scale + # 存储Query梯度结果 tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) @@ -326,90 +345,102 @@ def parallel_nsa_compression_bwd_kernel_dkv( @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'USE_OFFSETS': lambda args: args['offsets'] is not None # 根据是否提供offsets参数来决定使用模式 }) @triton.autotune( configs=[ triton.Config({}, num_warps=num_warps) - for num_warps in [1, 2, 4] + for num_warps in [1, 2, 4] # 自动调优不同的warp数配置 ], - key=['BS', 'BK'], + key=["BS", "BK"], # 以块大小作为调优key ) @triton.jit def parallel_nsa_kernel_topk( - q, - k, - lse, - scale, - block_indices, - offsets, - token_indices, + # 输入参数 + q, # Query矩阵 + k, # Key矩阵 + lse, # Log Sum Exp值(可选) + scale, # 缩放因子 + block_indices, # 输出的块索引 + offsets, # 序列偏移量(可选) + token_indices, # token索引 chunk_offsets, - T, - H: tl.constexpr, - HQ: tl.constexpr, - G: tl.constexpr, - K: tl.constexpr, - S: tl.constexpr, - BC: tl.constexpr, - BS: tl.constexpr, - BK: tl.constexpr, - USE_OFFSETS: tl.constexpr, + T, # 序列长度 + # 常量参数 + H: tl.constexpr, # 注意力头数 + HQ: tl.constexpr, # 每个头的维度 + G: tl.constexpr, # 组数 + K: tl.constexpr, # Key维度 + S: tl.constexpr, # 要选择的Top块数 + BC: tl.constexpr, # 压缩块大小 + BS: tl.constexpr, # 序列块大小 + BK: tl.constexpr, # Key块大小 + USE_OFFSETS: tl.constexpr, # 是否使用偏移量 ): + # 获取并行维度索引 i_t, i_bh = tl.program_id(0), tl.program_id(1) - i_b, i_h = i_bh // H, i_bh % H + i_b, i_h = i_bh // H, i_bh % H # 计算批次和头索引 + # 处理序列范围 if USE_OFFSETS: + # 使用偏移量时,计算实际序列范围 i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos boc = tl.load(chunk_offsets + i_n).to(tl.int32) else: + # 使用固定长度 bos, eos = i_b * T, i_b * T + T boc = i_b * tl.cdiv(T, BS) + # 创建并加载Query块 p_q = tl.make_block_ptr(q + (bos + i_t) * HQ*K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) # the Q block is kept in the shared memory throughout the whole kernel - # [G, BK] - b_q = tl.load(p_q, boundary_check=(0, 1)) - b_q = (b_q * scale).to(b_q.dtype) + b_q = tl.load(p_q, boundary_check=(0, 1)) # [G, BK] + b_q = (b_q * scale).to(b_q.dtype) # 应用scale + # 计算压缩相关参数 # the number of compression representations in total - TC = tl.cdiv(T, BS) + TC = tl.cdiv(T, BS) # 总压缩块数 # the number of compression representations required to iterate over # incomplete compression blocks are not included - NC = (i_t + 1) // BS + NC = (i_t + 1) // BS # 需要处理的压缩块数 ################################ # 1. lse computation ################################ if lse is not None: + # 如果提供了LSE,直接加载 b_lse = tl.load(lse + (bos + i_t) * HQ + i_h * G + tl.arange(0, G)) else: + # 如果没有提供LSE,需要计算 + # 初始化最大值和累积值 # max scores for the current block b_m = tl.full([G], float('-inf'), dtype=tl.float32) # lse = log(acc) + m b_acc = tl.zeros([G], dtype=tl.float32) + + # 使用online softmax算法计算LSE for i_c in range(0, NC, BC): - o_c = i_c + tl.arange(0, BC) + o_c = i_c + tl.arange(0, BC) # 生成偏移量 + # 加载Key块 p_k = tl.make_block_ptr(k + (boc * H + i_h) * K, (K, TC), (1, H*K), (0, i_c), (BK, BC), (0, 1)) - # [BK, BC] - b_k = tl.load(p_k, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) # [BK, BC] - # [G, BC] - b_s = tl.dot(b_q, b_k) - b_s = tl.where((o_c < NC)[None, :], b_s, float('-inf')) + # 计算attention scores + b_s = tl.dot(b_q, b_k) # [G, BC] = [G, BK] @ [BK, BC] + b_s = tl.where((o_c < NC)[None, :], b_s, float('-inf')) # 处理边界情况 - # [G] - b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m - b_r = tl.exp(b_mp - b_m) - # [G, BC] - b_p = tl.exp(b_s - b_m[:, None]) - # [G] - b_acc = b_acc * b_r + tl.sum(b_p, 1) + # 更新最大值和累积和 + b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m # 更新最大值[G] + b_r = tl.exp(b_mp - b_m) # 计算重缩放因子 + b_p = tl.exp(b_s - b_m[:, None]) # 计算softmax概率[G, BC] + b_acc = b_acc * b_r + tl.sum(b_p, 1) # 更新累积和[G] b_mp = b_m + + # 计算最终的LSE值 if NC == 0: b_lse = tl.zeros([G], dtype=tl.float32) else: @@ -418,139 +449,174 @@ def parallel_nsa_kernel_topk( ################################ # 2. topk selection ################################ - # [BC] - b_i = tl.full([BC], -1, dtype=tl.float32) - o_i = tl.zeros([BC], dtype=tl.int32) - m_i = tl.arange(0, BC) < BC//2 + # 初始化数组用于TopK选择 + b_i = tl.full([BC], -1, dtype=tl.float32) # 存储重要性分数[BC] + o_i = tl.zeros([BC], dtype=tl.int32) # 存储对应索引 + m_i = tl.arange(0, BC) < BC//2 # TopK mask + + # 遍历所有块计算重要性分数 for i_c in range(0, i_t // BS + 1, BC): o_c = i_c + tl.arange(0, BC) + # 加载Key块 p_k = tl.make_block_ptr(k + (boc * H + i_h) * K, (K, TC), (1, H*K), (0, i_c), (BK, BC), (0, 1)) - # [BK, BC] - b_k = tl.load(p_k, boundary_check=(0, 1)) - # [G, BC] - b_s = tl.dot(b_q, b_k) + b_k = tl.load(p_k, boundary_check=(0, 1)) # [BK, BC] + # 计算attention scores和概率 + b_s = tl.dot(b_q, b_k) # [G, BC] + # 对超出范围的块设置为负无穷 b_s = tl.where((i_t // BS > o_c)[None, :], b_s, float('-inf')) - # [G, BC] - b_p = tl.where((i_t // BS == o_c)[None, :], float(1.0), tl.exp(b_s - b_lse[:, None])) + # 计算attention概率 + # 对于当前块使用1.0,其他块使用exp(score - lse) + b_p = tl.where((i_t // BS == o_c)[None, :], float(1.0), tl.exp(b_s - b_lse[:, None])) # [G, BC] + + # 计算每个块的重要性分数(所有头的概率和) # the importance scores of the current block - # [BC] - b_i, b_ip = tl.sum(b_p, 0), b_i + b_i, b_ip = tl.sum(b_p, 0), b_i # [BC], 保存旧值 + # 更新索引,确保在有效范围内 o_i, o_ip = tl.where(o_c <= i_t // BS, o_c + 1, 0), o_i + # 使用双调排序进行TopK选择 + # 计算需要的排序轮数 n_dims: tl.constexpr = tl.standard._log2(b_i.shape[0]) + + # 第一阶段排序 for i in tl.static_range(1, n_dims): b_i, o_i = _bitonic_merge(b_i, o_i.to(tl.int32), i, 2, n_dims) - if i_c != 0: + # 第二阶段排序:合并新旧结果 + if i_c != 0: # 如果不是第一个块,需要与之前的结果合并 b_i, o_i = _bitonic_merge(b_i, o_i.to(tl.int32), n_dims, False, n_dims) + # 使用mask选择保留的值 b_i_new = b_ip * m_i + b_i * (1 - m_i) o_i_new = o_ip * m_i + o_i * (1 - m_i) + # 最终排序 b_i, o_i = _bitonic_merge(b_i_new, o_i_new.to(tl.int32), n_dims, True, n_dims) else: + # 第一个块直接排序 b_i, o_i = _bitonic_merge(b_i, o_i.to(tl.int32), n_dims, True, n_dims) - m_top = tl.arange(0, BC//S) == 0 + # 提取TopK结果 + m_top = tl.arange(0, BC//S) == 0 # 用于选择前S个结果的mask + # 重塑并选择最终的TopK索引 b_top = tl.sum(m_top[:, None] * tl.reshape(o_i - 1, [BC//S, S]), 0) + # 存储TopK结果 p_b = tl.make_block_ptr(block_indices + (bos + i_t) * H*S, (H*S,), (1,), (i_h * S,), (S,), (0,)) tl.store(p_b, b_top.to(p_b.dtype.element_ty)) @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None, - 'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor), + "USE_OFFSETS": lambda args: args["offsets"] is not None, # 是否使用序列偏移量 + "USE_BLOCK_COUNTS": lambda args: isinstance(args["block_counts"], torch.Tensor), # 是否使用可变块数 }) @triton.autotune( configs=[ triton.Config({}, num_warps=num_warps) - for num_warps in [1, 2, 4] + for num_warps in [1, 2, 4] # 自动调优不同的warp数配置 ], - key=['BS', 'BK', 'BV'], + key=["BS", "BK", "BV"], # 以块大小作为调优key ) @triton.jit def parallel_nsa_fwd_kernel( - q, - k, - v, + # 输入参数 + q, # Query矩阵 + k, # Key矩阵 + v, # Value矩阵 o, lse, - scale, - block_indices, - block_counts, - offsets, - token_indices, - T, - H: tl.constexpr, - HQ: tl.constexpr, - G: tl.constexpr, - K: tl.constexpr, - V: tl.constexpr, - S: tl.constexpr, - BS: tl.constexpr, - BK: tl.constexpr, - BV: tl.constexpr, - USE_OFFSETS: tl.constexpr, - USE_BLOCK_COUNTS: tl.constexpr + scale, # 缩放因子 + block_indices, # 选择性注意力的块索引 + block_counts, # 每个位置的块数量(可选) + offsets, # 序列偏移量(可选) + token_indices, # token索引 + T, # 序列长度 + # 常量参数 + H: tl.constexpr, # 注意力头数 + HQ: tl.constexpr, # 每个头的维度 + G: tl.constexpr, # Group Size + K: tl.constexpr, # Key维度 + V: tl.constexpr, # Value维度 + S: tl.constexpr, # 选择的块数 + BS: tl.constexpr, # 序列块大小 + BK: tl.constexpr, # Key块大小 + BV: tl.constexpr, # Value块大小 + USE_OFFSETS: tl.constexpr, # 是否使用偏移量 + USE_BLOCK_COUNTS: tl.constexpr, # 是否使用块数量 ): + # 获取并行维度索引 i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) - i_b, i_h = i_bh // H, i_bh % H + i_b, i_h = i_bh // H, i_bh % H # 计算批次和头索引 + # 处理序列范围 if USE_OFFSETS: + # 使用偏移量时,计算实际序列范围 i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos else: + # 使用固定长度 bos, eos = i_b * T, i_b * T + T + # 调整指针基地址 k += (bos * H + i_h) * K v += (bos * H + i_h) * V block_indices += (bos + i_t) * H*S + i_h * S + # 确定要处理的块数 if USE_BLOCK_COUNTS: NS = tl.load(block_counts + (bos + i_t) * H + i_h) else: NS = S + # 创建并加载Query块 p_q = tl.make_block_ptr(q + (bos + i_t) * HQ*K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) # the Q block is kept in the shared memory throughout the whole kernel # [G, BK] - b_q = tl.load(p_q, boundary_check=(0, 1)) - b_q = (b_q * scale).to(b_q.dtype) + b_q = tl.load(p_q, boundary_check=(0, 1)) # [G, BK] + b_q = (b_q * scale).to(b_q.dtype) # 应用scale因子 + # 创建选择性注意力(SLC)的输出指针 p_o = tl.make_block_ptr(o + (bos + i_t) * HQ*V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) p_lse = lse + (bos + i_t) * HQ + i_h * G + tl.arange(0, G) - # [G, BV] - b_o = tl.zeros([G, BV], dtype=tl.float32) + # 初始化选择性注意力的输出缓冲区 + b_o = tl.zeros([G, BV], dtype=tl.float32) # [G, BV],输出累积 - b_m = tl.full([G], float('-inf'), dtype=tl.float32) - b_acc = tl.zeros([G], dtype=tl.float32) + b_m = tl.full([G], float('-inf'), dtype=tl.float32) # 最大值追踪 + b_acc = tl.zeros([G], dtype=tl.float32) # softmax累积 + + # 选择性注意力计算:遍历所有选定的块 for i in range(NS): + # 加载块索引并计算起始位置 i_s = tl.load(block_indices + i).to(tl.int32) * BS - if i_s <= i_t and i_s >= 0: + if i_s <= i_t and i_s >= 0: # 检查块是否有效 + # 创建Key和Value的块指针 p_k = tl.make_block_ptr(k, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1)) p_v = tl.make_block_ptr(v, (T, V), (H*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0)) - # [BK, BS] - b_k = tl.load(p_k, boundary_check=(0, 1)) - # [BS, BV] - b_v = tl.load(p_v, boundary_check=(0, 1)) - # [G, BS] - b_s = tl.dot(b_q, b_k) + # 加载Key和Value块 + b_k = tl.load(p_k, boundary_check=(0, 1)) # [BK, BS] + b_v = tl.load(p_v, boundary_check=(0, 1)) # [BS, BV] + + # 计算注意力分数 + b_s = tl.dot(b_q, b_k) # [G, BS] + # 处理casual mask:确保只看到当前位置之前的token b_s = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s, float('-inf')) - # [G] - b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m - b_r = tl.exp(b_mp - b_m) - # [G, BS] - b_p = tl.exp(b_s - b_m[:, None]) - # [G] - b_acc = b_acc * b_r + tl.sum(b_p, 1) - # [G, BV] - b_o = b_o * b_r[:, None] + tl.dot(b_p.to(b_q.dtype), b_v) - - b_mp = b_m - b_o = b_o / b_acc[:, None] - b_m += tl.log(b_acc) + # online softmax计算 + b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m # [G],更新最大值 + b_r = tl.exp(b_mp - b_m) # 计算重缩放因子 + # 计算attention权重 + b_p = tl.exp(b_s - b_m[:, None]) # [G, BS] + # 更新softmax累积和 + b_acc = b_acc * b_r + tl.sum(b_p, 1) # [G] + # 更新输出累积 + b_o = b_o * b_r[:, None] + tl.dot(b_p.to(b_q.dtype), b_v) # [G, BV] + + b_mp = b_m # 保存当前最大值供下次迭代使用 + + # 选择性注意力的最终处理 + b_o = b_o / b_acc[:, None] # 归一化输出 + b_m += tl.log(b_acc) # 计算最终的LSE值 + # 存储选择性注意力的结果 tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) tl.store(p_lse, b_m.to(p_lse.dtype.element_ty)) @@ -815,12 +881,12 @@ def parallel_nsa_compression_fwd( H = k.shape[2] G = HQ // H BC = BS = block_size - if torch.cuda.get_device_capability()[0] >= 9: - BK = min(256, triton.next_power_of_2(K)) - BV = min(256, triton.next_power_of_2(V)) - else: - BK = min(128, triton.next_power_of_2(K)) - BV = min(128, triton.next_power_of_2(V)) + # if torch.cuda.get_device_capability()[0] >= 9: + # BK = min(256, triton.next_power_of_2(K)) + # BV = min(256, triton.next_power_of_2(V)) + # else: + BK = min(128, triton.next_power_of_2(K)) + BV = min(128, triton.next_power_of_2(V)) NK = triton.cdiv(K, BK) NV = triton.cdiv(V, BV) assert NK == 1, "The key dimension can not be larger than 256" diff --git a/tests/test_nsa.py b/tests/test_nsa.py index c94c909..3866aa3 100644 --- a/tests/test_nsa.py +++ b/tests/test_nsa.py @@ -6,7 +6,7 @@ import torch import triton -from fla.ops.common.utils import prepare_token_indices +# from fla.ops.common.utils import prepare_token_indices from native_sparse_attention.ops.naive import naive_nsa from native_sparse_attention.ops.parallel import parallel_nsa @@ -52,24 +52,25 @@ def test_parallel( torch.manual_seed(42) os.environ['TRITON_F32_DEFAULT'] = 'ieee' - perm_q = torch.randperm(T, device='cuda') - perm_k = torch.randperm(T, device='cuda') - perm_v = torch.randperm(T, device='cuda') - q = torch.linspace(0, 1, steps=T, dtype=dtype, device='cuda')[perm_q].view(1, T, 1, 1).expand(B, T, HQ, D).clone().requires_grad_(True) - k = torch.linspace(0, 1, steps=T, dtype=dtype, device='cuda')[perm_k].view(1, T, 1, 1).expand(B, T, H, D).clone().requires_grad_(True) - v = torch.linspace(0, 1, steps=T, dtype=dtype, device='cuda')[perm_v].view(1, T, 1, 1).expand(B, T, H, D).clone().requires_grad_(True) - g_slc = torch.rand((B, T, HQ), dtype=dtype, device='cuda').requires_grad_(True) - g_swa = torch.rand((B, T, HQ), dtype=dtype, device='cuda').requires_grad_(True) - do = torch.randn((B, T, HQ, D), dtype=dtype, device='cuda') - - block_indices = torch.full((B, T, H, S), T, dtype=torch.long, device='cuda') + perm_q = torch.randperm(T, device='npu') + perm_k = torch.randperm(T, device='npu') + perm_v = torch.randperm(T, device='npu') + q = torch.linspace(0, 1, steps=T, dtype=dtype, device='npu')[perm_q].view(1, T, 1, 1).expand(B, T, HQ, D).clone().requires_grad_(True) + k = torch.linspace(0, 1, steps=T, dtype=dtype, device='npu')[perm_k].view(1, T, 1, 1).expand(B, T, H, D).clone().requires_grad_(True) + v = torch.linspace(0, 1, steps=T, dtype=dtype, device='npu')[perm_v].view(1, T, 1, 1).expand(B, T, H, D).clone().requires_grad_(True) + g_cmp = torch.rand((B, T, HQ), dtype=dtype, device='npu').requires_grad_(True) + g_slc = torch.rand((B, T, HQ), dtype=dtype, device='npu').requires_grad_(True) + g_swa = torch.rand((B, T, HQ), dtype=dtype, device='npu').requires_grad_(True) + do = torch.randn((B, T, HQ, D), dtype=dtype, device='npu') + + block_indices = torch.full((B, T, H, S), T, dtype=torch.long, device='npu') for b in range(B): for t in range(T): for h in range(H): i_i = torch.randperm(max(1, triton.cdiv(t, block_size)))[:S] block_indices[b, t, h, :len(i_i)] = i_i block_indices = block_indices.sort(-1)[0] - block_counts = torch.randint(1, S + 1, (B, T, H), dtype=torch.long, device='cuda') + block_counts = torch.randint(1, S + 1, (B, T, H), dtype=torch.long, device='npu') ref = naive_nsa( q=q, @@ -95,6 +96,7 @@ def test_parallel( q=q, k=k, v=v, + g_cmp=g_cmp, g_slc=g_slc, g_swa=g_swa, block_indices=block_indices, @@ -120,100 +122,100 @@ def test_parallel( assert_close("dg_swa", ref_dg_swa, tri_dg_swa, 0.005) -@pytest.mark.parametrize("N", [4]) -@pytest.mark.parametrize("T", [64, 128, 200, 250, 256, 300, 400, 512, 1000, 2048]) -@pytest.mark.parametrize("H", [4]) -@pytest.mark.parametrize("HQ", [64]) -@pytest.mark.parametrize("D", [100, 64]) -@pytest.mark.parametrize("S", [16]) -@pytest.mark.parametrize("block_size", [32]) -@pytest.mark.parametrize("window_size", [0, 32]) -@pytest.mark.parametrize("dtype", [torch.bfloat16]) -def test_parallel_varlen( - N: int, - T: int, - H: int, - HQ: int, - D: int, - S: int, - block_size: int, - window_size: int, - dtype: torch.dtype, -): - torch.manual_seed(42) - os.environ['TRITON_F32_DEFAULT'] = 'ieee' - - # randomly split the sequence into N segments - offsets = torch.cat([ - torch.tensor([0], dtype=torch.long), - torch.arange(16, T)[torch.randperm(T - 1)[:N-1]], - torch.tensor([T], dtype=torch.long) - ], 0).cuda().sort()[0] - # seq-first required for inputs with variable lengths - perm_q = torch.randperm(T, device='cuda') - perm_k = torch.randperm(T, device='cuda') - perm_v = torch.randperm(T, device='cuda') - q = torch.linspace(0, 1, steps=T, dtype=dtype, device='cuda')[perm_q].view(1, T, 1, 1).expand(1, T, HQ, D).clone().requires_grad_(True) - k = torch.linspace(0, 1, steps=T, dtype=dtype, device='cuda')[perm_k].view(1, T, 1, 1).expand(1, T, H, D).clone().requires_grad_(True) - v = torch.linspace(0, 1, steps=T, dtype=dtype, device='cuda')[perm_v].view(1, T, 1, 1).expand(1, T, H, D).clone().requires_grad_(True) - g_slc = torch.rand((1, T, HQ), dtype=dtype, device='cuda').requires_grad_(True) - g_swa = torch.rand((1, T, HQ), dtype=dtype, device='cuda').requires_grad_(True) - do = torch.randn((1, T, HQ, D), dtype=dtype, device='cuda') - - token_indices = prepare_token_indices(offsets).tolist() - block_indices = torch.full((1, T, H, S), T, dtype=torch.long, device='cuda') - for i in range(T): - _, t = token_indices[i] - for h in range(H): - i_i = torch.randperm(max(1, triton.cdiv(t, block_size)))[:S] - block_indices[0, i, h, :len(i_i)] = i_i - block_indices = block_indices.sort(-1)[0] - block_counts = torch.randint(1, S + 1, (1, T, H), device='cuda') - - ref = naive_nsa( - q=q, - k=k, - v=v, - g_slc=g_slc, - g_swa=g_swa, - block_indices=block_indices, - block_counts=block_counts, - block_size=block_size, - window_size=window_size, - cu_seqlens=offsets - ) - ref.backward(do) - ref_dq, q.grad = q.grad.clone(), None - ref_dk, k.grad = k.grad.clone(), None - ref_dv, v.grad = v.grad.clone(), None - ref_dg_slc, g_slc.grad = g_slc.grad.clone(), None - if window_size > 0: - ref_dg_swa, g_swa.grad = g_swa.grad.clone(), None - - tri = parallel_nsa( - q=q, - k=k, - v=v, - g_slc=g_slc, - g_swa=g_swa, - block_indices=block_indices, - block_counts=block_counts, - block_size=block_size, - window_size=window_size, - cu_seqlens=offsets - ) - tri.backward(do) - tri_dq, q.grad = q.grad.clone(), None - tri_dk, k.grad = k.grad.clone(), None - tri_dv, v.grad = v.grad.clone(), None - tri_dg_slc, g_slc.grad = g_slc.grad.clone(), None - if window_size > 0: - tri_dg_swa, g_swa.grad = g_swa.grad.clone(), None - - assert_close(" o", ref, tri, 0.004) - assert_close("dq", ref_dq, tri_dq, 0.005) - assert_close("dk", ref_dk, tri_dk, 0.005) - assert_close("dv", ref_dv, tri_dv, 0.005) - assert_close("dg_slc", ref_dg_slc, tri_dg_slc, 0.005) - if window_size > 0: - assert_close("dg_swa", ref_dg_swa, tri_dg_swa, 0.005) +# @pytest.mark.parametrize("N", [4]) +# @pytest.mark.parametrize("T", [64, 128, 200, 250, 256, 300, 400, 512, 1000, 2048]) +# @pytest.mark.parametrize("H", [4]) +# @pytest.mark.parametrize("HQ", [64]) +# @pytest.mark.parametrize("D", [100, 64]) +# @pytest.mark.parametrize("S", [16]) +# @pytest.mark.parametrize("block_size", [32]) +# @pytest.mark.parametrize("window_size", [0, 32]) +# @pytest.mark.parametrize("dtype", [torch.bfloat16]) +# def test_parallel_varlen( +# N: int, +# T: int, +# H: int, +# HQ: int, +# D: int, +# S: int, +# block_size: int, +# window_size: int, +# dtype: torch.dtype, +# ): +# torch.manual_seed(42) +# os.environ['TRITON_F32_DEFAULT'] = 'ieee' + +# # randomly split the sequence into N segments +# offsets = torch.cat([ +# torch.tensor([0], dtype=torch.long), +# torch.arange(16, T)[torch.randperm(T - 1)[:N-1]], +# torch.tensor([T], dtype=torch.long) +# ], 0).npu().sort()[0] +# # seq-first required for inputs with variable lengths +# perm_q = torch.randperm(T, device='npu') +# perm_k = torch.randperm(T, device='npu') +# perm_v = torch.randperm(T, device='npu') +# q = torch.linspace(0, 1, steps=T, dtype=dtype, device='npu')[perm_q].view(1, T, 1, 1).expand(1, T, HQ, D).clone().requires_grad_(True) +# k = torch.linspace(0, 1, steps=T, dtype=dtype, device='npu')[perm_k].view(1, T, 1, 1).expand(1, T, H, D).clone().requires_grad_(True) +# v = torch.linspace(0, 1, steps=T, dtype=dtype, device='npu')[perm_v].view(1, T, 1, 1).expand(1, T, H, D).clone().requires_grad_(True) +# g_slc = torch.rand((1, T, HQ), dtype=dtype, device='npu').requires_grad_(True) +# g_swa = torch.rand((1, T, HQ), dtype=dtype, device='npu').requires_grad_(True) +# do = torch.randn((1, T, HQ, D), dtype=dtype, device='npu') + +# token_indices = prepare_token_indices(offsets).tolist() +# block_indices = torch.full((1, T, H, S), T, dtype=torch.long, device='npu') +# for i in range(T): +# _, t = token_indices[i] +# for h in range(H): +# i_i = torch.randperm(max(1, triton.cdiv(t, block_size)))[:S] +# block_indices[0, i, h, :len(i_i)] = i_i +# block_indices = block_indices.sort(-1)[0] +# block_counts = torch.randint(1, S + 1, (1, T, H), device='npu') + +# ref = naive_nsa( +# q=q, +# k=k, +# v=v, +# g_slc=g_slc, +# g_swa=g_swa, +# block_indices=block_indices, +# block_counts=block_counts, +# block_size=block_size, +# window_size=window_size, +# cu_seqlens=offsets +# ) +# ref.backward(do) +# ref_dq, q.grad = q.grad.clone(), None +# ref_dk, k.grad = k.grad.clone(), None +# ref_dv, v.grad = v.grad.clone(), None +# ref_dg_slc, g_slc.grad = g_slc.grad.clone(), None +# if window_size > 0: +# ref_dg_swa, g_swa.grad = g_swa.grad.clone(), None + +# tri = parallel_nsa( +# q=q, +# k=k, +# v=v, +# g_slc=g_slc, +# g_swa=g_swa, +# block_indices=block_indices, +# block_counts=block_counts, +# block_size=block_size, +# window_size=window_size, +# cu_seqlens=offsets +# ) +# tri.backward(do) +# tri_dq, q.grad = q.grad.clone(), None +# tri_dk, k.grad = k.grad.clone(), None +# tri_dv, v.grad = v.grad.clone(), None +# tri_dg_slc, g_slc.grad = g_slc.grad.clone(), None +# if window_size > 0: +# tri_dg_swa, g_swa.grad = g_swa.grad.clone(), None + +# assert_close(" o", ref, tri, 0.004) +# assert_close("dq", ref_dq, tri_dq, 0.005) +# assert_close("dk", ref_dk, tri_dk, 0.005) +# assert_close("dv", ref_dv, tri_dv, 0.005) +# assert_close("dg_slc", ref_dg_slc, tri_dg_slc, 0.005) +# if window_size > 0: +# assert_close("dg_swa", ref_dg_swa, tri_dg_swa, 0.005) From dceb41930e5308e3918b7581478daa8489a1941f Mon Sep 17 00:00:00 2001 From: "zhaorong.bd" Date: Fri, 22 Aug 2025 18:22:51 +0800 Subject: [PATCH 2/2] add npu modify --- native_sparse_attention/ops/parallel.py | 26 +-- tests/test_nsa_npu.py | 223 ++++++++++++++++++++++++ 2 files changed, 236 insertions(+), 13 deletions(-) create mode 100644 tests/test_nsa_npu.py diff --git a/native_sparse_attention/ops/parallel.py b/native_sparse_attention/ops/parallel.py index 694db43..d361181 100644 --- a/native_sparse_attention/ops/parallel.py +++ b/native_sparse_attention/ops/parallel.py @@ -162,7 +162,7 @@ def parallel_nsa_compression_bwd_kernel_dq( BS: tl.constexpr, # 序列块大小 BK: tl.constexpr, # Key块大小 BV: tl.constexpr, # Value块大小 - USE_OFFSETS: tl.constexpr, # 是否使用偏移量的标志 + USE_OFFSETS: tl.constexpr # 是否使用偏移量的标志 ): # 获取程序ID,用于并行计算 i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) @@ -881,12 +881,12 @@ def parallel_nsa_compression_fwd( H = k.shape[2] G = HQ // H BC = BS = block_size - # if torch.cuda.get_device_capability()[0] >= 9: - # BK = min(256, triton.next_power_of_2(K)) - # BV = min(256, triton.next_power_of_2(V)) - # else: - BK = min(128, triton.next_power_of_2(K)) - BV = min(128, triton.next_power_of_2(V)) + BKV_VALUE=256 + if torch.cuda.is_available(): + if torch.cuda.get_device_capability()[0] < 9: + BKV_VALUE=128 + BK = min(BKV_VALUE, triton.next_power_of_2(K)) + BV = min(BKV_VALUE, triton.next_power_of_2(V)) NK = triton.cdiv(K, BK) NV = triton.cdiv(V, BV) assert NK == 1, "The key dimension can not be larger than 256" @@ -1130,12 +1130,12 @@ def parallel_nsa_fwd( HQ = q.shape[2] G = HQ // H BS = block_size - if torch.cuda.get_device_capability()[0] >= 9: - BK = min(256, triton.next_power_of_2(K)) - BV = min(256, triton.next_power_of_2(V)) - else: - BK = min(128, triton.next_power_of_2(K)) - BV = min(128, triton.next_power_of_2(V)) + BKV_VALUE=256 + if torch.cuda.is_available(): + if torch.cuda.get_device_capability()[0] < 9: + BKV_VALUE=128 + BK = min(BKV_VALUE, triton.next_power_of_2(K)) + BV = min(BKV_VALUE, triton.next_power_of_2(V)) NK = triton.cdiv(K, BK) NV = triton.cdiv(V, BV) assert NK == 1, "The key dimension can not be larger than 256" diff --git a/tests/test_nsa_npu.py b/tests/test_nsa_npu.py new file mode 100644 index 0000000..6e1f4db --- /dev/null +++ b/tests/test_nsa_npu.py @@ -0,0 +1,223 @@ +# -*- coding: utf-8 -*- + +import os + +import pytest +import torch +import triton +import torch_npu +from fla.ops.common.utils import prepare_token_indices +from native_sparse_attention.ops.naive import naive_nsa +from native_sparse_attention.ops.parallel import parallel_nsa + + +def get_abs_err(x, y): + return (x-y).flatten().abs().max().item() + + +def get_err_ratio(x, y): + err = (x-y).flatten().square().mean().sqrt().item() + base = (x).flatten().square().mean().sqrt().item() + return err / base + + +def assert_close(prefix, ref, tri, ratio): + msg = f"{prefix} diff: {get_abs_err(ref, tri):.6f} ratio: {get_err_ratio(ref, tri):.6f}" + print(msg) + assert get_err_ratio(ref, tri) < ratio, msg + + +@pytest.mark.parametrize("B", [1]) +@pytest.mark.parametrize("T", [256, 1024, 2000]) +@pytest.mark.parametrize("H", [4]) +@pytest.mark.parametrize("HQ", [64]) +@pytest.mark.parametrize("D", [100, 64]) +@pytest.mark.parametrize("S", [16]) +@pytest.mark.parametrize("block_size", [32]) +@pytest.mark.parametrize("window_size", [0, 32]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("scale", [0.1]) +def test_parallel( + B: int, + H: int, + HQ: int, + T: int, + D: int, + S: int, + block_size: int, + window_size: int, + dtype: torch.dtype, + scale: float +): + torch.manual_seed(42) + os.environ['TRITON_F32_DEFAULT'] = 'ieee' + + perm_q = torch.randperm(T, device='npu') + perm_k = torch.randperm(T, device='npu') + perm_v = torch.randperm(T, device='npu') + q = torch.linspace(0, 1, steps=T, dtype=dtype, device='npu')[perm_q].view(1, T, 1, 1).expand(B, T, HQ, D).clone().requires_grad_(True) + k = torch.linspace(0, 1, steps=T, dtype=dtype, device='npu')[perm_k].view(1, T, 1, 1).expand(B, T, H, D).clone().requires_grad_(True) + v = torch.linspace(0, 1, steps=T, dtype=dtype, device='npu')[perm_v].view(1, T, 1, 1).expand(B, T, H, D).clone().requires_grad_(True) + g_cmp = torch.rand((B, T, HQ), dtype=dtype, device='npu').requires_grad_(True) + g_slc = torch.rand((B, T, HQ), dtype=dtype, device='npu').requires_grad_(True) + g_swa = torch.rand((B, T, HQ), dtype=dtype, device='npu').requires_grad_(True) + do = torch.randn((B, T, HQ, D), dtype=dtype, device='npu') + + block_indices = torch.full((B, T, H, S), T, dtype=torch.long, device='npu') + for b in range(B): + for t in range(T): + for h in range(H): + i_i = torch.randperm(max(1, triton.cdiv(t, block_size)))[:S] + block_indices[b, t, h, :len(i_i)] = i_i + block_indices = block_indices.sort(-1)[0] + block_counts = torch.randint(1, S + 1, (B, T, H), dtype=torch.long, device='npu') + + ref = naive_nsa( + q=q, + k=k, + v=v, + g_slc=g_slc, + g_swa=g_swa, + block_indices=block_indices, + block_counts=block_counts, + block_size=block_size, + window_size=window_size, + scale=scale + ) + ref.backward(do) + ref_dq, q.grad = q.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dv, v.grad = v.grad.clone(), None + ref_dg_slc, g_slc.grad = g_slc.grad.clone(), None + if window_size > 0: + ref_dg_swa, g_swa.grad = g_swa.grad.clone(), None + + tri = parallel_nsa( + q=q, + k=k, + v=v, + g_cmp=g_cmp, + g_slc=g_slc, + g_swa=g_swa, + block_indices=block_indices, + block_counts=block_counts, + block_size=block_size, + window_size=window_size, + scale=scale + ) + tri.backward(do) + tri_dq, q.grad = q.grad.clone(), None + tri_dk, k.grad = k.grad.clone(), None + tri_dv, v.grad = v.grad.clone(), None + tri_dg_slc, g_slc.grad = g_slc.grad.clone(), None + if window_size > 0: + tri_dg_swa, g_swa.grad = g_swa.grad.clone(), None + + assert_close(" o", ref, tri, 0.005) + assert_close("dq", ref_dq, tri_dq, 0.005) + assert_close("dk", ref_dk, tri_dk, 0.005) + assert_close("dv", ref_dv, tri_dv, 0.005) + assert_close("dg_slc", ref_dg_slc, tri_dg_slc, 0.005) + if window_size > 0: + assert_close("dg_swa", ref_dg_swa, tri_dg_swa, 0.005) + + +@pytest.mark.parametrize("N", [4]) +@pytest.mark.parametrize("T", [64, 128, 200, 250, 256, 300, 400, 512, 1000, 2048]) +@pytest.mark.parametrize("H", [4]) +@pytest.mark.parametrize("HQ", [64]) +@pytest.mark.parametrize("D", [100, 64]) +@pytest.mark.parametrize("S", [16]) +@pytest.mark.parametrize("block_size", [32]) +@pytest.mark.parametrize("window_size", [0, 32]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +def test_parallel_varlen( + N: int, + T: int, + H: int, + HQ: int, + D: int, + S: int, + block_size: int, + window_size: int, + dtype: torch.dtype, +): + torch.manual_seed(42) + os.environ['TRITON_F32_DEFAULT'] = 'ieee' + + # randomly split the sequence into N segments + offsets = torch.cat([ + torch.tensor([0], dtype=torch.long), + torch.arange(16, T)[torch.randperm(T - 1)[:N-1]], + torch.tensor([T], dtype=torch.long) + ], 0).npu().sort()[0] + # seq-first required for inputs with variable lengths + perm_q = torch.randperm(T, device='npu') + perm_k = torch.randperm(T, device='npu') + perm_v = torch.randperm(T, device='npu') + q = torch.linspace(0, 1, steps=T, dtype=dtype, device='npu')[perm_q].view(1, T, 1, 1).expand(1, T, HQ, D).clone().requires_grad_(True) + k = torch.linspace(0, 1, steps=T, dtype=dtype, device='npu')[perm_k].view(1, T, 1, 1).expand(1, T, H, D).clone().requires_grad_(True) + v = torch.linspace(0, 1, steps=T, dtype=dtype, device='npu')[perm_v].view(1, T, 1, 1).expand(1, T, H, D).clone().requires_grad_(True) + g_cmp = torch.rand((1, T, HQ), dtype=dtype, device='npu').requires_grad_(True) + g_slc = torch.rand((1, T, HQ), dtype=dtype, device='npu').requires_grad_(True) + g_swa = torch.rand((1, T, HQ), dtype=dtype, device='npu').requires_grad_(True) + do = torch.randn((1, T, HQ, D), dtype=dtype, device='npu') + + token_indices = prepare_token_indices(offsets).tolist() + block_indices = torch.full((1, T, H, S), T, dtype=torch.long, device='npu') + for i in range(T): + _, t = token_indices[i] + for h in range(H): + i_i = torch.randperm(max(1, triton.cdiv(t, block_size)))[:S] + block_indices[0, i, h, :len(i_i)] = i_i + block_indices = block_indices.sort(-1)[0] + block_counts = torch.randint(1, S + 1, (1, T, H), device='npu') + + ref = naive_nsa( + q=q, + k=k, + v=v, + g_slc=g_slc, + g_swa=g_swa, + block_indices=block_indices, + block_counts=block_counts, + block_size=block_size, + window_size=window_size, + cu_seqlens=offsets + ) + ref.backward(do) + ref_dq, q.grad = q.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dv, v.grad = v.grad.clone(), None + ref_dg_slc, g_slc.grad = g_slc.grad.clone(), None + if window_size > 0: + ref_dg_swa, g_swa.grad = g_swa.grad.clone(), None + + tri = parallel_nsa( + q=q, + k=k, + v=v, + g_cmp=g_cmp, + g_slc=g_slc, + g_swa=g_swa, + block_indices=block_indices, + block_counts=block_counts, + block_size=block_size, + window_size=window_size, + cu_seqlens=offsets + ) + tri.backward(do) + tri_dq, q.grad = q.grad.clone(), None + tri_dk, k.grad = k.grad.clone(), None + tri_dv, v.grad = v.grad.clone(), None + tri_dg_slc, g_slc.grad = g_slc.grad.clone(), None + if window_size > 0: + tri_dg_swa, g_swa.grad = g_swa.grad.clone(), None + + assert_close(" o", ref, tri, 0.004) + assert_close("dq", ref_dq, tri_dq, 0.005) + assert_close("dk", ref_dk, tri_dk, 0.005) + assert_close("dv", ref_dv, tri_dv, 0.005) + assert_close("dg_slc", ref_dg_slc, tri_dg_slc, 0.005) + if window_size > 0: + assert_close("dg_swa", ref_dg_swa, tri_dg_swa, 0.005)