From e4db749fd7a830af23649163fc9cb81d033481cb Mon Sep 17 00:00:00 2001 From: Dhruvil Darji Date: Mon, 23 Feb 2026 02:02:53 -0800 Subject: [PATCH] tests: add regression test for logger NameError in torchao_quantizer (fixes #13104) Add a test to prevent regression of the NameError that occurred when `logger` was used inside `_update_torch_safe_globals()` before being defined at module level. The fix (moving `logger` before the function) was included in PR #12901, but this test ensures it cannot regress. --- tests/quantization/torchao/test_torchao.py | 41 ++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py index 7a8e3cc67877..c7005e3fe73d 100644 --- a/tests/quantization/torchao/test_torchao.py +++ b/tests/quantization/torchao/test_torchao.py @@ -72,6 +72,47 @@ from torchao.quantization import Int8WeightOnlyConfig +class TorchAoQuantizerModuleTest(unittest.TestCase): + def test_logger_defined_before_module_level_code(self): + """ + Regression test for https://github.com/huggingface/diffusers/issues/13104. + Ensures that `logger` is defined at the module level in torchao_quantizer.py + before any module-level code that uses it (such as `_update_torch_safe_globals`). + Previously, `logger` was defined after the function that uses it, causing a + `NameError: name 'logger' is not defined` at import time when the exception + handler in `_update_torch_safe_globals` was triggered. + """ + import ast + import inspect + + from diffusers.quantizers.torchao import torchao_quantizer + + source = inspect.getsource(torchao_quantizer) + tree = ast.parse(source) + + logger_lineno = None + func_lineno = None + + for node in ast.walk(tree): + if isinstance(node, ast.Assign): + for target in node.targets: + if isinstance(target, ast.Name) and target.id == "logger": + logger_lineno = node.lineno + if isinstance(node, ast.FunctionDef) and node.name == "_update_torch_safe_globals": + func_lineno = node.lineno + + self.assertIsNotNone(logger_lineno, "logger must be defined at module level in torchao_quantizer.py") + self.assertIsNotNone( + func_lineno, "_update_torch_safe_globals must be defined in torchao_quantizer.py" + ) + self.assertLess( + logger_lineno, + func_lineno, + f"logger (line {logger_lineno}) must be defined before _update_torch_safe_globals " + f"(line {func_lineno}) to avoid NameError at import time (see issue #13104)", + ) + + @require_torch @require_torch_accelerator @require_torchao_version_greater_or_equal("0.7.0")