Skip to content

Commit

Permalink
support any input resolution in stable diffusion models (#1087)
Browse files Browse the repository at this point in the history
* support any input resolution in stable diffusion models

* Update optimum/exporters/openvino/model_configs.py
  • Loading branch information
eaidova authored Dec 20, 2024
1 parent 8ef3997 commit 420fa87
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 0 deletions.
17 changes: 17 additions & 0 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1783,6 +1783,23 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
return super().generate(input_name, framework, int_dtype, float_dtype)


class DummyUnetVisionInputGenerator(DummyVisionInputGenerator):
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
if input_name not in ["sample", "latent_sample"]:
return super().generate(input_name, framework, int_dtype, float_dtype)
# add height and width discount for enable any resolution generation
return self.random_float_tensor(
shape=[self.batch_size, self.num_channels, self.height - 1, self.width - 1],
framework=framework,
dtype=float_dtype,
)


@register_in_tasks_manager("unet", *["semantic-segmentation"], library_name="diffusers")
class UnetOpenVINOConfig(UNetOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (DummyUnetVisionInputGenerator,) + UNetOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES[1:]


@register_in_tasks_manager("sd3-transformer", *["semantic-segmentation"], library_name="diffusers")
class SD3TransformerOpenVINOConfig(UNetOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (
Expand Down
37 changes: 37 additions & 0 deletions tests/openvino/test_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,17 @@ def test_compare_to_diffusers_pipeline(self, model_arch: str):

np.testing.assert_allclose(ov_output, diffusers_output, atol=6e-3, rtol=1e-2)

# test on inputs nondivisible on 64
height, width, batch_size = 96, 96, 1

for output_type in ["latent", "np", "pt"]:
inputs["output_type"] = output_type

ov_output = ov_pipeline(**inputs, generator=get_generator("pt", SEED)).images
diffusers_output = diffusers_pipeline(**inputs, generator=get_generator("pt", SEED)).images

np.testing.assert_allclose(ov_output, diffusers_output, atol=6e-3, rtol=1e-2)

@parameterized.expand(CALLBACK_SUPPORT_ARCHITECTURES)
@require_diffusers
def test_callback(self, model_arch: str):
Expand Down Expand Up @@ -541,6 +552,20 @@ def test_compare_to_diffusers_pipeline(self, model_arch: str):

np.testing.assert_allclose(ov_output, diffusers_output, atol=6e-3, rtol=1e-2)

# test generation when input resolution nondevisible on 64
height, width, batch_size = 96, 96, 1

inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size, model_type=model_arch)

for output_type in ["latent", "np", "pt"]:
print(output_type)
inputs["output_type"] = output_type

ov_output = ov_pipeline(**inputs, generator=get_generator("pt", SEED)).images
diffusers_output = diffusers_pipeline(**inputs, generator=get_generator("pt", SEED)).images

np.testing.assert_allclose(ov_output, diffusers_output, atol=6e-3, rtol=1e-2)

@parameterized.expand(SUPPORTED_ARCHITECTURES)
@require_diffusers
def test_image_reproducibility(self, model_arch: str):
Expand Down Expand Up @@ -777,6 +802,18 @@ def test_compare_to_diffusers_pipeline(self, model_arch: str):

np.testing.assert_allclose(ov_output, diffusers_output, atol=6e-3, rtol=1e-2)

# test generation when input resolution nondevisible on 64
height, width, batch_size = 96, 96, 1
inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size)

for output_type in ["latent", "np", "pt"]:
inputs["output_type"] = output_type

ov_output = ov_pipeline(**inputs, generator=get_generator("pt", SEED)).images
diffusers_output = diffusers_pipeline(**inputs, generator=get_generator("pt", SEED)).images

np.testing.assert_allclose(ov_output, diffusers_output, atol=6e-3, rtol=1e-2)

@parameterized.expand(SUPPORTED_ARCHITECTURES)
@require_diffusers
def test_image_reproducibility(self, model_arch: str):
Expand Down

0 comments on commit 420fa87

Please sign in to comment.