Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions tests/quantization/torchao/test_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,47 @@
from torchao.quantization import Int8WeightOnlyConfig


class TorchAoQuantizerModuleTest(unittest.TestCase):
def test_logger_defined_before_module_level_code(self):
"""
Regression test for https://github.com/huggingface/diffusers/issues/13104.
Ensures that `logger` is defined at the module level in torchao_quantizer.py
before any module-level code that uses it (such as `_update_torch_safe_globals`).
Previously, `logger` was defined after the function that uses it, causing a
`NameError: name 'logger' is not defined` at import time when the exception
handler in `_update_torch_safe_globals` was triggered.
"""
import ast
import inspect

from diffusers.quantizers.torchao import torchao_quantizer

source = inspect.getsource(torchao_quantizer)
tree = ast.parse(source)

logger_lineno = None
func_lineno = None

for node in ast.walk(tree):
if isinstance(node, ast.Assign):
for target in node.targets:
if isinstance(target, ast.Name) and target.id == "logger":
logger_lineno = node.lineno
if isinstance(node, ast.FunctionDef) and node.name == "_update_torch_safe_globals":
func_lineno = node.lineno

self.assertIsNotNone(logger_lineno, "logger must be defined at module level in torchao_quantizer.py")
self.assertIsNotNone(
func_lineno, "_update_torch_safe_globals must be defined in torchao_quantizer.py"
)
self.assertLess(
logger_lineno,
func_lineno,
f"logger (line {logger_lineno}) must be defined before _update_torch_safe_globals "
f"(line {func_lineno}) to avoid NameError at import time (see issue #13104)",
)


@require_torch
@require_torch_accelerator
@require_torchao_version_greater_or_equal("0.7.0")
Expand Down