diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..3b022f1 --- /dev/null +++ b/.gitignore @@ -0,0 +1,179 @@ +IP-Adapter/ +models/ +sdxl_models/ +.gradio/ + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# UV +# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +#uv.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# Ruff stuff: +.ruff_cache/ + +# PyPI configuration file +.pypirc diff --git a/README.md b/README.md index 3b31fa1..65a7f5d 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,70 @@ -# marble -This is the official code base for MARBLE +# MARBLE + + + + + +This is the official implementation of MARBLE: Material Recomposition and Blending in CLIP-Space. Given an input image, MARBLE enables edits from material transfer, blending, to parametric control. + +![arch](static/teaser.jpg) + + +## Installation +MARBLE material blocks are built from the code base of [InstantStyle](https://github.com/instantX-research/InstantStyle). Additional functions are added into the `ip_adapter_instantstyle/ip_adapter.py` but please cite their papers accordingly. + +The current code base is tested on Python 3.9.7. + +We will begin by cloning this repo: + +``` +git clone https://github.com/Stability-AI/marble.git +``` + +Then, install the latest the libraries with: + +``` +cd marble +pip install -r requirements.txt +``` + +## Usage + +After installation and downloading the models, you can use the two demos `try_blend.ipynb` for material blending and `parametric_control.ipynb` for material transfer + multi-attribute parametric control. + +We also provide a gradio demo which can be run with `python gradio_demo.py`. + + +### Using your own materials +For material transfer, you could add your images `input_images/texture/`. + + +### ComfyUI extension + +Custom nodes and an [example workflow](./example_workflow.json) are provided for [ComfyUI](https://github.com/comfyanonymous/ComfyUI). + +To install: + +* Clone this repo into ```custom_nodes```: + ```shell + $ cd ComfyUI/custom_nodes + $ git clone https://github.com/Stability-AI/marble + ``` +* Install dependencies: + ```shell + $ cd marble + $ pip install -r requirements.txt + ``` +* Restart ComfyUI + + +## Citation +If you find MARBLE helpful in your research/applications, please cite using this BibTeX: + +```bibtex +@article{cheng2024marble, + title={MARBLE: Material Recomposition and Blending in CLIP-Space}, + author={Cheng, Ta-Ying and Sharma, Prafull and Boss, Mark and Jampani, Varun}, + journal={CVPR}, + year={2025} +} +``` diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..1168fd9 --- /dev/null +++ b/__init__.py @@ -0,0 +1,247 @@ +import sys +import os + +sys.path.append(os.path.dirname(__file__)) + +import comfy.model_management +import torch +from PIL import Image +import numpy as np +from .marble import ( + setup_control_mlps, + setup_pipeline, + run_blend, + run_parametric_control, +) + + +# Add conversion functions +def tensor_to_pil(tensor): + if isinstance(tensor, torch.Tensor): + tensor = tensor.squeeze(0) + # Convert to numpy and scale to 0-255 + image = (tensor.cpu().numpy() * 255).astype(np.uint8) + return Image.fromarray(image) + return tensor + + +def pil_to_tensor(pil_image): + if isinstance(pil_image, Image.Image): + # Convert PIL to numpy array + image = np.array(pil_image) + # Convert to tensor and normalize to 0-1 + tensor = torch.from_numpy(image).float() / 255.0 + tensor = tensor.unsqueeze(0) + device = comfy.model_management.get_torch_device() + tensor = tensor.to(device) + return tensor + return pil_image + + +MARBLE_CATEGORY = "marble" + + +class MarbleControlMLPLoader: + CATEGORY = MARBLE_CATEGORY + FUNCTION = "load" + RETURN_NAMES = ["control_mlp"] + RETURN_TYPES = ["CONTROL_MLP"] + + @classmethod + def INPUT_TYPES(cls): + return {} + + def load(self): + device = comfy.model_management.get_torch_device() + mlps = setup_control_mlps(device=device) + return (mlps,) + + +class MarbleIPAdapterLoader: + CATEGORY = MARBLE_CATEGORY + FUNCTION = "load" + RETURN_NAMES = ["ip_adapter"] + RETURN_TYPES = ["IP_ADAPTER"] + + @classmethod + def INPUT_TYPES(cls): + return {} + + def load(self): + device = comfy.model_management.get_torch_device() + ip_adapter = setup_pipeline(device=device) + return (ip_adapter,) + + +class MarbleBlendNode: + CATEGORY = MARBLE_CATEGORY + FUNCTION = "blend" + RETURN_NAMES = ["image"] + RETURN_TYPES = ["IMAGE"] + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "ip_adapter": ("IP_ADAPTER",), + "image": ("IMAGE",), + "texture_image1": ("IMAGE",), + "texture_image2": ("IMAGE",), + "edit_strength": ( + "FLOAT", + {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}, + ), + "num_inference_steps": ( + "INT", + {"default": 20, "min": 1, "max": 100, "step": 1}, + ), + "seed": ( + "INT", + {"default": 42, "min": 0, "max": 2147483647, "step": 1}, + ), + }, + "optional": { + "mask": ("MASK", {"default": None}), + "depth_map": ("IMAGE", {"default": None}), + }, + } + + def blend( + self, + ip_adapter, + image, + texture_image1, + texture_image2, + edit_strength, + num_inference_steps, + seed, + mask=None, + depth_map=None, + ): + # Convert all inputs to PIL + pil_image = tensor_to_pil(image) + pil_texture1 = tensor_to_pil(texture_image1) + pil_texture2 = tensor_to_pil(texture_image2) + pil_depth_map = tensor_to_pil(depth_map) if depth_map is not None else None + + result = run_blend( + ip_adapter, + pil_image, + pil_texture1, + pil_texture2, + edit_strength=edit_strength, + num_inference_steps=num_inference_steps, + seed=seed, + depth_map=pil_depth_map, + mask=mask, + ) + # Convert result back to tensor + return (pil_to_tensor(result),) + + +class MarbleParametricControl: + CATEGORY = MARBLE_CATEGORY + FUNCTION = "parametric_control" + RETURN_NAMES = ["image"] + RETURN_TYPES = ["IMAGE"] + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "ip_adapter": ("IP_ADAPTER",), + "image": ("IMAGE",), + "control_mlps": ("CONTROL_MLP",), + "num_inference_steps": ( + "INT", + {"default": 30, "min": 1, "max": 100, "step": 1}, + ), + "seed": ( + "INT", + {"default": 42, "min": 0, "max": 2147483647, "step": 1}, + ), + }, + "optional": { + "mask": ("MASK", {"default": None}), + "texture_image": ("IMAGE", {"default": None}), + "depth_map": ("IMAGE", {"default": None}), + "metallic_strength": ( + "FLOAT", + {"default": 0.0, "min": -20.0, "max": 20.0, "step": 0.1}, + ), + "roughness_strength": ( + "FLOAT", + {"default": 0.0, "min": -1.0, "max": 1.0, "step": 0.05}, + ), + "transparency_strength": ( + "FLOAT", + {"default": 0.0, "min": 0.0, "max": 4.0, "step": 0.1}, + ), + "glow_strength": ( + "FLOAT", + {"default": 0.0, "min": -1.0, "max": 3.0, "step": 0.1}, + ), + }, + } + + def parametric_control( + self, + ip_adapter, + image, + control_mlps, + num_inference_steps, + seed, + mask=None, + texture_image=None, + depth_map=None, + metallic_strength=0.0, + roughness_strength=0.0, + transparency_strength=0.0, + glow_strength=0.0, + ): + # Convert inputs to PIL + pil_image = tensor_to_pil(image) + pil_texture = ( + tensor_to_pil(texture_image) if texture_image is not None else None + ) + pil_depth_map = tensor_to_pil(depth_map) if depth_map is not None else None + + edit_mlps = {} + for mlp_name, strength in [ + ("metallic", metallic_strength), + ("roughness", roughness_strength), + ("transparency", transparency_strength), + ("glow", glow_strength), + ]: + if mlp_name in control_mlps and strength != 0.0: + edit_mlps[control_mlps[mlp_name]] = strength + + result = run_parametric_control( + ip_adapter, + pil_image, + edit_mlps, + texture_image=pil_texture, + num_inference_steps=num_inference_steps, + seed=seed, + depth_map=pil_depth_map, + mask=mask, + ) + # Convert result back to tensor + return (pil_to_tensor(result),) + + +NODE_CLASS_MAPPINGS = { + "MarbleControlMLPLoader": MarbleControlMLPLoader, + "MarbleIPAdapterLoader": MarbleIPAdapterLoader, + "MarbleBlendNode": MarbleBlendNode, + "MarbleParametricControl": MarbleParametricControl, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "MarbleControlMLPLoader": "Marble Control MLP Loader", + "MarbleIPAdapterLoader": "Marble IP Adapter Loader", + "MarbleBlendNode": "Marble Blend Node", + "MarbleParametricControl": "Marble Parametric Control", +} + +__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] diff --git a/example_workflow.json b/example_workflow.json new file mode 100644 index 0000000..671235f --- /dev/null +++ b/example_workflow.json @@ -0,0 +1,475 @@ +{ + "id": "dac030d6-0516-49af-a557-37e0cabcfecc", + "revision": 0, + "last_node_id": 26, + "last_link_id": 31, + "nodes": [ + { + "id": 13, + "type": "LoadImage", + "pos": [ + 26.386796951293945, + 336.0611572265625 + ], + "size": [ + 315, + 314.0000305175781 + ], + "flags": {}, + "order": 0, + "mode": 0, + "inputs": [], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "slot_index": 0, + "links": [ + 24, + 30 + ] + }, + { + "name": "MASK", + "type": "MASK", + "links": null + } + ], + "properties": { + "Node name for S&R": "LoadImage" + }, + "widgets_values": [ + "beetle.png", + "image" + ] + }, + { + "id": 11, + "type": "MarbleIPAdapterLoader", + "pos": [ + 53.263065338134766, + -7.059659957885742 + ], + "size": [ + 302.4000244140625, + 26 + ], + "flags": {}, + "order": 1, + "mode": 0, + "inputs": [], + "outputs": [ + { + "name": "ip_adapter", + "type": "IP_ADAPTER", + "slot_index": 0, + "links": [ + 25, + 29 + ] + } + ], + "properties": { + "Node name for S&R": "MarbleIPAdapterLoader" + }, + "widgets_values": [] + }, + { + "id": 24, + "type": "LoadImage", + "pos": [ + 30.925029754638672, + 717.4205322265625 + ], + "size": [ + 274.080078125, + 314 + ], + "flags": {}, + "order": 2, + "mode": 0, + "inputs": [], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 26 + ] + }, + { + "name": "MASK", + "type": "MASK", + "links": null + } + ], + "properties": { + "Node name for S&R": "LoadImage" + }, + "widgets_values": [ + "low_roughness.png", + "image" + ] + }, + { + "id": 25, + "type": "LoadImage", + "pos": [ + 33.45697021484375, + 1090.3607177734375 + ], + "size": [ + 274.080078125, + 314 + ], + "flags": {}, + "order": 3, + "mode": 0, + "inputs": [], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 27 + ] + }, + { + "name": "MASK", + "type": "MASK", + "links": null + } + ], + "properties": { + "Node name for S&R": "LoadImage" + }, + "widgets_values": [ + "high_roughness.png", + "image" + ] + }, + { + "id": 23, + "type": "PreviewImage", + "pos": [ + 1525.1435546875, + 782.8226928710938 + ], + "size": [ + 646.029541015625, + 524.1215209960938 + ], + "flags": {}, + "order": 7, + "mode": 0, + "inputs": [ + { + "name": "images", + "type": "IMAGE", + "link": 23 + } + ], + "outputs": [], + "properties": { + "Node name for S&R": "PreviewImage" + }, + "widgets_values": [] + }, + { + "id": 10, + "type": "MarbleControlMLPLoader", + "pos": [ + 43.088321685791016, + 110.31293487548828 + ], + "size": [ + 315, + 58 + ], + "flags": {}, + "order": 4, + "mode": 0, + "inputs": [], + "outputs": [ + { + "name": "control_mlp", + "type": "CONTROL_MLP", + "slot_index": 0, + "links": [ + 28 + ] + } + ], + "properties": { + "Node name for S&R": "MarbleControlMLPLoader" + }, + "widgets_values": [] + }, + { + "id": 26, + "type": "MarbleParametricControl", + "pos": [ + 890.8665771484375, + -15.984257698059082 + ], + "size": [ + 283.40234375, + 302 + ], + "flags": {}, + "order": 6, + "mode": 0, + "inputs": [ + { + "name": "ip_adapter", + "type": "IP_ADAPTER", + "link": 29 + }, + { + "name": "image", + "type": "IMAGE", + "link": 30 + }, + { + "name": "control_mlps", + "type": "CONTROL_MLP", + "link": 28 + }, + { + "name": "mask", + "shape": 7, + "type": "MASK", + "link": null + }, + { + "name": "texture_image", + "shape": 7, + "type": "IMAGE", + "link": null + }, + { + "name": "depth_map", + "shape": 7, + "type": "IMAGE", + "link": null + } + ], + "outputs": [ + { + "name": "image", + "type": "IMAGE", + "links": [ + 31 + ] + } + ], + "properties": { + "Node name for S&R": "MarbleParametricControl" + }, + "widgets_values": [ + 30, + 1678732183, + "randomize", + 0, + -0.39000000000000007, + 0, + 0 + ] + }, + { + "id": 22, + "type": "MarbleBlendNode", + "pos": [ + 902.6680908203125, + 776.9732055664062 + ], + "size": [ + 278.73828125, + 230 + ], + "flags": {}, + "order": 5, + "mode": 0, + "inputs": [ + { + "name": "ip_adapter", + "type": "IP_ADAPTER", + "link": 25 + }, + { + "name": "image", + "type": "IMAGE", + "link": 24 + }, + { + "name": "texture_image1", + "type": "IMAGE", + "link": 26 + }, + { + "name": "texture_image2", + "type": "IMAGE", + "link": 27 + }, + { + "name": "mask", + "shape": 7, + "type": "MASK", + "link": null + }, + { + "name": "depth_map", + "shape": 7, + "type": "IMAGE", + "link": null + } + ], + "outputs": [ + { + "name": "image", + "type": "IMAGE", + "links": [ + 23 + ] + } + ], + "properties": { + "Node name for S&R": "MarbleBlendNode" + }, + "widgets_values": [ + 1, + 20, + 1012675178, + "randomize" + ] + }, + { + "id": 18, + "type": "PreviewImage", + "pos": [ + 1529.1248779296875, + -12.8727445602417 + ], + "size": [ + 634.3290405273438, + 617.0318603515625 + ], + "flags": {}, + "order": 8, + "mode": 0, + "inputs": [ + { + "name": "images", + "type": "IMAGE", + "link": 31 + } + ], + "outputs": [], + "properties": { + "Node name for S&R": "PreviewImage" + }, + "widgets_values": [] + } + ], + "links": [ + [ + 23, + 22, + 0, + 23, + 0, + "IMAGE" + ], + [ + 24, + 13, + 0, + 22, + 1, + "IMAGE" + ], + [ + 25, + 11, + 0, + 22, + 0, + "IP_ADAPTER" + ], + [ + 26, + 24, + 0, + 22, + 2, + "IMAGE" + ], + [ + 27, + 25, + 0, + 22, + 3, + "IMAGE" + ], + [ + 28, + 10, + 0, + 26, + 2, + "CONTROL_MLP" + ], + [ + 29, + 11, + 0, + 26, + 0, + "IP_ADAPTER" + ], + [ + 30, + 13, + 0, + 26, + 1, + "IMAGE" + ], + [ + 31, + 26, + 0, + 18, + 0, + "IMAGE" + ] + ], + "groups": [ + { + "id": 1, + "title": "Loaders", + "bounding": [ + -3.8137731552124023, + -104.17147064208984, + 413.9270935058594, + 313.8103332519531 + ], + "color": "#3f789e", + "font_size": 24, + "flags": {} + } + ], + "config": {}, + "extra": { + "ds": { + "scale": 0.7513148009015782, + "offset": [ + 211.50521766916663, + 238.1879270976759 + ] + }, + "frontendVersion": "1.18.9" + }, + "version": 0.4 +} \ No newline at end of file diff --git a/gradio_demo.py b/gradio_demo.py new file mode 100644 index 0000000..5a26910 --- /dev/null +++ b/gradio_demo.py @@ -0,0 +1,272 @@ +import gradio as gr +from PIL import Image + +from marble import ( + get_session, + run_blend, + run_parametric_control, + setup_control_mlps, + setup_pipeline, +) + +# Setup the pipeline and control MLPs +control_mlps = setup_control_mlps() +ip_adapter = setup_pipeline() +get_session() + +# Load example images +EXAMPLE_IMAGES = { + "blend": { + "target": "input_images/context_image/beetle.png", + "texture1": "input_images/texture/low_roughness.png", + "texture2": "input_images/texture/high_roughness.png", + }, + "parametric": { + "target": "input_images/context_image/toy_car.png", + "texture": "input_images/texture/metal_bowl.png", + }, +} + + +def blend_images(target_image, texture1, texture2, edit_strength): + """Blend between two texture images""" + result = run_blend( + ip_adapter, target_image, texture1, texture2, edit_strength=edit_strength + ) + return result + + +def parametric_control( + target_image, + texture_image, + control_type, + metallic_strength, + roughness_strength, + transparency_strength, + glow_strength, +): + """Apply parametric control based on selected control type""" + edit_mlps = {} + + if control_type == "Roughness + Metallic": + edit_mlps = { + control_mlps["metallic"]: metallic_strength, + control_mlps["roughness"]: roughness_strength, + } + elif control_type == "Transparency": + edit_mlps = { + control_mlps["transparency"]: transparency_strength, + } + elif control_type == "Glow": + edit_mlps = { + control_mlps["glow"]: glow_strength, + } + + # Use target image as texture if no texture is provided + texture_to_use = texture_image if texture_image is not None else target_image + + result = run_parametric_control( + ip_adapter, + target_image, + edit_mlps, + texture_to_use, + ) + return result + + +# Create the Gradio interface +with gr.Blocks( + title="MARBLE: Material Recomposition and Blending in CLIP-Space" +) as demo: + gr.Markdown( + """ + # MARBLE: Material Recomposition and Blending in CLIP-Space + +
+ +
+ + MARBLE is a tool for material recomposition and blending in CLIP-Space. + We provide two modes of operation: + - **Texture Blending**: Blend the material properties of two texture images and apply it to a target image. + - **Parametric Control**: Apply parametric material control to a target image. You can either provide a texture image, transferring the material properties of the texture to the original image, or you can just provide a target image, and edit the material properties of the original image. + """ + ) + + with gr.Row(variant="panel"): + with gr.Tabs(): + with gr.TabItem("Texture Blending"): + with gr.Row(equal_height=False): + with gr.Column(): + with gr.Row(): + texture1 = gr.Image(label="Texture 1", type="pil") + texture2 = gr.Image(label="Texture 2", type="pil") + edit_strength = gr.Slider( + minimum=0.0, + maximum=1.0, + value=0.5, + step=0.1, + label="Blend Strength", + ) + with gr.Column(): + with gr.Row(): + target_image = gr.Image(label="Target Image", type="pil") + blend_output = gr.Image(label="Blended Result") + blend_btn = gr.Button("Blend Textures") + + # Add examples for blending + gr.Examples( + examples=[ + [ + Image.open(EXAMPLE_IMAGES["blend"]["target"]), + Image.open(EXAMPLE_IMAGES["blend"]["texture1"]), + Image.open(EXAMPLE_IMAGES["blend"]["texture2"]), + 0.5, + ] + ], + inputs=[target_image, texture1, texture2, edit_strength], + outputs=blend_output, + fn=blend_images, + cache_examples=True, + ) + + blend_btn.click( + fn=blend_images, + inputs=[target_image, texture1, texture2, edit_strength], + outputs=blend_output, + ) + + with gr.TabItem("Parametric Control"): + with gr.Row(equal_height=False): + with gr.Column(): + with gr.Row(): + target_image_pc = gr.Image(label="Target Image", type="pil") + texture_image_pc = gr.Image( + label="Texture Image (Optional - uses target image if not provided)", + type="pil", + ) + control_type = gr.Dropdown( + choices=["Roughness + Metallic", "Transparency", "Glow"], + value="Roughness + Metallic", + label="Control Type", + ) + + metallic_strength = gr.Slider( + minimum=-20, + maximum=20, + value=0, + step=0.1, + label="Metallic Strength", + visible=True, + ) + roughness_strength = gr.Slider( + minimum=-1, + maximum=1, + value=0, + step=0.1, + label="Roughness Strength", + visible=True, + ) + transparency_strength = gr.Slider( + minimum=0, + maximum=4, + value=0, + step=0.1, + label="Transparency Strength", + visible=False, + ) + glow_strength = gr.Slider( + minimum=0, + maximum=3, + value=0, + step=0.1, + label="Glow Strength", + visible=False, + ) + control_btn = gr.Button("Apply Control") + + with gr.Column(): + control_output = gr.Image(label="Result") + + def update_slider_visibility(control_type): + return [ + gr.update(visible=control_type == "Roughness + Metallic"), + gr.update(visible=control_type == "Roughness + Metallic"), + gr.update(visible=control_type == "Transparency"), + gr.update(visible=control_type == "Glow"), + ] + + control_type.change( + fn=update_slider_visibility, + inputs=[control_type], + outputs=[ + metallic_strength, + roughness_strength, + transparency_strength, + glow_strength, + ], + show_progress=False, + ) + + # Add examples for parametric control + gr.Examples( + examples=[ + [ + Image.open(EXAMPLE_IMAGES["parametric"]["target"]), + Image.open(EXAMPLE_IMAGES["parametric"]["texture"]), + "Roughness + Metallic", + 0, # metallic_strength + 0, # roughness_strength + 0, # transparency_strength + 0, # glow_strength + ], + [ + Image.open(EXAMPLE_IMAGES["parametric"]["target"]), + Image.open(EXAMPLE_IMAGES["parametric"]["texture"]), + "Roughness + Metallic", + 20, # metallic_strength + 0, # roughness_strength + 0, # transparency_strength + 0, # glow_strength + ], + [ + Image.open(EXAMPLE_IMAGES["parametric"]["target"]), + Image.open(EXAMPLE_IMAGES["parametric"]["texture"]), + "Roughness + Metallic", + 0, # metallic_strength + 1, # roughness_strength + 0, # transparency_strength + 0, # glow_strength + ], + ], + inputs=[ + target_image_pc, + texture_image_pc, + control_type, + metallic_strength, + roughness_strength, + transparency_strength, + glow_strength, + ], + outputs=control_output, + fn=parametric_control, + cache_examples=True, + ) + + control_btn.click( + fn=parametric_control, + inputs=[ + target_image_pc, + texture_image_pc, + control_type, + metallic_strength, + roughness_strength, + transparency_strength, + glow_strength, + ], + outputs=control_output, + ) + +if __name__ == "__main__": + demo.launch() diff --git a/input_images/context_image/beetle.png b/input_images/context_image/beetle.png new file mode 100644 index 0000000..eef4c87 Binary files /dev/null and b/input_images/context_image/beetle.png differ diff --git a/input_images/context_image/genart_teapot.jpg b/input_images/context_image/genart_teapot.jpg new file mode 100644 index 0000000..2e616ac Binary files /dev/null and b/input_images/context_image/genart_teapot.jpg differ diff --git a/input_images/context_image/toy_car.png b/input_images/context_image/toy_car.png new file mode 100644 index 0000000..b52e880 Binary files /dev/null and b/input_images/context_image/toy_car.png differ diff --git a/input_images/context_image/white_car_night.jpg b/input_images/context_image/white_car_night.jpg new file mode 100644 index 0000000..4eb0a71 Binary files /dev/null and b/input_images/context_image/white_car_night.jpg differ diff --git a/input_images/depth/beetle.png b/input_images/depth/beetle.png new file mode 100644 index 0000000..7f1a9d7 Binary files /dev/null and b/input_images/depth/beetle.png differ diff --git a/input_images/depth/toy_car.png b/input_images/depth/toy_car.png new file mode 100644 index 0000000..8c70a8d Binary files /dev/null and b/input_images/depth/toy_car.png differ diff --git a/input_images/texture/high_roughness.png b/input_images/texture/high_roughness.png new file mode 100644 index 0000000..4eb22be Binary files /dev/null and b/input_images/texture/high_roughness.png differ diff --git a/input_images/texture/low_roughness.png b/input_images/texture/low_roughness.png new file mode 100644 index 0000000..b3555a6 Binary files /dev/null and b/input_images/texture/low_roughness.png differ diff --git a/input_images/texture/metal_bowl.png b/input_images/texture/metal_bowl.png new file mode 100644 index 0000000..42fa36b Binary files /dev/null and b/input_images/texture/metal_bowl.png differ diff --git a/ip_adapter_instantstyle/__init__.py b/ip_adapter_instantstyle/__init__.py new file mode 100644 index 0000000..3b1f1ff --- /dev/null +++ b/ip_adapter_instantstyle/__init__.py @@ -0,0 +1,9 @@ +from .ip_adapter import IPAdapter, IPAdapterPlus, IPAdapterPlusXL, IPAdapterXL, IPAdapterFull + +__all__ = [ + "IPAdapter", + "IPAdapterPlus", + "IPAdapterPlusXL", + "IPAdapterXL", + "IPAdapterFull", +] diff --git a/ip_adapter_instantstyle/attention_processor.py b/ip_adapter_instantstyle/attention_processor.py new file mode 100644 index 0000000..6d30ffa --- /dev/null +++ b/ip_adapter_instantstyle/attention_processor.py @@ -0,0 +1,562 @@ +# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class AttnProcessor(nn.Module): + r""" + Default processor for performing attention-related computations. + """ + + def __init__( + self, + hidden_size=None, + cross_attention_dim=None, + ): + super().__init__() + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class IPAttnProcessor(nn.Module): + r""" + Attention processor for IP-Adapater. + Args: + hidden_size (`int`): + The hidden size of the attention layer. + cross_attention_dim (`int`): + The number of channels in the `encoder_hidden_states`. + scale (`float`, defaults to 1.0): + the weight scale of image prompt. + num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): + The context length of the image features. + """ + + def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, skip=False): + super().__init__() + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.scale = scale + self.num_tokens = num_tokens + self.skip = skip + + self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + else: + # get encoder_hidden_states, ip_hidden_states + end_pos = encoder_hidden_states.shape[1] - self.num_tokens + encoder_hidden_states, ip_hidden_states = ( + encoder_hidden_states[:, :end_pos, :], + encoder_hidden_states[:, end_pos:, :], + ) + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + if not self.skip: + # for ip-adapter + ip_key = self.to_k_ip(ip_hidden_states) + ip_value = self.to_v_ip(ip_hidden_states) + + ip_key = attn.head_to_batch_dim(ip_key) + ip_value = attn.head_to_batch_dim(ip_value) + + ip_attention_probs = attn.get_attention_scores(query, ip_key, None) + self.attn_map = ip_attention_probs + ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) + ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) + + hidden_states = hidden_states + self.scale * ip_hidden_states + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class AttnProcessor2_0(torch.nn.Module): + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__( + self, + hidden_size=None, + cross_attention_dim=None, + ): + super().__init__() + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class IPAttnProcessor2_0(torch.nn.Module): + r""" + Attention processor for IP-Adapater for PyTorch 2.0. + Args: + hidden_size (`int`): + The hidden size of the attention layer. + cross_attention_dim (`int`): + The number of channels in the `encoder_hidden_states`. + scale (`float`, defaults to 1.0): + the weight scale of image prompt. + num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): + The context length of the image features. + """ + + def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, skip=False): + super().__init__() + + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.scale = scale + self.num_tokens = num_tokens + self.skip = skip + + self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + else: + # get encoder_hidden_states, ip_hidden_states + end_pos = encoder_hidden_states.shape[1] - self.num_tokens + encoder_hidden_states, ip_hidden_states = ( + encoder_hidden_states[:, :end_pos, :], + encoder_hidden_states[:, end_pos:, :], + ) + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + if not self.skip: + # for ip-adapter + ip_key = self.to_k_ip(ip_hidden_states) + ip_value = self.to_v_ip(ip_hidden_states) + + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + ip_hidden_states = F.scaled_dot_product_attention( + query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + with torch.no_grad(): + self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1) + #print(self.attn_map.shape) + + ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + ip_hidden_states = ip_hidden_states.to(query.dtype) + + hidden_states = hidden_states + self.scale * ip_hidden_states + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +## for controlnet +class CNAttnProcessor: + r""" + Default processor for performing attention-related computations. + """ + + def __init__(self, num_tokens=4): + self.num_tokens = num_tokens + + def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + else: + end_pos = encoder_hidden_states.shape[1] - self.num_tokens + encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class CNAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self, num_tokens=4): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + self.num_tokens = num_tokens + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + else: + end_pos = encoder_hidden_states.shape[1] - self.num_tokens + encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states diff --git a/ip_adapter_instantstyle/ip_adapter.py b/ip_adapter_instantstyle/ip_adapter.py new file mode 100644 index 0000000..dd3eaa5 --- /dev/null +++ b/ip_adapter_instantstyle/ip_adapter.py @@ -0,0 +1,858 @@ +import os +import glob +from typing import List + +import torch +import torch.nn as nn +from diffusers import StableDiffusionPipeline +from diffusers.pipelines.controlnet import MultiControlNetModel +from PIL import Image +from safetensors import safe_open +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection + +from .utils import is_torch2_available, get_generator + +L = 4 + + +def pos_encode(x, L): + pos_encode = [] + + for freq in range(L): + pos_encode.append(torch.cos(2**freq * torch.pi * x)) + pos_encode.append(torch.sin(2**freq * torch.pi * x)) + pos_encode = torch.cat(pos_encode, dim=1) + return pos_encode + + +if is_torch2_available(): + from .attention_processor import ( + AttnProcessor2_0 as AttnProcessor, + ) + from .attention_processor import ( + CNAttnProcessor2_0 as CNAttnProcessor, + ) + from .attention_processor import ( + IPAttnProcessor2_0 as IPAttnProcessor, + ) +else: + from .attention_processor import AttnProcessor, CNAttnProcessor, IPAttnProcessor +from .resampler import Resampler + + +class ImageProjModel(torch.nn.Module): + """Projection Model""" + + def __init__( + self, + cross_attention_dim=1024, + clip_embeddings_dim=1024, + clip_extra_context_tokens=4, + ): + super().__init__() + + self.generator = None + self.cross_attention_dim = cross_attention_dim + self.clip_extra_context_tokens = clip_extra_context_tokens + self.proj = torch.nn.Linear( + clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim + ) + self.norm = torch.nn.LayerNorm(cross_attention_dim) + + def forward(self, image_embeds): + embeds = image_embeds + clip_extra_context_tokens = self.proj(embeds).reshape( + -1, self.clip_extra_context_tokens, self.cross_attention_dim + ) + clip_extra_context_tokens = self.norm(clip_extra_context_tokens) + return clip_extra_context_tokens + + +class MLPProjModel(torch.nn.Module): + """SD model with image prompt""" + + def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024): + super().__init__() + + self.proj = torch.nn.Sequential( + torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim), + torch.nn.GELU(), + torch.nn.Linear(clip_embeddings_dim, cross_attention_dim), + torch.nn.LayerNorm(cross_attention_dim), + ) + + def forward(self, image_embeds): + clip_extra_context_tokens = self.proj(image_embeds) + return clip_extra_context_tokens + + +class IPAdapter: + def __init__( + self, + sd_pipe, + image_encoder_path, + ip_ckpt, + device, + num_tokens=4, + target_blocks=["block"], + ): + self.device = device + self.image_encoder_path = image_encoder_path + self.ip_ckpt = ip_ckpt + self.num_tokens = num_tokens + self.target_blocks = target_blocks + + self.pipe = sd_pipe.to(self.device) + self.set_ip_adapter() + + # load image encoder + self.image_encoder = CLIPVisionModelWithProjection.from_pretrained( + self.image_encoder_path + ).to(self.device, dtype=torch.float16) + self.clip_image_processor = CLIPImageProcessor() + # image proj model + self.image_proj_model = self.init_proj() + + self.load_ip_adapter() + + def init_proj(self): + image_proj_model = ImageProjModel( + cross_attention_dim=self.pipe.unet.config.cross_attention_dim, + clip_embeddings_dim=self.image_encoder.config.projection_dim, + clip_extra_context_tokens=self.num_tokens, + ).to(self.device, dtype=torch.float16) + return image_proj_model + + def set_ip_adapter(self): + unet = self.pipe.unet + attn_procs = {} + for name in unet.attn_processors.keys(): + cross_attention_dim = ( + None + if name.endswith("attn1.processor") + else unet.config.cross_attention_dim + ) + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + if cross_attention_dim is None: + attn_procs[name] = AttnProcessor() + else: + selected = False + for block_name in self.target_blocks: + if block_name in name: + selected = True + break + if selected: + attn_procs[name] = IPAttnProcessor( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + scale=1.0, + num_tokens=self.num_tokens, + ).to(self.device, dtype=torch.float16) + else: + attn_procs[name] = IPAttnProcessor( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + scale=1.0, + num_tokens=self.num_tokens, + skip=True, + ).to(self.device, dtype=torch.float16) + unet.set_attn_processor(attn_procs) + if hasattr(self.pipe, "controlnet"): + if isinstance(self.pipe.controlnet, MultiControlNetModel): + for controlnet in self.pipe.controlnet.nets: + controlnet.set_attn_processor( + CNAttnProcessor(num_tokens=self.num_tokens) + ) + else: + self.pipe.controlnet.set_attn_processor( + CNAttnProcessor(num_tokens=self.num_tokens) + ) + + def load_ip_adapter(self): + if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors": + state_dict = {"image_proj": {}, "ip_adapter": {}} + with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f: + for key in f.keys(): + if key.startswith("image_proj."): + state_dict["image_proj"][key.replace("image_proj.", "")] = ( + f.get_tensor(key) + ) + elif key.startswith("ip_adapter."): + state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = ( + f.get_tensor(key) + ) + else: + state_dict = torch.load(self.ip_ckpt, map_location="cpu") + self.image_proj_model.load_state_dict(state_dict["image_proj"]) + ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) + ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False) + + @torch.inference_mode() + def get_image_embeds( + self, pil_image=None, clip_image_embeds=None, content_prompt_embeds=None + ): + if pil_image is not None: + if isinstance(pil_image, Image.Image): + pil_image = [pil_image] + clip_image = self.clip_image_processor( + images=pil_image, return_tensors="pt" + ).pixel_values + clip_image_embeds = self.image_encoder( + clip_image.to(self.device, dtype=torch.float16) + ).image_embeds + else: + clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16) + + if content_prompt_embeds is not None: + print(clip_image_embeds.shape) + print(content_prompt_embeds.shape) + clip_image_embeds = clip_image_embeds - content_prompt_embeds + + image_prompt_embeds = self.image_proj_model(clip_image_embeds) + uncond_image_prompt_embeds = self.image_proj_model( + torch.zeros_like(clip_image_embeds) + ) + return image_prompt_embeds, uncond_image_prompt_embeds + + @torch.inference_mode() + def generate_image_edit_dir( + self, + pil_image=None, + content_prompt_embeds=None, + edit_mlps: dict[torch.nn.Module, float] = None, + ): + print("Combining multiple MLPs!") + if pil_image is not None: + if isinstance(pil_image, Image.Image): + pil_image = [pil_image] + clip_image = self.clip_image_processor( + images=pil_image, return_tensors="pt" + ).pixel_values + clip_image_embeds = self.image_encoder( + clip_image.to(self.device, dtype=torch.float16) + ).image_embeds + pred_editing_dirs = [ + net( + clip_image_embeds, + torch.Tensor([strength]).to(self.device, dtype=torch.float16), + ) + for net, strength in edit_mlps.items() + ] + + clip_image_embeds = clip_image_embeds + sum(pred_editing_dirs) + + if content_prompt_embeds is not None: + clip_image_embeds = clip_image_embeds - content_prompt_embeds + + image_prompt_embeds = self.image_proj_model(clip_image_embeds) + uncond_image_prompt_embeds = self.image_proj_model( + torch.zeros_like(clip_image_embeds) + ) + return image_prompt_embeds, uncond_image_prompt_embeds + + @torch.inference_mode() + def get_image_edit_dir( + self, + start_image=None, + pil_image=None, + pil_image2=None, + content_prompt_embeds=None, + edit_strength=1.0, + ): + print("Blending Two Materials!") + if pil_image is not None: + if isinstance(pil_image, Image.Image): + pil_image = [pil_image] + clip_image = self.clip_image_processor( + images=pil_image, return_tensors="pt" + ).pixel_values + clip_image_embeds = self.image_encoder( + clip_image.to(self.device, dtype=torch.float16) + ).image_embeds + + if pil_image2 is not None: + if isinstance(pil_image2, Image.Image): + pil_image2 = [pil_image2] + clip_image2 = self.clip_image_processor( + images=pil_image2, return_tensors="pt" + ).pixel_values + clip_image_embeds2 = self.image_encoder( + clip_image2.to(self.device, dtype=torch.float16) + ).image_embeds + + if start_image is not None: + if isinstance(start_image, Image.Image): + start_image = [start_image] + clip_image_start = self.clip_image_processor( + images=start_image, return_tensors="pt" + ).pixel_values + clip_image_embeds_start = self.image_encoder( + clip_image_start.to(self.device, dtype=torch.float16) + ).image_embeds + + if content_prompt_embeds is not None: + clip_image_embeds = clip_image_embeds - content_prompt_embeds + clip_image_embeds2 = clip_image_embeds2 - content_prompt_embeds + + # clip_image_embeds += edit_strength * (clip_image_embeds2 - clip_image_embeds) + clip_image_embeds = clip_image_embeds_start + edit_strength * ( + clip_image_embeds2 - clip_image_embeds + ) + + image_prompt_embeds = self.image_proj_model(clip_image_embeds) + uncond_image_prompt_embeds = self.image_proj_model( + torch.zeros_like(clip_image_embeds) + ) + return image_prompt_embeds, uncond_image_prompt_embeds + + def set_scale(self, scale): + for attn_processor in self.pipe.unet.attn_processors.values(): + if isinstance(attn_processor, IPAttnProcessor): + attn_processor.scale = scale + + def set_scale(self, scale): + for attn_processor in self.pipe.unet.attn_processors.values(): + if isinstance(attn_processor, IPAttnProcessor): + attn_processor.scale = scale + + def generate( + self, + pil_image=None, + clip_image_embeds=None, + prompt=None, + negative_prompt=None, + scale=1.0, + num_samples=4, + seed=None, + guidance_scale=7.5, + num_inference_steps=30, + neg_content_emb=None, + **kwargs, + ): + self.set_scale(scale) + + if pil_image is not None: + num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) + else: + num_prompts = clip_image_embeds.size(0) + + if prompt is None: + prompt = "best quality, high quality" + if negative_prompt is None: + negative_prompt = ( + "monochrome, lowres, bad anatomy, worst quality, low quality" + ) + + if not isinstance(prompt, List): + prompt = [prompt] * num_prompts + if not isinstance(negative_prompt, List): + negative_prompt = [negative_prompt] * num_prompts + + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds( + pil_image=pil_image, + clip_image_embeds=clip_image_embeds, + content_prompt_embeds=neg_content_emb, + ) + bs_embed, seq_len, _ = image_prompt_embeds.shape + image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) + image_prompt_embeds = image_prompt_embeds.view( + bs_embed * num_samples, seq_len, -1 + ) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat( + 1, num_samples, 1 + ) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.view( + bs_embed * num_samples, seq_len, -1 + ) + + with torch.inference_mode(): + prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt( + prompt, + device=self.device, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) + negative_prompt_embeds = torch.cat( + [negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1 + ) + + generator = get_generator(seed, self.device) + + images = self.pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps, + generator=generator, + **kwargs, + ).images + + return images + + +class IPAdapterXL(IPAdapter): + """SDXL""" + + def generate( + self, + pil_image, + prompt=None, + negative_prompt=None, + scale=1.0, + num_samples=4, + seed=None, + num_inference_steps=30, + neg_content_emb=None, + neg_content_prompt=None, + neg_content_scale=1.0, + clip_strength=1.0, + **kwargs, + ): + self.set_scale(scale) + + num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) + + if prompt is None: + prompt = "best quality, high quality" + if negative_prompt is None: + negative_prompt = ( + "monochrome, lowres, bad anatomy, worst quality, low quality" + ) + + if not isinstance(prompt, List): + prompt = [prompt] * num_prompts + if not isinstance(negative_prompt, List): + negative_prompt = [negative_prompt] * num_prompts + + if neg_content_emb is None: + if neg_content_prompt is not None: + with torch.inference_mode(): + ( + prompt_embeds_, # torch.Size([1, 77, 2048]) + negative_prompt_embeds_, + pooled_prompt_embeds_, # torch.Size([1, 1280]) + negative_pooled_prompt_embeds_, + ) = self.pipe.encode_prompt( + neg_content_prompt, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + pooled_prompt_embeds_ *= neg_content_scale + else: + pooled_prompt_embeds_ = neg_content_emb + else: + pooled_prompt_embeds_ = None + + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds( + pil_image, content_prompt_embeds=pooled_prompt_embeds_ + ) + bs_embed, seq_len, _ = image_prompt_embeds.shape + image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) + image_prompt_embeds = image_prompt_embeds.view( + bs_embed * num_samples, seq_len, -1 + ) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat( + 1, num_samples, 1 + ) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.view( + bs_embed * num_samples, seq_len, -1 + ) + print("CLIP Strength is {}".format(clip_strength)) + image_prompt_embeds *= clip_strength + uncond_image_prompt_embeds *= clip_strength + + with torch.inference_mode(): + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.pipe.encode_prompt( + prompt, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) + negative_prompt_embeds = torch.cat( + [negative_prompt_embeds, uncond_image_prompt_embeds], dim=1 + ) + + self.generator = get_generator(seed, self.device) + + images = self.pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + num_inference_steps=num_inference_steps, + generator=self.generator, + **kwargs, + ).images + + return images + + def generate_parametric_edits( + self, + pil_image, + edit_mlps: dict[torch.nn.Module, float], + prompt=None, + negative_prompt=None, + scale=1.0, + num_samples=4, + seed=None, + num_inference_steps=30, + neg_content_emb=None, + neg_content_prompt=None, + neg_content_scale=1.0, + **kwargs, + ): + self.set_scale(scale) + + num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) + + if prompt is None: + prompt = "best quality, high quality" + if negative_prompt is None: + negative_prompt = ( + "monochrome, lowres, bad anatomy, worst quality, low quality" + ) + + if not isinstance(prompt, List): + prompt = [prompt] * num_prompts + if not isinstance(negative_prompt, List): + negative_prompt = [negative_prompt] * num_prompts + + if neg_content_emb is None: + if neg_content_prompt is not None: + with torch.inference_mode(): + ( + prompt_embeds_, # torch.Size([1, 77, 2048]) + negative_prompt_embeds_, + pooled_prompt_embeds_, # torch.Size([1, 1280]) + negative_pooled_prompt_embeds_, + ) = self.pipe.encode_prompt( + neg_content_prompt, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + pooled_prompt_embeds_ *= neg_content_scale + else: + pooled_prompt_embeds_ = neg_content_emb + else: + pooled_prompt_embeds_ = None + image_prompt_embeds, uncond_image_prompt_embeds = self.generate_image_edit_dir( + pil_image, content_prompt_embeds=pooled_prompt_embeds_, edit_mlps=edit_mlps + ) + + bs_embed, seq_len, _ = image_prompt_embeds.shape + image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) + image_prompt_embeds = image_prompt_embeds.view( + bs_embed * num_samples, seq_len, -1 + ) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat( + 1, num_samples, 1 + ) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.view( + bs_embed * num_samples, seq_len, -1 + ) + + with torch.inference_mode(): + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.pipe.encode_prompt( + prompt, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) + negative_prompt_embeds = torch.cat( + [negative_prompt_embeds, uncond_image_prompt_embeds], dim=1 + ) + + self.generator = get_generator(seed, self.device) + + images = self.pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + num_inference_steps=num_inference_steps, + generator=self.generator, + **kwargs, + ).images + + return images + + def generate_edit( + self, + start_image, + pil_image, + pil_image2, + prompt=None, + negative_prompt=None, + scale=1.0, + num_samples=4, + seed=None, + num_inference_steps=30, + neg_content_emb=None, + neg_content_prompt=None, + neg_content_scale=1.0, + edit_strength=1.0, + **kwargs, + ): + self.set_scale(scale) + + num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) + + if prompt is None: + prompt = "best quality, high quality" + if negative_prompt is None: + negative_prompt = ( + "monochrome, lowres, bad anatomy, worst quality, low quality" + ) + + if not isinstance(prompt, List): + prompt = [prompt] * num_prompts + if not isinstance(negative_prompt, List): + negative_prompt = [negative_prompt] * num_prompts + + if neg_content_emb is None: + if neg_content_prompt is not None: + with torch.inference_mode(): + ( + prompt_embeds_, # torch.Size([1, 77, 2048]) + negative_prompt_embeds_, + pooled_prompt_embeds_, # torch.Size([1, 1280]) + negative_pooled_prompt_embeds_, + ) = self.pipe.encode_prompt( + neg_content_prompt, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + pooled_prompt_embeds_ *= neg_content_scale + else: + pooled_prompt_embeds_ = neg_content_emb + else: + pooled_prompt_embeds_ = None + + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_edit_dir( + start_image, + pil_image, + pil_image2, + content_prompt_embeds=pooled_prompt_embeds_, + edit_strength=edit_strength, + ) + + bs_embed, seq_len, _ = image_prompt_embeds.shape + image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) + image_prompt_embeds = image_prompt_embeds.view( + bs_embed * num_samples, seq_len, -1 + ) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat( + 1, num_samples, 1 + ) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.view( + bs_embed * num_samples, seq_len, -1 + ) + + with torch.inference_mode(): + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.pipe.encode_prompt( + prompt, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) + negative_prompt_embeds = torch.cat( + [negative_prompt_embeds, uncond_image_prompt_embeds], dim=1 + ) + + self.generator = get_generator(seed, self.device) + + images = self.pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + num_inference_steps=num_inference_steps, + generator=self.generator, + **kwargs, + ).images + + return images + + +class IPAdapterPlus(IPAdapter): + """IP-Adapter with fine-grained features""" + + def init_proj(self): + image_proj_model = Resampler( + dim=self.pipe.unet.config.cross_attention_dim, + depth=4, + dim_head=64, + heads=12, + num_queries=self.num_tokens, + embedding_dim=self.image_encoder.config.hidden_size, + output_dim=self.pipe.unet.config.cross_attention_dim, + ff_mult=4, + ).to(self.device, dtype=torch.float16) + return image_proj_model + + @torch.inference_mode() + def get_image_embeds(self, pil_image=None, clip_image_embeds=None): + if isinstance(pil_image, Image.Image): + pil_image = [pil_image] + clip_image = self.clip_image_processor( + images=pil_image, return_tensors="pt" + ).pixel_values + clip_image = clip_image.to(self.device, dtype=torch.float16) + clip_image_embeds = self.image_encoder( + clip_image, output_hidden_states=True + ).hidden_states[-2] + image_prompt_embeds = self.image_proj_model(clip_image_embeds) + uncond_clip_image_embeds = self.image_encoder( + torch.zeros_like(clip_image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) + return image_prompt_embeds, uncond_image_prompt_embeds + + +class IPAdapterFull(IPAdapterPlus): + """IP-Adapter with full features""" + + def init_proj(self): + image_proj_model = MLPProjModel( + cross_attention_dim=self.pipe.unet.config.cross_attention_dim, + clip_embeddings_dim=self.image_encoder.config.hidden_size, + ).to(self.device, dtype=torch.float16) + return image_proj_model + + +class IPAdapterPlusXL(IPAdapter): + """SDXL""" + + def init_proj(self): + image_proj_model = Resampler( + dim=1280, + depth=4, + dim_head=64, + heads=20, + num_queries=self.num_tokens, + embedding_dim=self.image_encoder.config.hidden_size, + output_dim=self.pipe.unet.config.cross_attention_dim, + ff_mult=4, + ).to(self.device, dtype=torch.float16) + return image_proj_model + + @torch.inference_mode() + def get_image_embeds(self, pil_image): + if isinstance(pil_image, Image.Image): + pil_image = [pil_image] + clip_image = self.clip_image_processor( + images=pil_image, return_tensors="pt" + ).pixel_values + clip_image = clip_image.to(self.device, dtype=torch.float16) + clip_image_embeds = self.image_encoder( + clip_image, output_hidden_states=True + ).hidden_states[-2] + image_prompt_embeds = self.image_proj_model(clip_image_embeds) + uncond_clip_image_embeds = self.image_encoder( + torch.zeros_like(clip_image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) + return image_prompt_embeds, uncond_image_prompt_embeds + + def generate( + self, + pil_image, + prompt=None, + negative_prompt=None, + scale=1.0, + num_samples=4, + seed=None, + num_inference_steps=30, + **kwargs, + ): + self.set_scale(scale) + + num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) + + if prompt is None: + prompt = "best quality, high quality" + if negative_prompt is None: + negative_prompt = ( + "monochrome, lowres, bad anatomy, worst quality, low quality" + ) + + if not isinstance(prompt, List): + prompt = [prompt] * num_prompts + if not isinstance(negative_prompt, List): + negative_prompt = [negative_prompt] * num_prompts + + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds( + pil_image + ) + bs_embed, seq_len, _ = image_prompt_embeds.shape + image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) + image_prompt_embeds = image_prompt_embeds.view( + bs_embed * num_samples, seq_len, -1 + ) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat( + 1, num_samples, 1 + ) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.view( + bs_embed * num_samples, seq_len, -1 + ) + + with torch.inference_mode(): + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.pipe.encode_prompt( + prompt, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) + negative_prompt_embeds = torch.cat( + [negative_prompt_embeds, uncond_image_prompt_embeds], dim=1 + ) + + generator = get_generator(seed, self.device) + + images = self.pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + num_inference_steps=num_inference_steps, + generator=generator, + **kwargs, + ).images + + return images diff --git a/ip_adapter_instantstyle/resampler.py b/ip_adapter_instantstyle/resampler.py new file mode 100644 index 0000000..2426667 --- /dev/null +++ b/ip_adapter_instantstyle/resampler.py @@ -0,0 +1,158 @@ +# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py +# and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py + +import math + +import torch +import torch.nn as nn +from einops import rearrange +from einops.layers.torch import Rearrange + + +# FFN +def FeedForward(dim, mult=4): + inner_dim = int(dim * mult) + return nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, inner_dim, bias=False), + nn.GELU(), + nn.Linear(inner_dim, dim, bias=False), + ) + + +def reshape_tensor(x, heads): + bs, length, width = x.shape + # (bs, length, width) --> (bs, length, n_heads, dim_per_head) + x = x.view(bs, length, heads, -1) + # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) + x = x.transpose(1, 2) + # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) + x = x.reshape(bs, heads, length, -1) + return x + + +class PerceiverAttention(nn.Module): + def __init__(self, *, dim, dim_head=64, heads=8): + super().__init__() + self.scale = dim_head**-0.5 + self.dim_head = dim_head + self.heads = heads + inner_dim = dim_head * heads + + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + def forward(self, x, latents): + """ + Args: + x (torch.Tensor): image features + shape (b, n1, D) + latent (torch.Tensor): latent features + shape (b, n2, D) + """ + x = self.norm1(x) + latents = self.norm2(latents) + + b, l, _ = latents.shape + + q = self.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + + q = reshape_tensor(q, self.heads) + k = reshape_tensor(k, self.heads) + v = reshape_tensor(v, self.heads) + + # attention + scale = 1 / math.sqrt(math.sqrt(self.dim_head)) + weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + out = weight @ v + + out = out.permute(0, 2, 1, 3).reshape(b, l, -1) + + return self.to_out(out) + + +class Resampler(nn.Module): + def __init__( + self, + dim=1024, + depth=8, + dim_head=64, + heads=16, + num_queries=8, + embedding_dim=768, + output_dim=1024, + ff_mult=4, + max_seq_len: int = 257, # CLIP tokens + CLS token + apply_pos_emb: bool = False, + num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence + ): + super().__init__() + self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None + + self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) + + self.proj_in = nn.Linear(embedding_dim, dim) + + self.proj_out = nn.Linear(dim, output_dim) + self.norm_out = nn.LayerNorm(output_dim) + + self.to_latents_from_mean_pooled_seq = ( + nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, dim * num_latents_mean_pooled), + Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled), + ) + if num_latents_mean_pooled > 0 + else None + ) + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [ + PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ] + ) + ) + + def forward(self, x): + if self.pos_emb is not None: + n, device = x.shape[1], x.device + pos_emb = self.pos_emb(torch.arange(n, device=device)) + x = x + pos_emb + + latents = self.latents.repeat(x.size(0), 1, 1) + + x = self.proj_in(x) + + if self.to_latents_from_mean_pooled_seq: + meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool)) + meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq) + latents = torch.cat((meanpooled_latents, latents), dim=-2) + + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + + latents = self.proj_out(latents) + return self.norm_out(latents) + + +def masked_mean(t, *, dim, mask=None): + if mask is None: + return t.mean(dim=dim) + + denom = mask.sum(dim=dim, keepdim=True) + mask = rearrange(mask, "b n -> b n 1") + masked_t = t.masked_fill(~mask, 0.0) + + return masked_t.sum(dim=dim) / denom.clamp(min=1e-5) diff --git a/ip_adapter_instantstyle/utils.py b/ip_adapter_instantstyle/utils.py new file mode 100644 index 0000000..6a27335 --- /dev/null +++ b/ip_adapter_instantstyle/utils.py @@ -0,0 +1,93 @@ +import torch +import torch.nn.functional as F +import numpy as np +from PIL import Image + +attn_maps = {} +def hook_fn(name): + def forward_hook(module, input, output): + if hasattr(module.processor, "attn_map"): + attn_maps[name] = module.processor.attn_map + del module.processor.attn_map + + return forward_hook + +def register_cross_attention_hook(unet): + for name, module in unet.named_modules(): + if name.split('.')[-1].startswith('attn2'): + module.register_forward_hook(hook_fn(name)) + + return unet + +def upscale(attn_map, target_size): + attn_map = torch.mean(attn_map, dim=0) + attn_map = attn_map.permute(1,0) + temp_size = None + + for i in range(0,5): + scale = 2 ** i + if ( target_size[0] // scale ) * ( target_size[1] // scale) == attn_map.shape[1]*64: + temp_size = (target_size[0]//(scale*8), target_size[1]//(scale*8)) + break + + assert temp_size is not None, "temp_size cannot is None" + + attn_map = attn_map.view(attn_map.shape[0], *temp_size) + + attn_map = F.interpolate( + attn_map.unsqueeze(0).to(dtype=torch.float32), + size=target_size, + mode='bilinear', + align_corners=False + )[0] + + attn_map = torch.softmax(attn_map, dim=0) + return attn_map +def get_net_attn_map(image_size, batch_size=2, instance_or_negative=False, detach=True): + + idx = 0 if instance_or_negative else 1 + net_attn_maps = [] + + for name, attn_map in attn_maps.items(): + attn_map = attn_map.cpu() if detach else attn_map + attn_map = torch.chunk(attn_map, batch_size)[idx].squeeze() + attn_map = upscale(attn_map, image_size) + net_attn_maps.append(attn_map) + + net_attn_maps = torch.mean(torch.stack(net_attn_maps,dim=0),dim=0) + + return net_attn_maps + +def attnmaps2images(net_attn_maps): + + #total_attn_scores = 0 + images = [] + + for attn_map in net_attn_maps: + attn_map = attn_map.cpu().numpy() + #total_attn_scores += attn_map.mean().item() + + normalized_attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255 + normalized_attn_map = normalized_attn_map.astype(np.uint8) + #print("norm: ", normalized_attn_map.shape) + image = Image.fromarray(normalized_attn_map) + + #image = fix_save_attn_map(attn_map) + images.append(image) + + #print(total_attn_scores) + return images +def is_torch2_available(): + return hasattr(F, "scaled_dot_product_attention") + +def get_generator(seed, device): + + if seed is not None: + if isinstance(seed, list): + generator = [torch.Generator(device).manual_seed(seed_item) for seed_item in seed] + else: + generator = torch.Generator(device).manual_seed(seed) + else: + generator = None + + return generator \ No newline at end of file diff --git a/marble.py b/marble.py new file mode 100644 index 0000000..097f8c9 --- /dev/null +++ b/marble.py @@ -0,0 +1,287 @@ +import os +from typing import Dict + +import numpy as np +import torch +from diffusers import ControlNetModel, StableDiffusionXLControlNetInpaintPipeline +from huggingface_hub import hf_hub_download, list_repo_files +from PIL import Image, ImageChops, ImageEnhance +from rembg import new_session, remove +from transformers import DPTForDepthEstimation, DPTImageProcessor + +from ip_adapter_instantstyle import IPAdapterXL +from ip_adapter_instantstyle.utils import register_cross_attention_hook +from parametric_control_mlp import control_mlp + +file_dir = os.path.dirname(os.path.abspath(__file__)) +base_model_path = "stabilityai/stable-diffusion-xl-base-1.0" +image_encoder_path = "models/image_encoder" +ip_ckpt = "sdxl_models/ip-adapter_sdxl_vit-h.bin" +controlnet_path = "diffusers/controlnet-depth-sdxl-1.0" + +# Cache for rembg sessions +_session_cache = None +CONTROL_MLPS = ["metallic", "roughness", "transparency", "glow"] + + +def get_session(): + global _session_cache + if _session_cache is None: + _session_cache = new_session() + return _session_cache + + +def setup_control_mlps( + features: int = 1024, device: str = "cuda", dtype: torch.dtype = torch.float16 +) -> Dict[str, torch.nn.Module]: + ret = {} + for mlp in CONTROL_MLPS: + ret[mlp] = setup_control_mlp(mlp, features, device, dtype) + return ret + + +def setup_control_mlp( + material_parameter: str, + features: int = 1024, + device: str = "cuda", + dtype: torch.dtype = torch.float16, +): + net = control_mlp(features) + net.load_state_dict( + torch.load(os.path.join(file_dir, f"model_weights/{material_parameter}.pt")) + ) + net.to(device, dtype=dtype) + net.eval() + return net + + +def download_ip_adapter(): + repo_id = "h94/IP-Adapter" + target_folders = ["models/", "sdxl_models/"] + local_dir = file_dir + + # Check if folders exist and contain files + folders_exist = all( + os.path.exists(os.path.join(local_dir, folder)) for folder in target_folders + ) + + if folders_exist: + # Check if any of the target folders are empty + folders_empty = any( + len(os.listdir(os.path.join(local_dir, folder))) == 0 + for folder in target_folders + ) + if not folders_empty: + print("IP-Adapter files already downloaded. Skipping download.") + return + + # List all files in the repo + all_files = list_repo_files(repo_id) + + # Filter for files in the desired folders + filtered_files = [ + f for f in all_files if any(f.startswith(folder) for folder in target_folders) + ] + + # Download each file + for file_path in filtered_files: + local_path = hf_hub_download( + repo_id=repo_id, + filename=file_path, + local_dir=local_dir, + local_dir_use_symlinks=False, + ) + print(f"Downloaded: {file_path} to {local_path}") + + +def setup_pipeline( + device: str = "cuda", + dtype: torch.dtype = torch.float16, +): + download_ip_adapter() + + cur_block = ("up", 0, 1) + + controlnet = ControlNetModel.from_pretrained( + controlnet_path, variant="fp16", use_safetensors=True, torch_dtype=dtype + ).to(device) + + pipe = StableDiffusionXLControlNetInpaintPipeline.from_pretrained( + base_model_path, + controlnet=controlnet, + use_safetensors=True, + torch_dtype=dtype, + add_watermarker=False, + ).to(device) + + pipe.unet = register_cross_attention_hook(pipe.unet) + + block_name = ( + cur_block[0] + + "_blocks." + + str(cur_block[1]) + + ".attentions." + + str(cur_block[2]) + ) + + print("Testing block {}".format(block_name)) + + return IPAdapterXL( + pipe, + os.path.join(file_dir, image_encoder_path), + os.path.join(file_dir, ip_ckpt), + device, + target_blocks=[block_name], + ) + + +def get_dpt_model(device: str = "cuda", dtype: torch.dtype = torch.float16): + image_processor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas") + model = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas") + model.to(device, dtype=dtype) + model.eval() + return model, image_processor + + +def run_dpt_depth( + image: Image.Image, model, processor, device: str = "cuda" +) -> Image.Image: + """Run DPT depth estimation on an image.""" + # Prepare image + inputs = processor(images=image, return_tensors="pt").to(device, dtype=model.dtype) + + # Get depth prediction + with torch.no_grad(): + depth_map = model(**inputs).predicted_depth + + # Now normalize to 0-1 range + depth_map = (depth_map - depth_map.min()) / ( + depth_map.max() - depth_map.min() + 1e-7 + ) + depth_map = depth_map.clip(0, 1) * 255 + + # Convert to PIL Image + depth_map = depth_map.squeeze().cpu().numpy().astype(np.uint8) + return Image.fromarray(depth_map).resize((1024, 1024)) + + +def prepare_mask(image: Image.Image) -> Image.Image: + """Prepare mask from image using rembg.""" + rm_bg = remove(image, session=get_session()) + target_mask = ( + rm_bg.convert("RGB") + .point(lambda x: 0 if x < 1 else 255) + .convert("L") + .convert("RGB") + ) + return target_mask.resize((1024, 1024)) + + +def prepare_init_image(image: Image.Image, mask: Image.Image) -> Image.Image: + """Prepare initial image for inpainting.""" + + # Create grayscale version + gray_image = image.convert("L").convert("RGB") + gray_image = ImageEnhance.Brightness(gray_image).enhance(1.0) + + # Create mask inversions + invert_mask = ImageChops.invert(mask) + + # Combine images + grayscale_img = ImageChops.darker(gray_image, mask) + img_black_mask = ImageChops.darker(image, invert_mask) + init_img = ImageChops.lighter(img_black_mask, grayscale_img) + + return init_img.resize((1024, 1024)) + + +def run_parametric_control( + ip_model, + target_image: Image.Image, + edit_mlps: dict[torch.nn.Module, float], + texture_image: Image.Image = None, + num_inference_steps: int = 30, + seed: int = 42, + depth_map: Image.Image = None, + mask: Image.Image = None, +) -> Image.Image: + """Run parametric control with metallic and roughness adjustments.""" + # Get depth map + if depth_map is None: + model, processor = get_dpt_model() + depth_map = run_dpt_depth(target_image, model, processor) + else: + depth_map = depth_map.resize((1024, 1024)) + + # Prepare mask and init image + if mask is None: + mask = prepare_mask(target_image) + else: + mask = mask.resize((1024, 1024)) + + if texture_image is None: + texture_image = target_image + + init_img = prepare_init_image(target_image, mask) + + # Generate edit + images = ip_model.generate_parametric_edits( + texture_image, + image=init_img, + control_image=depth_map, + mask_image=mask, + controlnet_conditioning_scale=1.0, + num_samples=1, + num_inference_steps=num_inference_steps, + seed=seed, + edit_mlps=edit_mlps, + strength=1.0, + ) + + return images[0] + + +def run_blend( + ip_model, + target_image: Image.Image, + texture_image1: Image.Image, + texture_image2: Image.Image, + edit_strength: float = 0.0, + num_inference_steps: int = 20, + seed: int = 1, + depth_map: Image.Image = None, + mask: Image.Image = None, +) -> Image.Image: + """Run blending between two texture images.""" + # Get depth map + if depth_map is None: + model, processor = get_dpt_model() + depth_map = run_dpt_depth(target_image, model, processor) + else: + depth_map = depth_map.resize((1024, 1024)) + + # Prepare mask and init image + if mask is None: + mask = prepare_mask(target_image) + else: + mask = mask.resize((1024, 1024)) + init_img = prepare_init_image(target_image, mask) + + # Generate edit + images = ip_model.generate_edit( + start_image=texture_image1, + pil_image=texture_image1, + pil_image2=texture_image2, + image=init_img, + control_image=depth_map, + mask_image=mask, + controlnet_conditioning_scale=1.0, + num_samples=1, + num_inference_steps=num_inference_steps, + seed=seed, + edit_strength=edit_strength, + clip_strength=1.0, + strength=1.0, + ) + + return images[0] diff --git a/model_weights/glow.pt b/model_weights/glow.pt new file mode 100644 index 0000000..3b7fc51 Binary files /dev/null and b/model_weights/glow.pt differ diff --git a/model_weights/metallic.pt b/model_weights/metallic.pt new file mode 100644 index 0000000..e5cb99f Binary files /dev/null and b/model_weights/metallic.pt differ diff --git a/model_weights/roughness.pt b/model_weights/roughness.pt new file mode 100644 index 0000000..9041586 Binary files /dev/null and b/model_weights/roughness.pt differ diff --git a/model_weights/transparency.pt b/model_weights/transparency.pt new file mode 100644 index 0000000..dea3c3b Binary files /dev/null and b/model_weights/transparency.pt differ diff --git a/parametric_control.ipynb b/parametric_control.ipynb new file mode 100644 index 0000000..25a3f81 --- /dev/null +++ b/parametric_control.ipynb @@ -0,0 +1,303 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "3b1976c7-5491-4124-94c1-dd99b5fcd016", + "metadata": {}, + "outputs": [], + "source": [ + "from diffusers import StableDiffusionXLControlNetInpaintPipeline, ControlNetModel\n", + "from rembg import remove\n", + "from PIL import Image, ImageFilter\n", + "import torch\n", + "import torch.nn as nn\n", + "from ip_adapter_instantstyle import IPAdapterXL\n", + "from ip_adapter_instantstyle.utils import register_cross_attention_hook, get_net_attn_map, attnmaps2images\n", + "from PIL import Image, ImageChops\n", + "import numpy as np\n", + "import glob\n", + "import os\n", + "\n", + "\n", + "\"\"\"Import DPT for Depth Model\"\"\"\n", + "import DPT.util.io\n", + "\n", + "from torchvision.transforms import Compose\n", + "\n", + "from DPT.dpt.models import DPTDepthModel\n", + "from DPT.dpt.midas_net import MidasNet_large\n", + "from DPT.dpt.transforms import Resize, NormalizeImage, PrepareForNet\n", + "\n", + "from parametric_control_mlp import control_mlp" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2306a004-ac0e-4714-9868-ae77b699fef1", + "metadata": {}, + "outputs": [], + "source": [ + "# Metallic MLP\n", + "mlp = control_mlp(1024)\n", + "mlp.load_state_dict(torch.load('model_weights/metallic.pt'))\n", + "mlp = mlp.to(\"cuda\", dtype=torch.float16)\n", + "mlp.eval()\n", + "\n", + "# Roughness MLP\n", + "mlp2 = control_mlp(1024)\n", + "mlp2.load_state_dict(torch.load('model_weights/roughness.pt'))\n", + "mlp2 = mlp2.to(\"cuda\", dtype=torch.float16)\n", + "mlp2.eval()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f14b02bd-42af-4139-8c8e-6a5366be7733", + "metadata": {}, + "outputs": [], + "source": [ + "\"\"\"Get MARBLE Model ready\"\"\"\n", + "base_model_path = \"stabilityai/stable-diffusion-xl-base-1.0\"\n", + "image_encoder_path = \"models/image_encoder\"\n", + "ip_ckpt = \"sdxl_models/ip-adapter_sdxl_vit-h.bin\"\n", + "controlnet_path = \"diffusers/controlnet-depth-sdxl-1.0\"\n", + "device = \"cuda\"\n", + "\n", + "\"\"\"Load IP-Adapter + Instant Style + Editing MLP\"\"\"\n", + "cur_block = ('up', 0, 1)\n", + "torch.cuda.empty_cache()\n", + "\n", + "controlnet = ControlNetModel.from_pretrained(controlnet_path, variant=\"fp16\", use_safetensors=True, torch_dtype=torch.float16).to(device)\n", + "pipe = StableDiffusionXLControlNetInpaintPipeline.from_pretrained(\n", + " base_model_path,\n", + " controlnet=controlnet,\n", + " use_safetensors=True,\n", + " torch_dtype=torch.float16,\n", + " add_watermarker=False,\n", + ").to(device)\n", + "\n", + "pipe.unet = register_cross_attention_hook(pipe.unet)\n", + "block_name = cur_block[0] + \"_blocks.\" + str(cur_block[1])+ \".attentions.\" + str(cur_block[2])\n", + "print(\"Testing block {}\".format(block_name))\n", + "ip_model = IPAdapterXL(pipe, image_encoder_path, ip_ckpt, device, target_blocks=[block_name], edit_mlp=mlp, edit_mlp2=mlp2)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7a744532-efde-4949-bffc-9b8d73c88aa5", + "metadata": { + "collapsed": true, + "jupyter": { + "outputs_hidden": true + }, + "tags": [] + }, + "outputs": [], + "source": [ + "\"\"\"\n", + "Get Depth Model Ready\n", + "\"\"\"\n", + "import cv2\n", + "model_path = \"DPT/dpt_weights/dpt_hybrid-midas-501f0c75.pt\"\n", + "net_w = net_h = 384\n", + "model = DPTDepthModel(\n", + " path=model_path,\n", + " backbone=\"vitb_rn50_384\",\n", + " non_negative=True,\n", + " enable_attention_hooks=False,\n", + ")\n", + "normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])\n", + "\n", + "transform = Compose(\n", + " [\n", + " Resize(\n", + " net_w,\n", + " net_h,\n", + " resize_target=None,\n", + " keep_aspect_ratio=True,\n", + " ensure_multiple_of=32,\n", + " resize_method=\"minimal\",\n", + " image_interpolation_method=cv2.INTER_CUBIC,\n", + " ),\n", + " normalization,\n", + " PrepareForNet(),\n", + " ]\n", + " )\n", + "\n", + "model.eval()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "94851e56-fe11-4aa4-8574-e12d0433508d", + "metadata": {}, + "outputs": [], + "source": [ + "# Edit strengths for metallic. More negative = more metallic, best results between range -20 to 20\n", + "edit_strengths1 = [-20, 0, 20]\n", + "\n", + "# Edit strengths for roughness. More positive = more roughness, best results between range -1 to 1\n", + "edit_strengths2 = [-1, 0, 1]\n", + "\n", + "\n", + "all_images = []\n", + "for edit_strength1 in edit_strengths1:\n", + " for edit_strength2 in edit_strengths2:\n", + " \n", + " target_image_path = 'input_images/context_image/toy_car.png'\n", + " target_image = Image.open(target_image_path).convert('RGB')\n", + " \n", + " \"\"\"\n", + " Compute depth map from input_image\n", + " \"\"\"\n", + "\n", + " img = np.array(target_image)\n", + "\n", + " img_input = transform({\"image\": img})[\"image\"]\n", + "\n", + " # compute\n", + " with torch.no_grad():\n", + " sample = torch.from_numpy(img_input).unsqueeze(0)\n", + "\n", + " # if optimize == True and device == torch.device(\"cuda\"):\n", + " # sample = sample.to(memory_format=torch.channels_last)\n", + " # sample = sample.half()\n", + "\n", + " prediction = model.forward(sample)\n", + " prediction = (\n", + " torch.nn.functional.interpolate(\n", + " prediction.unsqueeze(1),\n", + " size=img.shape[:2],\n", + " mode=\"bicubic\",\n", + " align_corners=False,\n", + " )\n", + " .squeeze()\n", + " .cpu()\n", + " .numpy()\n", + " )\n", + "\n", + " depth_min = prediction.min()\n", + " depth_max = prediction.max()\n", + " bits = 2\n", + " max_val = (2 ** (8 * bits)) - 1\n", + "\n", + " if depth_max - depth_min > np.finfo(\"float\").eps:\n", + " out = max_val * (prediction - depth_min) / (depth_max - depth_min)\n", + " else:\n", + " out = np.zeros(prediction.shape, dtype=depth.dtype)\n", + "\n", + " out = (out / 256).astype('uint8')\n", + " depth_map = Image.fromarray(out).resize((1024, 1024))\n", + " \n", + " \n", + " \"\"\"Preprocessing data for MARBLE\"\"\"\n", + " rm_bg = remove(target_image)\n", + " target_mask = rm_bg.convert(\"RGB\").point(lambda x: 0 if x < 1 else 255).convert('L').convert('RGB')# Convert mask to grayscale\n", + "\n", + " noise = np.random.randint(0, 256, target_image.size + (3,), dtype=np.uint8)\n", + " noise_image = Image.fromarray(noise)\n", + " mask_target_img = ImageChops.lighter(target_image, target_mask)\n", + " invert_target_mask = ImageChops.invert(target_mask)\n", + "\n", + " from PIL import ImageEnhance\n", + " gray_target_image = target_image.convert('L').convert('RGB')\n", + " gray_target_image = ImageEnhance.Brightness(gray_target_image)\n", + "\n", + " # Adjust brightness\n", + " # The factor 1.0 means original brightness, greater than 1.0 makes the image brighter. Adjust this if the image is too dim\n", + " factor = 1.0 # Try adjusting this to get the desired brightness\n", + "\n", + " gray_target_image = gray_target_image.enhance(factor)\n", + " grayscale_img = ImageChops.darker(gray_target_image, target_mask)\n", + " img_black_mask = ImageChops.darker(target_image, invert_target_mask)\n", + " grayscale_init_img = ImageChops.lighter(img_black_mask, grayscale_img)\n", + " init_img = grayscale_init_img\n", + " \n", + " # The texture to be applied onto car\n", + " ip_image = Image.open('input_images/texture/metal_bowl.png')\n", + "\n", + "\n", + " init_img = target_image\n", + " init_img = init_img.resize((1024,1024))\n", + " mask = target_mask.resize((1024, 1024))\n", + "\n", + "\n", + " cur_seed = 42\n", + " images = ip_model.generate_edit_mlp_lr_multi(pil_image = ip_image, image=init_img, control_image=depth_map, \\\n", + " mask_image=mask, controlnet_conditioning_scale=1., num_samples=1, \\\n", + " num_inference_steps=30, seed=cur_seed, edit_strength=edit_strength1, \\\n", + " edit_strength2=edit_strength2, strength=1)\n", + " all_images.append(images[0].resize((512,512)))\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d61d15b6-efe6-49fb-a33c-589de9781a08", + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "def show_image_grid(images, x, y, figsize=(10, 10)):\n", + " \"\"\"\n", + " Display a list of images in an x by y grid.\n", + "\n", + " Args:\n", + " images (list of np.array): List of images (e.g., numpy arrays).\n", + " x (int): Number of columns.\n", + " y (int): Number of rows.\n", + " figsize (tuple): Size of the figure.\n", + " \"\"\"\n", + " fig, axes = plt.subplots(y, x, figsize=figsize)\n", + " axes = axes.flatten()\n", + "\n", + " for i in range(x * y):\n", + " ax = axes[i]\n", + " if i < len(images):\n", + " ax.imshow(images[i])\n", + " ax.axis('off')\n", + " else:\n", + " ax.axis('off') # Hide unused subplots\n", + "\n", + " plt.tight_layout()\n", + " plt.show()\n", + "show_image_grid(all_images, len(edit_strengths1), len(edit_strengths2))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10b21813-7d4c-4f65-964b-c01ea1b21a38", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.17" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/parametric_control_mlp.py b/parametric_control_mlp.py new file mode 100644 index 0000000..1f31b05 --- /dev/null +++ b/parametric_control_mlp.py @@ -0,0 +1,32 @@ +import torch +import torch.nn as nn + + +class control_mlp(nn.Module): + def __init__(self, embedding_size): + super(control_mlp, self).__init__() + + self.fc1 = nn.Linear(embedding_size, 1024) + self.fc2 = nn.Linear(1024, 2048) + self.relu = nn.ReLU() + + self.edit_strength_fc1 = nn.Linear(1, 128) + self.edit_strength_fc2 = nn.Linear(128, 2) + + def forward(self, x, edit_strength): + x = self.relu(self.fc1(x)) + x = self.fc2(x) + + edit_strength = self.relu(self.edit_strength_fc1(edit_strength.unsqueeze(1))) + edit_strength = self.edit_strength_fc2(edit_strength) + + edit_strength1, edit_strength2 = edit_strength[:, 0], edit_strength[:, 1] + # print(edit_strength1.shape) + # exit() + + output = ( + edit_strength1.unsqueeze(1) * x[:, :1024] + + edit_strength2.unsqueeze(1) * x[:, 1024:] + ) + + return output diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..5bc438a --- /dev/null +++ b/requirements.txt @@ -0,0 +1,11 @@ +diffusers==0.30 +rembg[gpu] +einops==0.7.0 +transformers==4.27.4 +opencv-python==4.8.1.78 +gradio==5.29.0 +accelerate==0.26.1 +timm==0.6.12 +torch==2.3.0 +torchvision==0.18.0 +huggingface_hub==0.30.2 diff --git a/static/teaser.jpg b/static/teaser.jpg new file mode 100644 index 0000000..4ea8ed6 Binary files /dev/null and b/static/teaser.jpg differ diff --git a/try_blend.ipynb b/try_blend.ipynb new file mode 100644 index 0000000..ded3f80 --- /dev/null +++ b/try_blend.ipynb @@ -0,0 +1,58 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from marble import run_blend, setup_pipeline\n", + "from PIL import Image" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ip_adapter = setup_pipeline()\n", + "target_image = Image.open(\"input_images/context_image/beetle.png\")\n", + "ip_image = Image.open(\"input_images/texture/low_roughness.png\")\n", + "ip_image2 = Image.open(\"input_images/texture/high_roughness.png\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for edit_strength in [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]:\n", + " blended_image = run_blend(ip_adapter, target_image, ip_image, ip_image2, edit_strength=edit_strength)\n", + " blended_image.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.16" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}