Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@
torch_all_close,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ..test_modeling_common import ModelTesterMixin


enable_full_determinism()


class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class AsymmetricAutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
Copy link
Collaborator

@dg845 dg845 Oct 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The following autoencoder modules currently don't inherit from AutoencoderTesterMixin:

  • AsymmetricAutoencoderKL
  • AutoencoderKLCosmos
  • AutoencoderDC
  • AutoencoderKLTemporalDecoder
  • AutoencoderKLMagvit
  • AutoencoderKLMochi
  • AutoencoderOobleck
  • VQModel

Is this because they don't have enable_tiling or enable_slicing methods? Since test_enable_disable_tiling and test_enable_disable_slicing check for these methods, it seems like it should be safe for them to inherit from AutoencoderTesterMixin. It would also have the added benefit of allowing them to have the one test that was moved from UNetTesterMixin into AutoencoderTesterMixin.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No one has used it :/

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would also have the added benefit of allowing them to have the one test that was moved from UNetTesterMixin into AutoencoderTesterMixin.

Do you mean test_forward_with_norm_groups?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I edited my comment above to try to consolidate all of the models that don't inherit from AutoencoderTesterMixin. Do you think it would make sense for all autoencoder test classes to inherit from AutoencoderTesterMixin, as the slicing/tiling tests will be skipped if enable_slicing/enable_tiling methods are not available?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean test_forward_with_norm_groups?

Yeah I think that's the one

model_class = AsymmetricAutoencoderKL
main_input_name = "sample"
base_precision = 1e-2
Expand Down
4 changes: 2 additions & 2 deletions tests/models/autoencoders/test_models_autoencoder_cosmos.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
from diffusers import AutoencoderKLCosmos

from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ..test_modeling_common import ModelTesterMixin


enable_full_determinism()


class AutoencoderKLCosmosTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class AutoencoderKLCosmosTests(ModelTesterMixin, unittest.TestCase):
model_class = AutoencoderKLCosmos
main_input_name = "sample"
base_precision = 1e-2
Expand Down
4 changes: 2 additions & 2 deletions tests/models/autoencoders/test_models_autoencoder_dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@
floats_tensor,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ..test_modeling_common import ModelTesterMixin


enable_full_determinism()


class AutoencoderDCTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class AutoencoderDCTests(ModelTesterMixin, unittest.TestCase):
model_class = AutoencoderDC
main_input_name = "sample"
base_precision = 1e-2
Expand Down
66 changes: 2 additions & 64 deletions tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@
floats_tensor,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ..test_modeling_common import AutoencoderTesterMixin, ModelTesterMixin


enable_full_determinism()


class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLHunyuanVideo
main_input_name = "sample"
base_precision = 1e-2
Expand Down Expand Up @@ -87,68 +87,6 @@ def prepare_init_args_and_inputs_for_common(self):
inputs_dict = self.dummy_input
return init_dict, inputs_dict

def test_enable_disable_tiling(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)

inputs_dict.update({"return_dict": False})

torch.manual_seed(0)
output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]

torch.manual_seed(0)
model.enable_tiling()
output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]

self.assertLess(
(output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
0.5,
"VAE tiling should not affect the inference results",
)

torch.manual_seed(0)
model.disable_tiling()
output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]

self.assertEqual(
output_without_tiling.detach().cpu().numpy().all(),
output_without_tiling_2.detach().cpu().numpy().all(),
"Without tiling outputs should match with the outputs when tiling is manually disabled.",
)

def test_enable_disable_slicing(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)

inputs_dict.update({"return_dict": False})

torch.manual_seed(0)
output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]

torch.manual_seed(0)
model.enable_slicing()
output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]

self.assertLess(
(output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
0.5,
"VAE slicing should not affect the inference results",
)

torch.manual_seed(0)
model.disable_slicing()
output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]

self.assertEqual(
output_without_slicing.detach().cpu().numpy().all(),
output_without_slicing_2.detach().cpu().numpy().all(),
"Without slicing outputs should match with the outputs when slicing is manually disabled.",
)

def test_gradient_checkpointing_is_applied(self):
expected_set = {
"HunyuanVideoDecoder3D",
Expand Down
66 changes: 2 additions & 64 deletions tests/models/autoencoders/test_models_autoencoder_kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@
torch_all_close,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ..test_modeling_common import AutoencoderTesterMixin, ModelTesterMixin


enable_full_determinism()


class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class AutoencoderKLTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKL
main_input_name = "sample"
base_precision = 1e-2
Expand Down Expand Up @@ -83,68 +83,6 @@ def prepare_init_args_and_inputs_for_common(self):
inputs_dict = self.dummy_input
return init_dict, inputs_dict

def test_enable_disable_tiling(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)

inputs_dict.update({"return_dict": False})

torch.manual_seed(0)
output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]

torch.manual_seed(0)
model.enable_tiling()
output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]

self.assertLess(
(output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
0.5,
"VAE tiling should not affect the inference results",
)

torch.manual_seed(0)
model.disable_tiling()
output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]

self.assertEqual(
output_without_tiling.detach().cpu().numpy().all(),
output_without_tiling_2.detach().cpu().numpy().all(),
"Without tiling outputs should match with the outputs when tiling is manually disabled.",
)

