Skip to content

Commit

Permalink
fix timestep export shapes in sd3 and flux and tests with diffusers 0…
Browse files Browse the repository at this point in the history
….32 (#1094)
  • Loading branch information
eaidova authored Dec 26, 2024
1 parent 8a56275 commit 014a840
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
6 changes: 3 additions & 3 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1806,7 +1806,7 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int

@register_in_tasks_manager("unet", *["semantic-segmentation"], library_name="diffusers")
@register_in_tasks_manager("unet-2d-condition", *["semantic-segmentation"], library_name="diffusers")
class UnetOpenVINOConfig(UNetOnnxConfig):
class UNetOpenVINOConfig(UNetOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (
DummyUnetVisionInputGenerator,
DummyUnetTimestepInputGenerator,
Expand All @@ -1821,10 +1821,10 @@ def inputs(self) -> Dict[str, Dict[int, str]]:

@register_in_tasks_manager("sd3-transformer", *["semantic-segmentation"], library_name="diffusers")
@register_in_tasks_manager("sd3-transformer-2d", *["semantic-segmentation"], library_name="diffusers")
class SD3TransformerOpenVINOConfig(UNetOnnxConfig):
class SD3TransformerOpenVINOConfig(UNetOpenVINOConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (
(DummyTransformerTimestpsInputGenerator,)
+ UNetOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES
+ UNetOpenVINOConfig.DUMMY_INPUT_GENERATOR_CLASSES
+ (PooledProjectionsDummyInputGenerator,)
)
NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args(
Expand Down
14 changes: 7 additions & 7 deletions tests/openvino/test_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,8 @@ def test_shape(self, model_arch: str):
),
)
else:
packed_height = height // pipeline.vae_scale_factor
packed_width = width // pipeline.vae_scale_factor
packed_height = height // pipeline.vae_scale_factor // 2
packed_width = width // pipeline.vae_scale_factor // 2
channels = pipeline.transformer.config.in_channels
self.assertEqual(outputs.shape, (batch_size, packed_height * packed_width, channels))

Expand Down Expand Up @@ -426,7 +426,7 @@ def generate_inputs(self, height=128, width=128, batch_size=1, channel=3, input_
height=height, width=width, batch_size=batch_size, channel=channel, input_type=input_type
)

if "flux" == model_type:
if model_type in ["flux", "stable-diffusion-3"]:
inputs["height"] = height
inputs["width"] = width

Expand Down Expand Up @@ -529,8 +529,8 @@ def test_shape(self, model_arch: str):
),
)
else:
packed_height = height // pipeline.vae_scale_factor
packed_width = width // pipeline.vae_scale_factor
packed_height = height // pipeline.vae_scale_factor // 2
packed_width = width // pipeline.vae_scale_factor // 2
channels = pipeline.transformer.config.in_channels
self.assertEqual(outputs.shape, (batch_size, packed_height * packed_width, channels))

Expand Down Expand Up @@ -780,8 +780,8 @@ def test_shape(self, model_arch: str):
),
)
else:
packed_height = height // pipeline.vae_scale_factor
packed_width = width // pipeline.vae_scale_factor
packed_height = height // pipeline.vae_scale_factor // 2
packed_width = width // pipeline.vae_scale_factor // 2
channels = pipeline.transformer.config.in_channels
self.assertEqual(outputs.shape, (batch_size, packed_height * packed_width, channels))

Expand Down

0 comments on commit 014a840

Please sign in to comment.