Skip to content

Commit

Permalink
transpose to diff format
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Jul 4, 2024
1 parent 50d867c commit 42244e0
Showing 1 changed file with 19 additions and 1 deletion.
20 changes: 19 additions & 1 deletion optimum/intel/openvino/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLPipeline,
)
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.utils import CONFIG_NAME, is_invisible_watermark_available
from huggingface_hub import snapshot_download
Expand Down Expand Up @@ -90,6 +91,7 @@ def __init__(
tokenizer: Optional["CLIPTokenizer"] = None,
tokenizer_2: Optional["CLIPTokenizer"] = None,
feature_extractor: Optional["CLIPFeatureExtractor"] = None,
safety_checker: Optional["StableDiffusionSafetyChecker"] = None,
device: str = "CPU",
dynamic_shapes: bool = True,
compile: bool = True,
Expand Down Expand Up @@ -135,7 +137,7 @@ def __init__(
self.tokenizer_2 = tokenizer_2
self.scheduler = scheduler
self.feature_extractor = feature_extractor
self.safety_checker = None
self.safety_checker = safety_checker
self.preprocessors = []

if self.is_dynamic:
Expand Down Expand Up @@ -1082,6 +1084,22 @@ def __call__(
**kwargs,
)

def run_safety_checker(self, image: np.ndarray):
if self.safety_checker is None:
has_nsfw_concept = None
else:
# Transpose the image to NHWC
image = image.transpose(0, 2, 3, 1)

feature_extractor_input = self.image_processor.numpy_to_pil(image)
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt")
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)

# Transpose the image back to NCHW
image = image.transpose(0, 3, 1, 2)

return image, has_nsfw_concept


def _raise_invalid_batch_size(
expected_batch_size: int, batch_size: int, num_images_per_prompt: int, guidance_scale: float
Expand Down

0 comments on commit 42244e0

Please sign in to comment.