diff --git a/tests/ops/test_deepseek_nsa_cmp_fwd.py b/tests/ops/test_deepseek_nsa_cmp_fwd.py index 8a92d58..7fe96b3 100644 --- a/tests/ops/test_deepseek_nsa_cmp_fwd.py +++ b/tests/ops/test_deepseek_nsa_cmp_fwd.py @@ -1,3 +1,4 @@ +import inspect import pytest import torch @@ -32,8 +33,12 @@ def test_nsa_cmp_fwd_varlen_op( assert group % 16 == 0, "Group size must be a multiple of 16 in NSA" - # Use locals() to create params dictionary from function arguments - params = locals().copy() + # Create params dictionary from function arguments using the function signature + # to avoid including pytest-injected local variables. + # Note: Need to capture locals() before list comprehension due to scope issues + local_vars = locals() + sig = inspect.signature(globals()[inspect.stack()[0].function]) + params = {name: local_vars[name] for name in sig.parameters} benchmark = NSACmpFwdVarlenBenchmark(**params) inputs = benchmark.gen_inputs()