diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..3b022f1
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,179 @@
+IP-Adapter/
+models/
+sdxl_models/
+.gradio/
+
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# UV
+# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
+# This is especially recommended for binary packages to ensure reproducibility, and is more
+# commonly ignored for libraries.
+#uv.lock
+
+# poetry
+# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
+# This is especially recommended for binary packages to ensure reproducibility, and is more
+# commonly ignored for libraries.
+# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
+#poetry.lock
+
+# pdm
+# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
+#pdm.lock
+# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
+# in version control.
+# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
+.pdm.toml
+.pdm-python
+.pdm-build/
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+# PyCharm
+# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
+# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
+# and can be added to the global gitignore or merged into this file. For a more nuclear
+# option (not recommended) you can uncomment the following to ignore the entire idea folder.
+#.idea/
+
+# Ruff stuff:
+.ruff_cache/
+
+# PyPI configuration file
+.pypirc
diff --git a/README.md b/README.md
index 3b31fa1..65a7f5d 100644
--- a/README.md
+++ b/README.md
@@ -1,2 +1,70 @@
-# marble
-This is the official code base for MARBLE
+# MARBLE
+
+
+
+
+
+This is the official implementation of MARBLE: Material Recomposition and Blending in CLIP-Space. Given an input image, MARBLE enables edits from material transfer, blending, to parametric control.
+
+
+
+
+## Installation
+MARBLE material blocks are built from the code base of [InstantStyle](https://github.com/instantX-research/InstantStyle). Additional functions are added into the `ip_adapter_instantstyle/ip_adapter.py` but please cite their papers accordingly.
+
+The current code base is tested on Python 3.9.7.
+
+We will begin by cloning this repo:
+
+```
+git clone https://github.com/Stability-AI/marble.git
+```
+
+Then, install the latest the libraries with:
+
+```
+cd marble
+pip install -r requirements.txt
+```
+
+## Usage
+
+After installation and downloading the models, you can use the two demos `try_blend.ipynb` for material blending and `parametric_control.ipynb` for material transfer + multi-attribute parametric control.
+
+We also provide a gradio demo which can be run with `python gradio_demo.py`.
+
+
+### Using your own materials
+For material transfer, you could add your images `input_images/texture/`.
+
+
+### ComfyUI extension
+
+Custom nodes and an [example workflow](./example_workflow.json) are provided for [ComfyUI](https://github.com/comfyanonymous/ComfyUI).
+
+To install:
+
+* Clone this repo into ```custom_nodes```:
+ ```shell
+ $ cd ComfyUI/custom_nodes
+ $ git clone https://github.com/Stability-AI/marble
+ ```
+* Install dependencies:
+ ```shell
+ $ cd marble
+ $ pip install -r requirements.txt
+ ```
+* Restart ComfyUI
+
+
+## Citation
+If you find MARBLE helpful in your research/applications, please cite using this BibTeX:
+
+```bibtex
+@article{cheng2024marble,
+ title={MARBLE: Material Recomposition and Blending in CLIP-Space},
+ author={Cheng, Ta-Ying and Sharma, Prafull and Boss, Mark and Jampani, Varun},
+ journal={CVPR},
+ year={2025}
+}
+```
diff --git a/__init__.py b/__init__.py
new file mode 100644
index 0000000..1168fd9
--- /dev/null
+++ b/__init__.py
@@ -0,0 +1,247 @@
+import sys
+import os
+
+sys.path.append(os.path.dirname(__file__))
+
+import comfy.model_management
+import torch
+from PIL import Image
+import numpy as np
+from .marble import (
+ setup_control_mlps,
+ setup_pipeline,
+ run_blend,
+ run_parametric_control,
+)
+
+
+# Add conversion functions
+def tensor_to_pil(tensor):
+ if isinstance(tensor, torch.Tensor):
+ tensor = tensor.squeeze(0)
+ # Convert to numpy and scale to 0-255
+ image = (tensor.cpu().numpy() * 255).astype(np.uint8)
+ return Image.fromarray(image)
+ return tensor
+
+
+def pil_to_tensor(pil_image):
+ if isinstance(pil_image, Image.Image):
+ # Convert PIL to numpy array
+ image = np.array(pil_image)
+ # Convert to tensor and normalize to 0-1
+ tensor = torch.from_numpy(image).float() / 255.0
+ tensor = tensor.unsqueeze(0)
+ device = comfy.model_management.get_torch_device()
+ tensor = tensor.to(device)
+ return tensor
+ return pil_image
+
+
+MARBLE_CATEGORY = "marble"
+
+
+class MarbleControlMLPLoader:
+ CATEGORY = MARBLE_CATEGORY
+ FUNCTION = "load"
+ RETURN_NAMES = ["control_mlp"]
+ RETURN_TYPES = ["CONTROL_MLP"]
+
+ @classmethod
+ def INPUT_TYPES(cls):
+ return {}
+
+ def load(self):
+ device = comfy.model_management.get_torch_device()
+ mlps = setup_control_mlps(device=device)
+ return (mlps,)
+
+
+class MarbleIPAdapterLoader:
+ CATEGORY = MARBLE_CATEGORY
+ FUNCTION = "load"
+ RETURN_NAMES = ["ip_adapter"]
+ RETURN_TYPES = ["IP_ADAPTER"]
+
+ @classmethod
+ def INPUT_TYPES(cls):
+ return {}
+
+ def load(self):
+ device = comfy.model_management.get_torch_device()
+ ip_adapter = setup_pipeline(device=device)
+ return (ip_adapter,)
+
+
+class MarbleBlendNode:
+ CATEGORY = MARBLE_CATEGORY
+ FUNCTION = "blend"
+ RETURN_NAMES = ["image"]
+ RETURN_TYPES = ["IMAGE"]
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "ip_adapter": ("IP_ADAPTER",),
+ "image": ("IMAGE",),
+ "texture_image1": ("IMAGE",),
+ "texture_image2": ("IMAGE",),
+ "edit_strength": (
+ "FLOAT",
+ {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01},
+ ),
+ "num_inference_steps": (
+ "INT",
+ {"default": 20, "min": 1, "max": 100, "step": 1},
+ ),
+ "seed": (
+ "INT",
+ {"default": 42, "min": 0, "max": 2147483647, "step": 1},
+ ),
+ },
+ "optional": {
+ "mask": ("MASK", {"default": None}),
+ "depth_map": ("IMAGE", {"default": None}),
+ },
+ }
+
+ def blend(
+ self,
+ ip_adapter,
+ image,
+ texture_image1,
+ texture_image2,
+ edit_strength,
+ num_inference_steps,
+ seed,
+ mask=None,
+ depth_map=None,
+ ):
+ # Convert all inputs to PIL
+ pil_image = tensor_to_pil(image)
+ pil_texture1 = tensor_to_pil(texture_image1)
+ pil_texture2 = tensor_to_pil(texture_image2)
+ pil_depth_map = tensor_to_pil(depth_map) if depth_map is not None else None
+
+ result = run_blend(
+ ip_adapter,
+ pil_image,
+ pil_texture1,
+ pil_texture2,
+ edit_strength=edit_strength,
+ num_inference_steps=num_inference_steps,
+ seed=seed,
+ depth_map=pil_depth_map,
+ mask=mask,
+ )
+ # Convert result back to tensor
+ return (pil_to_tensor(result),)
+
+
+class MarbleParametricControl:
+ CATEGORY = MARBLE_CATEGORY
+ FUNCTION = "parametric_control"
+ RETURN_NAMES = ["image"]
+ RETURN_TYPES = ["IMAGE"]
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "ip_adapter": ("IP_ADAPTER",),
+ "image": ("IMAGE",),
+ "control_mlps": ("CONTROL_MLP",),
+ "num_inference_steps": (
+ "INT",
+ {"default": 30, "min": 1, "max": 100, "step": 1},
+ ),
+ "seed": (
+ "INT",
+ {"default": 42, "min": 0, "max": 2147483647, "step": 1},
+ ),
+ },
+ "optional": {
+ "mask": ("MASK", {"default": None}),
+ "texture_image": ("IMAGE", {"default": None}),
+ "depth_map": ("IMAGE", {"default": None}),
+ "metallic_strength": (
+ "FLOAT",
+ {"default": 0.0, "min": -20.0, "max": 20.0, "step": 0.1},
+ ),
+ "roughness_strength": (
+ "FLOAT",
+ {"default": 0.0, "min": -1.0, "max": 1.0, "step": 0.05},
+ ),
+ "transparency_strength": (
+ "FLOAT",
+ {"default": 0.0, "min": 0.0, "max": 4.0, "step": 0.1},
+ ),
+ "glow_strength": (
+ "FLOAT",
+ {"default": 0.0, "min": -1.0, "max": 3.0, "step": 0.1},
+ ),
+ },
+ }
+
+ def parametric_control(
+ self,
+ ip_adapter,
+ image,
+ control_mlps,
+ num_inference_steps,
+ seed,
+ mask=None,
+ texture_image=None,
+ depth_map=None,
+ metallic_strength=0.0,
+ roughness_strength=0.0,
+ transparency_strength=0.0,
+ glow_strength=0.0,
+ ):
+ # Convert inputs to PIL
+ pil_image = tensor_to_pil(image)
+ pil_texture = (
+ tensor_to_pil(texture_image) if texture_image is not None else None
+ )
+ pil_depth_map = tensor_to_pil(depth_map) if depth_map is not None else None
+
+ edit_mlps = {}
+ for mlp_name, strength in [
+ ("metallic", metallic_strength),
+ ("roughness", roughness_strength),
+ ("transparency", transparency_strength),
+ ("glow", glow_strength),
+ ]:
+ if mlp_name in control_mlps and strength != 0.0:
+ edit_mlps[control_mlps[mlp_name]] = strength
+
+ result = run_parametric_control(
+ ip_adapter,
+ pil_image,
+ edit_mlps,
+ texture_image=pil_texture,
+ num_inference_steps=num_inference_steps,
+ seed=seed,
+ depth_map=pil_depth_map,
+ mask=mask,
+ )
+ # Convert result back to tensor
+ return (pil_to_tensor(result),)
+
+
+NODE_CLASS_MAPPINGS = {
+ "MarbleControlMLPLoader": MarbleControlMLPLoader,
+ "MarbleIPAdapterLoader": MarbleIPAdapterLoader,
+ "MarbleBlendNode": MarbleBlendNode,
+ "MarbleParametricControl": MarbleParametricControl,
+}
+
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "MarbleControlMLPLoader": "Marble Control MLP Loader",
+ "MarbleIPAdapterLoader": "Marble IP Adapter Loader",
+ "MarbleBlendNode": "Marble Blend Node",
+ "MarbleParametricControl": "Marble Parametric Control",
+}
+
+__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]
diff --git a/example_workflow.json b/example_workflow.json
new file mode 100644
index 0000000..671235f
--- /dev/null
+++ b/example_workflow.json
@@ -0,0 +1,475 @@
+{
+ "id": "dac030d6-0516-49af-a557-37e0cabcfecc",
+ "revision": 0,
+ "last_node_id": 26,
+ "last_link_id": 31,
+ "nodes": [
+ {
+ "id": 13,
+ "type": "LoadImage",
+ "pos": [
+ 26.386796951293945,
+ 336.0611572265625
+ ],
+ "size": [
+ 315,
+ 314.0000305175781
+ ],
+ "flags": {},
+ "order": 0,
+ "mode": 0,
+ "inputs": [],
+ "outputs": [
+ {
+ "name": "IMAGE",
+ "type": "IMAGE",
+ "slot_index": 0,
+ "links": [
+ 24,
+ 30
+ ]
+ },
+ {
+ "name": "MASK",
+ "type": "MASK",
+ "links": null
+ }
+ ],
+ "properties": {
+ "Node name for S&R": "LoadImage"
+ },
+ "widgets_values": [
+ "beetle.png",
+ "image"
+ ]
+ },
+ {
+ "id": 11,
+ "type": "MarbleIPAdapterLoader",
+ "pos": [
+ 53.263065338134766,
+ -7.059659957885742
+ ],
+ "size": [
+ 302.4000244140625,
+ 26
+ ],
+ "flags": {},
+ "order": 1,
+ "mode": 0,
+ "inputs": [],
+ "outputs": [
+ {
+ "name": "ip_adapter",
+ "type": "IP_ADAPTER",
+ "slot_index": 0,
+ "links": [
+ 25,
+ 29
+ ]
+ }
+ ],
+ "properties": {
+ "Node name for S&R": "MarbleIPAdapterLoader"
+ },
+ "widgets_values": []
+ },
+ {
+ "id": 24,
+ "type": "LoadImage",
+ "pos": [
+ 30.925029754638672,
+ 717.4205322265625
+ ],
+ "size": [
+ 274.080078125,
+ 314
+ ],
+ "flags": {},
+ "order": 2,
+ "mode": 0,
+ "inputs": [],
+ "outputs": [
+ {
+ "name": "IMAGE",
+ "type": "IMAGE",
+ "links": [
+ 26
+ ]
+ },
+ {
+ "name": "MASK",
+ "type": "MASK",
+ "links": null
+ }
+ ],
+ "properties": {
+ "Node name for S&R": "LoadImage"
+ },
+ "widgets_values": [
+ "low_roughness.png",
+ "image"
+ ]
+ },
+ {
+ "id": 25,
+ "type": "LoadImage",
+ "pos": [
+ 33.45697021484375,
+ 1090.3607177734375
+ ],
+ "size": [
+ 274.080078125,
+ 314
+ ],
+ "flags": {},
+ "order": 3,
+ "mode": 0,
+ "inputs": [],
+ "outputs": [
+ {
+ "name": "IMAGE",
+ "type": "IMAGE",
+ "links": [
+ 27
+ ]
+ },
+ {
+ "name": "MASK",
+ "type": "MASK",
+ "links": null
+ }
+ ],
+ "properties": {
+ "Node name for S&R": "LoadImage"
+ },
+ "widgets_values": [
+ "high_roughness.png",
+ "image"
+ ]
+ },
+ {
+ "id": 23,
+ "type": "PreviewImage",
+ "pos": [
+ 1525.1435546875,
+ 782.8226928710938
+ ],
+ "size": [
+ 646.029541015625,
+ 524.1215209960938
+ ],
+ "flags": {},
+ "order": 7,
+ "mode": 0,
+ "inputs": [
+ {
+ "name": "images",
+ "type": "IMAGE",
+ "link": 23
+ }
+ ],
+ "outputs": [],
+ "properties": {
+ "Node name for S&R": "PreviewImage"
+ },
+ "widgets_values": []
+ },
+ {
+ "id": 10,
+ "type": "MarbleControlMLPLoader",
+ "pos": [
+ 43.088321685791016,
+ 110.31293487548828
+ ],
+ "size": [
+ 315,
+ 58
+ ],
+ "flags": {},
+ "order": 4,
+ "mode": 0,
+ "inputs": [],
+ "outputs": [
+ {
+ "name": "control_mlp",
+ "type": "CONTROL_MLP",
+ "slot_index": 0,
+ "links": [
+ 28
+ ]
+ }
+ ],
+ "properties": {
+ "Node name for S&R": "MarbleControlMLPLoader"
+ },
+ "widgets_values": []
+ },
+ {
+ "id": 26,
+ "type": "MarbleParametricControl",
+ "pos": [
+ 890.8665771484375,
+ -15.984257698059082
+ ],
+ "size": [
+ 283.40234375,
+ 302
+ ],
+ "flags": {},
+ "order": 6,
+ "mode": 0,
+ "inputs": [
+ {
+ "name": "ip_adapter",
+ "type": "IP_ADAPTER",
+ "link": 29
+ },
+ {
+ "name": "image",
+ "type": "IMAGE",
+ "link": 30
+ },
+ {
+ "name": "control_mlps",
+ "type": "CONTROL_MLP",
+ "link": 28
+ },
+ {
+ "name": "mask",
+ "shape": 7,
+ "type": "MASK",
+ "link": null
+ },
+ {
+ "name": "texture_image",
+ "shape": 7,
+ "type": "IMAGE",
+ "link": null
+ },
+ {
+ "name": "depth_map",
+ "shape": 7,
+ "type": "IMAGE",
+ "link": null
+ }
+ ],
+ "outputs": [
+ {
+ "name": "image",
+ "type": "IMAGE",
+ "links": [
+ 31
+ ]
+ }
+ ],
+ "properties": {
+ "Node name for S&R": "MarbleParametricControl"
+ },
+ "widgets_values": [
+ 30,
+ 1678732183,
+ "randomize",
+ 0,
+ -0.39000000000000007,
+ 0,
+ 0
+ ]
+ },
+ {
+ "id": 22,
+ "type": "MarbleBlendNode",
+ "pos": [
+ 902.6680908203125,
+ 776.9732055664062
+ ],
+ "size": [
+ 278.73828125,
+ 230
+ ],
+ "flags": {},
+ "order": 5,
+ "mode": 0,
+ "inputs": [
+ {
+ "name": "ip_adapter",
+ "type": "IP_ADAPTER",
+ "link": 25
+ },
+ {
+ "name": "image",
+ "type": "IMAGE",
+ "link": 24
+ },
+ {
+ "name": "texture_image1",
+ "type": "IMAGE",
+ "link": 26
+ },
+ {
+ "name": "texture_image2",
+ "type": "IMAGE",
+ "link": 27
+ },
+ {
+ "name": "mask",
+ "shape": 7,
+ "type": "MASK",
+ "link": null
+ },
+ {
+ "name": "depth_map",
+ "shape": 7,
+ "type": "IMAGE",
+ "link": null
+ }
+ ],
+ "outputs": [
+ {
+ "name": "image",
+ "type": "IMAGE",
+ "links": [
+ 23
+ ]
+ }
+ ],
+ "properties": {
+ "Node name for S&R": "MarbleBlendNode"
+ },
+ "widgets_values": [
+ 1,
+ 20,
+ 1012675178,
+ "randomize"
+ ]
+ },
+ {
+ "id": 18,
+ "type": "PreviewImage",
+ "pos": [
+ 1529.1248779296875,
+ -12.8727445602417
+ ],
+ "size": [
+ 634.3290405273438,
+ 617.0318603515625
+ ],
+ "flags": {},
+ "order": 8,
+ "mode": 0,
+ "inputs": [
+ {
+ "name": "images",
+ "type": "IMAGE",
+ "link": 31
+ }
+ ],
+ "outputs": [],
+ "properties": {
+ "Node name for S&R": "PreviewImage"
+ },
+ "widgets_values": []
+ }
+ ],
+ "links": [
+ [
+ 23,
+ 22,
+ 0,
+ 23,
+ 0,
+ "IMAGE"
+ ],
+ [
+ 24,
+ 13,
+ 0,
+ 22,
+ 1,
+ "IMAGE"
+ ],
+ [
+ 25,
+ 11,
+ 0,
+ 22,
+ 0,
+ "IP_ADAPTER"
+ ],
+ [
+ 26,
+ 24,
+ 0,
+ 22,
+ 2,
+ "IMAGE"
+ ],
+ [
+ 27,
+ 25,
+ 0,
+ 22,
+ 3,
+ "IMAGE"
+ ],
+ [
+ 28,
+ 10,
+ 0,
+ 26,
+ 2,
+ "CONTROL_MLP"
+ ],
+ [
+ 29,
+ 11,
+ 0,
+ 26,
+ 0,
+ "IP_ADAPTER"
+ ],
+ [
+ 30,
+ 13,
+ 0,
+ 26,
+ 1,
+ "IMAGE"
+ ],
+ [
+ 31,
+ 26,
+ 0,
+ 18,
+ 0,
+ "IMAGE"
+ ]
+ ],
+ "groups": [
+ {
+ "id": 1,
+ "title": "Loaders",
+ "bounding": [
+ -3.8137731552124023,
+ -104.17147064208984,
+ 413.9270935058594,
+ 313.8103332519531
+ ],
+ "color": "#3f789e",
+ "font_size": 24,
+ "flags": {}
+ }
+ ],
+ "config": {},
+ "extra": {
+ "ds": {
+ "scale": 0.7513148009015782,
+ "offset": [
+ 211.50521766916663,
+ 238.1879270976759
+ ]
+ },
+ "frontendVersion": "1.18.9"
+ },
+ "version": 0.4
+}
\ No newline at end of file
diff --git a/gradio_demo.py b/gradio_demo.py
new file mode 100644
index 0000000..5a26910
--- /dev/null
+++ b/gradio_demo.py
@@ -0,0 +1,272 @@
+import gradio as gr
+from PIL import Image
+
+from marble import (
+ get_session,
+ run_blend,
+ run_parametric_control,
+ setup_control_mlps,
+ setup_pipeline,
+)
+
+# Setup the pipeline and control MLPs
+control_mlps = setup_control_mlps()
+ip_adapter = setup_pipeline()
+get_session()
+
+# Load example images
+EXAMPLE_IMAGES = {
+ "blend": {
+ "target": "input_images/context_image/beetle.png",
+ "texture1": "input_images/texture/low_roughness.png",
+ "texture2": "input_images/texture/high_roughness.png",
+ },
+ "parametric": {
+ "target": "input_images/context_image/toy_car.png",
+ "texture": "input_images/texture/metal_bowl.png",
+ },
+}
+
+
+def blend_images(target_image, texture1, texture2, edit_strength):
+ """Blend between two texture images"""
+ result = run_blend(
+ ip_adapter, target_image, texture1, texture2, edit_strength=edit_strength
+ )
+ return result
+
+
+def parametric_control(
+ target_image,
+ texture_image,
+ control_type,
+ metallic_strength,
+ roughness_strength,
+ transparency_strength,
+ glow_strength,
+):
+ """Apply parametric control based on selected control type"""
+ edit_mlps = {}
+
+ if control_type == "Roughness + Metallic":
+ edit_mlps = {
+ control_mlps["metallic"]: metallic_strength,
+ control_mlps["roughness"]: roughness_strength,
+ }
+ elif control_type == "Transparency":
+ edit_mlps = {
+ control_mlps["transparency"]: transparency_strength,
+ }
+ elif control_type == "Glow":
+ edit_mlps = {
+ control_mlps["glow"]: glow_strength,
+ }
+
+ # Use target image as texture if no texture is provided
+ texture_to_use = texture_image if texture_image is not None else target_image
+
+ result = run_parametric_control(
+ ip_adapter,
+ target_image,
+ edit_mlps,
+ texture_to_use,
+ )
+ return result
+
+
+# Create the Gradio interface
+with gr.Blocks(
+ title="MARBLE: Material Recomposition and Blending in CLIP-Space"
+) as demo:
+ gr.Markdown(
+ """
+ # MARBLE: Material Recomposition and Blending in CLIP-Space
+
+

+

+
+
+ MARBLE is a tool for material recomposition and blending in CLIP-Space.
+ We provide two modes of operation:
+ - **Texture Blending**: Blend the material properties of two texture images and apply it to a target image.
+ - **Parametric Control**: Apply parametric material control to a target image. You can either provide a texture image, transferring the material properties of the texture to the original image, or you can just provide a target image, and edit the material properties of the original image.
+ """
+ )
+
+ with gr.Row(variant="panel"):
+ with gr.Tabs():
+ with gr.TabItem("Texture Blending"):
+ with gr.Row(equal_height=False):
+ with gr.Column():
+ with gr.Row():
+ texture1 = gr.Image(label="Texture 1", type="pil")
+ texture2 = gr.Image(label="Texture 2", type="pil")
+ edit_strength = gr.Slider(
+ minimum=0.0,
+ maximum=1.0,
+ value=0.5,
+ step=0.1,
+ label="Blend Strength",
+ )
+ with gr.Column():
+ with gr.Row():
+ target_image = gr.Image(label="Target Image", type="pil")
+ blend_output = gr.Image(label="Blended Result")
+ blend_btn = gr.Button("Blend Textures")
+
+ # Add examples for blending
+ gr.Examples(
+ examples=[
+ [
+ Image.open(EXAMPLE_IMAGES["blend"]["target"]),
+ Image.open(EXAMPLE_IMAGES["blend"]["texture1"]),
+ Image.open(EXAMPLE_IMAGES["blend"]["texture2"]),
+ 0.5,
+ ]
+ ],
+ inputs=[target_image, texture1, texture2, edit_strength],
+ outputs=blend_output,
+ fn=blend_images,
+ cache_examples=True,
+ )
+
+ blend_btn.click(
+ fn=blend_images,
+ inputs=[target_image, texture1, texture2, edit_strength],
+ outputs=blend_output,
+ )
+
+ with gr.TabItem("Parametric Control"):
+ with gr.Row(equal_height=False):
+ with gr.Column():
+ with gr.Row():
+ target_image_pc = gr.Image(label="Target Image", type="pil")
+ texture_image_pc = gr.Image(
+ label="Texture Image (Optional - uses target image if not provided)",
+ type="pil",
+ )
+ control_type = gr.Dropdown(
+ choices=["Roughness + Metallic", "Transparency", "Glow"],
+ value="Roughness + Metallic",
+ label="Control Type",
+ )
+
+ metallic_strength = gr.Slider(
+ minimum=-20,
+ maximum=20,
+ value=0,
+ step=0.1,
+ label="Metallic Strength",
+ visible=True,
+ )
+ roughness_strength = gr.Slider(
+ minimum=-1,
+ maximum=1,
+ value=0,
+ step=0.1,
+ label="Roughness Strength",
+ visible=True,
+ )
+ transparency_strength = gr.Slider(
+ minimum=0,
+ maximum=4,
+ value=0,
+ step=0.1,
+ label="Transparency Strength",
+ visible=False,
+ )
+ glow_strength = gr.Slider(
+ minimum=0,
+ maximum=3,
+ value=0,
+ step=0.1,
+ label="Glow Strength",
+ visible=False,
+ )
+ control_btn = gr.Button("Apply Control")
+
+ with gr.Column():
+ control_output = gr.Image(label="Result")
+
+ def update_slider_visibility(control_type):
+ return [
+ gr.update(visible=control_type == "Roughness + Metallic"),
+ gr.update(visible=control_type == "Roughness + Metallic"),
+ gr.update(visible=control_type == "Transparency"),
+ gr.update(visible=control_type == "Glow"),
+ ]
+
+ control_type.change(
+ fn=update_slider_visibility,
+ inputs=[control_type],
+ outputs=[
+ metallic_strength,
+ roughness_strength,
+ transparency_strength,
+ glow_strength,
+ ],
+ show_progress=False,
+ )
+
+ # Add examples for parametric control
+ gr.Examples(
+ examples=[
+ [
+ Image.open(EXAMPLE_IMAGES["parametric"]["target"]),
+ Image.open(EXAMPLE_IMAGES["parametric"]["texture"]),
+ "Roughness + Metallic",
+ 0, # metallic_strength
+ 0, # roughness_strength
+ 0, # transparency_strength
+ 0, # glow_strength
+ ],
+ [
+ Image.open(EXAMPLE_IMAGES["parametric"]["target"]),
+ Image.open(EXAMPLE_IMAGES["parametric"]["texture"]),
+ "Roughness + Metallic",
+ 20, # metallic_strength
+ 0, # roughness_strength
+ 0, # transparency_strength
+ 0, # glow_strength
+ ],
+ [
+ Image.open(EXAMPLE_IMAGES["parametric"]["target"]),
+ Image.open(EXAMPLE_IMAGES["parametric"]["texture"]),
+ "Roughness + Metallic",
+ 0, # metallic_strength
+ 1, # roughness_strength
+ 0, # transparency_strength
+ 0, # glow_strength
+ ],
+ ],
+ inputs=[
+ target_image_pc,
+ texture_image_pc,
+ control_type,
+ metallic_strength,
+ roughness_strength,
+ transparency_strength,
+ glow_strength,
+ ],
+ outputs=control_output,
+ fn=parametric_control,
+ cache_examples=True,
+ )
+
+ control_btn.click(
+ fn=parametric_control,
+ inputs=[
+ target_image_pc,
+ texture_image_pc,
+ control_type,
+ metallic_strength,
+ roughness_strength,
+ transparency_strength,
+ glow_strength,
+ ],
+ outputs=control_output,
+ )
+
+if __name__ == "__main__":
+ demo.launch()
diff --git a/input_images/context_image/beetle.png b/input_images/context_image/beetle.png
new file mode 100644
index 0000000..eef4c87
Binary files /dev/null and b/input_images/context_image/beetle.png differ
diff --git a/input_images/context_image/genart_teapot.jpg b/input_images/context_image/genart_teapot.jpg
new file mode 100644
index 0000000..2e616ac
Binary files /dev/null and b/input_images/context_image/genart_teapot.jpg differ
diff --git a/input_images/context_image/toy_car.png b/input_images/context_image/toy_car.png
new file mode 100644
index 0000000..b52e880
Binary files /dev/null and b/input_images/context_image/toy_car.png differ
diff --git a/input_images/context_image/white_car_night.jpg b/input_images/context_image/white_car_night.jpg
new file mode 100644
index 0000000..4eb0a71
Binary files /dev/null and b/input_images/context_image/white_car_night.jpg differ
diff --git a/input_images/depth/beetle.png b/input_images/depth/beetle.png
new file mode 100644
index 0000000..7f1a9d7
Binary files /dev/null and b/input_images/depth/beetle.png differ
diff --git a/input_images/depth/toy_car.png b/input_images/depth/toy_car.png
new file mode 100644
index 0000000..8c70a8d
Binary files /dev/null and b/input_images/depth/toy_car.png differ
diff --git a/input_images/texture/high_roughness.png b/input_images/texture/high_roughness.png
new file mode 100644
index 0000000..4eb22be
Binary files /dev/null and b/input_images/texture/high_roughness.png differ
diff --git a/input_images/texture/low_roughness.png b/input_images/texture/low_roughness.png
new file mode 100644
index 0000000..b3555a6
Binary files /dev/null and b/input_images/texture/low_roughness.png differ
diff --git a/input_images/texture/metal_bowl.png b/input_images/texture/metal_bowl.png
new file mode 100644
index 0000000..42fa36b
Binary files /dev/null and b/input_images/texture/metal_bowl.png differ
diff --git a/ip_adapter_instantstyle/__init__.py b/ip_adapter_instantstyle/__init__.py
new file mode 100644
index 0000000..3b1f1ff
--- /dev/null
+++ b/ip_adapter_instantstyle/__init__.py
@@ -0,0 +1,9 @@
+from .ip_adapter import IPAdapter, IPAdapterPlus, IPAdapterPlusXL, IPAdapterXL, IPAdapterFull
+
+__all__ = [
+ "IPAdapter",
+ "IPAdapterPlus",
+ "IPAdapterPlusXL",
+ "IPAdapterXL",
+ "IPAdapterFull",
+]
diff --git a/ip_adapter_instantstyle/attention_processor.py b/ip_adapter_instantstyle/attention_processor.py
new file mode 100644
index 0000000..6d30ffa
--- /dev/null
+++ b/ip_adapter_instantstyle/attention_processor.py
@@ -0,0 +1,562 @@
+# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class AttnProcessor(nn.Module):
+ r"""
+ Default processor for performing attention-related computations.
+ """
+
+ def __init__(
+ self,
+ hidden_size=None,
+ cross_attention_dim=None,
+ ):
+ super().__init__()
+
+ def __call__(
+ self,
+ attn,
+ hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ temb=None,
+ ):
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ query = attn.head_to_batch_dim(query)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
+ hidden_states = torch.bmm(attention_probs, value)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class IPAttnProcessor(nn.Module):
+ r"""
+ Attention processor for IP-Adapater.
+ Args:
+ hidden_size (`int`):
+ The hidden size of the attention layer.
+ cross_attention_dim (`int`):
+ The number of channels in the `encoder_hidden_states`.
+ scale (`float`, defaults to 1.0):
+ the weight scale of image prompt.
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
+ The context length of the image features.
+ """
+
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, skip=False):
+ super().__init__()
+
+ self.hidden_size = hidden_size
+ self.cross_attention_dim = cross_attention_dim
+ self.scale = scale
+ self.num_tokens = num_tokens
+ self.skip = skip
+
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+
+ def __call__(
+ self,
+ attn,
+ hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ temb=None,
+ ):
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ else:
+ # get encoder_hidden_states, ip_hidden_states
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
+ encoder_hidden_states, ip_hidden_states = (
+ encoder_hidden_states[:, :end_pos, :],
+ encoder_hidden_states[:, end_pos:, :],
+ )
+ if attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ query = attn.head_to_batch_dim(query)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
+ hidden_states = torch.bmm(attention_probs, value)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ if not self.skip:
+ # for ip-adapter
+ ip_key = self.to_k_ip(ip_hidden_states)
+ ip_value = self.to_v_ip(ip_hidden_states)
+
+ ip_key = attn.head_to_batch_dim(ip_key)
+ ip_value = attn.head_to_batch_dim(ip_value)
+
+ ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
+ self.attn_map = ip_attention_probs
+ ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
+ ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
+
+ hidden_states = hidden_states + self.scale * ip_hidden_states
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class AttnProcessor2_0(torch.nn.Module):
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
+ """
+
+ def __init__(
+ self,
+ hidden_size=None,
+ cross_attention_dim=None,
+ ):
+ super().__init__()
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn,
+ hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ temb=None,
+ ):
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class IPAttnProcessor2_0(torch.nn.Module):
+ r"""
+ Attention processor for IP-Adapater for PyTorch 2.0.
+ Args:
+ hidden_size (`int`):
+ The hidden size of the attention layer.
+ cross_attention_dim (`int`):
+ The number of channels in the `encoder_hidden_states`.
+ scale (`float`, defaults to 1.0):
+ the weight scale of image prompt.
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
+ The context length of the image features.
+ """
+
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, skip=False):
+ super().__init__()
+
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ self.hidden_size = hidden_size
+ self.cross_attention_dim = cross_attention_dim
+ self.scale = scale
+ self.num_tokens = num_tokens
+ self.skip = skip
+
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+
+ def __call__(
+ self,
+ attn,
+ hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ temb=None,
+ ):
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ else:
+ # get encoder_hidden_states, ip_hidden_states
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
+ encoder_hidden_states, ip_hidden_states = (
+ encoder_hidden_states[:, :end_pos, :],
+ encoder_hidden_states[:, end_pos:, :],
+ )
+ if attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ if not self.skip:
+ # for ip-adapter
+ ip_key = self.to_k_ip(ip_hidden_states)
+ ip_value = self.to_v_ip(ip_hidden_states)
+
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ ip_hidden_states = F.scaled_dot_product_attention(
+ query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
+ )
+ with torch.no_grad():
+ self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)
+ #print(self.attn_map.shape)
+
+ ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ ip_hidden_states = ip_hidden_states.to(query.dtype)
+
+ hidden_states = hidden_states + self.scale * ip_hidden_states
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+## for controlnet
+class CNAttnProcessor:
+ r"""
+ Default processor for performing attention-related computations.
+ """
+
+ def __init__(self, num_tokens=4):
+ self.num_tokens = num_tokens
+
+ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ else:
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
+ encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
+ if attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ query = attn.head_to_batch_dim(query)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
+ hidden_states = torch.bmm(attention_probs, value)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class CNAttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
+ """
+
+ def __init__(self, num_tokens=4):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+ self.num_tokens = num_tokens
+
+ def __call__(
+ self,
+ attn,
+ hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ temb=None,
+ ):
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ else:
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
+ encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
+ if attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
diff --git a/ip_adapter_instantstyle/ip_adapter.py b/ip_adapter_instantstyle/ip_adapter.py
new file mode 100644
index 0000000..dd3eaa5
--- /dev/null
+++ b/ip_adapter_instantstyle/ip_adapter.py
@@ -0,0 +1,858 @@
+import os
+import glob
+from typing import List
+
+import torch
+import torch.nn as nn
+from diffusers import StableDiffusionPipeline
+from diffusers.pipelines.controlnet import MultiControlNetModel
+from PIL import Image
+from safetensors import safe_open
+from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
+
+from .utils import is_torch2_available, get_generator
+
+L = 4
+
+
+def pos_encode(x, L):
+ pos_encode = []
+
+ for freq in range(L):
+ pos_encode.append(torch.cos(2**freq * torch.pi * x))
+ pos_encode.append(torch.sin(2**freq * torch.pi * x))
+ pos_encode = torch.cat(pos_encode, dim=1)
+ return pos_encode
+
+
+if is_torch2_available():
+ from .attention_processor import (
+ AttnProcessor2_0 as AttnProcessor,
+ )
+ from .attention_processor import (
+ CNAttnProcessor2_0 as CNAttnProcessor,
+ )
+ from .attention_processor import (
+ IPAttnProcessor2_0 as IPAttnProcessor,
+ )
+else:
+ from .attention_processor import AttnProcessor, CNAttnProcessor, IPAttnProcessor
+from .resampler import Resampler
+
+
+class ImageProjModel(torch.nn.Module):
+ """Projection Model"""
+
+ def __init__(
+ self,
+ cross_attention_dim=1024,
+ clip_embeddings_dim=1024,
+ clip_extra_context_tokens=4,
+ ):
+ super().__init__()
+
+ self.generator = None
+ self.cross_attention_dim = cross_attention_dim
+ self.clip_extra_context_tokens = clip_extra_context_tokens
+ self.proj = torch.nn.Linear(
+ clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim
+ )
+ self.norm = torch.nn.LayerNorm(cross_attention_dim)
+
+ def forward(self, image_embeds):
+ embeds = image_embeds
+ clip_extra_context_tokens = self.proj(embeds).reshape(
+ -1, self.clip_extra_context_tokens, self.cross_attention_dim
+ )
+ clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
+ return clip_extra_context_tokens
+
+
+class MLPProjModel(torch.nn.Module):
+ """SD model with image prompt"""
+
+ def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024):
+ super().__init__()
+
+ self.proj = torch.nn.Sequential(
+ torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim),
+ torch.nn.GELU(),
+ torch.nn.Linear(clip_embeddings_dim, cross_attention_dim),
+ torch.nn.LayerNorm(cross_attention_dim),
+ )
+
+ def forward(self, image_embeds):
+ clip_extra_context_tokens = self.proj(image_embeds)
+ return clip_extra_context_tokens
+
+
+class IPAdapter:
+ def __init__(
+ self,
+ sd_pipe,
+ image_encoder_path,
+ ip_ckpt,
+ device,
+ num_tokens=4,
+ target_blocks=["block"],
+ ):
+ self.device = device
+ self.image_encoder_path = image_encoder_path
+ self.ip_ckpt = ip_ckpt
+ self.num_tokens = num_tokens
+ self.target_blocks = target_blocks
+
+ self.pipe = sd_pipe.to(self.device)
+ self.set_ip_adapter()
+
+ # load image encoder
+ self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(
+ self.image_encoder_path
+ ).to(self.device, dtype=torch.float16)
+ self.clip_image_processor = CLIPImageProcessor()
+ # image proj model
+ self.image_proj_model = self.init_proj()
+
+ self.load_ip_adapter()
+
+ def init_proj(self):
+ image_proj_model = ImageProjModel(
+ cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
+ clip_embeddings_dim=self.image_encoder.config.projection_dim,
+ clip_extra_context_tokens=self.num_tokens,
+ ).to(self.device, dtype=torch.float16)
+ return image_proj_model
+
+ def set_ip_adapter(self):
+ unet = self.pipe.unet
+ attn_procs = {}
+ for name in unet.attn_processors.keys():
+ cross_attention_dim = (
+ None
+ if name.endswith("attn1.processor")
+ else unet.config.cross_attention_dim
+ )
+ if name.startswith("mid_block"):
+ hidden_size = unet.config.block_out_channels[-1]
+ elif name.startswith("up_blocks"):
+ block_id = int(name[len("up_blocks.")])
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
+ elif name.startswith("down_blocks"):
+ block_id = int(name[len("down_blocks.")])
+ hidden_size = unet.config.block_out_channels[block_id]
+ if cross_attention_dim is None:
+ attn_procs[name] = AttnProcessor()
+ else:
+ selected = False
+ for block_name in self.target_blocks:
+ if block_name in name:
+ selected = True
+ break
+ if selected:
+ attn_procs[name] = IPAttnProcessor(
+ hidden_size=hidden_size,
+ cross_attention_dim=cross_attention_dim,
+ scale=1.0,
+ num_tokens=self.num_tokens,
+ ).to(self.device, dtype=torch.float16)
+ else:
+ attn_procs[name] = IPAttnProcessor(
+ hidden_size=hidden_size,
+ cross_attention_dim=cross_attention_dim,
+ scale=1.0,
+ num_tokens=self.num_tokens,
+ skip=True,
+ ).to(self.device, dtype=torch.float16)
+ unet.set_attn_processor(attn_procs)
+ if hasattr(self.pipe, "controlnet"):
+ if isinstance(self.pipe.controlnet, MultiControlNetModel):
+ for controlnet in self.pipe.controlnet.nets:
+ controlnet.set_attn_processor(
+ CNAttnProcessor(num_tokens=self.num_tokens)
+ )
+ else:
+ self.pipe.controlnet.set_attn_processor(
+ CNAttnProcessor(num_tokens=self.num_tokens)
+ )
+
+ def load_ip_adapter(self):
+ if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
+ state_dict = {"image_proj": {}, "ip_adapter": {}}
+ with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
+ for key in f.keys():
+ if key.startswith("image_proj."):
+ state_dict["image_proj"][key.replace("image_proj.", "")] = (
+ f.get_tensor(key)
+ )
+ elif key.startswith("ip_adapter."):
+ state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = (
+ f.get_tensor(key)
+ )
+ else:
+ state_dict = torch.load(self.ip_ckpt, map_location="cpu")
+ self.image_proj_model.load_state_dict(state_dict["image_proj"])
+ ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
+ ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False)
+
+ @torch.inference_mode()
+ def get_image_embeds(
+ self, pil_image=None, clip_image_embeds=None, content_prompt_embeds=None
+ ):
+ if pil_image is not None:
+ if isinstance(pil_image, Image.Image):
+ pil_image = [pil_image]
+ clip_image = self.clip_image_processor(
+ images=pil_image, return_tensors="pt"
+ ).pixel_values
+ clip_image_embeds = self.image_encoder(
+ clip_image.to(self.device, dtype=torch.float16)
+ ).image_embeds
+ else:
+ clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
+
+ if content_prompt_embeds is not None:
+ print(clip_image_embeds.shape)
+ print(content_prompt_embeds.shape)
+ clip_image_embeds = clip_image_embeds - content_prompt_embeds
+
+ image_prompt_embeds = self.image_proj_model(clip_image_embeds)
+ uncond_image_prompt_embeds = self.image_proj_model(
+ torch.zeros_like(clip_image_embeds)
+ )
+ return image_prompt_embeds, uncond_image_prompt_embeds
+
+ @torch.inference_mode()
+ def generate_image_edit_dir(
+ self,
+ pil_image=None,
+ content_prompt_embeds=None,
+ edit_mlps: dict[torch.nn.Module, float] = None,
+ ):
+ print("Combining multiple MLPs!")
+ if pil_image is not None:
+ if isinstance(pil_image, Image.Image):
+ pil_image = [pil_image]
+ clip_image = self.clip_image_processor(
+ images=pil_image, return_tensors="pt"
+ ).pixel_values
+ clip_image_embeds = self.image_encoder(
+ clip_image.to(self.device, dtype=torch.float16)
+ ).image_embeds
+ pred_editing_dirs = [
+ net(
+ clip_image_embeds,
+ torch.Tensor([strength]).to(self.device, dtype=torch.float16),
+ )
+ for net, strength in edit_mlps.items()
+ ]
+
+ clip_image_embeds = clip_image_embeds + sum(pred_editing_dirs)
+
+ if content_prompt_embeds is not None:
+ clip_image_embeds = clip_image_embeds - content_prompt_embeds
+
+ image_prompt_embeds = self.image_proj_model(clip_image_embeds)
+ uncond_image_prompt_embeds = self.image_proj_model(
+ torch.zeros_like(clip_image_embeds)
+ )
+ return image_prompt_embeds, uncond_image_prompt_embeds
+
+ @torch.inference_mode()
+ def get_image_edit_dir(
+ self,
+ start_image=None,
+ pil_image=None,
+ pil_image2=None,
+ content_prompt_embeds=None,
+ edit_strength=1.0,
+ ):
+ print("Blending Two Materials!")
+ if pil_image is not None:
+ if isinstance(pil_image, Image.Image):
+ pil_image = [pil_image]
+ clip_image = self.clip_image_processor(
+ images=pil_image, return_tensors="pt"
+ ).pixel_values
+ clip_image_embeds = self.image_encoder(
+ clip_image.to(self.device, dtype=torch.float16)
+ ).image_embeds
+
+ if pil_image2 is not None:
+ if isinstance(pil_image2, Image.Image):
+ pil_image2 = [pil_image2]
+ clip_image2 = self.clip_image_processor(
+ images=pil_image2, return_tensors="pt"
+ ).pixel_values
+ clip_image_embeds2 = self.image_encoder(
+ clip_image2.to(self.device, dtype=torch.float16)
+ ).image_embeds
+
+ if start_image is not None:
+ if isinstance(start_image, Image.Image):
+ start_image = [start_image]
+ clip_image_start = self.clip_image_processor(
+ images=start_image, return_tensors="pt"
+ ).pixel_values
+ clip_image_embeds_start = self.image_encoder(
+ clip_image_start.to(self.device, dtype=torch.float16)
+ ).image_embeds
+
+ if content_prompt_embeds is not None:
+ clip_image_embeds = clip_image_embeds - content_prompt_embeds
+ clip_image_embeds2 = clip_image_embeds2 - content_prompt_embeds
+
+ # clip_image_embeds += edit_strength * (clip_image_embeds2 - clip_image_embeds)
+ clip_image_embeds = clip_image_embeds_start + edit_strength * (
+ clip_image_embeds2 - clip_image_embeds
+ )
+
+ image_prompt_embeds = self.image_proj_model(clip_image_embeds)
+ uncond_image_prompt_embeds = self.image_proj_model(
+ torch.zeros_like(clip_image_embeds)
+ )
+ return image_prompt_embeds, uncond_image_prompt_embeds
+
+ def set_scale(self, scale):
+ for attn_processor in self.pipe.unet.attn_processors.values():
+ if isinstance(attn_processor, IPAttnProcessor):
+ attn_processor.scale = scale
+
+ def set_scale(self, scale):
+ for attn_processor in self.pipe.unet.attn_processors.values():
+ if isinstance(attn_processor, IPAttnProcessor):
+ attn_processor.scale = scale
+
+ def generate(
+ self,
+ pil_image=None,
+ clip_image_embeds=None,
+ prompt=None,
+ negative_prompt=None,
+ scale=1.0,
+ num_samples=4,
+ seed=None,
+ guidance_scale=7.5,
+ num_inference_steps=30,
+ neg_content_emb=None,
+ **kwargs,
+ ):
+ self.set_scale(scale)
+
+ if pil_image is not None:
+ num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
+ else:
+ num_prompts = clip_image_embeds.size(0)
+
+ if prompt is None:
+ prompt = "best quality, high quality"
+ if negative_prompt is None:
+ negative_prompt = (
+ "monochrome, lowres, bad anatomy, worst quality, low quality"
+ )
+
+ if not isinstance(prompt, List):
+ prompt = [prompt] * num_prompts
+ if not isinstance(negative_prompt, List):
+ negative_prompt = [negative_prompt] * num_prompts
+
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(
+ pil_image=pil_image,
+ clip_image_embeds=clip_image_embeds,
+ content_prompt_embeds=neg_content_emb,
+ )
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
+ image_prompt_embeds = image_prompt_embeds.view(
+ bs_embed * num_samples, seq_len, -1
+ )
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(
+ 1, num_samples, 1
+ )
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(
+ bs_embed * num_samples, seq_len, -1
+ )
+
+ with torch.inference_mode():
+ prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
+ prompt,
+ device=self.device,
+ num_images_per_prompt=num_samples,
+ do_classifier_free_guidance=True,
+ negative_prompt=negative_prompt,
+ )
+ prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
+ negative_prompt_embeds = torch.cat(
+ [negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1
+ )
+
+ generator = get_generator(seed, self.device)
+
+ images = self.pipe(
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ guidance_scale=guidance_scale,
+ num_inference_steps=num_inference_steps,
+ generator=generator,
+ **kwargs,
+ ).images
+
+ return images
+
+
+class IPAdapterXL(IPAdapter):
+ """SDXL"""
+
+ def generate(
+ self,
+ pil_image,
+ prompt=None,
+ negative_prompt=None,
+ scale=1.0,
+ num_samples=4,
+ seed=None,
+ num_inference_steps=30,
+ neg_content_emb=None,
+ neg_content_prompt=None,
+ neg_content_scale=1.0,
+ clip_strength=1.0,
+ **kwargs,
+ ):
+ self.set_scale(scale)
+
+ num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
+
+ if prompt is None:
+ prompt = "best quality, high quality"
+ if negative_prompt is None:
+ negative_prompt = (
+ "monochrome, lowres, bad anatomy, worst quality, low quality"
+ )
+
+ if not isinstance(prompt, List):
+ prompt = [prompt] * num_prompts
+ if not isinstance(negative_prompt, List):
+ negative_prompt = [negative_prompt] * num_prompts
+
+ if neg_content_emb is None:
+ if neg_content_prompt is not None:
+ with torch.inference_mode():
+ (
+ prompt_embeds_, # torch.Size([1, 77, 2048])
+ negative_prompt_embeds_,
+ pooled_prompt_embeds_, # torch.Size([1, 1280])
+ negative_pooled_prompt_embeds_,
+ ) = self.pipe.encode_prompt(
+ neg_content_prompt,
+ num_images_per_prompt=num_samples,
+ do_classifier_free_guidance=True,
+ negative_prompt=negative_prompt,
+ )
+ pooled_prompt_embeds_ *= neg_content_scale
+ else:
+ pooled_prompt_embeds_ = neg_content_emb
+ else:
+ pooled_prompt_embeds_ = None
+
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(
+ pil_image, content_prompt_embeds=pooled_prompt_embeds_
+ )
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
+ image_prompt_embeds = image_prompt_embeds.view(
+ bs_embed * num_samples, seq_len, -1
+ )
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(
+ 1, num_samples, 1
+ )
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(
+ bs_embed * num_samples, seq_len, -1
+ )
+ print("CLIP Strength is {}".format(clip_strength))
+ image_prompt_embeds *= clip_strength
+ uncond_image_prompt_embeds *= clip_strength
+
+ with torch.inference_mode():
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.pipe.encode_prompt(
+ prompt,
+ num_images_per_prompt=num_samples,
+ do_classifier_free_guidance=True,
+ negative_prompt=negative_prompt,
+ )
+ prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
+ negative_prompt_embeds = torch.cat(
+ [negative_prompt_embeds, uncond_image_prompt_embeds], dim=1
+ )
+
+ self.generator = get_generator(seed, self.device)
+
+ images = self.pipe(
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ num_inference_steps=num_inference_steps,
+ generator=self.generator,
+ **kwargs,
+ ).images
+
+ return images
+
+ def generate_parametric_edits(
+ self,
+ pil_image,
+ edit_mlps: dict[torch.nn.Module, float],
+ prompt=None,
+ negative_prompt=None,
+ scale=1.0,
+ num_samples=4,
+ seed=None,
+ num_inference_steps=30,
+ neg_content_emb=None,
+ neg_content_prompt=None,
+ neg_content_scale=1.0,
+ **kwargs,
+ ):
+ self.set_scale(scale)
+
+ num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
+
+ if prompt is None:
+ prompt = "best quality, high quality"
+ if negative_prompt is None:
+ negative_prompt = (
+ "monochrome, lowres, bad anatomy, worst quality, low quality"
+ )
+
+ if not isinstance(prompt, List):
+ prompt = [prompt] * num_prompts
+ if not isinstance(negative_prompt, List):
+ negative_prompt = [negative_prompt] * num_prompts
+
+ if neg_content_emb is None:
+ if neg_content_prompt is not None:
+ with torch.inference_mode():
+ (
+ prompt_embeds_, # torch.Size([1, 77, 2048])
+ negative_prompt_embeds_,
+ pooled_prompt_embeds_, # torch.Size([1, 1280])
+ negative_pooled_prompt_embeds_,
+ ) = self.pipe.encode_prompt(
+ neg_content_prompt,
+ num_images_per_prompt=num_samples,
+ do_classifier_free_guidance=True,
+ negative_prompt=negative_prompt,
+ )
+ pooled_prompt_embeds_ *= neg_content_scale
+ else:
+ pooled_prompt_embeds_ = neg_content_emb
+ else:
+ pooled_prompt_embeds_ = None
+ image_prompt_embeds, uncond_image_prompt_embeds = self.generate_image_edit_dir(
+ pil_image, content_prompt_embeds=pooled_prompt_embeds_, edit_mlps=edit_mlps
+ )
+
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
+ image_prompt_embeds = image_prompt_embeds.view(
+ bs_embed * num_samples, seq_len, -1
+ )
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(
+ 1, num_samples, 1
+ )
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(
+ bs_embed * num_samples, seq_len, -1
+ )
+
+ with torch.inference_mode():
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.pipe.encode_prompt(
+ prompt,
+ num_images_per_prompt=num_samples,
+ do_classifier_free_guidance=True,
+ negative_prompt=negative_prompt,
+ )
+ prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
+ negative_prompt_embeds = torch.cat(
+ [negative_prompt_embeds, uncond_image_prompt_embeds], dim=1
+ )
+
+ self.generator = get_generator(seed, self.device)
+
+ images = self.pipe(
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ num_inference_steps=num_inference_steps,
+ generator=self.generator,
+ **kwargs,
+ ).images
+
+ return images
+
+ def generate_edit(
+ self,
+ start_image,
+ pil_image,
+ pil_image2,
+ prompt=None,
+ negative_prompt=None,
+ scale=1.0,
+ num_samples=4,
+ seed=None,
+ num_inference_steps=30,
+ neg_content_emb=None,
+ neg_content_prompt=None,
+ neg_content_scale=1.0,
+ edit_strength=1.0,
+ **kwargs,
+ ):
+ self.set_scale(scale)
+
+ num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
+
+ if prompt is None:
+ prompt = "best quality, high quality"
+ if negative_prompt is None:
+ negative_prompt = (
+ "monochrome, lowres, bad anatomy, worst quality, low quality"
+ )
+
+ if not isinstance(prompt, List):
+ prompt = [prompt] * num_prompts
+ if not isinstance(negative_prompt, List):
+ negative_prompt = [negative_prompt] * num_prompts
+
+ if neg_content_emb is None:
+ if neg_content_prompt is not None:
+ with torch.inference_mode():
+ (
+ prompt_embeds_, # torch.Size([1, 77, 2048])
+ negative_prompt_embeds_,
+ pooled_prompt_embeds_, # torch.Size([1, 1280])
+ negative_pooled_prompt_embeds_,
+ ) = self.pipe.encode_prompt(
+ neg_content_prompt,
+ num_images_per_prompt=num_samples,
+ do_classifier_free_guidance=True,
+ negative_prompt=negative_prompt,
+ )
+ pooled_prompt_embeds_ *= neg_content_scale
+ else:
+ pooled_prompt_embeds_ = neg_content_emb
+ else:
+ pooled_prompt_embeds_ = None
+
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_edit_dir(
+ start_image,
+ pil_image,
+ pil_image2,
+ content_prompt_embeds=pooled_prompt_embeds_,
+ edit_strength=edit_strength,
+ )
+
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
+ image_prompt_embeds = image_prompt_embeds.view(
+ bs_embed * num_samples, seq_len, -1
+ )
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(
+ 1, num_samples, 1
+ )
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(
+ bs_embed * num_samples, seq_len, -1
+ )
+
+ with torch.inference_mode():
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.pipe.encode_prompt(
+ prompt,
+ num_images_per_prompt=num_samples,
+ do_classifier_free_guidance=True,
+ negative_prompt=negative_prompt,
+ )
+ prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
+ negative_prompt_embeds = torch.cat(
+ [negative_prompt_embeds, uncond_image_prompt_embeds], dim=1
+ )
+
+ self.generator = get_generator(seed, self.device)
+
+ images = self.pipe(
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ num_inference_steps=num_inference_steps,
+ generator=self.generator,
+ **kwargs,
+ ).images
+
+ return images
+
+
+class IPAdapterPlus(IPAdapter):
+ """IP-Adapter with fine-grained features"""
+
+ def init_proj(self):
+ image_proj_model = Resampler(
+ dim=self.pipe.unet.config.cross_attention_dim,
+ depth=4,
+ dim_head=64,
+ heads=12,
+ num_queries=self.num_tokens,
+ embedding_dim=self.image_encoder.config.hidden_size,
+ output_dim=self.pipe.unet.config.cross_attention_dim,
+ ff_mult=4,
+ ).to(self.device, dtype=torch.float16)
+ return image_proj_model
+
+ @torch.inference_mode()
+ def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
+ if isinstance(pil_image, Image.Image):
+ pil_image = [pil_image]
+ clip_image = self.clip_image_processor(
+ images=pil_image, return_tensors="pt"
+ ).pixel_values
+ clip_image = clip_image.to(self.device, dtype=torch.float16)
+ clip_image_embeds = self.image_encoder(
+ clip_image, output_hidden_states=True
+ ).hidden_states[-2]
+ image_prompt_embeds = self.image_proj_model(clip_image_embeds)
+ uncond_clip_image_embeds = self.image_encoder(
+ torch.zeros_like(clip_image), output_hidden_states=True
+ ).hidden_states[-2]
+ uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
+ return image_prompt_embeds, uncond_image_prompt_embeds
+
+
+class IPAdapterFull(IPAdapterPlus):
+ """IP-Adapter with full features"""
+
+ def init_proj(self):
+ image_proj_model = MLPProjModel(
+ cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
+ clip_embeddings_dim=self.image_encoder.config.hidden_size,
+ ).to(self.device, dtype=torch.float16)
+ return image_proj_model
+
+
+class IPAdapterPlusXL(IPAdapter):
+ """SDXL"""
+
+ def init_proj(self):
+ image_proj_model = Resampler(
+ dim=1280,
+ depth=4,
+ dim_head=64,
+ heads=20,
+ num_queries=self.num_tokens,
+ embedding_dim=self.image_encoder.config.hidden_size,
+ output_dim=self.pipe.unet.config.cross_attention_dim,
+ ff_mult=4,
+ ).to(self.device, dtype=torch.float16)
+ return image_proj_model
+
+ @torch.inference_mode()
+ def get_image_embeds(self, pil_image):
+ if isinstance(pil_image, Image.Image):
+ pil_image = [pil_image]
+ clip_image = self.clip_image_processor(
+ images=pil_image, return_tensors="pt"
+ ).pixel_values
+ clip_image = clip_image.to(self.device, dtype=torch.float16)
+ clip_image_embeds = self.image_encoder(
+ clip_image, output_hidden_states=True
+ ).hidden_states[-2]
+ image_prompt_embeds = self.image_proj_model(clip_image_embeds)
+ uncond_clip_image_embeds = self.image_encoder(
+ torch.zeros_like(clip_image), output_hidden_states=True
+ ).hidden_states[-2]
+ uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
+ return image_prompt_embeds, uncond_image_prompt_embeds
+
+ def generate(
+ self,
+ pil_image,
+ prompt=None,
+ negative_prompt=None,
+ scale=1.0,
+ num_samples=4,
+ seed=None,
+ num_inference_steps=30,
+ **kwargs,
+ ):
+ self.set_scale(scale)
+
+ num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
+
+ if prompt is None:
+ prompt = "best quality, high quality"
+ if negative_prompt is None:
+ negative_prompt = (
+ "monochrome, lowres, bad anatomy, worst quality, low quality"
+ )
+
+ if not isinstance(prompt, List):
+ prompt = [prompt] * num_prompts
+ if not isinstance(negative_prompt, List):
+ negative_prompt = [negative_prompt] * num_prompts
+
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(
+ pil_image
+ )
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
+ image_prompt_embeds = image_prompt_embeds.view(
+ bs_embed * num_samples, seq_len, -1
+ )
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(
+ 1, num_samples, 1
+ )
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(
+ bs_embed * num_samples, seq_len, -1
+ )
+
+ with torch.inference_mode():
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.pipe.encode_prompt(
+ prompt,
+ num_images_per_prompt=num_samples,
+ do_classifier_free_guidance=True,
+ negative_prompt=negative_prompt,
+ )
+ prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
+ negative_prompt_embeds = torch.cat(
+ [negative_prompt_embeds, uncond_image_prompt_embeds], dim=1
+ )
+
+ generator = get_generator(seed, self.device)
+
+ images = self.pipe(
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ num_inference_steps=num_inference_steps,
+ generator=generator,
+ **kwargs,
+ ).images
+
+ return images
diff --git a/ip_adapter_instantstyle/resampler.py b/ip_adapter_instantstyle/resampler.py
new file mode 100644
index 0000000..2426667
--- /dev/null
+++ b/ip_adapter_instantstyle/resampler.py
@@ -0,0 +1,158 @@
+# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
+# and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
+
+import math
+
+import torch
+import torch.nn as nn
+from einops import rearrange
+from einops.layers.torch import Rearrange
+
+
+# FFN
+def FeedForward(dim, mult=4):
+ inner_dim = int(dim * mult)
+ return nn.Sequential(
+ nn.LayerNorm(dim),
+ nn.Linear(dim, inner_dim, bias=False),
+ nn.GELU(),
+ nn.Linear(inner_dim, dim, bias=False),
+ )
+
+
+def reshape_tensor(x, heads):
+ bs, length, width = x.shape
+ # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
+ x = x.view(bs, length, heads, -1)
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
+ x = x.transpose(1, 2)
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
+ x = x.reshape(bs, heads, length, -1)
+ return x
+
+
+class PerceiverAttention(nn.Module):
+ def __init__(self, *, dim, dim_head=64, heads=8):
+ super().__init__()
+ self.scale = dim_head**-0.5
+ self.dim_head = dim_head
+ self.heads = heads
+ inner_dim = dim_head * heads
+
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
+
+ def forward(self, x, latents):
+ """
+ Args:
+ x (torch.Tensor): image features
+ shape (b, n1, D)
+ latent (torch.Tensor): latent features
+ shape (b, n2, D)
+ """
+ x = self.norm1(x)
+ latents = self.norm2(latents)
+
+ b, l, _ = latents.shape
+
+ q = self.to_q(latents)
+ kv_input = torch.cat((x, latents), dim=-2)
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
+
+ q = reshape_tensor(q, self.heads)
+ k = reshape_tensor(k, self.heads)
+ v = reshape_tensor(v, self.heads)
+
+ # attention
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
+ out = weight @ v
+
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
+
+ return self.to_out(out)
+
+
+class Resampler(nn.Module):
+ def __init__(
+ self,
+ dim=1024,
+ depth=8,
+ dim_head=64,
+ heads=16,
+ num_queries=8,
+ embedding_dim=768,
+ output_dim=1024,
+ ff_mult=4,
+ max_seq_len: int = 257, # CLIP tokens + CLS token
+ apply_pos_emb: bool = False,
+ num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence
+ ):
+ super().__init__()
+ self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None
+
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
+
+ self.proj_in = nn.Linear(embedding_dim, dim)
+
+ self.proj_out = nn.Linear(dim, output_dim)
+ self.norm_out = nn.LayerNorm(output_dim)
+
+ self.to_latents_from_mean_pooled_seq = (
+ nn.Sequential(
+ nn.LayerNorm(dim),
+ nn.Linear(dim, dim * num_latents_mean_pooled),
+ Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
+ )
+ if num_latents_mean_pooled > 0
+ else None
+ )
+
+ self.layers = nn.ModuleList([])
+ for _ in range(depth):
+ self.layers.append(
+ nn.ModuleList(
+ [
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
+ FeedForward(dim=dim, mult=ff_mult),
+ ]
+ )
+ )
+
+ def forward(self, x):
+ if self.pos_emb is not None:
+ n, device = x.shape[1], x.device
+ pos_emb = self.pos_emb(torch.arange(n, device=device))
+ x = x + pos_emb
+
+ latents = self.latents.repeat(x.size(0), 1, 1)
+
+ x = self.proj_in(x)
+
+ if self.to_latents_from_mean_pooled_seq:
+ meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
+ meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
+ latents = torch.cat((meanpooled_latents, latents), dim=-2)
+
+ for attn, ff in self.layers:
+ latents = attn(x, latents) + latents
+ latents = ff(latents) + latents
+
+ latents = self.proj_out(latents)
+ return self.norm_out(latents)
+
+
+def masked_mean(t, *, dim, mask=None):
+ if mask is None:
+ return t.mean(dim=dim)
+
+ denom = mask.sum(dim=dim, keepdim=True)
+ mask = rearrange(mask, "b n -> b n 1")
+ masked_t = t.masked_fill(~mask, 0.0)
+
+ return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)
diff --git a/ip_adapter_instantstyle/utils.py b/ip_adapter_instantstyle/utils.py
new file mode 100644
index 0000000..6a27335
--- /dev/null
+++ b/ip_adapter_instantstyle/utils.py
@@ -0,0 +1,93 @@
+import torch
+import torch.nn.functional as F
+import numpy as np
+from PIL import Image
+
+attn_maps = {}
+def hook_fn(name):
+ def forward_hook(module, input, output):
+ if hasattr(module.processor, "attn_map"):
+ attn_maps[name] = module.processor.attn_map
+ del module.processor.attn_map
+
+ return forward_hook
+
+def register_cross_attention_hook(unet):
+ for name, module in unet.named_modules():
+ if name.split('.')[-1].startswith('attn2'):
+ module.register_forward_hook(hook_fn(name))
+
+ return unet
+
+def upscale(attn_map, target_size):
+ attn_map = torch.mean(attn_map, dim=0)
+ attn_map = attn_map.permute(1,0)
+ temp_size = None
+
+ for i in range(0,5):
+ scale = 2 ** i
+ if ( target_size[0] // scale ) * ( target_size[1] // scale) == attn_map.shape[1]*64:
+ temp_size = (target_size[0]//(scale*8), target_size[1]//(scale*8))
+ break
+
+ assert temp_size is not None, "temp_size cannot is None"
+
+ attn_map = attn_map.view(attn_map.shape[0], *temp_size)
+
+ attn_map = F.interpolate(
+ attn_map.unsqueeze(0).to(dtype=torch.float32),
+ size=target_size,
+ mode='bilinear',
+ align_corners=False
+ )[0]
+
+ attn_map = torch.softmax(attn_map, dim=0)
+ return attn_map
+def get_net_attn_map(image_size, batch_size=2, instance_or_negative=False, detach=True):
+
+ idx = 0 if instance_or_negative else 1
+ net_attn_maps = []
+
+ for name, attn_map in attn_maps.items():
+ attn_map = attn_map.cpu() if detach else attn_map
+ attn_map = torch.chunk(attn_map, batch_size)[idx].squeeze()
+ attn_map = upscale(attn_map, image_size)
+ net_attn_maps.append(attn_map)
+
+ net_attn_maps = torch.mean(torch.stack(net_attn_maps,dim=0),dim=0)
+
+ return net_attn_maps
+
+def attnmaps2images(net_attn_maps):
+
+ #total_attn_scores = 0
+ images = []
+
+ for attn_map in net_attn_maps:
+ attn_map = attn_map.cpu().numpy()
+ #total_attn_scores += attn_map.mean().item()
+
+ normalized_attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255
+ normalized_attn_map = normalized_attn_map.astype(np.uint8)
+ #print("norm: ", normalized_attn_map.shape)
+ image = Image.fromarray(normalized_attn_map)
+
+ #image = fix_save_attn_map(attn_map)
+ images.append(image)
+
+ #print(total_attn_scores)
+ return images
+def is_torch2_available():
+ return hasattr(F, "scaled_dot_product_attention")
+
+def get_generator(seed, device):
+
+ if seed is not None:
+ if isinstance(seed, list):
+ generator = [torch.Generator(device).manual_seed(seed_item) for seed_item in seed]
+ else:
+ generator = torch.Generator(device).manual_seed(seed)
+ else:
+ generator = None
+
+ return generator
\ No newline at end of file
diff --git a/marble.py b/marble.py
new file mode 100644
index 0000000..097f8c9
--- /dev/null
+++ b/marble.py
@@ -0,0 +1,287 @@
+import os
+from typing import Dict
+
+import numpy as np
+import torch
+from diffusers import ControlNetModel, StableDiffusionXLControlNetInpaintPipeline
+from huggingface_hub import hf_hub_download, list_repo_files
+from PIL import Image, ImageChops, ImageEnhance
+from rembg import new_session, remove
+from transformers import DPTForDepthEstimation, DPTImageProcessor
+
+from ip_adapter_instantstyle import IPAdapterXL
+from ip_adapter_instantstyle.utils import register_cross_attention_hook
+from parametric_control_mlp import control_mlp
+
+file_dir = os.path.dirname(os.path.abspath(__file__))
+base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
+image_encoder_path = "models/image_encoder"
+ip_ckpt = "sdxl_models/ip-adapter_sdxl_vit-h.bin"
+controlnet_path = "diffusers/controlnet-depth-sdxl-1.0"
+
+# Cache for rembg sessions
+_session_cache = None
+CONTROL_MLPS = ["metallic", "roughness", "transparency", "glow"]
+
+
+def get_session():
+ global _session_cache
+ if _session_cache is None:
+ _session_cache = new_session()
+ return _session_cache
+
+
+def setup_control_mlps(
+ features: int = 1024, device: str = "cuda", dtype: torch.dtype = torch.float16
+) -> Dict[str, torch.nn.Module]:
+ ret = {}
+ for mlp in CONTROL_MLPS:
+ ret[mlp] = setup_control_mlp(mlp, features, device, dtype)
+ return ret
+
+
+def setup_control_mlp(
+ material_parameter: str,
+ features: int = 1024,
+ device: str = "cuda",
+ dtype: torch.dtype = torch.float16,
+):
+ net = control_mlp(features)
+ net.load_state_dict(
+ torch.load(os.path.join(file_dir, f"model_weights/{material_parameter}.pt"))
+ )
+ net.to(device, dtype=dtype)
+ net.eval()
+ return net
+
+
+def download_ip_adapter():
+ repo_id = "h94/IP-Adapter"
+ target_folders = ["models/", "sdxl_models/"]
+ local_dir = file_dir
+
+ # Check if folders exist and contain files
+ folders_exist = all(
+ os.path.exists(os.path.join(local_dir, folder)) for folder in target_folders
+ )
+
+ if folders_exist:
+ # Check if any of the target folders are empty
+ folders_empty = any(
+ len(os.listdir(os.path.join(local_dir, folder))) == 0
+ for folder in target_folders
+ )
+ if not folders_empty:
+ print("IP-Adapter files already downloaded. Skipping download.")
+ return
+
+ # List all files in the repo
+ all_files = list_repo_files(repo_id)
+
+ # Filter for files in the desired folders
+ filtered_files = [
+ f for f in all_files if any(f.startswith(folder) for folder in target_folders)
+ ]
+
+ # Download each file
+ for file_path in filtered_files:
+ local_path = hf_hub_download(
+ repo_id=repo_id,
+ filename=file_path,
+ local_dir=local_dir,
+ local_dir_use_symlinks=False,
+ )
+ print(f"Downloaded: {file_path} to {local_path}")
+
+
+def setup_pipeline(
+ device: str = "cuda",
+ dtype: torch.dtype = torch.float16,
+):
+ download_ip_adapter()
+
+ cur_block = ("up", 0, 1)
+
+ controlnet = ControlNetModel.from_pretrained(
+ controlnet_path, variant="fp16", use_safetensors=True, torch_dtype=dtype
+ ).to(device)
+
+ pipe = StableDiffusionXLControlNetInpaintPipeline.from_pretrained(
+ base_model_path,
+ controlnet=controlnet,
+ use_safetensors=True,
+ torch_dtype=dtype,
+ add_watermarker=False,
+ ).to(device)
+
+ pipe.unet = register_cross_attention_hook(pipe.unet)
+
+ block_name = (
+ cur_block[0]
+ + "_blocks."
+ + str(cur_block[1])
+ + ".attentions."
+ + str(cur_block[2])
+ )
+
+ print("Testing block {}".format(block_name))
+
+ return IPAdapterXL(
+ pipe,
+ os.path.join(file_dir, image_encoder_path),
+ os.path.join(file_dir, ip_ckpt),
+ device,
+ target_blocks=[block_name],
+ )
+
+
+def get_dpt_model(device: str = "cuda", dtype: torch.dtype = torch.float16):
+ image_processor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
+ model = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas")
+ model.to(device, dtype=dtype)
+ model.eval()
+ return model, image_processor
+
+
+def run_dpt_depth(
+ image: Image.Image, model, processor, device: str = "cuda"
+) -> Image.Image:
+ """Run DPT depth estimation on an image."""
+ # Prepare image
+ inputs = processor(images=image, return_tensors="pt").to(device, dtype=model.dtype)
+
+ # Get depth prediction
+ with torch.no_grad():
+ depth_map = model(**inputs).predicted_depth
+
+ # Now normalize to 0-1 range
+ depth_map = (depth_map - depth_map.min()) / (
+ depth_map.max() - depth_map.min() + 1e-7
+ )
+ depth_map = depth_map.clip(0, 1) * 255
+
+ # Convert to PIL Image
+ depth_map = depth_map.squeeze().cpu().numpy().astype(np.uint8)
+ return Image.fromarray(depth_map).resize((1024, 1024))
+
+
+def prepare_mask(image: Image.Image) -> Image.Image:
+ """Prepare mask from image using rembg."""
+ rm_bg = remove(image, session=get_session())
+ target_mask = (
+ rm_bg.convert("RGB")
+ .point(lambda x: 0 if x < 1 else 255)
+ .convert("L")
+ .convert("RGB")
+ )
+ return target_mask.resize((1024, 1024))
+
+
+def prepare_init_image(image: Image.Image, mask: Image.Image) -> Image.Image:
+ """Prepare initial image for inpainting."""
+
+ # Create grayscale version
+ gray_image = image.convert("L").convert("RGB")
+ gray_image = ImageEnhance.Brightness(gray_image).enhance(1.0)
+
+ # Create mask inversions
+ invert_mask = ImageChops.invert(mask)
+
+ # Combine images
+ grayscale_img = ImageChops.darker(gray_image, mask)
+ img_black_mask = ImageChops.darker(image, invert_mask)
+ init_img = ImageChops.lighter(img_black_mask, grayscale_img)
+
+ return init_img.resize((1024, 1024))
+
+
+def run_parametric_control(
+ ip_model,
+ target_image: Image.Image,
+ edit_mlps: dict[torch.nn.Module, float],
+ texture_image: Image.Image = None,
+ num_inference_steps: int = 30,
+ seed: int = 42,
+ depth_map: Image.Image = None,
+ mask: Image.Image = None,
+) -> Image.Image:
+ """Run parametric control with metallic and roughness adjustments."""
+ # Get depth map
+ if depth_map is None:
+ model, processor = get_dpt_model()
+ depth_map = run_dpt_depth(target_image, model, processor)
+ else:
+ depth_map = depth_map.resize((1024, 1024))
+
+ # Prepare mask and init image
+ if mask is None:
+ mask = prepare_mask(target_image)
+ else:
+ mask = mask.resize((1024, 1024))
+
+ if texture_image is None:
+ texture_image = target_image
+
+ init_img = prepare_init_image(target_image, mask)
+
+ # Generate edit
+ images = ip_model.generate_parametric_edits(
+ texture_image,
+ image=init_img,
+ control_image=depth_map,
+ mask_image=mask,
+ controlnet_conditioning_scale=1.0,
+ num_samples=1,
+ num_inference_steps=num_inference_steps,
+ seed=seed,
+ edit_mlps=edit_mlps,
+ strength=1.0,
+ )
+
+ return images[0]
+
+
+def run_blend(
+ ip_model,
+ target_image: Image.Image,
+ texture_image1: Image.Image,
+ texture_image2: Image.Image,
+ edit_strength: float = 0.0,
+ num_inference_steps: int = 20,
+ seed: int = 1,
+ depth_map: Image.Image = None,
+ mask: Image.Image = None,
+) -> Image.Image:
+ """Run blending between two texture images."""
+ # Get depth map
+ if depth_map is None:
+ model, processor = get_dpt_model()
+ depth_map = run_dpt_depth(target_image, model, processor)
+ else:
+ depth_map = depth_map.resize((1024, 1024))
+
+ # Prepare mask and init image
+ if mask is None:
+ mask = prepare_mask(target_image)
+ else:
+ mask = mask.resize((1024, 1024))
+ init_img = prepare_init_image(target_image, mask)
+
+ # Generate edit
+ images = ip_model.generate_edit(
+ start_image=texture_image1,
+ pil_image=texture_image1,
+ pil_image2=texture_image2,
+ image=init_img,
+ control_image=depth_map,
+ mask_image=mask,
+ controlnet_conditioning_scale=1.0,
+ num_samples=1,
+ num_inference_steps=num_inference_steps,
+ seed=seed,
+ edit_strength=edit_strength,
+ clip_strength=1.0,
+ strength=1.0,
+ )
+
+ return images[0]
diff --git a/model_weights/glow.pt b/model_weights/glow.pt
new file mode 100644
index 0000000..3b7fc51
Binary files /dev/null and b/model_weights/glow.pt differ
diff --git a/model_weights/metallic.pt b/model_weights/metallic.pt
new file mode 100644
index 0000000..e5cb99f
Binary files /dev/null and b/model_weights/metallic.pt differ
diff --git a/model_weights/roughness.pt b/model_weights/roughness.pt
new file mode 100644
index 0000000..9041586
Binary files /dev/null and b/model_weights/roughness.pt differ
diff --git a/model_weights/transparency.pt b/model_weights/transparency.pt
new file mode 100644
index 0000000..dea3c3b
Binary files /dev/null and b/model_weights/transparency.pt differ
diff --git a/parametric_control.ipynb b/parametric_control.ipynb
new file mode 100644
index 0000000..25a3f81
--- /dev/null
+++ b/parametric_control.ipynb
@@ -0,0 +1,303 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "3b1976c7-5491-4124-94c1-dd99b5fcd016",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from diffusers import StableDiffusionXLControlNetInpaintPipeline, ControlNetModel\n",
+ "from rembg import remove\n",
+ "from PIL import Image, ImageFilter\n",
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "from ip_adapter_instantstyle import IPAdapterXL\n",
+ "from ip_adapter_instantstyle.utils import register_cross_attention_hook, get_net_attn_map, attnmaps2images\n",
+ "from PIL import Image, ImageChops\n",
+ "import numpy as np\n",
+ "import glob\n",
+ "import os\n",
+ "\n",
+ "\n",
+ "\"\"\"Import DPT for Depth Model\"\"\"\n",
+ "import DPT.util.io\n",
+ "\n",
+ "from torchvision.transforms import Compose\n",
+ "\n",
+ "from DPT.dpt.models import DPTDepthModel\n",
+ "from DPT.dpt.midas_net import MidasNet_large\n",
+ "from DPT.dpt.transforms import Resize, NormalizeImage, PrepareForNet\n",
+ "\n",
+ "from parametric_control_mlp import control_mlp"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "2306a004-ac0e-4714-9868-ae77b699fef1",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Metallic MLP\n",
+ "mlp = control_mlp(1024)\n",
+ "mlp.load_state_dict(torch.load('model_weights/metallic.pt'))\n",
+ "mlp = mlp.to(\"cuda\", dtype=torch.float16)\n",
+ "mlp.eval()\n",
+ "\n",
+ "# Roughness MLP\n",
+ "mlp2 = control_mlp(1024)\n",
+ "mlp2.load_state_dict(torch.load('model_weights/roughness.pt'))\n",
+ "mlp2 = mlp2.to(\"cuda\", dtype=torch.float16)\n",
+ "mlp2.eval()\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "f14b02bd-42af-4139-8c8e-6a5366be7733",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "\"\"\"Get MARBLE Model ready\"\"\"\n",
+ "base_model_path = \"stabilityai/stable-diffusion-xl-base-1.0\"\n",
+ "image_encoder_path = \"models/image_encoder\"\n",
+ "ip_ckpt = \"sdxl_models/ip-adapter_sdxl_vit-h.bin\"\n",
+ "controlnet_path = \"diffusers/controlnet-depth-sdxl-1.0\"\n",
+ "device = \"cuda\"\n",
+ "\n",
+ "\"\"\"Load IP-Adapter + Instant Style + Editing MLP\"\"\"\n",
+ "cur_block = ('up', 0, 1)\n",
+ "torch.cuda.empty_cache()\n",
+ "\n",
+ "controlnet = ControlNetModel.from_pretrained(controlnet_path, variant=\"fp16\", use_safetensors=True, torch_dtype=torch.float16).to(device)\n",
+ "pipe = StableDiffusionXLControlNetInpaintPipeline.from_pretrained(\n",
+ " base_model_path,\n",
+ " controlnet=controlnet,\n",
+ " use_safetensors=True,\n",
+ " torch_dtype=torch.float16,\n",
+ " add_watermarker=False,\n",
+ ").to(device)\n",
+ "\n",
+ "pipe.unet = register_cross_attention_hook(pipe.unet)\n",
+ "block_name = cur_block[0] + \"_blocks.\" + str(cur_block[1])+ \".attentions.\" + str(cur_block[2])\n",
+ "print(\"Testing block {}\".format(block_name))\n",
+ "ip_model = IPAdapterXL(pipe, image_encoder_path, ip_ckpt, device, target_blocks=[block_name], edit_mlp=mlp, edit_mlp2=mlp2)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "7a744532-efde-4949-bffc-9b8d73c88aa5",
+ "metadata": {
+ "collapsed": true,
+ "jupyter": {
+ "outputs_hidden": true
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "\"\"\"\n",
+ "Get Depth Model Ready\n",
+ "\"\"\"\n",
+ "import cv2\n",
+ "model_path = \"DPT/dpt_weights/dpt_hybrid-midas-501f0c75.pt\"\n",
+ "net_w = net_h = 384\n",
+ "model = DPTDepthModel(\n",
+ " path=model_path,\n",
+ " backbone=\"vitb_rn50_384\",\n",
+ " non_negative=True,\n",
+ " enable_attention_hooks=False,\n",
+ ")\n",
+ "normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])\n",
+ "\n",
+ "transform = Compose(\n",
+ " [\n",
+ " Resize(\n",
+ " net_w,\n",
+ " net_h,\n",
+ " resize_target=None,\n",
+ " keep_aspect_ratio=True,\n",
+ " ensure_multiple_of=32,\n",
+ " resize_method=\"minimal\",\n",
+ " image_interpolation_method=cv2.INTER_CUBIC,\n",
+ " ),\n",
+ " normalization,\n",
+ " PrepareForNet(),\n",
+ " ]\n",
+ " )\n",
+ "\n",
+ "model.eval()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "94851e56-fe11-4aa4-8574-e12d0433508d",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Edit strengths for metallic. More negative = more metallic, best results between range -20 to 20\n",
+ "edit_strengths1 = [-20, 0, 20]\n",
+ "\n",
+ "# Edit strengths for roughness. More positive = more roughness, best results between range -1 to 1\n",
+ "edit_strengths2 = [-1, 0, 1]\n",
+ "\n",
+ "\n",
+ "all_images = []\n",
+ "for edit_strength1 in edit_strengths1:\n",
+ " for edit_strength2 in edit_strengths2:\n",
+ " \n",
+ " target_image_path = 'input_images/context_image/toy_car.png'\n",
+ " target_image = Image.open(target_image_path).convert('RGB')\n",
+ " \n",
+ " \"\"\"\n",
+ " Compute depth map from input_image\n",
+ " \"\"\"\n",
+ "\n",
+ " img = np.array(target_image)\n",
+ "\n",
+ " img_input = transform({\"image\": img})[\"image\"]\n",
+ "\n",
+ " # compute\n",
+ " with torch.no_grad():\n",
+ " sample = torch.from_numpy(img_input).unsqueeze(0)\n",
+ "\n",
+ " # if optimize == True and device == torch.device(\"cuda\"):\n",
+ " # sample = sample.to(memory_format=torch.channels_last)\n",
+ " # sample = sample.half()\n",
+ "\n",
+ " prediction = model.forward(sample)\n",
+ " prediction = (\n",
+ " torch.nn.functional.interpolate(\n",
+ " prediction.unsqueeze(1),\n",
+ " size=img.shape[:2],\n",
+ " mode=\"bicubic\",\n",
+ " align_corners=False,\n",
+ " )\n",
+ " .squeeze()\n",
+ " .cpu()\n",
+ " .numpy()\n",
+ " )\n",
+ "\n",
+ " depth_min = prediction.min()\n",
+ " depth_max = prediction.max()\n",
+ " bits = 2\n",
+ " max_val = (2 ** (8 * bits)) - 1\n",
+ "\n",
+ " if depth_max - depth_min > np.finfo(\"float\").eps:\n",
+ " out = max_val * (prediction - depth_min) / (depth_max - depth_min)\n",
+ " else:\n",
+ " out = np.zeros(prediction.shape, dtype=depth.dtype)\n",
+ "\n",
+ " out = (out / 256).astype('uint8')\n",
+ " depth_map = Image.fromarray(out).resize((1024, 1024))\n",
+ " \n",
+ " \n",
+ " \"\"\"Preprocessing data for MARBLE\"\"\"\n",
+ " rm_bg = remove(target_image)\n",
+ " target_mask = rm_bg.convert(\"RGB\").point(lambda x: 0 if x < 1 else 255).convert('L').convert('RGB')# Convert mask to grayscale\n",
+ "\n",
+ " noise = np.random.randint(0, 256, target_image.size + (3,), dtype=np.uint8)\n",
+ " noise_image = Image.fromarray(noise)\n",
+ " mask_target_img = ImageChops.lighter(target_image, target_mask)\n",
+ " invert_target_mask = ImageChops.invert(target_mask)\n",
+ "\n",
+ " from PIL import ImageEnhance\n",
+ " gray_target_image = target_image.convert('L').convert('RGB')\n",
+ " gray_target_image = ImageEnhance.Brightness(gray_target_image)\n",
+ "\n",
+ " # Adjust brightness\n",
+ " # The factor 1.0 means original brightness, greater than 1.0 makes the image brighter. Adjust this if the image is too dim\n",
+ " factor = 1.0 # Try adjusting this to get the desired brightness\n",
+ "\n",
+ " gray_target_image = gray_target_image.enhance(factor)\n",
+ " grayscale_img = ImageChops.darker(gray_target_image, target_mask)\n",
+ " img_black_mask = ImageChops.darker(target_image, invert_target_mask)\n",
+ " grayscale_init_img = ImageChops.lighter(img_black_mask, grayscale_img)\n",
+ " init_img = grayscale_init_img\n",
+ " \n",
+ " # The texture to be applied onto car\n",
+ " ip_image = Image.open('input_images/texture/metal_bowl.png')\n",
+ "\n",
+ "\n",
+ " init_img = target_image\n",
+ " init_img = init_img.resize((1024,1024))\n",
+ " mask = target_mask.resize((1024, 1024))\n",
+ "\n",
+ "\n",
+ " cur_seed = 42\n",
+ " images = ip_model.generate_edit_mlp_lr_multi(pil_image = ip_image, image=init_img, control_image=depth_map, \\\n",
+ " mask_image=mask, controlnet_conditioning_scale=1., num_samples=1, \\\n",
+ " num_inference_steps=30, seed=cur_seed, edit_strength=edit_strength1, \\\n",
+ " edit_strength2=edit_strength2, strength=1)\n",
+ " all_images.append(images[0].resize((512,512)))\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "d61d15b6-efe6-49fb-a33c-589de9781a08",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import matplotlib.pyplot as plt\n",
+ "def show_image_grid(images, x, y, figsize=(10, 10)):\n",
+ " \"\"\"\n",
+ " Display a list of images in an x by y grid.\n",
+ "\n",
+ " Args:\n",
+ " images (list of np.array): List of images (e.g., numpy arrays).\n",
+ " x (int): Number of columns.\n",
+ " y (int): Number of rows.\n",
+ " figsize (tuple): Size of the figure.\n",
+ " \"\"\"\n",
+ " fig, axes = plt.subplots(y, x, figsize=figsize)\n",
+ " axes = axes.flatten()\n",
+ "\n",
+ " for i in range(x * y):\n",
+ " ax = axes[i]\n",
+ " if i < len(images):\n",
+ " ax.imshow(images[i])\n",
+ " ax.axis('off')\n",
+ " else:\n",
+ " ax.axis('off') # Hide unused subplots\n",
+ "\n",
+ " plt.tight_layout()\n",
+ " plt.show()\n",
+ "show_image_grid(all_images, len(edit_strengths1), len(edit_strengths2))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "10b21813-7d4c-4f65-964b-c01ea1b21a38",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "venv",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.17"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/parametric_control_mlp.py b/parametric_control_mlp.py
new file mode 100644
index 0000000..1f31b05
--- /dev/null
+++ b/parametric_control_mlp.py
@@ -0,0 +1,32 @@
+import torch
+import torch.nn as nn
+
+
+class control_mlp(nn.Module):
+ def __init__(self, embedding_size):
+ super(control_mlp, self).__init__()
+
+ self.fc1 = nn.Linear(embedding_size, 1024)
+ self.fc2 = nn.Linear(1024, 2048)
+ self.relu = nn.ReLU()
+
+ self.edit_strength_fc1 = nn.Linear(1, 128)
+ self.edit_strength_fc2 = nn.Linear(128, 2)
+
+ def forward(self, x, edit_strength):
+ x = self.relu(self.fc1(x))
+ x = self.fc2(x)
+
+ edit_strength = self.relu(self.edit_strength_fc1(edit_strength.unsqueeze(1)))
+ edit_strength = self.edit_strength_fc2(edit_strength)
+
+ edit_strength1, edit_strength2 = edit_strength[:, 0], edit_strength[:, 1]
+ # print(edit_strength1.shape)
+ # exit()
+
+ output = (
+ edit_strength1.unsqueeze(1) * x[:, :1024]
+ + edit_strength2.unsqueeze(1) * x[:, 1024:]
+ )
+
+ return output
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..5bc438a
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,11 @@
+diffusers==0.30
+rembg[gpu]
+einops==0.7.0
+transformers==4.27.4
+opencv-python==4.8.1.78
+gradio==5.29.0
+accelerate==0.26.1
+timm==0.6.12
+torch==2.3.0
+torchvision==0.18.0
+huggingface_hub==0.30.2
diff --git a/static/teaser.jpg b/static/teaser.jpg
new file mode 100644
index 0000000..4ea8ed6
Binary files /dev/null and b/static/teaser.jpg differ
diff --git a/try_blend.ipynb b/try_blend.ipynb
new file mode 100644
index 0000000..ded3f80
--- /dev/null
+++ b/try_blend.ipynb
@@ -0,0 +1,58 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from marble import run_blend, setup_pipeline\n",
+ "from PIL import Image"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "ip_adapter = setup_pipeline()\n",
+ "target_image = Image.open(\"input_images/context_image/beetle.png\")\n",
+ "ip_image = Image.open(\"input_images/texture/low_roughness.png\")\n",
+ "ip_image2 = Image.open(\"input_images/texture/high_roughness.png\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "for edit_strength in [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]:\n",
+ " blended_image = run_blend(ip_adapter, target_image, ip_image, ip_image2, edit_strength=edit_strength)\n",
+ " blended_image.show()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "venv",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.16"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}