diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py index 7a8e3cc67877..c7005e3fe73d 100644 --- a/tests/quantization/torchao/test_torchao.py +++ b/tests/quantization/torchao/test_torchao.py @@ -72,6 +72,47 @@ from torchao.quantization import Int8WeightOnlyConfig +class TorchAoQuantizerModuleTest(unittest.TestCase): + def test_logger_defined_before_module_level_code(self): + """ + Regression test for https://github.com/huggingface/diffusers/issues/13104. + Ensures that `logger` is defined at the module level in torchao_quantizer.py + before any module-level code that uses it (such as `_update_torch_safe_globals`). + Previously, `logger` was defined after the function that uses it, causing a + `NameError: name 'logger' is not defined` at import time when the exception + handler in `_update_torch_safe_globals` was triggered. + """ + import ast + import inspect + + from diffusers.quantizers.torchao import torchao_quantizer + + source = inspect.getsource(torchao_quantizer) + tree = ast.parse(source) + + logger_lineno = None + func_lineno = None + + for node in ast.walk(tree): + if isinstance(node, ast.Assign): + for target in node.targets: + if isinstance(target, ast.Name) and target.id == "logger": + logger_lineno = node.lineno + if isinstance(node, ast.FunctionDef) and node.name == "_update_torch_safe_globals": + func_lineno = node.lineno + + self.assertIsNotNone(logger_lineno, "logger must be defined at module level in torchao_quantizer.py") + self.assertIsNotNone( + func_lineno, "_update_torch_safe_globals must be defined in torchao_quantizer.py" + ) + self.assertLess( + logger_lineno, + func_lineno, + f"logger (line {logger_lineno}) must be defined before _update_torch_safe_globals " + f"(line {func_lineno}) to avoid NameError at import time (see issue #13104)", + ) + + @require_torch @require_torch_accelerator @require_torchao_version_greater_or_equal("0.7.0")