Skip to content

Commit 99e2cff

Browse files
Feature/zimage inpaint pipeline (#13006)
* Add ZImageInpaintPipeline Updated the pipeline structure to include ZImageInpaintPipeline alongside ZImagePipeline and ZImageImg2ImgPipeline. Implemented the ZImageInpaintPipeline class for inpainting tasks, including necessary methods for encoding prompts, preparing masked latents, and denoising. Enhanced the auto_pipeline to map the new ZImageInpaintPipeline for inpainting generation tasks. Added unit tests for ZImageInpaintPipeline to ensure functionality and performance. Updated dummy objects to include ZImageInpaintPipeline for testing purposes. * Add documentation and improve test stability for ZImageInpaintPipeline - Add torch.empty fix for x_pad_token and cap_pad_token in test - Add # Copied from annotations for encode_prompt methods - Add documentation with usage example and autodoc directive * Address PR review feedback for ZImageInpaintPipeline Add batch size validation and callback handling fixes per review, using diffusers conventions rather than suggested code verbatim. * Update src/diffusers/pipelines/z_image/pipeline_z_image_inpaint.py Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com> * Update src/diffusers/pipelines/z_image/pipeline_z_image_inpaint.py Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com> * Add input validation and fix XLA support for ZImageInpaintPipeline - Add missing is_torch_xla_available import for TPU support - Add xm.mark_step() in denoising loop for proper XLA execution - Add check_inputs() method for comprehensive input validation - Call check_inputs() at the start of __call__ Addresses PR review feedback from @asomoza. * Cleanup --------- Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com>
1 parent a3dcd98 commit 99e2cff

File tree

8 files changed

+1395
-3
lines changed

8 files changed

+1395
-3
lines changed

docs/source/en/api/pipelines/z_image.md

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,41 @@ image = pipe(
5353
image.save("zimage_img2img.png")
5454
```
5555

56+
## Inpainting
57+
58+
Use [`ZImageInpaintPipeline`] to inpaint specific regions of an image based on a text prompt and mask.
59+
60+
```python
61+
import torch
62+
import numpy as np
63+
from PIL import Image
64+
from diffusers import ZImageInpaintPipeline
65+
from diffusers.utils import load_image
66+
67+
pipe = ZImageInpaintPipeline.from_pretrained("Tongyi-MAI/Z-Image-Turbo", torch_dtype=torch.bfloat16)
68+
pipe.to("cuda")
69+
70+
url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
71+
init_image = load_image(url).resize((1024, 1024))
72+
73+
# Create a mask (white = inpaint, black = preserve)
74+
mask = np.zeros((1024, 1024), dtype=np.uint8)
75+
mask[256:768, 256:768] = 255 # Inpaint center region
76+
mask_image = Image.fromarray(mask)
77+
78+
prompt = "A beautiful lake with mountains in the background"
79+
image = pipe(
80+
prompt,
81+
image=init_image,
82+
mask_image=mask_image,
83+
strength=1.0,
84+
num_inference_steps=9,
85+
guidance_scale=0.0,
86+
generator=torch.Generator("cuda").manual_seed(42),
87+
).images[0]
88+
image.save("zimage_inpaint.png")
89+
```
90+
5691
## ZImagePipeline
5792

5893
[[autodoc]] ZImagePipeline
@@ -64,3 +99,9 @@ image.save("zimage_img2img.png")
6499
[[autodoc]] ZImageImg2ImgPipeline
65100
- all
66101
- __call__
102+
103+
## ZImageInpaintPipeline
104+
105+
[[autodoc]] ZImageInpaintPipeline
106+
- all
107+
- __call__

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -696,6 +696,7 @@
696696
"ZImageControlNetInpaintPipeline",
697697
"ZImageControlNetPipeline",
698698
"ZImageImg2ImgPipeline",
699+
"ZImageInpaintPipeline",
699700
"ZImageOmniPipeline",
700701
"ZImagePipeline",
701702
]
@@ -1428,6 +1429,7 @@
14281429
ZImageControlNetInpaintPipeline,
14291430
ZImageControlNetPipeline,
14301431
ZImageImg2ImgPipeline,
1432+
ZImageInpaintPipeline,
14311433
ZImageOmniPipeline,
14321434
ZImagePipeline,
14331435
)

src/diffusers/pipelines/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -410,11 +410,12 @@
410410
"Kandinsky5I2IPipeline",
411411
]
412412
_import_structure["z_image"] = [
413-
"ZImageImg2ImgPipeline",
414-
"ZImagePipeline",
415-
"ZImageControlNetPipeline",
416413
"ZImageControlNetInpaintPipeline",
414+
"ZImageControlNetPipeline",
415+
"ZImageImg2ImgPipeline",
416+
"ZImageInpaintPipeline",
417417
"ZImageOmniPipeline",
418+
"ZImagePipeline",
418419
]
419420
_import_structure["skyreels_v2"] = [
420421
"SkyReelsV2DiffusionForcingPipeline",
@@ -870,6 +871,7 @@
870871
ZImageControlNetInpaintPipeline,
871872
ZImageControlNetPipeline,
872873
ZImageImg2ImgPipeline,
874+
ZImageInpaintPipeline,
873875
ZImageOmniPipeline,
874876
ZImagePipeline,
875877
)

src/diffusers/pipelines/auto_pipeline.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@
127127
ZImageControlNetInpaintPipeline,
128128
ZImageControlNetPipeline,
129129
ZImageImg2ImgPipeline,
130+
ZImageInpaintPipeline,
130131
ZImageOmniPipeline,
131132
ZImagePipeline,
132133
)
@@ -235,6 +236,7 @@
235236
("stable-diffusion-pag", StableDiffusionPAGInpaintPipeline),
236237
("qwenimage", QwenImageInpaintPipeline),
237238
("qwenimage-edit", QwenImageEditInpaintPipeline),
239+
("z-image", ZImageInpaintPipeline),
238240
]
239241
)
240242

src/diffusers/pipelines/z_image/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
_import_structure["pipeline_z_image_controlnet"] = ["ZImageControlNetPipeline"]
2727
_import_structure["pipeline_z_image_controlnet_inpaint"] = ["ZImageControlNetInpaintPipeline"]
2828
_import_structure["pipeline_z_image_img2img"] = ["ZImageImg2ImgPipeline"]
29+
_import_structure["pipeline_z_image_inpaint"] = ["ZImageInpaintPipeline"]
2930
_import_structure["pipeline_z_image_omni"] = ["ZImageOmniPipeline"]
3031

3132

@@ -42,6 +43,7 @@
4243
from .pipeline_z_image_controlnet import ZImageControlNetPipeline
4344
from .pipeline_z_image_controlnet_inpaint import ZImageControlNetInpaintPipeline
4445
from .pipeline_z_image_img2img import ZImageImg2ImgPipeline
46+
from .pipeline_z_image_inpaint import ZImageInpaintPipeline
4547
from .pipeline_z_image_omni import ZImageOmniPipeline
4648
else:
4749
import sys

0 commit comments

Comments
 (0)