def test_enable_disable_slicing(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)

inputs_dict.update({"return_dict": False})

torch.manual_seed(0)
output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]

torch.manual_seed(0)
model.enable_slicing()
output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]

self.assertLess(
(output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
0.5,
"VAE slicing should not affect the inference results",
)

torch.manual_seed(0)
model.disable_slicing()
output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]

self.assertEqual(
output_without_slicing.detach().cpu().numpy().all(),
output_without_slicing_2.detach().cpu().numpy().all(),
"Without slicing outputs should match with the outputs when slicing is manually disabled.",
)

def test_gradient_checkpointing_is_applied(self):
expected_set = {"Decoder", "Encoder", "UNetMidBlock2D"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
Expand Down
66 changes: 2 additions & 64 deletions tests/models/autoencoders/test_models_autoencoder_kl_cogvideox.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@
floats_tensor,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ..test_modeling_common import AutoencoderTesterMixin, ModelTesterMixin


enable_full_determinism()


class AutoencoderKLCogVideoXTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class AutoencoderKLCogVideoXTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLCogVideoX
main_input_name = "sample"
base_precision = 1e-2
Expand Down Expand Up @@ -82,68 +82,6 @@ def prepare_init_args_and_inputs_for_common(self):
inputs_dict = self.dummy_input
return init_dict, inputs_dict

def test_enable_disable_tiling(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)

inputs_dict.update({"return_dict": False})

torch.manual_seed(0)
output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]

torch.manual_seed(0)
model.enable_tiling()
output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]

self.assertLess(
(output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
0.5,
"VAE tiling should not affect the inference results",
)

torch.manual_seed(0)
model.disable_tiling()
output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]

self.assertEqual(
output_without_tiling.detach().cpu().numpy().all(),
output_without_tiling_2.detach().cpu().numpy().all(),
"Without tiling outputs should match with the outputs when tiling is manually disabled.",
)

def test_enable_disable_slicing(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)

inputs_dict.update({"return_dict": False})

torch.manual_seed(0)
output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]

torch.manual_seed(0)
model.enable_slicing()
output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]

self.assertLess(
(output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
0.5,
"VAE slicing should not affect the inference results",
)

torch.manual_seed(0)
model.disable_slicing()
output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]

self.assertEqual(
output_without_slicing.detach().cpu().numpy().all(),
output_without_slicing_2.detach().cpu().numpy().all(),
"Without slicing outputs should match with the outputs when slicing is manually disabled.",
)

def test_gradient_checkpointing_is_applied(self):
expected_set = {
"CogVideoXDownBlock3D",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@
floats_tensor,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ..test_modeling_common import ModelTesterMixin


enable_full_determinism()


class AutoencoderKLTemporalDecoderTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class AutoencoderKLTemporalDecoderTests(ModelTesterMixin, unittest.TestCase):
model_class = AutoencoderKLTemporalDecoder
main_input_name = "sample"
base_precision = 1e-2
Expand Down
37 changes: 3 additions & 34 deletions tests/models/autoencoders/test_models_autoencoder_ltx_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@
floats_tensor,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ..test_modeling_common import AutoencoderTesterMixin, ModelTesterMixin


enable_full_determinism()


class AutoencoderKLLTXVideo090Tests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class AutoencoderKLLTXVideo090Tests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLLTXVideo
main_input_name = "sample"
base_precision = 1e-2
Expand Down Expand Up @@ -99,7 +99,7 @@ def test_forward_with_norm_groups(self):
pass


class AutoencoderKLLTXVideo091Tests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class AutoencoderKLLTXVideo091Tests(ModelTesterMixin, unittest.TestCase):
model_class = AutoencoderKLLTXVideo
main_input_name = "sample"
base_precision = 1e-2
Expand Down Expand Up @@ -167,34 +167,3 @@ def test_outputs_equivalence(self):
@unittest.skip("AutoencoderKLLTXVideo does not support `norm_num_groups` because it does not use GroupNorm.")
def test_forward_with_norm_groups(self):
pass

def test_enable_disable_tiling(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)

inputs_dict.update({"return_dict": False})

torch.manual_seed(0)
output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]

torch.manual_seed(0)
model.enable_tiling()
output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]

self.assertLess(
(output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
0.5,
"VAE tiling should not affect the inference results",
)

torch.manual_seed(0)
model.disable_tiling()
output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]

self.assertEqual(
output_without_tiling.detach().cpu().numpy().all(),
output_without_tiling_2.detach().cpu().numpy().all(),
"Without tiling outputs should match with the outputs when tiling is manually disabled.",
)
4 changes: 2 additions & 2 deletions tests/models/autoencoders/test_models_autoencoder_magvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
from diffusers import AutoencoderKLMagvit

from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ..test_modeling_common import ModelTesterMixin


enable_full_determinism()


class AutoencoderKLMagvitTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class AutoencoderKLMagvitTests(ModelTesterMixin, unittest.TestCase):
model_class = AutoencoderKLMagvit
main_input_name = "sample"
base_precision = 1e-2
Expand Down
Loading
Loading