diff --git a/tests/modular_pipelines/test_modular_pipelines_common.py b/tests/modular_pipelines/test_modular_pipelines_common.py index e97b543ff85d..c94f41935938 100644 --- a/tests/modular_pipelines/test_modular_pipelines_common.py +++ b/tests/modular_pipelines/test_modular_pipelines_common.py @@ -1,4 +1,6 @@ import gc +import json +import os import tempfile from typing import Callable @@ -349,6 +351,33 @@ def test_save_from_pretrained(self): assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3 + def test_modular_index_consistency(self): + pipe = self.get_pipeline() + components_spec = pipe._component_specs + components = sorted(components_spec.keys()) + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir) + index_file = os.path.join(tmpdir, "modular_model_index.json") + assert os.path.exists(index_file) + + with open(index_file) as f: + index_contents = json.load(f) + + compulsory_keys = {"_blocks_class_name", "_class_name", "_diffusers_version"} + for k in compulsory_keys: + assert k in index_contents + + to_check_attrs = {"pretrained_model_name_or_path", "revision", "subfolder"} + for component in components: + spec = components_spec[component] + for attr in to_check_attrs: + if getattr(spec, "pretrained_model_name_or_path", None) is not None: + for attr in to_check_attrs: + assert component in index_contents, f"{component} should be present in index but isn't." + attr_value_from_index = index_contents[component][2][attr] + assert getattr(spec, attr) == attr_value_from_index + def test_workflow_map(self): blocks = self.pipeline_blocks_class() if blocks._workflow_map is None: