|
72 | 72 | from torchao.quantization import Int8WeightOnlyConfig |
73 | 73 |
|
74 | 74 |
|
| 75 | +class TorchAoQuantizerModuleTest(unittest.TestCase): |
| 76 | + def test_logger_defined_before_module_level_code(self): |
| 77 | + """ |
| 78 | + Regression test for https://github.com/huggingface/diffusers/issues/13104. |
| 79 | + Ensures that `logger` is defined at the module level in torchao_quantizer.py |
| 80 | + before any module-level code that uses it (such as `_update_torch_safe_globals`). |
| 81 | + Previously, `logger` was defined after the function that uses it, causing a |
| 82 | + `NameError: name 'logger' is not defined` at import time when the exception |
| 83 | + handler in `_update_torch_safe_globals` was triggered. |
| 84 | + """ |
| 85 | + import ast |
| 86 | + import inspect |
| 87 | + |
| 88 | + from diffusers.quantizers.torchao import torchao_quantizer |
| 89 | + |
| 90 | + source = inspect.getsource(torchao_quantizer) |
| 91 | + tree = ast.parse(source) |
| 92 | + |
| 93 | + logger_lineno = None |
| 94 | + func_lineno = None |
| 95 | + |
| 96 | + for node in ast.walk(tree): |
| 97 | + if isinstance(node, ast.Assign): |
| 98 | + for target in node.targets: |
| 99 | + if isinstance(target, ast.Name) and target.id == "logger": |
| 100 | + logger_lineno = node.lineno |
| 101 | + if isinstance(node, ast.FunctionDef) and node.name == "_update_torch_safe_globals": |
| 102 | + func_lineno = node.lineno |
| 103 | + |
| 104 | + self.assertIsNotNone(logger_lineno, "logger must be defined at module level in torchao_quantizer.py") |
| 105 | + self.assertIsNotNone( |
| 106 | + func_lineno, "_update_torch_safe_globals must be defined in torchao_quantizer.py" |
| 107 | + ) |
| 108 | + self.assertLess( |
| 109 | + logger_lineno, |
| 110 | + func_lineno, |
| 111 | + f"logger (line {logger_lineno}) must be defined before _update_torch_safe_globals " |
| 112 | + f"(line {func_lineno}) to avoid NameError at import time (see issue #13104)", |
| 113 | + ) |
| 114 | + |
| 115 | + |
75 | 116 | @require_torch |
76 | 117 | @require_torch_accelerator |
77 | 118 | @require_torchao_version_greater_or_equal("0.7.0") |
|
0 commit comments