Skip to content

Commit

Permalink
added safety checking test
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Jul 4, 2024
1 parent 42244e0 commit 1b1c991
Showing 1 changed file with 48 additions and 0 deletions.
48 changes: 48 additions & 0 deletions tests/openvino/test_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,3 +555,51 @@ def test_num_images_per_prompt_static_model(self, model_arch: str):
inputs = _generate_inputs(batch_size)
outputs = pipeline(**inputs, num_images_per_prompt=num_images, height=_height, width=width).images
self.assertEqual(outputs.shape, (batch_size * num_images, height, width, 3))

@parameterized.expand(SUPPORTED_ARCHITECTURES)
@unittest.skipIf(is_diffusers_version("<=", "0.21.4"), "not supported with this diffusers version")
def test_safety_checker(self, model_arch: str):
ov_pipeline = self.MODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch], export=True, ov_config=F32_CONFIG)
self.assertIsInstance(ov_pipeline.text_encoder, OVModelTextEncoder)
self.assertIsInstance(ov_pipeline.vae_encoder, OVModelVaeEncoder)
self.assertIsInstance(ov_pipeline.vae_decoder, OVModelVaeDecoder)
self.assertIsInstance(ov_pipeline.unet, OVModelUnet)
self.assertIsInstance(ov_pipeline.config, Dict)

from diffusers import LatentConsistencyModelPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker

safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
pipeline = LatentConsistencyModelPipeline.from_pretrained(
MODEL_NAMES[model_arch], safety_checker=safety_checker
)

batch_size, num_images_per_prompt, height, width = 2, 3, 64, 128
latents = ov_pipeline.prepare_latents(
batch_size * num_images_per_prompt,
ov_pipeline.unet.config["in_channels"],
height,
width,
dtype=np.float32,
generator=np.random.RandomState(0),
)

kwargs = {
"prompt": ["sailing ship in storm by Leonardo da Vinci"] * batch_size,
"num_inference_steps": 1,
"num_images_per_prompt": num_images_per_prompt,
"height": height,
"width": width,
"guidance_scale": 8.5,
}

for output_type in ["latent", "np"]:
ov_outputs = ov_pipeline(latents=latents, output_type=output_type, **kwargs).images
self.assertIsInstance(ov_outputs, np.ndarray)
with torch.no_grad():
outputs = pipeline(latents=torch.from_numpy(latents), output_type=output_type, **kwargs).images

# Compare model outputs
self.assertTrue(np.allclose(ov_outputs, outputs, atol=1e-4))
# Compare model devices
self.assertEqual(pipeline.device.type, ov_pipeline.device)

0 comments on commit 1b1c991

Please sign in to comment.