Skip to content

Commit e4db749

Browse files
committed
tests: add regression test for logger NameError in torchao_quantizer (fixes #13104)
Add a test to prevent regression of the NameError that occurred when `logger` was used inside `_update_torch_safe_globals()` before being defined at module level. The fix (moving `logger` before the function) was included in PR #12901, but this test ensures it cannot regress.
1 parent a80b192 commit e4db749

File tree

1 file changed

+41
-0
lines changed

1 file changed

+41
-0
lines changed

tests/quantization/torchao/test_torchao.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,47 @@
7272
from torchao.quantization import Int8WeightOnlyConfig
7373

7474

75+
class TorchAoQuantizerModuleTest(unittest.TestCase):
76+
def test_logger_defined_before_module_level_code(self):
77+
"""
78+
Regression test for https://github.com/huggingface/diffusers/issues/13104.
79+
Ensures that `logger` is defined at the module level in torchao_quantizer.py
80+
before any module-level code that uses it (such as `_update_torch_safe_globals`).
81+
Previously, `logger` was defined after the function that uses it, causing a
82+
`NameError: name 'logger' is not defined` at import time when the exception
83+
handler in `_update_torch_safe_globals` was triggered.
84+
"""
85+
import ast
86+
import inspect
87+
88+
from diffusers.quantizers.torchao import torchao_quantizer
89+
90+
source = inspect.getsource(torchao_quantizer)
91+
tree = ast.parse(source)
92+
93+
logger_lineno = None
94+
func_lineno = None
95+
96+
for node in ast.walk(tree):
97+
if isinstance(node, ast.Assign):
98+
for target in node.targets:
99+
if isinstance(target, ast.Name) and target.id == "logger":
100+
logger_lineno = node.lineno
101+
if isinstance(node, ast.FunctionDef) and node.name == "_update_torch_safe_globals":
102+
func_lineno = node.lineno
103+
104+
self.assertIsNotNone(logger_lineno, "logger must be defined at module level in torchao_quantizer.py")
105+
self.assertIsNotNone(
106+
func_lineno, "_update_torch_safe_globals must be defined in torchao_quantizer.py"
107+
)
108+
self.assertLess(
109+
logger_lineno,
110+
func_lineno,
111+
f"logger (line {logger_lineno}) must be defined before _update_torch_safe_globals "
112+
f"(line {func_lineno}) to avoid NameError at import time (see issue #13104)",
113+
)
114+
115+
75116
@require_torch
76117
@require_torch_accelerator
77118
@require_torchao_version_greater_or_equal("0.7.0")

0 commit comments

Comments
 (0)