diff --git a/.github/workflows/windows_release_nightly_pytorch.yml b/.github/workflows/windows_release_nightly_pytorch.yml
index 672a7f22068..fa24a985c7f 100644
--- a/.github/workflows/windows_release_nightly_pytorch.yml
+++ b/.github/workflows/windows_release_nightly_pytorch.yml
@@ -7,7 +7,7 @@ on:
description: 'cuda version'
required: true
type: string
- default: "121"
+ default: "124"
python_minor:
description: 'python minor version'
@@ -19,7 +19,7 @@ on:
description: 'python patch version'
required: true
type: string
- default: "2"
+ default: "3"
# push:
# branches:
# - master
@@ -49,7 +49,7 @@ jobs:
echo 'import site' >> ./python3${{ inputs.python_minor }}._pth
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
./python.exe get-pip.py
- python -m pip wheel torch torchvision torchaudio mpmath==1.3.0 --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir
+ python -m pip wheel torch torchvision torchaudio mpmath==1.3.0 numpy==1.26.4 --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir
ls ../temp_wheel_dir
./python.exe -s -m pip install --pre ../temp_wheel_dir/*
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
diff --git a/README.md b/README.md
index ba1e844b3f9..de0c062ae56 100644
--- a/README.md
+++ b/README.md
@@ -41,29 +41,32 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
## Shortcuts
-| Keybind | Explanation |
-|---------------------------|--------------------------------------------------------------------------------------------------------------------|
-| Ctrl + Enter | Queue up current graph for generation |
-| Ctrl + Shift + Enter | Queue up current graph as first for generation |
-| Ctrl + Z/Ctrl + Y | Undo/Redo |
-| Ctrl + S | Save workflow |
-| Ctrl + O | Load workflow |
-| Ctrl + A | Select all nodes |
-| Alt + C | Collapse/uncollapse selected nodes |
-| Ctrl + M | Mute/unmute selected nodes |
-| Ctrl + B | Bypass selected nodes (acts like the node was removed from the graph and the wires reconnected through) |
-| Delete/Backspace | Delete selected nodes |
-| Ctrl + Delete/Backspace | Delete the current graph |
-| Space | Move the canvas around when held and moving the cursor |
-| Ctrl/Shift + Click | Add clicked node to selection |
-| Ctrl + C/Ctrl + V | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) |
-| Ctrl + C/Ctrl + Shift + V | Copy and paste selected nodes (maintaining connections from outputs of unselected nodes to inputs of pasted nodes) |
-| Shift + Drag | Move multiple selected nodes at the same time |
-| Ctrl + D | Load default graph |
-| Q | Toggle visibility of the queue |
-| H | Toggle visibility of history |
-| R | Refresh graph |
-| Double-Click LMB | Open node quick search palette |
+| Keybind | Explanation |
+|------------------------------------|--------------------------------------------------------------------------------------------------------------------|
+| Ctrl + Enter | Queue up current graph for generation |
+| Ctrl + Shift + Enter | Queue up current graph as first for generation |
+| Ctrl + Z/Ctrl + Y | Undo/Redo |
+| Ctrl + S | Save workflow |
+| Ctrl + O | Load workflow |
+| Ctrl + A | Select all nodes |
+| Alt + C | Collapse/uncollapse selected nodes |
+| Ctrl + M | Mute/unmute selected nodes |
+| Ctrl + B | Bypass selected nodes (acts like the node was removed from the graph and the wires reconnected through) |
+| Delete/Backspace | Delete selected nodes |
+| Ctrl + Backspace | Delete the current graph |
+| Space | Move the canvas around when held and moving the cursor |
+| Ctrl/Shift + Click | Add clicked node to selection |
+| Ctrl + C/Ctrl + V | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) |
+| Ctrl + C/Ctrl + Shift + V | Copy and paste selected nodes (maintaining connections from outputs of unselected nodes to inputs of pasted nodes) |
+| Shift + Drag | Move multiple selected nodes at the same time |
+| Ctrl + D | Load default graph |
+| Alt + `+` | Canvas Zoom in |
+| Alt + `-` | Canvas Zoom out |
+| Ctrl + Shift + LMB + Vertical drag | Canvas Zoom in/out |
+| Q | Toggle visibility of the queue |
+| H | Toggle visibility of history |
+| R | Refresh graph |
+| Double-Click LMB | Open node quick search palette |
Ctrl can also be replaced with Cmd instead for macOS users
@@ -99,11 +102,11 @@ Put your VAE in: models/vae
### AMD GPUs (Linux only)
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
-```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.7```
+```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.0```
This is the command to install the nightly with ROCm 6.0 which might have some performance improvements:
-```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.0```
+```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.1```
### NVIDIA
@@ -113,7 +116,7 @@ Nvidia users should install stable pytorch using this command:
This is the command to install pytorch nightly instead which might have performance improvements:
-```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121```
+```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu124```
#### Troubleshooting
@@ -133,7 +136,16 @@ After this you should have everything installed and can proceed to running Comfy
### Others:
-#### [Intel Arc](https://github.com/comfyanonymous/ComfyUI/discussions/476)
+#### Intel GPUs
+
+Intel GPU support is available for all Intel GPUs supported by Intel's Extension for Pytorch (IPEX) with the support requirements listed in the [Installation](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu) page. Choose your platform and method of install and follow the instructions. The steps are as follows:
+
+1. Start by installing the drivers or kernel listed or newer in the Installation page of IPEX linked above for Windows and Linux if needed.
+1. Follow the instructions to install [Intel's oneAPI Basekit](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit-download.html) for your platform.
+1. Install the packages for IPEX using the instructions provided in the Installation page for your platform.
+1. Follow the [ComfyUI manual installation](#manual-install-windows-linux) instructions for Windows and Linux and run ComfyUI normally as described above after everything is installed.
+
+Additional discussion and help can be found [here](https://github.com/comfyanonymous/ComfyUI/discussions/476).
#### Apple Mac silicon
@@ -195,20 +207,20 @@ To use a textual inversion concepts/embeddings in a text prompt put them in the
```embedding:embedding_filename.pt```
-## How to increase generation speed?
-
-Make sure you use the regular loaders/Load Checkpoint node to load checkpoints. It will auto pick the right settings depending on your GPU.
-
-You can set this command line setting to disable the upcasting to fp32 in some cross attention operations which will increase your speed. Note that this will very likely give you black images on SD2.x models. If you use xformers or pytorch attention this option does not do anything.
-
-```--dont-upcast-attention```
-
## How to show high-quality previews?
Use ```--preview-method auto``` to enable previews.
The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth) (for SD1.x and SD2.x) and [taesdxl_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesdxl_decoder.pth) (for SDXL) models and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI to enable high-quality previews.
+## How to use TLS/SSL?
+Generate a self-signed certificate (not appropriate for shared/production use) and key by running the command: `openssl req -x509 -newkey rsa:4096 -keyout key.pem -out cert.pem -sha256 -days 3650 -nodes -subj "/C=XX/ST=StateName/L=CityName/O=CompanyName/OU=CompanySectionName/CN=CommonNameOrHostname"`
+
+Use `--tls-keyfile key.pem --tls-certfile cert.pem` to enable TLS/SSL, the app will now be accessible with `https://...` instead of `http://...`.
+
+> Note: Windows users can use [alexisrolland/docker-openssl](https://github.com/alexisrolland/docker-openssl) or one of the [3rd party binary distributions](https://wiki.openssl.org/index.php/Binaries) to run the command example above.
+
If you use a container, note that the volume mount `-v` can be a relative path so `... -v ".\:/openssl-certs" ...` would create the key & cert files in the current directory of your command prompt or powershell terminal.
+
## Support and dev channel
[Matrix space: #comfyui_space:matrix.org](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) (it's like discord but open source).
diff --git a/comfy/cldm/cldm.py b/comfy/cldm/cldm.py
index 5eee5a51c95..28076dd9251 100644
--- a/comfy/cldm/cldm.py
+++ b/comfy/cldm/cldm.py
@@ -52,6 +52,7 @@ def __init__(
adm_in_channels=None,
transformer_depth_middle=None,
transformer_depth_output=None,
+ attn_precision=None,
device=None,
operations=comfy.ops.disable_weight_init,
**kwargs,
@@ -202,7 +203,7 @@ def __init__(
SpatialTransformer(
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
- use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
+ use_checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=self.dtype, device=device, operations=operations
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
@@ -262,7 +263,7 @@ def __init__(
mid_block += [SpatialTransformer( # always uses a self-attn
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
- use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
+ use_checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=self.dtype, device=device, operations=operations
),
ResBlock(
ch,
diff --git a/comfy/cli_args.py b/comfy/cli_args.py
index d117989f6bc..7def320191a 100644
--- a/comfy/cli_args.py
+++ b/comfy/cli_args.py
@@ -35,6 +35,8 @@ def __call__(self, parser, namespace, values, option_string=None):
parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)")
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
+parser.add_argument("--tls-keyfile", type=str, help="Path to TLS (SSL) key file. Enables TLS, makes app accessible at https://... requires --tls-certfile to function")
+parser.add_argument("--tls-certfile", type=str, help="Path to TLS (SSL) certificate file. Enables TLS, makes app accessible at https://... requires --tls-keyfile to function")
parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.")
parser.add_argument("--max-upload-size", type=float, default=100, help="Set the maximum upload size in MB.")
@@ -50,7 +52,6 @@ def __call__(self, parser, namespace, values, option_string=None):
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.")
-parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.")
fp_group = parser.add_mutually_exclusive_group()
fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
@@ -96,6 +97,11 @@ class LatentPreviewMethod(enum.Enum):
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
+upcast = parser.add_mutually_exclusive_group()
+upcast.add_argument("--force-upcast-attention", action="store_true", help="Force enable attention upcasting, please report if it fixes black images.")
+upcast.add_argument("--dont-upcast-attention", action="store_true", help="Disable all upcasting of attention. Should be unnecessary except for debugging.")
+
+
vram_group = parser.add_mutually_exclusive_group()
vram_group.add_argument("--gpu-only", action="store_true", help="Store and run everything (text encoders/CLIP models, etc... on the GPU).")
vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py
index 7af016829d3..f9b281894d4 100644
--- a/comfy/k_diffusion/sampling.py
+++ b/comfy/k_diffusion/sampling.py
@@ -527,6 +527,9 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None,
@torch.no_grad()
def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
"""DPM-Solver++ (stochastic)."""
+ if len(sigmas) <= 1:
+ return x
+
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
seed = extra_args.get("seed", None)
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
@@ -595,6 +598,8 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No
@torch.no_grad()
def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
"""DPM-Solver++(2M) SDE."""
+ if len(sigmas) <= 1:
+ return x
if solver_type not in {'heun', 'midpoint'}:
raise ValueError('solver_type must be \'heun\' or \'midpoint\'')
@@ -642,6 +647,9 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""DPM-Solver++(3M) SDE."""
+ if len(sigmas) <= 1:
+ return x
+
seed = extra_args.get("seed", None)
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
@@ -690,18 +698,27 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
@torch.no_grad()
def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
+ if len(sigmas) <= 1:
+ return x
+
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
@torch.no_grad()
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
+ if len(sigmas) <= 1:
+ return x
+
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
@torch.no_grad()
def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
+ if len(sigmas) <= 1:
+ return x
+
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
return sample_dpmpp_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=r)
diff --git a/comfy/ldm/cascade/stage_b.py b/comfy/ldm/cascade/stage_b.py
index 6d2c2223143..7c3d8feabd8 100644
--- a/comfy/ldm/cascade/stage_b.py
+++ b/comfy/ldm/cascade/stage_b.py
@@ -17,7 +17,6 @@
"""
import math
-import numpy as np
import torch
from torch import nn
from .common import AttnBlock, LayerNorm2d_op, ResBlock, FeedForwardBlock, TimestepBlock
diff --git a/comfy/ldm/cascade/stage_c.py b/comfy/ldm/cascade/stage_c.py
index 67c1e52b635..c85da1f01c1 100644
--- a/comfy/ldm/cascade/stage_c.py
+++ b/comfy/ldm/cascade/stage_c.py
@@ -18,7 +18,6 @@
import torch
from torch import nn
-import numpy as np
import math
from .common import AttnBlock, LayerNorm2d_op, ResBlock, FeedForwardBlock, TimestepBlock
# from .controlnet import ControlNetDeliverer
diff --git a/comfy/ldm/models/autoencoder.py b/comfy/ldm/models/autoencoder.py
index b91ec3249fb..f5f4de28830 100644
--- a/comfy/ldm/models/autoencoder.py
+++ b/comfy/ldm/models/autoencoder.py
@@ -1,6 +1,4 @@
import torch
-# import pytorch_lightning as pl
-import torch.nn.functional as F
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Tuple, Union
diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py
index bf9c65ff0b8..0e39a71456f 100644
--- a/comfy/ldm/modules/attention.py
+++ b/comfy/ldm/modules/attention.py
@@ -3,10 +3,10 @@
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
-from typing import Optional, Any
+from typing import Optional
import logging
-from .diffusionmodules.util import checkpoint, AlphaBlender, timestep_embedding
+from .diffusionmodules.util import AlphaBlender, timestep_embedding
from .sub_quadratic_attention import efficient_dot_product_attention
from comfy import model_management
@@ -19,13 +19,14 @@
import comfy.ops
ops = comfy.ops.disable_weight_init
-# CrossAttn precision handling
-if args.dont_upcast_attention:
- logging.info("disabling upcasting of attention")
- _ATTN_PRECISION = "fp16"
-else:
- _ATTN_PRECISION = "fp32"
+FORCE_UPCAST_ATTENTION_DTYPE = model_management.force_upcast_attention_dtype()
+def get_attn_precision(attn_precision):
+ if args.dont_upcast_attention:
+ return None
+ if FORCE_UPCAST_ATTENTION_DTYPE is not None:
+ return FORCE_UPCAST_ATTENTION_DTYPE
+ return attn_precision
def exists(val):
return val is not None
@@ -85,7 +86,9 @@ def forward(self, x):
def Normalize(in_channels, dtype=None, device=None):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
-def attention_basic(q, k, v, heads, mask=None):
+def attention_basic(q, k, v, heads, mask=None, attn_precision=None):
+ attn_precision = get_attn_precision(attn_precision)
+
b, _, dim_head = q.shape
dim_head //= heads
scale = dim_head ** -0.5
@@ -101,7 +104,7 @@ def attention_basic(q, k, v, heads, mask=None):
)
# force cast to fp32 to avoid overflowing
- if _ATTN_PRECISION =="fp32":
+ if attn_precision == torch.float32:
sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
else:
sim = einsum('b i d, b j d -> b i j', q, k) * scale
@@ -135,7 +138,9 @@ def attention_basic(q, k, v, heads, mask=None):
return out
-def attention_sub_quad(query, key, value, heads, mask=None):
+def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None):
+ attn_precision = get_attn_precision(attn_precision)
+
b, _, dim_head = query.shape
dim_head //= heads
@@ -146,7 +151,7 @@ def attention_sub_quad(query, key, value, heads, mask=None):
key = key.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 3, 1).reshape(b * heads, dim_head, -1)
dtype = query.dtype
- upcast_attention = _ATTN_PRECISION =="fp32" and query.dtype != torch.float32
+ upcast_attention = attn_precision == torch.float32 and query.dtype != torch.float32
if upcast_attention:
bytes_per_token = torch.finfo(torch.float32).bits//8
else:
@@ -195,7 +200,9 @@ def attention_sub_quad(query, key, value, heads, mask=None):
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
return hidden_states
-def attention_split(q, k, v, heads, mask=None):
+def attention_split(q, k, v, heads, mask=None, attn_precision=None):
+ attn_precision = get_attn_precision(attn_precision)
+
b, _, dim_head = q.shape
dim_head //= heads
scale = dim_head ** -0.5
@@ -214,10 +221,12 @@ def attention_split(q, k, v, heads, mask=None):
mem_free_total = model_management.get_free_memory(q.device)
- if _ATTN_PRECISION =="fp32":
+ if attn_precision == torch.float32:
element_size = 4
+ upcast = True
else:
element_size = q.element_size()
+ upcast = False
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * element_size
@@ -256,7 +265,7 @@ def attention_split(q, k, v, heads, mask=None):
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
- if _ATTN_PRECISION =="fp32":
+ if upcast:
with torch.autocast(enabled=False, device_type = 'cuda'):
s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * scale
else:
@@ -304,24 +313,30 @@ def attention_split(q, k, v, heads, mask=None):
BROKEN_XFORMERS = False
try:
x_vers = xformers.__version__
- #I think 0.0.23 is also broken (q with bs bigger than 65535 gives CUDA error)
- BROKEN_XFORMERS = x_vers.startswith("0.0.21") or x_vers.startswith("0.0.22") or x_vers.startswith("0.0.23")
+ # XFormers bug confirmed on all versions from 0.0.21 to 0.0.26 (q with bs bigger than 65535 gives CUDA error)
+ BROKEN_XFORMERS = x_vers.startswith("0.0.2") and not x_vers.startswith("0.0.20")
except:
pass
-def attention_xformers(q, k, v, heads, mask=None):
+def attention_xformers(q, k, v, heads, mask=None, attn_precision=None):
b, _, dim_head = q.shape
dim_head //= heads
+
+ disabled_xformers = False
+
if BROKEN_XFORMERS:
if b * heads > 65535:
- return attention_pytorch(q, k, v, heads, mask)
+ disabled_xformers = True
+
+ if not disabled_xformers:
+ if torch.jit.is_tracing() or torch.jit.is_scripting():
+ disabled_xformers = True
+
+ if disabled_xformers:
+ return attention_pytorch(q, k, v, heads, mask)
q, k, v = map(
- lambda t: t.unsqueeze(3)
- .reshape(b, -1, heads, dim_head)
- .permute(0, 2, 1, 3)
- .reshape(b * heads, -1, dim_head)
- .contiguous(),
+ lambda t: t.reshape(b, -1, heads, dim_head),
(q, k, v),
)
@@ -334,14 +349,11 @@ def attention_xformers(q, k, v, heads, mask=None):
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
out = (
- out.unsqueeze(0)
- .reshape(b, heads, -1, dim_head)
- .permute(0, 2, 1, 3)
- .reshape(b, -1, heads * dim_head)
+ out.reshape(b, -1, heads * dim_head)
)
return out
-def attention_pytorch(q, k, v, heads, mask=None):
+def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None):
b, _, dim_head = q.shape
dim_head //= heads
q, k, v = map(
@@ -391,10 +403,11 @@ def optimized_attention_for_device(device, mask=False, small_input=False):
class CrossAttention(nn.Module):
- def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=ops):
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=ops):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
+ self.attn_precision = attn_precision
self.heads = heads
self.dim_head = dim_head
@@ -416,15 +429,15 @@ def forward(self, x, context=None, value=None, mask=None):
v = self.to_v(context)
if mask is None:
- out = optimized_attention(q, k, v, self.heads)
+ out = optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision)
else:
- out = optimized_attention_masked(q, k, v, self.heads, mask)
+ out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision)
return self.to_out(out)
class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, ff_in=False, inner_dim=None,
- disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, dtype=None, device=None, operations=ops):
+ disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, attn_precision=None, dtype=None, device=None, operations=ops):
super().__init__()
self.ff_in = ff_in or inner_dim is not None
@@ -432,6 +445,7 @@ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=
inner_dim = dim
self.is_res = inner_dim == dim
+ self.attn_precision = attn_precision
if self.ff_in:
self.norm_in = operations.LayerNorm(dim, dtype=dtype, device=device)
@@ -439,7 +453,7 @@ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=
self.disable_self_attn = disable_self_attn
self.attn1 = CrossAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout,
- context_dim=context_dim if self.disable_self_attn else None, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn
+ context_dim=context_dim if self.disable_self_attn else None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn
self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
if disable_temporal_crossattention:
@@ -453,20 +467,16 @@ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=
context_dim_attn2 = context_dim
self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2,
- heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
+ heads=n_heads, dim_head=d_head, dropout=dropout, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
self.norm2 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
self.norm3 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
- self.checkpoint = checkpoint
self.n_heads = n_heads
self.d_head = d_head
self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa
def forward(self, x, context=None, transformer_options={}):
- return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
-
- def _forward(self, x, context=None, transformer_options={}):
extra_options = {}
block = transformer_options.get("block", None)
block_index = transformer_options.get("block_index", 0)
@@ -483,6 +493,7 @@ def _forward(self, x, context=None, transformer_options={}):
extra_options["n_heads"] = self.n_heads
extra_options["dim_head"] = self.d_head
+ extra_options["attn_precision"] = self.attn_precision
if self.ff_in:
x_skip = x
@@ -593,7 +604,7 @@ class SpatialTransformer(nn.Module):
def __init__(self, in_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None,
disable_self_attn=False, use_linear=False,
- use_checkpoint=True, dtype=None, device=None, operations=ops):
+ use_checkpoint=True, attn_precision=None, dtype=None, device=None, operations=ops):
super().__init__()
if exists(context_dim) and not isinstance(context_dim, list):
context_dim = [context_dim] * depth
@@ -611,7 +622,7 @@ def __init__(self, in_channels, n_heads, d_head,
self.transformer_blocks = nn.ModuleList(
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
- disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, dtype=dtype, device=device, operations=operations)
+ disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations)
for d in range(depth)]
)
if not use_linear:
@@ -632,7 +643,7 @@ def forward(self, x, context=None, transformer_options={}):
x = self.norm(x)
if not self.use_linear:
x = self.proj_in(x)
- x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
+ x = x.movedim(1, 3).flatten(1, 2).contiguous()
if self.use_linear:
x = self.proj_in(x)
for i, block in enumerate(self.transformer_blocks):
@@ -640,7 +651,7 @@ def forward(self, x, context=None, transformer_options={}):
x = block(x, context=context[i], transformer_options=transformer_options)
if self.use_linear:
x = self.proj_out(x)
- x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
+ x = x.reshape(x.shape[0], h, w, x.shape[-1]).movedim(3, 1).contiguous()
if not self.use_linear:
x = self.proj_out(x)
return x + x_in
@@ -667,6 +678,7 @@ def __init__(
disable_self_attn=False,
disable_temporal_crossattention=False,
max_time_embed_period: int = 10000,
+ attn_precision=None,
dtype=None, device=None, operations=ops
):
super().__init__(
@@ -679,6 +691,7 @@ def __init__(
context_dim=context_dim,
use_linear=use_linear,
disable_self_attn=disable_self_attn,
+ attn_precision=attn_precision,
dtype=dtype, device=device, operations=operations
)
self.time_depth = time_depth
@@ -708,6 +721,7 @@ def __init__(
inner_dim=time_mix_inner_dim,
disable_self_attn=disable_self_attn,
disable_temporal_crossattention=disable_temporal_crossattention,
+ attn_precision=attn_precision,
dtype=dtype, device=device, operations=operations
)
for _ in range(self.depth)
diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py
index d24ce34ae45..22849c1613a 100644
--- a/comfy/ldm/modules/diffusionmodules/model.py
+++ b/comfy/ldm/modules/diffusionmodules/model.py
@@ -3,7 +3,6 @@
import torch
import torch.nn as nn
import numpy as np
-from einops import rearrange
from typing import Optional, Any
import logging
diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py
index d782eff31d9..ba8fc2c4a06 100644
--- a/comfy/ldm/modules/diffusionmodules/openaimodel.py
+++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py
@@ -258,7 +258,7 @@ def _forward(self, x, emb):
else:
if emb_out is not None:
if self.exchange_temb_dims:
- emb_out = rearrange(emb_out, "b t c ... -> b c t ...")
+ emb_out = emb_out.movedim(1, 2)
h = h + emb_out
h = self.out_layers(h)
return self.skip_connection(x) + h
@@ -431,6 +431,7 @@ def __init__(
video_kernel_size=None,
disable_temporal_crossattention=False,
max_ddpm_temb_period=10000,
+ attn_precision=None,
device=None,
operations=ops,
):
@@ -550,13 +551,14 @@ def get_attention_layer(
disable_self_attn=disable_self_attn,
disable_temporal_crossattention=disable_temporal_crossattention,
max_time_embed_period=max_ddpm_temb_period,
+ attn_precision=attn_precision,
dtype=self.dtype, device=device, operations=operations
)
else:
return SpatialTransformer(
ch, num_heads, dim_head, depth=depth, context_dim=context_dim,
disable_self_attn=disable_self_attn, use_linear=use_linear_in_transformer,
- use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
+ use_checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=self.dtype, device=device, operations=operations
)
def get_resblock(
diff --git a/comfy/model_base.py b/comfy/model_base.py
index 8c89adf5e53..841598b7327 100644
--- a/comfy/model_base.py
+++ b/comfy/model_base.py
@@ -162,7 +162,7 @@ def extra_conds(self, **kwargs):
c_concat = kwargs.get("noise_concat", None)
if c_concat is not None:
- out['c_concat'] = comfy.conds.CONDNoiseShape(data)
+ out['c_concat'] = comfy.conds.CONDNoiseShape(c_concat)
return out
diff --git a/comfy/model_management.py b/comfy/model_management.py
index a02bafb6b8a..eb2eab2ca10 100644
--- a/comfy/model_management.py
+++ b/comfy/model_management.py
@@ -2,9 +2,9 @@
import logging
from enum import Enum
from comfy.cli_args import args
-import comfy.utils
import torch
import sys
+import platform
class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram
@@ -97,7 +97,7 @@ def get_torch_device():
return torch.device("cpu")
else:
if is_intel_xpu():
- return torch.device("xpu")
+ return torch.device("xpu", torch.xpu.current_device())
else:
return torch.device(torch.cuda.current_device())
@@ -116,8 +116,8 @@ def get_total_memory(dev=None, torch_total_too=False):
elif is_intel_xpu():
stats = torch.xpu.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
- mem_total = torch.xpu.get_device_properties(dev).total_memory
mem_total_torch = mem_reserved
+ mem_total = torch.xpu.get_device_properties(dev).total_memory
else:
stats = torch.cuda.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
@@ -133,10 +133,11 @@ def get_total_memory(dev=None, torch_total_too=False):
total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
total_ram = psutil.virtual_memory().total / (1024 * 1024)
logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
-if not args.normalvram and not args.cpu:
- if lowvram_available and total_vram <= 4096:
- logging.warning("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram")
- set_vram_to = VRAMState.LOW_VRAM
+
+try:
+ logging.info("pytorch version: {}".format(torch.version.__version__))
+except:
+ pass
try:
OOM_EXCEPTION = torch.cuda.OutOfMemoryError
@@ -299,7 +300,7 @@ def model_memory_required(self, device):
else:
return self.model_memory()
- def model_load(self, lowvram_model_memory=0):
+ def model_load(self, lowvram_model_memory=0, force_patch_weights=False):
patch_model_to = self.device
self.model.model_patches_to(self.device)
@@ -309,7 +310,7 @@ def model_load(self, lowvram_model_memory=0):
try:
if lowvram_model_memory > 0 and load_weights:
- self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory)
+ self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
else:
self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights)
except Exception as e:
@@ -318,11 +319,16 @@ def model_load(self, lowvram_model_memory=0):
raise e
if is_intel_xpu() and not args.disable_ipex_optimize:
- self.real_model = torch.xpu.optimize(self.real_model.eval(), inplace=True, auto_kernel_selection=True, graph_mode=True)
+ self.real_model = ipex.optimize(self.real_model.eval(), graph_mode=True, concat_linear=True)
self.weights_loaded = True
return self.real_model
+ def should_reload_model(self, force_patch_weights=False):
+ if force_patch_weights and self.model.lowvram_patch_counter > 0:
+ return True
+ return False
+
def model_unload(self, unpatch_weights=True):
self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights)
self.model.model_patches_to(self.model.offload_device)
@@ -393,7 +399,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
if mem_free_torch > mem_free_total * 0.25:
soft_empty_cache()
-def load_models_gpu(models, memory_required=0):
+def load_models_gpu(models, memory_required=0, force_patch_weights=False):
global vram_state
inference_memory = minimum_inference_memory()
@@ -405,12 +411,22 @@ def load_models_gpu(models, memory_required=0):
models_already_loaded = []
for x in models:
loaded_model = LoadedModel(x)
+ loaded = None
- if loaded_model in current_loaded_models:
- index = current_loaded_models.index(loaded_model)
- current_loaded_models.insert(0, current_loaded_models.pop(index))
- models_already_loaded.append(loaded_model)
- else:
+ try:
+ loaded_model_index = current_loaded_models.index(loaded_model)
+ except:
+ loaded_model_index = None
+
+ if loaded_model_index is not None:
+ loaded = current_loaded_models[loaded_model_index]
+ if loaded.should_reload_model(force_patch_weights=force_patch_weights): #TODO: cleanup this model reload logic
+ current_loaded_models.pop(loaded_model_index).model_unload(unpatch_weights=True)
+ loaded = None
+ else:
+ models_already_loaded.append(loaded)
+
+ if loaded is None:
if hasattr(x, "model"):
logging.info(f"Requested to load {x.model.__class__.__name__}")
models_to_load.append(loaded_model)
@@ -450,15 +466,13 @@ def load_models_gpu(models, memory_required=0):
model_size = loaded_model.model_memory_required(torch_dev)
current_free_mem = get_free_memory(torch_dev)
lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 ))
- if model_size > (current_free_mem - inference_memory): #only switch to lowvram if really necessary
- vram_set_state = VRAMState.LOW_VRAM
- else:
+ if model_size <= (current_free_mem - inference_memory): #only switch to lowvram if really necessary
lowvram_model_memory = 0
if vram_set_state == VRAMState.NO_VRAM:
lowvram_model_memory = 64 * 1024 * 1024
- cur_loaded_model = loaded_model.model_load(lowvram_model_memory)
+ cur_loaded_model = loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
current_loaded_models.insert(0, loaded_model)
return
@@ -566,8 +580,6 @@ def text_encoder_device():
if args.gpu_only:
return get_torch_device()
elif vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.NORMAL_VRAM:
- if is_intel_xpu():
- return torch.device("cpu")
if should_use_fp16(prioritize_performance=False):
return get_torch_device()
else:
@@ -631,8 +643,18 @@ def supports_dtype(device, dtype): #TODO
def device_supports_non_blocking(device):
if is_device_mps(device):
return False #pytorch bug? mps doesn't support non blocking
+ if args.deterministic: #TODO: figure out why deterministic breaks non blocking from gpu to cpu (previews)
+ return False
+ if directml_enabled:
+ return False
+ return True
+
+def device_should_use_non_blocking(device):
+ if not device_supports_non_blocking(device):
+ return False
return False
- # return True #TODO: figure out why this causes issues
+ # return True #TODO: figure out why this causes memory issues on Nvidia and possibly others
+
def cast_to_device(tensor, device, dtype, copy=False):
device_supports_cast = False
@@ -644,7 +666,7 @@ def cast_to_device(tensor, device, dtype, copy=False):
elif is_intel_xpu():
device_supports_cast = True
- non_blocking = device_supports_non_blocking(device)
+ non_blocking = device_should_use_non_blocking(device)
if device_supports_cast:
if copy:
@@ -687,6 +709,18 @@ def pytorch_attention_flash_attention():
return True
return False
+def force_upcast_attention_dtype():
+ upcast = args.force_upcast_attention
+ try:
+ if platform.mac_ver()[0] in ['14.5']: #black image bug on OSX Sonoma 14.5
+ upcast = True
+ except:
+ pass
+ if upcast:
+ return torch.float32
+ else:
+ return None
+
def get_free_memory(dev=None, torch_free_too=False):
global directml_enabled
if dev is None:
@@ -702,10 +736,10 @@ def get_free_memory(dev=None, torch_free_too=False):
elif is_intel_xpu():
stats = torch.xpu.memory_stats(dev)
mem_active = stats['active_bytes.all.current']
- mem_allocated = stats['allocated_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_torch = mem_reserved - mem_active
- mem_free_total = torch.xpu.get_device_properties(dev).total_memory - mem_allocated
+ mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved
+ mem_free_total = mem_free_xpu + mem_free_torch
else:
stats = torch.cuda.memory_stats(dev)
mem_active = stats['active_bytes.all.current']
@@ -859,6 +893,7 @@ def unload_all_models():
def resolve_lowvram_weight(weight, model, key): #TODO: remove
+ print("WARNING: The comfy.model_management.resolve_lowvram_weight function will be removed soon, please stop using it.")
return weight
#TODO: might be cleaner to put this somewhere else
diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py
index cf51c4ad86f..2e746d8a9e1 100644
--- a/comfy/model_patcher.py
+++ b/comfy/model_patcher.py
@@ -6,17 +6,29 @@
import comfy.utils
import comfy.model_management
+from comfy.types import UnetWrapperFunction
-def apply_weight_decompose(dora_scale, weight):
+
+def weight_decompose(dora_scale, weight, lora_diff, alpha, strength):
+ dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32)
+ lora_diff *= alpha
+ weight_calc = weight + lora_diff.type(weight.dtype)
weight_norm = (
- weight.transpose(0, 1)
- .reshape(weight.shape[1], -1)
+ weight_calc.transpose(0, 1)
+ .reshape(weight_calc.shape[1], -1)
.norm(dim=1, keepdim=True)
- .reshape(weight.shape[1], *[1] * (weight.dim() - 1))
+ .reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
.transpose(0, 1)
)
- return weight * (dora_scale / weight_norm)
+ weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
+ if strength != 1.0:
+ weight_calc -= weight
+ weight += strength * (weight_calc)
+ else:
+ weight[:] = weight_calc
+ return weight
+
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
to = model_options["transformer_options"].copy()
@@ -58,6 +70,7 @@ def __init__(self, model, load_device, offload_device, size=0, current_device=No
self.weight_inplace_update = weight_inplace_update
self.model_lowvram = False
+ self.lowvram_patch_counter = 0
self.patches_uuid = uuid.uuid4()
def model_size(self):
@@ -116,7 +129,7 @@ def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_op
if disable_cfg1_optimization:
self.model_options["disable_cfg1_optimization"] = True
- def set_model_unet_function_wrapper(self, unet_wrapper_function):
+ def set_model_unet_function_wrapper(self, unet_wrapper_function: UnetWrapperFunction):
self.model_options["model_function_wrapper"] = unet_wrapper_function
def set_model_denoise_mask_function(self, denoise_mask_function):
@@ -272,7 +285,7 @@ def patch_model(self, device_to=None, patch_weights=True):
return self.model
- def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0):
+ def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False):
self.patch_model(device_to, patch_weights=False)
logging.info("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024)))
@@ -284,6 +297,7 @@ def __call__(self, weight):
return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key)
mem_counter = 0
+ patch_counter = 0
for n, m in self.model.named_modules():
lowvram_weight = False
if hasattr(m, "comfy_cast_weights"):
@@ -296,9 +310,17 @@ def __call__(self, weight):
if lowvram_weight:
if weight_key in self.patches:
- m.weight_function = LowVramPatch(weight_key, self)
+ if force_patch_weights:
+ self.patch_weight_to_device(weight_key)
+ else:
+ m.weight_function = LowVramPatch(weight_key, self)
+ patch_counter += 1
if bias_key in self.patches:
- m.bias_function = LowVramPatch(bias_key, self)
+ if force_patch_weights:
+ self.patch_weight_to_device(bias_key)
+ else:
+ m.bias_function = LowVramPatch(bias_key, self)
+ patch_counter += 1
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
@@ -311,11 +333,12 @@ def __call__(self, weight):
logging.debug("lowvram: loaded module regularly {}".format(m))
self.model_lowvram = True
+ self.lowvram_patch_counter = patch_counter
return self.model
def calculate_weight(self, patches, weight, key):
for p in patches:
- alpha = p[0]
+ strength = p[0]
v = p[1]
strength_model = p[2]
@@ -333,26 +356,31 @@ def calculate_weight(self, patches, weight, key):
if patch_type == "diff":
w1 = v[0]
- if alpha != 0.0:
+ if strength != 0.0:
if w1.shape != weight.shape:
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
else:
- weight += alpha * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype)
+ weight += strength * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype)
elif patch_type == "lora": #lora/locon
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32)
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32)
dora_scale = v[4]
if v[2] is not None:
- alpha *= v[2] / mat2.shape[0]
+ alpha = v[2] / mat2.shape[0]
+ else:
+ alpha = 1.0
+
if v[3] is not None:
#locon mid weights, hopefully the math is fine because I didn't properly test it
mat3 = comfy.model_management.cast_to_device(v[3], weight.device, torch.float32)
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
try:
- weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
+ lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
if dora_scale is not None:
- weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
+ weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength)
+ else:
+ weight += ((strength * alpha) * lora_diff).type(weight.dtype)
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "lokr":
@@ -389,19 +417,26 @@ def calculate_weight(self, patches, weight, key):
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
if v[2] is not None and dim is not None:
- alpha *= v[2] / dim
+ alpha = v[2] / dim
+ else:
+ alpha = 1.0
try:
- weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
+ lora_diff = torch.kron(w1, w2).reshape(weight.shape)
if dora_scale is not None:
- weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
+ weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength)
+ else:
+ weight += ((strength * alpha) * lora_diff).type(weight.dtype)
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "loha":
w1a = v[0]
w1b = v[1]
if v[2] is not None:
- alpha *= v[2] / w1b.shape[0]
+ alpha = v[2] / w1b.shape[0]
+ else:
+ alpha = 1.0
+
w2a = v[3]
w2b = v[4]
dora_scale = v[7]
@@ -424,14 +459,18 @@ def calculate_weight(self, patches, weight, key):
comfy.model_management.cast_to_device(w2b, weight.device, torch.float32))
try:
- weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
+ lora_diff = (m1 * m2).reshape(weight.shape)
if dora_scale is not None:
- weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
+ weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength)
+ else:
+ weight += ((strength * alpha) * lora_diff).type(weight.dtype)
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "glora":
if v[4] is not None:
- alpha *= v[4] / v[0].shape[0]
+ alpha = v[4] / v[0].shape[0]
+ else:
+ alpha = 1.0
dora_scale = v[5]
@@ -441,9 +480,11 @@ def calculate_weight(self, patches, weight, key):
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32)
try:
- weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype)
+ lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape)
if dora_scale is not None:
- weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
+ weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength)
+ else:
+ weight += ((strength * alpha) * lora_diff).type(weight.dtype)
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
else:
@@ -462,6 +503,7 @@ def unpatch_model(self, device_to=None, unpatch_weights=True):
m.bias_function = None
self.model_lowvram = False
+ self.lowvram_patch_counter = 0
keys = list(self.backup.keys())
diff --git a/comfy/ops.py b/comfy/ops.py
index eb6507682d1..7ebb3dd2fe9 100644
--- a/comfy/ops.py
+++ b/comfy/ops.py
@@ -21,7 +21,7 @@
def cast_bias_weight(s, input):
bias = None
- non_blocking = comfy.model_management.device_supports_non_blocking(input.device)
+ non_blocking = comfy.model_management.device_should_use_non_blocking(input.device)
if s.bias is not None:
bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
if s.bias_function is not None:
diff --git a/comfy/samplers.py b/comfy/samplers.py
index 415a35cc3c5..29962a916b6 100644
--- a/comfy/samplers.py
+++ b/comfy/samplers.py
@@ -34,7 +34,7 @@ def get_area_and_mult(conds, x_in, timestep_in):
mask = conds['mask']
assert(mask.shape[1] == x_in.shape[2])
assert(mask.shape[2] == x_in.shape[3])
- mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength
+ mask = mask[:input_x.shape[0],area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength
mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1)
else:
mask = torch.ones_like(input_x)
@@ -539,6 +539,9 @@ def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=N
def ksampler(sampler_name, extra_options={}, inpaint_options={}):
if sampler_name == "dpm_fast":
def dpm_fast_function(model, noise, sigmas, extra_args, callback, disable):
+ if len(sigmas) <= 1:
+ return noise
+
sigma_min = sigmas[-1]
if sigma_min == 0:
sigma_min = sigmas[-2]
@@ -547,6 +550,9 @@ def dpm_fast_function(model, noise, sigmas, extra_args, callback, disable):
sampler_function = dpm_fast_function
elif sampler_name == "dpm_adaptive":
def dpm_adaptive_function(model, noise, sigmas, extra_args, callback, disable, **extra_options):
+ if len(sigmas) <= 1:
+ return noise
+
sigma_min = sigmas[-1]
if sigma_min == 0:
sigma_min = sigmas[-2]
diff --git a/comfy/sd.py b/comfy/sd.py
index 8c471dd329d..7dea3f6bfe2 100644
--- a/comfy/sd.py
+++ b/comfy/sd.py
@@ -14,7 +14,6 @@
from . import clip_vision
from . import gligen
from . import diffusers_convert
-from . import model_base
from . import model_detection
from . import sd1_clip
@@ -210,16 +209,26 @@ def __init__(self, sd=None, device=None, config=None, dtype=None):
self.first_stage_model = StageC_coder()
self.downscale_ratio = 32
self.latent_channels = 16
- else:
+ elif "decoder.conv_in.weight" in sd:
#default SD1.x/SD2.x VAE parameters
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
- if 'encoder.down.2.downsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE
+ if 'encoder.down.2.downsample.conv.weight' not in sd and 'decoder.up.3.upsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE
ddconfig['ch_mult'] = [1, 2, 4]
self.downscale_ratio = 4
self.upscale_ratio = 4
- self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4)
+ self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
+ if 'quant_conv.weight' in sd:
+ self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4)
+ else:
+ self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
+ encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig},
+ decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig})
+ else:
+ logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
+ self.first_stage_model = None
+ return
else:
self.first_stage_model = AutoencoderKL(**(config['params']))
self.first_stage_model = self.first_stage_model.eval()
@@ -438,6 +447,8 @@ def load_gligen(ckpt_path):
return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None):
+ logging.warning("Warning: The load checkpoint with config function is deprecated and will eventually be removed, please use the other one.")
+ model, clip, vae, _ = load_checkpoint_guess_config(ckpt_path, output_vae=output_vae, output_clip=output_clip, output_clipvision=False, embedding_directory=embedding_directory, output_model=True)
#TODO: this function is a mess and should be removed eventually
if config is None:
with open(config_path, 'r') as stream:
@@ -445,81 +456,20 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
model_config_params = config['model']['params']
clip_config = model_config_params['cond_stage_config']
scale_factor = model_config_params['scale_factor']
- vae_config = model_config_params['first_stage_config']
-
- fp16 = False
- if "unet_config" in model_config_params:
- if "params" in model_config_params["unet_config"]:
- unet_config = model_config_params["unet_config"]["params"]
- if "use_fp16" in unet_config:
- fp16 = unet_config.pop("use_fp16")
- if fp16:
- unet_config["dtype"] = torch.float16
-
- noise_aug_config = None
- if "noise_aug_config" in model_config_params:
- noise_aug_config = model_config_params["noise_aug_config"]
-
- model_type = model_base.ModelType.EPS
if "parameterization" in model_config_params:
if model_config_params["parameterization"] == "v":
- model_type = model_base.ModelType.V_PREDICTION
-
- clip = None
- vae = None
-
- class WeightsLoader(torch.nn.Module):
- pass
-
- if state_dict is None:
- state_dict = comfy.utils.load_torch_file(ckpt_path)
-
- class EmptyClass:
- pass
-
- model_config = comfy.supported_models_base.BASE({})
-
- from . import latent_formats
- model_config.latent_format = latent_formats.SD15(scale_factor=scale_factor)
- model_config.unet_config = model_detection.convert_config(unet_config)
+ m = model.clone()
+ class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingDiscrete, comfy.model_sampling.V_PREDICTION):
+ pass
+ m.add_object_patch("model_sampling", ModelSamplingAdvanced(model.model.model_config))
+ model = m
- if config['model']["target"].endswith("ImageEmbeddingConditionedLatentDiffusion"):
- model = model_base.SD21UNCLIP(model_config, noise_aug_config["params"], model_type=model_type)
- else:
- model = model_base.BaseModel(model_config, model_type=model_type)
-
- if config['model']["target"].endswith("LatentInpaintDiffusion"):
- model.set_inpaint()
-
- if fp16:
- model = model.half()
-
- offload_device = model_management.unet_offload_device()
- model = model.to(offload_device)
- model.load_model_weights(state_dict, "model.diffusion_model.")
-
- if output_vae:
- vae_sd = comfy.utils.state_dict_prefix_replace(state_dict, {"first_stage_model.": ""}, filter_keys=True)
- vae = VAE(sd=vae_sd, config=vae_config)
-
- if output_clip:
- w = WeightsLoader()
- clip_target = EmptyClass()
- clip_target.params = clip_config.get("params", {})
- if clip_config["target"].endswith("FrozenOpenCLIPEmbedder"):
- clip_target.clip = sd2_clip.SD2ClipModel
- clip_target.tokenizer = sd2_clip.SD2Tokenizer
- clip = CLIP(clip_target, embedding_directory=embedding_directory)
- w.cond_stage_model = clip.cond_stage_model.clip_h
- elif clip_config["target"].endswith("FrozenCLIPEmbedder"):
- clip_target.clip = sd1_clip.SD1ClipModel
- clip_target.tokenizer = sd1_clip.SD1Tokenizer
- clip = CLIP(clip_target, embedding_directory=embedding_directory)
- w.cond_stage_model = clip.cond_stage_model.clip_l
- load_clip_weights(w, state_dict)
+ layer_idx = clip_config.get("params", {}).get("layer_idx", None)
+ if layer_idx is not None:
+ clip.clip_layer(layer_idx)
- return (comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae)
+ return (model, clip, vae)
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True):
sd = comfy.utils.load_torch_file(ckpt_path)
@@ -565,7 +515,11 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
clip = CLIP(clip_target, embedding_directory=embedding_directory)
m, u = clip.load_sd(clip_sd, full_model=True)
if len(m) > 0:
- logging.warning("clip missing: {}".format(m))
+ m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
+ if len(m_filter) > 0:
+ logging.warning("clip missing: {}".format(m))
+ else:
+ logging.debug("clip missing: {}".format(m))
if len(u) > 0:
logging.debug("clip unexpected {}:".format(u))
@@ -637,7 +591,7 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m
load_models.append(clip.load_model())
clip_sd = clip.get_sd()
- model_management.load_models_gpu(load_models)
+ model_management.load_models_gpu(load_models, force_patch_weights=True)
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
sd = model.model.state_dict_for_saving(clip_sd, vae.get_sd(), clip_vision_sd)
for k in extra_keys:
diff --git a/comfy/sd2_clip.py b/comfy/sd2_clip.py
index 9c878d54ab6..d14b445441b 100644
--- a/comfy/sd2_clip.py
+++ b/comfy/sd2_clip.py
@@ -1,5 +1,4 @@
from comfy import sd1_clip
-import torch
import os
class SD2ClipHModel(sd1_clip.SDClipModel):
diff --git a/comfy/supported_models.py b/comfy/supported_models.py
index b3b69e05b10..6ca32e8eece 100644
--- a/comfy/supported_models.py
+++ b/comfy/supported_models.py
@@ -65,6 +65,12 @@ class SD20(supported_models_base.BASE):
"use_temporal_attention": False,
}
+ unet_extra_config = {
+ "num_heads": -1,
+ "num_head_channels": 64,
+ "attn_precision": torch.float32,
+ }
+
latent_format = latent_formats.SD15
def model_type(self, state_dict, prefix=""):
@@ -276,6 +282,12 @@ class SVD_img2vid(supported_models_base.BASE):
"use_temporal_resblock": True
}
+ unet_extra_config = {
+ "num_heads": -1,
+ "num_head_channels": 64,
+ "attn_precision": torch.float32,
+ }
+
clip_vision_prefix = "conditioner.embedders.0.open_clip.model.visual."
latent_format = latent_formats.SD15
diff --git a/comfy/types.py b/comfy/types.py
new file mode 100644
index 00000000000..70cf4b158e5
--- /dev/null
+++ b/comfy/types.py
@@ -0,0 +1,32 @@
+import torch
+from typing import Callable, Protocol, TypedDict, Optional, List
+
+
+class UnetApplyFunction(Protocol):
+ """Function signature protocol on comfy.model_base.BaseModel.apply_model"""
+
+ def __call__(self, x: torch.Tensor, t: torch.Tensor, **kwargs) -> torch.Tensor:
+ pass
+
+
+class UnetApplyConds(TypedDict):
+ """Optional conditions for unet apply function."""
+
+ c_concat: Optional[torch.Tensor]
+ c_crossattn: Optional[torch.Tensor]
+ control: Optional[torch.Tensor]
+ transformer_options: Optional[dict]
+
+
+class UnetParams(TypedDict):
+ # Tensor of shape [B, C, H, W]
+ input: torch.Tensor
+ # Tensor of shape [B]
+ timestep: torch.Tensor
+ c: UnetApplyConds
+ # List of [0, 1], [0], [1], ...
+ # 0 means conditional, 1 means conditional unconditional
+ cond_or_uncond: List[int]
+
+
+UnetWrapperFunction = Callable[[UnetApplyFunction, UnetParams], torch.Tensor]
diff --git a/comfy_extras/chainner_models/__init__.py b/comfy_extras/chainner_models/__init__.py
deleted file mode 100644
index e69de29bb2d..00000000000
diff --git a/comfy_extras/chainner_models/architecture/DAT.py b/comfy_extras/chainner_models/architecture/DAT.py
deleted file mode 100644
index 0bcc26ef422..00000000000
--- a/comfy_extras/chainner_models/architecture/DAT.py
+++ /dev/null
@@ -1,1182 +0,0 @@
-# pylint: skip-file
-import math
-import re
-
-import numpy as np
-import torch
-import torch.nn as nn
-import torch.utils.checkpoint as checkpoint
-from einops import rearrange
-from einops.layers.torch import Rearrange
-from torch import Tensor
-from torch.nn import functional as F
-
-from .timm.drop import DropPath
-from .timm.weight_init import trunc_normal_
-
-
-def img2windows(img, H_sp, W_sp):
- """
- Input: Image (B, C, H, W)
- Output: Window Partition (B', N, C)
- """
- B, C, H, W = img.shape
- img_reshape = img.view(B, C, H // H_sp, H_sp, W // W_sp, W_sp)
- img_perm = (
- img_reshape.permute(0, 2, 4, 3, 5, 1).contiguous().reshape(-1, H_sp * W_sp, C)
- )
- return img_perm
-
-
-def windows2img(img_splits_hw, H_sp, W_sp, H, W):
- """
- Input: Window Partition (B', N, C)
- Output: Image (B, H, W, C)
- """
- B = int(img_splits_hw.shape[0] / (H * W / H_sp / W_sp))
-
- img = img_splits_hw.view(B, H // H_sp, W // W_sp, H_sp, W_sp, -1)
- img = img.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
- return img
-
-
-class SpatialGate(nn.Module):
- """Spatial-Gate.
- Args:
- dim (int): Half of input channels.
- """
-
- def __init__(self, dim):
- super().__init__()
- self.norm = nn.LayerNorm(dim)
- self.conv = nn.Conv2d(
- dim, dim, kernel_size=3, stride=1, padding=1, groups=dim
- ) # DW Conv
-
- def forward(self, x, H, W):
- # Split
- x1, x2 = x.chunk(2, dim=-1)
- B, N, C = x.shape
- x2 = (
- self.conv(self.norm(x2).transpose(1, 2).contiguous().view(B, C // 2, H, W))
- .flatten(2)
- .transpose(-1, -2)
- .contiguous()
- )
-
- return x1 * x2
-
-
-class SGFN(nn.Module):
- """Spatial-Gate Feed-Forward Network.
- Args:
- in_features (int): Number of input channels.
- hidden_features (int | None): Number of hidden channels. Default: None
- out_features (int | None): Number of output channels. Default: None
- act_layer (nn.Module): Activation layer. Default: nn.GELU
- drop (float): Dropout rate. Default: 0.0
- """
-
- def __init__(
- self,
- in_features,
- hidden_features=None,
- out_features=None,
- act_layer=nn.GELU,
- drop=0.0,
- ):
- super().__init__()
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- self.fc1 = nn.Linear(in_features, hidden_features)
- self.act = act_layer()
- self.sg = SpatialGate(hidden_features // 2)
- self.fc2 = nn.Linear(hidden_features // 2, out_features)
- self.drop = nn.Dropout(drop)
-
- def forward(self, x, H, W):
- """
- Input: x: (B, H*W, C), H, W
- Output: x: (B, H*W, C)
- """
- x = self.fc1(x)
- x = self.act(x)
- x = self.drop(x)
-
- x = self.sg(x, H, W)
- x = self.drop(x)
-
- x = self.fc2(x)
- x = self.drop(x)
- return x
-
-
-class DynamicPosBias(nn.Module):
- # The implementation builds on Crossformer code https://github.com/cheerss/CrossFormer/blob/main/models/crossformer.py
- """Dynamic Relative Position Bias.
- Args:
- dim (int): Number of input channels.
- num_heads (int): Number of attention heads.
- residual (bool): If True, use residual strage to connect conv.
- """
-
- def __init__(self, dim, num_heads, residual):
- super().__init__()
- self.residual = residual
- self.num_heads = num_heads
- self.pos_dim = dim // 4
- self.pos_proj = nn.Linear(2, self.pos_dim)
- self.pos1 = nn.Sequential(
- nn.LayerNorm(self.pos_dim),
- nn.ReLU(inplace=True),
- nn.Linear(self.pos_dim, self.pos_dim),
- )
- self.pos2 = nn.Sequential(
- nn.LayerNorm(self.pos_dim),
- nn.ReLU(inplace=True),
- nn.Linear(self.pos_dim, self.pos_dim),
- )
- self.pos3 = nn.Sequential(
- nn.LayerNorm(self.pos_dim),
- nn.ReLU(inplace=True),
- nn.Linear(self.pos_dim, self.num_heads),
- )
-
- def forward(self, biases):
- if self.residual:
- pos = self.pos_proj(biases) # 2Gh-1 * 2Gw-1, heads
- pos = pos + self.pos1(pos)
- pos = pos + self.pos2(pos)
- pos = self.pos3(pos)
- else:
- pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases))))
- return pos
-
-
-class Spatial_Attention(nn.Module):
- """Spatial Window Self-Attention.
- It supports rectangle window (containing square window).
- Args:
- dim (int): Number of input channels.
- idx (int): The indentix of window. (0/1)
- split_size (tuple(int)): Height and Width of spatial window.
- dim_out (int | None): The dimension of the attention output. Default: None
- num_heads (int): Number of attention heads. Default: 6
- attn_drop (float): Dropout ratio of attention weight. Default: 0.0
- proj_drop (float): Dropout ratio of output. Default: 0.0
- qk_scale (float | None): Override default qk scale of head_dim ** -0.5 if set
- position_bias (bool): The dynamic relative position bias. Default: True
- """
-
- def __init__(
- self,
- dim,
- idx,
- split_size=[8, 8],
- dim_out=None,
- num_heads=6,
- attn_drop=0.0,
- proj_drop=0.0,
- qk_scale=None,
- position_bias=True,
- ):
- super().__init__()
- self.dim = dim
- self.dim_out = dim_out or dim
- self.split_size = split_size
- self.num_heads = num_heads
- self.idx = idx
- self.position_bias = position_bias
-
- head_dim = dim // num_heads
- self.scale = qk_scale or head_dim**-0.5
-
- if idx == 0:
- H_sp, W_sp = self.split_size[0], self.split_size[1]
- elif idx == 1:
- W_sp, H_sp = self.split_size[0], self.split_size[1]
- else:
- print("ERROR MODE", idx)
- exit(0)
- self.H_sp = H_sp
- self.W_sp = W_sp
-
- if self.position_bias:
- self.pos = DynamicPosBias(self.dim // 4, self.num_heads, residual=False)
- # generate mother-set
- position_bias_h = torch.arange(1 - self.H_sp, self.H_sp)
- position_bias_w = torch.arange(1 - self.W_sp, self.W_sp)
- biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w]))
- biases = biases.flatten(1).transpose(0, 1).contiguous().float()
- self.register_buffer("rpe_biases", biases)
-
- # get pair-wise relative position index for each token inside the window
- coords_h = torch.arange(self.H_sp)
- coords_w = torch.arange(self.W_sp)
- coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
- coords_flatten = torch.flatten(coords, 1)
- relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
- relative_coords = relative_coords.permute(1, 2, 0).contiguous()
- relative_coords[:, :, 0] += self.H_sp - 1
- relative_coords[:, :, 1] += self.W_sp - 1
- relative_coords[:, :, 0] *= 2 * self.W_sp - 1
- relative_position_index = relative_coords.sum(-1)
- self.register_buffer("relative_position_index", relative_position_index)
-
- self.attn_drop = nn.Dropout(attn_drop)
-
- def im2win(self, x, H, W):
- B, N, C = x.shape
- x = x.transpose(-2, -1).contiguous().view(B, C, H, W)
- x = img2windows(x, self.H_sp, self.W_sp)
- x = (
- x.reshape(-1, self.H_sp * self.W_sp, self.num_heads, C // self.num_heads)
- .permute(0, 2, 1, 3)
- .contiguous()
- )
- return x
-
- def forward(self, qkv, H, W, mask=None):
- """
- Input: qkv: (B, 3*L, C), H, W, mask: (B, N, N), N is the window size
- Output: x (B, H, W, C)
- """
- q, k, v = qkv[0], qkv[1], qkv[2]
-
- B, L, C = q.shape
- assert L == H * W, "flatten img_tokens has wrong size"
-
- # partition the q,k,v, image to window
- q = self.im2win(q, H, W)
- k = self.im2win(k, H, W)
- v = self.im2win(v, H, W)
-
- q = q * self.scale
- attn = q @ k.transpose(-2, -1) # B head N C @ B head C N --> B head N N
-
- # calculate drpe
- if self.position_bias:
- pos = self.pos(self.rpe_biases)
- # select position bias
- relative_position_bias = pos[self.relative_position_index.view(-1)].view(
- self.H_sp * self.W_sp, self.H_sp * self.W_sp, -1
- )
- relative_position_bias = relative_position_bias.permute(
- 2, 0, 1
- ).contiguous()
- attn = attn + relative_position_bias.unsqueeze(0)
-
- N = attn.shape[3]
-
- # use mask for shift window
- if mask is not None:
- nW = mask.shape[0]
- attn = attn.view(B, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(
- 0
- )
- attn = attn.view(-1, self.num_heads, N, N)
-
- attn = nn.functional.softmax(attn, dim=-1, dtype=attn.dtype)
- attn = self.attn_drop(attn)
-
- x = attn @ v
- x = x.transpose(1, 2).reshape(
- -1, self.H_sp * self.W_sp, C
- ) # B head N N @ B head N C
-
- # merge the window, window to image
- x = windows2img(x, self.H_sp, self.W_sp, H, W) # B H' W' C
-
- return x
-
-
-class Adaptive_Spatial_Attention(nn.Module):
- # The implementation builds on CAT code https://github.com/Zhengchen1999/CAT
- """Adaptive Spatial Self-Attention
- Args:
- dim (int): Number of input channels.
- num_heads (int): Number of attention heads. Default: 6
- split_size (tuple(int)): Height and Width of spatial window.
- shift_size (tuple(int)): Shift size for spatial window.
- qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
- qk_scale (float | None): Override default qk scale of head_dim ** -0.5 if set.
- drop (float): Dropout rate. Default: 0.0
- attn_drop (float): Attention dropout rate. Default: 0.0
- rg_idx (int): The indentix of Residual Group (RG)
- b_idx (int): The indentix of Block in each RG
- """
-
- def __init__(
- self,
- dim,
- num_heads,
- reso=64,
- split_size=[8, 8],
- shift_size=[1, 2],
- qkv_bias=False,
- qk_scale=None,
- drop=0.0,
- attn_drop=0.0,
- rg_idx=0,
- b_idx=0,
- ):
- super().__init__()
- self.dim = dim
- self.num_heads = num_heads
- self.split_size = split_size
- self.shift_size = shift_size
- self.b_idx = b_idx
- self.rg_idx = rg_idx
- self.patches_resolution = reso
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
-
- assert (
- 0 <= self.shift_size[0] < self.split_size[0]
- ), "shift_size must in 0-split_size0"
- assert (
- 0 <= self.shift_size[1] < self.split_size[1]
- ), "shift_size must in 0-split_size1"
-
- self.branch_num = 2
-
- self.proj = nn.Linear(dim, dim)
- self.proj_drop = nn.Dropout(drop)
-
- self.attns = nn.ModuleList(
- [
- Spatial_Attention(
- dim // 2,
- idx=i,
- split_size=split_size,
- num_heads=num_heads // 2,
- dim_out=dim // 2,
- qk_scale=qk_scale,
- attn_drop=attn_drop,
- proj_drop=drop,
- position_bias=True,
- )
- for i in range(self.branch_num)
- ]
- )
-
- if (self.rg_idx % 2 == 0 and self.b_idx > 0 and (self.b_idx - 2) % 4 == 0) or (
- self.rg_idx % 2 != 0 and self.b_idx % 4 == 0
- ):
- attn_mask = self.calculate_mask(
- self.patches_resolution, self.patches_resolution
- )
- self.register_buffer("attn_mask_0", attn_mask[0])
- self.register_buffer("attn_mask_1", attn_mask[1])
- else:
- attn_mask = None
- self.register_buffer("attn_mask_0", None)
- self.register_buffer("attn_mask_1", None)
-
- self.dwconv = nn.Sequential(
- nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim),
- nn.BatchNorm2d(dim),
- nn.GELU(),
- )
- self.channel_interaction = nn.Sequential(
- nn.AdaptiveAvgPool2d(1),
- nn.Conv2d(dim, dim // 8, kernel_size=1),
- nn.BatchNorm2d(dim // 8),
- nn.GELU(),
- nn.Conv2d(dim // 8, dim, kernel_size=1),
- )
- self.spatial_interaction = nn.Sequential(
- nn.Conv2d(dim, dim // 16, kernel_size=1),
- nn.BatchNorm2d(dim // 16),
- nn.GELU(),
- nn.Conv2d(dim // 16, 1, kernel_size=1),
- )
-
- def calculate_mask(self, H, W):
- # The implementation builds on Swin Transformer code https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py
- # calculate attention mask for shift window
- img_mask_0 = torch.zeros((1, H, W, 1)) # 1 H W 1 idx=0
- img_mask_1 = torch.zeros((1, H, W, 1)) # 1 H W 1 idx=1
- h_slices_0 = (
- slice(0, -self.split_size[0]),
- slice(-self.split_size[0], -self.shift_size[0]),
- slice(-self.shift_size[0], None),
- )
- w_slices_0 = (
- slice(0, -self.split_size[1]),
- slice(-self.split_size[1], -self.shift_size[1]),
- slice(-self.shift_size[1], None),
- )
-
- h_slices_1 = (
- slice(0, -self.split_size[1]),
- slice(-self.split_size[1], -self.shift_size[1]),
- slice(-self.shift_size[1], None),
- )
- w_slices_1 = (
- slice(0, -self.split_size[0]),
- slice(-self.split_size[0], -self.shift_size[0]),
- slice(-self.shift_size[0], None),
- )
- cnt = 0
- for h in h_slices_0:
- for w in w_slices_0:
- img_mask_0[:, h, w, :] = cnt
- cnt += 1
- cnt = 0
- for h in h_slices_1:
- for w in w_slices_1:
- img_mask_1[:, h, w, :] = cnt
- cnt += 1
-
- # calculate mask for window-0
- img_mask_0 = img_mask_0.view(
- 1,
- H // self.split_size[0],
- self.split_size[0],
- W // self.split_size[1],
- self.split_size[1],
- 1,
- )
- img_mask_0 = (
- img_mask_0.permute(0, 1, 3, 2, 4, 5)
- .contiguous()
- .view(-1, self.split_size[0], self.split_size[1], 1)
- ) # nW, sw[0], sw[1], 1
- mask_windows_0 = img_mask_0.view(-1, self.split_size[0] * self.split_size[1])
- attn_mask_0 = mask_windows_0.unsqueeze(1) - mask_windows_0.unsqueeze(2)
- attn_mask_0 = attn_mask_0.masked_fill(
- attn_mask_0 != 0, float(-100.0)
- ).masked_fill(attn_mask_0 == 0, float(0.0))
-
- # calculate mask for window-1
- img_mask_1 = img_mask_1.view(
- 1,
- H // self.split_size[1],
- self.split_size[1],
- W // self.split_size[0],
- self.split_size[0],
- 1,
- )
- img_mask_1 = (
- img_mask_1.permute(0, 1, 3, 2, 4, 5)
- .contiguous()
- .view(-1, self.split_size[1], self.split_size[0], 1)
- ) # nW, sw[1], sw[0], 1
- mask_windows_1 = img_mask_1.view(-1, self.split_size[1] * self.split_size[0])
- attn_mask_1 = mask_windows_1.unsqueeze(1) - mask_windows_1.unsqueeze(2)
- attn_mask_1 = attn_mask_1.masked_fill(
- attn_mask_1 != 0, float(-100.0)
- ).masked_fill(attn_mask_1 == 0, float(0.0))
-
- return attn_mask_0, attn_mask_1
-
- def forward(self, x, H, W):
- """
- Input: x: (B, H*W, C), H, W
- Output: x: (B, H*W, C)
- """
- B, L, C = x.shape
- assert L == H * W, "flatten img_tokens has wrong size"
-
- qkv = self.qkv(x).reshape(B, -1, 3, C).permute(2, 0, 1, 3) # 3, B, HW, C
- # V without partition
- v = qkv[2].transpose(-2, -1).contiguous().view(B, C, H, W)
-
- # image padding
- max_split_size = max(self.split_size[0], self.split_size[1])
- pad_l = pad_t = 0
- pad_r = (max_split_size - W % max_split_size) % max_split_size
- pad_b = (max_split_size - H % max_split_size) % max_split_size
-
- qkv = qkv.reshape(3 * B, H, W, C).permute(0, 3, 1, 2) # 3B C H W
- qkv = (
- F.pad(qkv, (pad_l, pad_r, pad_t, pad_b))
- .reshape(3, B, C, -1)
- .transpose(-2, -1)
- ) # l r t b
- _H = pad_b + H
- _W = pad_r + W
- _L = _H * _W
-
- # window-0 and window-1 on split channels [C/2, C/2]; for square windows (e.g., 8x8), window-0 and window-1 can be merged
- # shift in block: (0, 4, 8, ...), (2, 6, 10, ...), (0, 4, 8, ...), (2, 6, 10, ...), ...
- if (self.rg_idx % 2 == 0 and self.b_idx > 0 and (self.b_idx - 2) % 4 == 0) or (
- self.rg_idx % 2 != 0 and self.b_idx % 4 == 0
- ):
- qkv = qkv.view(3, B, _H, _W, C)
- qkv_0 = torch.roll(
- qkv[:, :, :, :, : C // 2],
- shifts=(-self.shift_size[0], -self.shift_size[1]),
- dims=(2, 3),
- )
- qkv_0 = qkv_0.view(3, B, _L, C // 2)
- qkv_1 = torch.roll(
- qkv[:, :, :, :, C // 2 :],
- shifts=(-self.shift_size[1], -self.shift_size[0]),
- dims=(2, 3),
- )
- qkv_1 = qkv_1.view(3, B, _L, C // 2)
-
- if self.patches_resolution != _H or self.patches_resolution != _W:
- mask_tmp = self.calculate_mask(_H, _W)
- x1_shift = self.attns[0](qkv_0, _H, _W, mask=mask_tmp[0].to(x.device))
- x2_shift = self.attns[1](qkv_1, _H, _W, mask=mask_tmp[1].to(x.device))
- else:
- x1_shift = self.attns[0](qkv_0, _H, _W, mask=self.attn_mask_0)
- x2_shift = self.attns[1](qkv_1, _H, _W, mask=self.attn_mask_1)
-
- x1 = torch.roll(
- x1_shift, shifts=(self.shift_size[0], self.shift_size[1]), dims=(1, 2)
- )
- x2 = torch.roll(
- x2_shift, shifts=(self.shift_size[1], self.shift_size[0]), dims=(1, 2)
- )
- x1 = x1[:, :H, :W, :].reshape(B, L, C // 2)
- x2 = x2[:, :H, :W, :].reshape(B, L, C // 2)
- # attention output
- attened_x = torch.cat([x1, x2], dim=2)
-
- else:
- x1 = self.attns[0](qkv[:, :, :, : C // 2], _H, _W)[:, :H, :W, :].reshape(
- B, L, C // 2
- )
- x2 = self.attns[1](qkv[:, :, :, C // 2 :], _H, _W)[:, :H, :W, :].reshape(
- B, L, C // 2
- )
- # attention output
- attened_x = torch.cat([x1, x2], dim=2)
-
- # convolution output
- conv_x = self.dwconv(v)
-
- # Adaptive Interaction Module (AIM)
- # C-Map (before sigmoid)
- channel_map = (
- self.channel_interaction(conv_x)
- .permute(0, 2, 3, 1)
- .contiguous()
- .view(B, 1, C)
- )
- # S-Map (before sigmoid)
- attention_reshape = attened_x.transpose(-2, -1).contiguous().view(B, C, H, W)
- spatial_map = self.spatial_interaction(attention_reshape)
-
- # C-I
- attened_x = attened_x * torch.sigmoid(channel_map)
- # S-I
- conv_x = torch.sigmoid(spatial_map) * conv_x
- conv_x = conv_x.permute(0, 2, 3, 1).contiguous().view(B, L, C)
-
- x = attened_x + conv_x
-
- x = self.proj(x)
- x = self.proj_drop(x)
-
- return x
-
-
-class Adaptive_Channel_Attention(nn.Module):
- # The implementation builds on XCiT code https://github.com/facebookresearch/xcit
- """Adaptive Channel Self-Attention
- Args:
- dim (int): Number of input channels.
- num_heads (int): Number of attention heads. Default: 6
- qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
- qk_scale (float | None): Override default qk scale of head_dim ** -0.5 if set.
- attn_drop (float): Attention dropout rate. Default: 0.0
- drop_path (float): Stochastic depth rate. Default: 0.0
- """
-
- def __init__(
- self,
- dim,
- num_heads=8,
- qkv_bias=False,
- qk_scale=None,
- attn_drop=0.0,
- proj_drop=0.0,
- ):
- super().__init__()
- self.num_heads = num_heads
- self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
-
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
- self.attn_drop = nn.Dropout(attn_drop)
- self.proj = nn.Linear(dim, dim)
- self.proj_drop = nn.Dropout(proj_drop)
-
- self.dwconv = nn.Sequential(
- nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim),
- nn.BatchNorm2d(dim),
- nn.GELU(),
- )
- self.channel_interaction = nn.Sequential(
- nn.AdaptiveAvgPool2d(1),
- nn.Conv2d(dim, dim // 8, kernel_size=1),
- nn.BatchNorm2d(dim // 8),
- nn.GELU(),
- nn.Conv2d(dim // 8, dim, kernel_size=1),
- )
- self.spatial_interaction = nn.Sequential(
- nn.Conv2d(dim, dim // 16, kernel_size=1),
- nn.BatchNorm2d(dim // 16),
- nn.GELU(),
- nn.Conv2d(dim // 16, 1, kernel_size=1),
- )
-
- def forward(self, x, H, W):
- """
- Input: x: (B, H*W, C), H, W
- Output: x: (B, H*W, C)
- """
- B, N, C = x.shape
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
- qkv = qkv.permute(2, 0, 3, 1, 4)
- q, k, v = qkv[0], qkv[1], qkv[2]
-
- q = q.transpose(-2, -1)
- k = k.transpose(-2, -1)
- v = v.transpose(-2, -1)
-
- v_ = v.reshape(B, C, N).contiguous().view(B, C, H, W)
-
- q = torch.nn.functional.normalize(q, dim=-1)
- k = torch.nn.functional.normalize(k, dim=-1)
-
- attn = (q @ k.transpose(-2, -1)) * self.temperature
- attn = attn.softmax(dim=-1)
- attn = self.attn_drop(attn)
-
- # attention output
- attened_x = (attn @ v).permute(0, 3, 1, 2).reshape(B, N, C)
-
- # convolution output
- conv_x = self.dwconv(v_)
-
- # Adaptive Interaction Module (AIM)
- # C-Map (before sigmoid)
- attention_reshape = attened_x.transpose(-2, -1).contiguous().view(B, C, H, W)
- channel_map = self.channel_interaction(attention_reshape)
- # S-Map (before sigmoid)
- spatial_map = (
- self.spatial_interaction(conv_x)
- .permute(0, 2, 3, 1)
- .contiguous()
- .view(B, N, 1)
- )
-
- # S-I
- attened_x = attened_x * torch.sigmoid(spatial_map)
- # C-I
- conv_x = conv_x * torch.sigmoid(channel_map)
- conv_x = conv_x.permute(0, 2, 3, 1).contiguous().view(B, N, C)
-
- x = attened_x + conv_x
-
- x = self.proj(x)
- x = self.proj_drop(x)
-
- return x
-
-
-class DATB(nn.Module):
- def __init__(
- self,
- dim,
- num_heads,
- reso=64,
- split_size=[2, 4],
- shift_size=[1, 2],
- expansion_factor=4.0,
- qkv_bias=False,
- qk_scale=None,
- drop=0.0,
- attn_drop=0.0,
- drop_path=0.0,
- act_layer=nn.GELU,
- norm_layer=nn.LayerNorm,
- rg_idx=0,
- b_idx=0,
- ):
- super().__init__()
-
- self.norm1 = norm_layer(dim)
-
- if b_idx % 2 == 0:
- # DSTB
- self.attn = Adaptive_Spatial_Attention(
- dim,
- num_heads=num_heads,
- reso=reso,
- split_size=split_size,
- shift_size=shift_size,
- qkv_bias=qkv_bias,
- qk_scale=qk_scale,
- drop=drop,
- attn_drop=attn_drop,
- rg_idx=rg_idx,
- b_idx=b_idx,
- )
- else:
- # DCTB
- self.attn = Adaptive_Channel_Attention(
- dim,
- num_heads=num_heads,
- qkv_bias=qkv_bias,
- qk_scale=qk_scale,
- attn_drop=attn_drop,
- proj_drop=drop,
- )
- self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
-
- ffn_hidden_dim = int(dim * expansion_factor)
- self.ffn = SGFN(
- in_features=dim,
- hidden_features=ffn_hidden_dim,
- out_features=dim,
- act_layer=act_layer,
- )
- self.norm2 = norm_layer(dim)
-
- def forward(self, x, x_size):
- """
- Input: x: (B, H*W, C), x_size: (H, W)
- Output: x: (B, H*W, C)
- """
- H, W = x_size
- x = x + self.drop_path(self.attn(self.norm1(x), H, W))
- x = x + self.drop_path(self.ffn(self.norm2(x), H, W))
-
- return x
-
-
-class ResidualGroup(nn.Module):
- """ResidualGroup
- Args:
- dim (int): Number of input channels.
- reso (int): Input resolution.
- num_heads (int): Number of attention heads.
- split_size (tuple(int)): Height and Width of spatial window.
- expansion_factor (float): Ratio of ffn hidden dim to embedding dim.
- qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
- qk_scale (float | None): Override default qk scale of head_dim ** -0.5 if set. Default: None
- drop (float): Dropout rate. Default: 0
- attn_drop(float): Attention dropout rate. Default: 0
- drop_paths (float | None): Stochastic depth rate.
- act_layer (nn.Module): Activation layer. Default: nn.GELU
- norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm
- depth (int): Number of dual aggregation Transformer blocks in residual group.
- use_chk (bool): Whether to use checkpointing to save memory.
- resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
- """
-
- def __init__(
- self,
- dim,
- reso,
- num_heads,
- split_size=[2, 4],
- expansion_factor=4.0,
- qkv_bias=False,
- qk_scale=None,
- drop=0.0,
- attn_drop=0.0,
- drop_paths=None,
- act_layer=nn.GELU,
- norm_layer=nn.LayerNorm,
- depth=2,
- use_chk=False,
- resi_connection="1conv",
- rg_idx=0,
- ):
- super().__init__()
- self.use_chk = use_chk
- self.reso = reso
-
- self.blocks = nn.ModuleList(
- [
- DATB(
- dim=dim,
- num_heads=num_heads,
- reso=reso,
- split_size=split_size,
- shift_size=[split_size[0] // 2, split_size[1] // 2],
- expansion_factor=expansion_factor,
- qkv_bias=qkv_bias,
- qk_scale=qk_scale,
- drop=drop,
- attn_drop=attn_drop,
- drop_path=drop_paths[i],
- act_layer=act_layer,
- norm_layer=norm_layer,
- rg_idx=rg_idx,
- b_idx=i,
- )
- for i in range(depth)
- ]
- )
-
- if resi_connection == "1conv":
- self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
- elif resi_connection == "3conv":
- self.conv = nn.Sequential(
- nn.Conv2d(dim, dim // 4, 3, 1, 1),
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
- nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
- nn.Conv2d(dim // 4, dim, 3, 1, 1),
- )
-
- def forward(self, x, x_size):
- """
- Input: x: (B, H*W, C), x_size: (H, W)
- Output: x: (B, H*W, C)
- """
- H, W = x_size
- res = x
- for blk in self.blocks:
- if self.use_chk:
- x = checkpoint.checkpoint(blk, x, x_size)
- else:
- x = blk(x, x_size)
- x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W)
- x = self.conv(x)
- x = rearrange(x, "b c h w -> b (h w) c")
- x = res + x
-
- return x
-
-
-class Upsample(nn.Sequential):
- """Upsample module.
- Args:
- scale (int): Scale factor. Supported scales: 2^n and 3.
- num_feat (int): Channel number of intermediate features.
- """
-
- def __init__(self, scale, num_feat):
- m = []
- if (scale & (scale - 1)) == 0: # scale = 2^n
- for _ in range(int(math.log(scale, 2))):
- m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
- m.append(nn.PixelShuffle(2))
- elif scale == 3:
- m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
- m.append(nn.PixelShuffle(3))
- else:
- raise ValueError(
- f"scale {scale} is not supported. " "Supported scales: 2^n and 3."
- )
- super(Upsample, self).__init__(*m)
-
-
-class UpsampleOneStep(nn.Sequential):
- """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
- Used in lightweight SR to save parameters.
-
- Args:
- scale (int): Scale factor. Supported scales: 2^n and 3.
- num_feat (int): Channel number of intermediate features.
-
- """
-
- def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
- self.num_feat = num_feat
- self.input_resolution = input_resolution
- m = []
- m.append(nn.Conv2d(num_feat, (scale**2) * num_out_ch, 3, 1, 1))
- m.append(nn.PixelShuffle(scale))
- super(UpsampleOneStep, self).__init__(*m)
-
- def flops(self):
- h, w = self.input_resolution
- flops = h * w * self.num_feat * 3 * 9
- return flops
-
-
-class DAT(nn.Module):
- """Dual Aggregation Transformer
- Args:
- img_size (int): Input image size. Default: 64
- in_chans (int): Number of input image channels. Default: 3
- embed_dim (int): Patch embedding dimension. Default: 180
- depths (tuple(int)): Depth of each residual group (number of DATB in each RG).
- split_size (tuple(int)): Height and Width of spatial window.
- num_heads (tuple(int)): Number of attention heads in different residual groups.
- expansion_factor (float): Ratio of ffn hidden dim to embedding dim. Default: 4
- qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
- qk_scale (float | None): Override default qk scale of head_dim ** -0.5 if set. Default: None
- drop_rate (float): Dropout rate. Default: 0
- attn_drop_rate (float): Attention dropout rate. Default: 0
- drop_path_rate (float): Stochastic depth rate. Default: 0.1
- act_layer (nn.Module): Activation layer. Default: nn.GELU
- norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm
- use_chk (bool): Whether to use checkpointing to save memory.
- upscale: Upscale factor. 2/3/4 for image SR
- img_range: Image range. 1. or 255.
- resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
- """
-
- def __init__(self, state_dict):
- super().__init__()
-
- # defaults
- img_size = 64
- in_chans = 3
- embed_dim = 180
- split_size = [2, 4]
- depth = [2, 2, 2, 2]
- num_heads = [2, 2, 2, 2]
- expansion_factor = 4.0
- qkv_bias = True
- qk_scale = None
- drop_rate = 0.0
- attn_drop_rate = 0.0
- drop_path_rate = 0.1
- act_layer = nn.GELU
- norm_layer = nn.LayerNorm
- use_chk = False
- upscale = 2
- img_range = 1.0
- resi_connection = "1conv"
- upsampler = "pixelshuffle"
-
- self.model_arch = "DAT"
- self.sub_type = "SR"
- self.state = state_dict
-
- state_keys = state_dict.keys()
- if "conv_before_upsample.0.weight" in state_keys:
- if "conv_up1.weight" in state_keys:
- upsampler = "nearest+conv"
- else:
- upsampler = "pixelshuffle"
- supports_fp16 = False
- elif "upsample.0.weight" in state_keys:
- upsampler = "pixelshuffledirect"
- else:
- upsampler = ""
-
- num_feat = (
- state_dict.get("conv_before_upsample.0.weight", None).shape[1]
- if state_dict.get("conv_before_upsample.weight", None)
- else 64
- )
-
- num_in_ch = state_dict["conv_first.weight"].shape[1]
- in_chans = num_in_ch
- if "conv_last.weight" in state_keys:
- num_out_ch = state_dict["conv_last.weight"].shape[0]
- else:
- num_out_ch = num_in_ch
-
- upscale = 1
- if upsampler == "nearest+conv":
- upsample_keys = [
- x for x in state_keys if "conv_up" in x and "bias" not in x
- ]
-
- for upsample_key in upsample_keys:
- upscale *= 2
- elif upsampler == "pixelshuffle":
- upsample_keys = [
- x
- for x in state_keys
- if "upsample" in x and "conv" not in x and "bias" not in x
- ]
- for upsample_key in upsample_keys:
- shape = state_dict[upsample_key].shape[0]
- upscale *= math.sqrt(shape // num_feat)
- upscale = int(upscale)
- elif upsampler == "pixelshuffledirect":
- upscale = int(
- math.sqrt(state_dict["upsample.0.bias"].shape[0] // num_out_ch)
- )
-
- max_layer_num = 0
- max_block_num = 0
- for key in state_keys:
- result = re.match(r"layers.(\d*).blocks.(\d*).norm1.weight", key)
- if result:
- layer_num, block_num = result.groups()
- max_layer_num = max(max_layer_num, int(layer_num))
- max_block_num = max(max_block_num, int(block_num))
-
- depth = [max_block_num + 1 for _ in range(max_layer_num + 1)]
-
- if "layers.0.blocks.1.attn.temperature" in state_keys:
- num_heads_num = state_dict["layers.0.blocks.1.attn.temperature"].shape[0]
- num_heads = [num_heads_num for _ in range(max_layer_num + 1)]
- else:
- num_heads = depth
-
- embed_dim = state_dict["conv_first.weight"].shape[0]
- expansion_factor = float(
- state_dict["layers.0.blocks.0.ffn.fc1.weight"].shape[0] / embed_dim
- )
-
- # TODO: could actually count the layers, but this should do
- if "layers.0.conv.4.weight" in state_keys:
- resi_connection = "3conv"
- else:
- resi_connection = "1conv"
-
- if "layers.0.blocks.2.attn.attn_mask_0" in state_keys:
- attn_mask_0_x, attn_mask_0_y, attn_mask_0_z = state_dict[
- "layers.0.blocks.2.attn.attn_mask_0"
- ].shape
-
- img_size = int(math.sqrt(attn_mask_0_x * attn_mask_0_y))
-
- if "layers.0.blocks.0.attn.attns.0.rpe_biases" in state_keys:
- split_sizes = (
- state_dict["layers.0.blocks.0.attn.attns.0.rpe_biases"][-1] + 1
- )
- split_size = [int(x) for x in split_sizes]
-
- self.in_nc = num_in_ch
- self.out_nc = num_out_ch
- self.num_feat = num_feat
- self.embed_dim = embed_dim
- self.num_heads = num_heads
- self.depth = depth
- self.scale = upscale
- self.upsampler = upsampler
- self.img_size = img_size
- self.img_range = img_range
- self.expansion_factor = expansion_factor
- self.resi_connection = resi_connection
- self.split_size = split_size
-
- self.supports_fp16 = False # Too much weirdness to support this at the moment
- self.supports_bfp16 = True
- self.min_size_restriction = 16
-
- num_in_ch = in_chans
- num_out_ch = in_chans
- num_feat = 64
- self.img_range = img_range
- if in_chans == 3:
- rgb_mean = (0.4488, 0.4371, 0.4040)
- self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
- else:
- self.mean = torch.zeros(1, 1, 1, 1)
- self.upscale = upscale
- self.upsampler = upsampler
-
- # ------------------------- 1, Shallow Feature Extraction ------------------------- #
- self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
-
- # ------------------------- 2, Deep Feature Extraction ------------------------- #
- self.num_layers = len(depth)
- self.use_chk = use_chk
- self.num_features = (
- self.embed_dim
- ) = embed_dim # num_features for consistency with other models
- heads = num_heads
-
- self.before_RG = nn.Sequential(
- Rearrange("b c h w -> b (h w) c"), nn.LayerNorm(embed_dim)
- )
-
- curr_dim = embed_dim
- dpr = [
- x.item() for x in torch.linspace(0, drop_path_rate, np.sum(depth))
- ] # stochastic depth decay rule
-
- self.layers = nn.ModuleList()
- for i in range(self.num_layers):
- layer = ResidualGroup(
- dim=embed_dim,
- num_heads=heads[i],
- reso=img_size,
- split_size=split_size,
- expansion_factor=expansion_factor,
- qkv_bias=qkv_bias,
- qk_scale=qk_scale,
- drop=drop_rate,
- attn_drop=attn_drop_rate,
- drop_paths=dpr[sum(depth[:i]) : sum(depth[: i + 1])],
- act_layer=act_layer,
- norm_layer=norm_layer,
- depth=depth[i],
- use_chk=use_chk,
- resi_connection=resi_connection,
- rg_idx=i,
- )
- self.layers.append(layer)
-
- self.norm = norm_layer(curr_dim)
- # build the last conv layer in deep feature extraction
- if resi_connection == "1conv":
- self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
- elif resi_connection == "3conv":
- # to save parameters and memory
- self.conv_after_body = nn.Sequential(
- nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
- nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
- nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1),
- )
-
- # ------------------------- 3, Reconstruction ------------------------- #
- if self.upsampler == "pixelshuffle":
- # for classical SR
- self.conv_before_upsample = nn.Sequential(
- nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)
- )
- self.upsample = Upsample(upscale, num_feat)
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
- elif self.upsampler == "pixelshuffledirect":
- # for lightweight SR (to save parameters)
- self.upsample = UpsampleOneStep(
- upscale, embed_dim, num_out_ch, (img_size, img_size)
- )
-
- self.apply(self._init_weights)
- self.load_state_dict(state_dict, strict=True)
-
- def _init_weights(self, m):
- if isinstance(m, nn.Linear):
- trunc_normal_(m.weight, std=0.02)
- if isinstance(m, nn.Linear) and m.bias is not None:
- nn.init.constant_(m.bias, 0)
- elif isinstance(
- m, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm, nn.InstanceNorm2d)
- ):
- nn.init.constant_(m.bias, 0)
- nn.init.constant_(m.weight, 1.0)
-
- def forward_features(self, x):
- _, _, H, W = x.shape
- x_size = [H, W]
- x = self.before_RG(x)
- for layer in self.layers:
- x = layer(x, x_size)
- x = self.norm(x)
- x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W)
-
- return x
-
- def forward(self, x):
- """
- Input: x: (B, C, H, W)
- """
- self.mean = self.mean.type_as(x)
- x = (x - self.mean) * self.img_range
-
- if self.upsampler == "pixelshuffle":
- # for image SR
- x = self.conv_first(x)
- x = self.conv_after_body(self.forward_features(x)) + x
- x = self.conv_before_upsample(x)
- x = self.conv_last(self.upsample(x))
- elif self.upsampler == "pixelshuffledirect":
- # for lightweight SR
- x = self.conv_first(x)
- x = self.conv_after_body(self.forward_features(x)) + x
- x = self.upsample(x)
-
- x = x / self.img_range + self.mean
- return x
diff --git a/comfy_extras/chainner_models/architecture/HAT.py b/comfy_extras/chainner_models/architecture/HAT.py
deleted file mode 100644
index 6694742199b..00000000000
--- a/comfy_extras/chainner_models/architecture/HAT.py
+++ /dev/null
@@ -1,1277 +0,0 @@
-# pylint: skip-file
-# HAT from https://github.com/XPixelGroup/HAT/blob/main/hat/archs/hat_arch.py
-import math
-import re
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from einops import rearrange
-
-from .timm.helpers import to_2tuple
-from .timm.weight_init import trunc_normal_
-
-
-def drop_path(x, drop_prob: float = 0.0, training: bool = False):
- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
- From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
- """
- if drop_prob == 0.0 or not training:
- return x
- keep_prob = 1 - drop_prob
- shape = (x.shape[0],) + (1,) * (
- x.ndim - 1
- ) # work with diff dim tensors, not just 2D ConvNets
- random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
- random_tensor.floor_() # binarize
- output = x.div(keep_prob) * random_tensor
- return output
-
-
-class DropPath(nn.Module):
- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
- From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
- """
-
- def __init__(self, drop_prob=None):
- super(DropPath, self).__init__()
- self.drop_prob = drop_prob
-
- def forward(self, x):
- return drop_path(x, self.drop_prob, self.training) # type: ignore
-
-
-class ChannelAttention(nn.Module):
- """Channel attention used in RCAN.
- Args:
- num_feat (int): Channel number of intermediate features.
- squeeze_factor (int): Channel squeeze factor. Default: 16.
- """
-
- def __init__(self, num_feat, squeeze_factor=16):
- super(ChannelAttention, self).__init__()
- self.attention = nn.Sequential(
- nn.AdaptiveAvgPool2d(1),
- nn.Conv2d(num_feat, num_feat // squeeze_factor, 1, padding=0),
- nn.ReLU(inplace=True),
- nn.Conv2d(num_feat // squeeze_factor, num_feat, 1, padding=0),
- nn.Sigmoid(),
- )
-
- def forward(self, x):
- y = self.attention(x)
- return x * y
-
-
-class CAB(nn.Module):
- def __init__(self, num_feat, compress_ratio=3, squeeze_factor=30):
- super(CAB, self).__init__()
-
- self.cab = nn.Sequential(
- nn.Conv2d(num_feat, num_feat // compress_ratio, 3, 1, 1),
- nn.GELU(),
- nn.Conv2d(num_feat // compress_ratio, num_feat, 3, 1, 1),
- ChannelAttention(num_feat, squeeze_factor),
- )
-
- def forward(self, x):
- return self.cab(x)
-
-
-class Mlp(nn.Module):
- def __init__(
- self,
- in_features,
- hidden_features=None,
- out_features=None,
- act_layer=nn.GELU,
- drop=0.0,
- ):
- super().__init__()
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- self.fc1 = nn.Linear(in_features, hidden_features)
- self.act = act_layer()
- self.fc2 = nn.Linear(hidden_features, out_features)
- self.drop = nn.Dropout(drop)
-
- def forward(self, x):
- x = self.fc1(x)
- x = self.act(x)
- x = self.drop(x)
- x = self.fc2(x)
- x = self.drop(x)
- return x
-
-
-def window_partition(x, window_size):
- """
- Args:
- x: (b, h, w, c)
- window_size (int): window size
- Returns:
- windows: (num_windows*b, window_size, window_size, c)
- """
- b, h, w, c = x.shape
- x = x.view(b, h // window_size, window_size, w // window_size, window_size, c)
- windows = (
- x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, c)
- )
- return windows
-
-
-def window_reverse(windows, window_size, h, w):
- """
- Args:
- windows: (num_windows*b, window_size, window_size, c)
- window_size (int): Window size
- h (int): Height of image
- w (int): Width of image
- Returns:
- x: (b, h, w, c)
- """
- b = int(windows.shape[0] / (h * w / window_size / window_size))
- x = windows.view(
- b, h // window_size, w // window_size, window_size, window_size, -1
- )
- x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(b, h, w, -1)
- return x
-
-
-class WindowAttention(nn.Module):
- r"""Window based multi-head self attention (W-MSA) module with relative position bias.
- It supports both of shifted and non-shifted window.
- Args:
- dim (int): Number of input channels.
- window_size (tuple[int]): The height and width of the window.
- num_heads (int): Number of attention heads.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
- attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
- proj_drop (float, optional): Dropout ratio of output. Default: 0.0
- """
-
- def __init__(
- self,
- dim,
- window_size,
- num_heads,
- qkv_bias=True,
- qk_scale=None,
- attn_drop=0.0,
- proj_drop=0.0,
- ):
- super().__init__()
- self.dim = dim
- self.window_size = window_size # Wh, Ww
- self.num_heads = num_heads
- head_dim = dim // num_heads
- self.scale = qk_scale or head_dim**-0.5
-
- # define a parameter table of relative position bias
- self.relative_position_bias_table = nn.Parameter( # type: ignore
- torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
- ) # 2*Wh-1 * 2*Ww-1, nH
-
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
- self.attn_drop = nn.Dropout(attn_drop)
- self.proj = nn.Linear(dim, dim)
-
- self.proj_drop = nn.Dropout(proj_drop)
-
- trunc_normal_(self.relative_position_bias_table, std=0.02)
- self.softmax = nn.Softmax(dim=-1)
-
- def forward(self, x, rpi, mask=None):
- """
- Args:
- x: input features with shape of (num_windows*b, n, c)
- mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
- """
- b_, n, c = x.shape
- qkv = (
- self.qkv(x)
- .reshape(b_, n, 3, self.num_heads, c // self.num_heads)
- .permute(2, 0, 3, 1, 4)
- )
- q, k, v = (
- qkv[0],
- qkv[1],
- qkv[2],
- ) # make torchscript happy (cannot use tensor as tuple)
-
- q = q * self.scale
- attn = q @ k.transpose(-2, -1)
-
- relative_position_bias = self.relative_position_bias_table[rpi.view(-1)].view(
- self.window_size[0] * self.window_size[1],
- self.window_size[0] * self.window_size[1],
- -1,
- ) # Wh*Ww,Wh*Ww,nH
- relative_position_bias = relative_position_bias.permute(
- 2, 0, 1
- ).contiguous() # nH, Wh*Ww, Wh*Ww
- attn = attn + relative_position_bias.unsqueeze(0)
-
- if mask is not None:
- nw = mask.shape[0]
- attn = attn.view(b_ // nw, nw, self.num_heads, n, n) + mask.unsqueeze(
- 1
- ).unsqueeze(0)
- attn = attn.view(-1, self.num_heads, n, n)
- attn = self.softmax(attn)
- else:
- attn = self.softmax(attn)
-
- attn = self.attn_drop(attn)
-
- x = (attn @ v).transpose(1, 2).reshape(b_, n, c)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
-
-
-class HAB(nn.Module):
- r"""Hybrid Attention Block.
- Args:
- dim (int): Number of input channels.
- input_resolution (tuple[int]): Input resolution.
- num_heads (int): Number of attention heads.
- window_size (int): Window size.
- shift_size (int): Shift size for SW-MSA.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
- drop (float, optional): Dropout rate. Default: 0.0
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
- drop_path (float, optional): Stochastic depth rate. Default: 0.0
- act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- """
-
- def __init__(
- self,
- dim,
- input_resolution,
- num_heads,
- window_size=7,
- shift_size=0,
- compress_ratio=3,
- squeeze_factor=30,
- conv_scale=0.01,
- mlp_ratio=4.0,
- qkv_bias=True,
- qk_scale=None,
- drop=0.0,
- attn_drop=0.0,
- drop_path=0.0,
- act_layer=nn.GELU,
- norm_layer=nn.LayerNorm,
- ):
- super().__init__()
- self.dim = dim
- self.input_resolution = input_resolution
- self.num_heads = num_heads
- self.window_size = window_size
- self.shift_size = shift_size
- self.mlp_ratio = mlp_ratio
- if min(self.input_resolution) <= self.window_size:
- # if window size is larger than input resolution, we don't partition windows
- self.shift_size = 0
- self.window_size = min(self.input_resolution)
- assert (
- 0 <= self.shift_size < self.window_size
- ), "shift_size must in 0-window_size"
-
- self.norm1 = norm_layer(dim)
- self.attn = WindowAttention(
- dim,
- window_size=to_2tuple(self.window_size),
- num_heads=num_heads,
- qkv_bias=qkv_bias,
- qk_scale=qk_scale,
- attn_drop=attn_drop,
- proj_drop=drop,
- )
-
- self.conv_scale = conv_scale
- self.conv_block = CAB(
- num_feat=dim, compress_ratio=compress_ratio, squeeze_factor=squeeze_factor
- )
-
- self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
- self.norm2 = norm_layer(dim)
- mlp_hidden_dim = int(dim * mlp_ratio)
- self.mlp = Mlp(
- in_features=dim,
- hidden_features=mlp_hidden_dim,
- act_layer=act_layer,
- drop=drop,
- )
-
- def forward(self, x, x_size, rpi_sa, attn_mask):
- h, w = x_size
- b, _, c = x.shape
- # assert seq_len == h * w, "input feature has wrong size"
-
- shortcut = x
- x = self.norm1(x)
- x = x.view(b, h, w, c)
-
- # Conv_X
- conv_x = self.conv_block(x.permute(0, 3, 1, 2))
- conv_x = conv_x.permute(0, 2, 3, 1).contiguous().view(b, h * w, c)
-
- # cyclic shift
- if self.shift_size > 0:
- shifted_x = torch.roll(
- x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
- )
- attn_mask = attn_mask
- else:
- shifted_x = x
- attn_mask = None
-
- # partition windows
- x_windows = window_partition(
- shifted_x, self.window_size
- ) # nw*b, window_size, window_size, c
- x_windows = x_windows.view(
- -1, self.window_size * self.window_size, c
- ) # nw*b, window_size*window_size, c
-
- # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
- attn_windows = self.attn(x_windows, rpi=rpi_sa, mask=attn_mask)
-
- # merge windows
- attn_windows = attn_windows.view(-1, self.window_size, self.window_size, c)
- shifted_x = window_reverse(attn_windows, self.window_size, h, w) # b h' w' c
-
- # reverse cyclic shift
- if self.shift_size > 0:
- attn_x = torch.roll(
- shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)
- )
- else:
- attn_x = shifted_x
- attn_x = attn_x.view(b, h * w, c)
-
- # FFN
- x = shortcut + self.drop_path(attn_x) + conv_x * self.conv_scale
- x = x + self.drop_path(self.mlp(self.norm2(x)))
-
- return x
-
-
-class PatchMerging(nn.Module):
- r"""Patch Merging Layer.
- Args:
- input_resolution (tuple[int]): Resolution of input feature.
- dim (int): Number of input channels.
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- """
-
- def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
- super().__init__()
- self.input_resolution = input_resolution
- self.dim = dim
- self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
- self.norm = norm_layer(4 * dim)
-
- def forward(self, x):
- """
- x: b, h*w, c
- """
- h, w = self.input_resolution
- b, seq_len, c = x.shape
- assert seq_len == h * w, "input feature has wrong size"
- assert h % 2 == 0 and w % 2 == 0, f"x size ({h}*{w}) are not even."
-
- x = x.view(b, h, w, c)
-
- x0 = x[:, 0::2, 0::2, :] # b h/2 w/2 c
- x1 = x[:, 1::2, 0::2, :] # b h/2 w/2 c
- x2 = x[:, 0::2, 1::2, :] # b h/2 w/2 c
- x3 = x[:, 1::2, 1::2, :] # b h/2 w/2 c
- x = torch.cat([x0, x1, x2, x3], -1) # b h/2 w/2 4*c
- x = x.view(b, -1, 4 * c) # b h/2*w/2 4*c
-
- x = self.norm(x)
- x = self.reduction(x)
-
- return x
-
-
-class OCAB(nn.Module):
- # overlapping cross-attention block
-
- def __init__(
- self,
- dim,
- input_resolution,
- window_size,
- overlap_ratio,
- num_heads,
- qkv_bias=True,
- qk_scale=None,
- mlp_ratio=2,
- norm_layer=nn.LayerNorm,
- ):
- super().__init__()
- self.dim = dim
- self.input_resolution = input_resolution
- self.window_size = window_size
- self.num_heads = num_heads
- head_dim = dim // num_heads
- self.scale = qk_scale or head_dim**-0.5
- self.overlap_win_size = int(window_size * overlap_ratio) + window_size
-
- self.norm1 = norm_layer(dim)
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
- self.unfold = nn.Unfold(
- kernel_size=(self.overlap_win_size, self.overlap_win_size),
- stride=window_size,
- padding=(self.overlap_win_size - window_size) // 2,
- )
-
- # define a parameter table of relative position bias
- self.relative_position_bias_table = nn.Parameter( # type: ignore
- torch.zeros(
- (window_size + self.overlap_win_size - 1)
- * (window_size + self.overlap_win_size - 1),
- num_heads,
- )
- ) # 2*Wh-1 * 2*Ww-1, nH
-
- trunc_normal_(self.relative_position_bias_table, std=0.02)
- self.softmax = nn.Softmax(dim=-1)
-
- self.proj = nn.Linear(dim, dim)
-
- self.norm2 = norm_layer(dim)
- mlp_hidden_dim = int(dim * mlp_ratio)
- self.mlp = Mlp(
- in_features=dim, hidden_features=mlp_hidden_dim, act_layer=nn.GELU
- )
-
- def forward(self, x, x_size, rpi):
- h, w = x_size
- b, _, c = x.shape
-
- shortcut = x
- x = self.norm1(x)
- x = x.view(b, h, w, c)
-
- qkv = self.qkv(x).reshape(b, h, w, 3, c).permute(3, 0, 4, 1, 2) # 3, b, c, h, w
- q = qkv[0].permute(0, 2, 3, 1) # b, h, w, c
- kv = torch.cat((qkv[1], qkv[2]), dim=1) # b, 2*c, h, w
-
- # partition windows
- q_windows = window_partition(
- q, self.window_size
- ) # nw*b, window_size, window_size, c
- q_windows = q_windows.view(
- -1, self.window_size * self.window_size, c
- ) # nw*b, window_size*window_size, c
-
- kv_windows = self.unfold(kv) # b, c*w*w, nw
- kv_windows = rearrange(
- kv_windows,
- "b (nc ch owh oww) nw -> nc (b nw) (owh oww) ch",
- nc=2,
- ch=c,
- owh=self.overlap_win_size,
- oww=self.overlap_win_size,
- ).contiguous() # 2, nw*b, ow*ow, c
- # Do the above rearrangement without the rearrange function
- # kv_windows = kv_windows.view(
- # 2, b, self.overlap_win_size, self.overlap_win_size, c, -1
- # )
- # kv_windows = kv_windows.permute(0, 5, 1, 2, 3, 4).contiguous()
- # kv_windows = kv_windows.view(
- # 2, -1, self.overlap_win_size * self.overlap_win_size, c
- # )
-
- k_windows, v_windows = kv_windows[0], kv_windows[1] # nw*b, ow*ow, c
-
- b_, nq, _ = q_windows.shape
- _, n, _ = k_windows.shape
- d = self.dim // self.num_heads
- q = q_windows.reshape(b_, nq, self.num_heads, d).permute(
- 0, 2, 1, 3
- ) # nw*b, nH, nq, d
- k = k_windows.reshape(b_, n, self.num_heads, d).permute(
- 0, 2, 1, 3
- ) # nw*b, nH, n, d
- v = v_windows.reshape(b_, n, self.num_heads, d).permute(
- 0, 2, 1, 3
- ) # nw*b, nH, n, d
-
- q = q * self.scale
- attn = q @ k.transpose(-2, -1)
-
- relative_position_bias = self.relative_position_bias_table[rpi.view(-1)].view(
- self.window_size * self.window_size,
- self.overlap_win_size * self.overlap_win_size,
- -1,
- ) # ws*ws, wse*wse, nH
- relative_position_bias = relative_position_bias.permute(
- 2, 0, 1
- ).contiguous() # nH, ws*ws, wse*wse
- attn = attn + relative_position_bias.unsqueeze(0)
-
- attn = self.softmax(attn)
- attn_windows = (attn @ v).transpose(1, 2).reshape(b_, nq, self.dim)
-
- # merge windows
- attn_windows = attn_windows.view(
- -1, self.window_size, self.window_size, self.dim
- )
- x = window_reverse(attn_windows, self.window_size, h, w) # b h w c
- x = x.view(b, h * w, self.dim)
-
- x = self.proj(x) + shortcut
-
- x = x + self.mlp(self.norm2(x))
- return x
-
-
-class AttenBlocks(nn.Module):
- """A series of attention blocks for one RHAG.
- Args:
- dim (int): Number of input channels.
- input_resolution (tuple[int]): Input resolution.
- depth (int): Number of blocks.
- num_heads (int): Number of attention heads.
- window_size (int): Local window size.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
- drop (float, optional): Dropout rate. Default: 0.0
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
- drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
- """
-
- def __init__(
- self,
- dim,
- input_resolution,
- depth,
- num_heads,
- window_size,
- compress_ratio,
- squeeze_factor,
- conv_scale,
- overlap_ratio,
- mlp_ratio=4.0,
- qkv_bias=True,
- qk_scale=None,
- drop=0.0,
- attn_drop=0.0,
- drop_path=0.0,
- norm_layer=nn.LayerNorm,
- downsample=None,
- use_checkpoint=False,
- ):
- super().__init__()
- self.dim = dim
- self.input_resolution = input_resolution
- self.depth = depth
- self.use_checkpoint = use_checkpoint
-
- # build blocks
- self.blocks = nn.ModuleList(
- [
- HAB(
- dim=dim,
- input_resolution=input_resolution,
- num_heads=num_heads,
- window_size=window_size,
- shift_size=0 if (i % 2 == 0) else window_size // 2,
- compress_ratio=compress_ratio,
- squeeze_factor=squeeze_factor,
- conv_scale=conv_scale,
- mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias,
- qk_scale=qk_scale,
- drop=drop,
- attn_drop=attn_drop,
- drop_path=drop_path[i]
- if isinstance(drop_path, list)
- else drop_path,
- norm_layer=norm_layer,
- )
- for i in range(depth)
- ]
- )
-
- # OCAB
- self.overlap_attn = OCAB(
- dim=dim,
- input_resolution=input_resolution,
- window_size=window_size,
- overlap_ratio=overlap_ratio,
- num_heads=num_heads,
- qkv_bias=qkv_bias,
- qk_scale=qk_scale,
- mlp_ratio=mlp_ratio, # type: ignore
- norm_layer=norm_layer,
- )
-
- # patch merging layer
- if downsample is not None:
- self.downsample = downsample(
- input_resolution, dim=dim, norm_layer=norm_layer
- )
- else:
- self.downsample = None
-
- def forward(self, x, x_size, params):
- for blk in self.blocks:
- x = blk(x, x_size, params["rpi_sa"], params["attn_mask"])
-
- x = self.overlap_attn(x, x_size, params["rpi_oca"])
-
- if self.downsample is not None:
- x = self.downsample(x)
- return x
-
-
-class RHAG(nn.Module):
- """Residual Hybrid Attention Group (RHAG).
- Args:
- dim (int): Number of input channels.
- input_resolution (tuple[int]): Input resolution.
- depth (int): Number of blocks.
- num_heads (int): Number of attention heads.
- window_size (int): Local window size.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
- drop (float, optional): Dropout rate. Default: 0.0
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
- drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
- img_size: Input image size.
- patch_size: Patch size.
- resi_connection: The convolutional block before residual connection.
- """
-
- def __init__(
- self,
- dim,
- input_resolution,
- depth,
- num_heads,
- window_size,
- compress_ratio,
- squeeze_factor,
- conv_scale,
- overlap_ratio,
- mlp_ratio=4.0,
- qkv_bias=True,
- qk_scale=None,
- drop=0.0,
- attn_drop=0.0,
- drop_path=0.0,
- norm_layer=nn.LayerNorm,
- downsample=None,
- use_checkpoint=False,
- img_size=224,
- patch_size=4,
- resi_connection="1conv",
- ):
- super(RHAG, self).__init__()
-
- self.dim = dim
- self.input_resolution = input_resolution
-
- self.residual_group = AttenBlocks(
- dim=dim,
- input_resolution=input_resolution,
- depth=depth,
- num_heads=num_heads,
- window_size=window_size,
- compress_ratio=compress_ratio,
- squeeze_factor=squeeze_factor,
- conv_scale=conv_scale,
- overlap_ratio=overlap_ratio,
- mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias,
- qk_scale=qk_scale,
- drop=drop,
- attn_drop=attn_drop,
- drop_path=drop_path,
- norm_layer=norm_layer,
- downsample=downsample,
- use_checkpoint=use_checkpoint,
- )
-
- if resi_connection == "1conv":
- self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
- elif resi_connection == "identity":
- self.conv = nn.Identity()
-
- self.patch_embed = PatchEmbed(
- img_size=img_size,
- patch_size=patch_size,
- in_chans=0,
- embed_dim=dim,
- norm_layer=None,
- )
-
- self.patch_unembed = PatchUnEmbed(
- img_size=img_size,
- patch_size=patch_size,
- in_chans=0,
- embed_dim=dim,
- norm_layer=None,
- )
-
- def forward(self, x, x_size, params):
- return (
- self.patch_embed(
- self.conv(
- self.patch_unembed(self.residual_group(x, x_size, params), x_size)
- )
- )
- + x
- )
-
-
-class PatchEmbed(nn.Module):
- r"""Image to Patch Embedding
- Args:
- img_size (int): Image size. Default: 224.
- patch_size (int): Patch token size. Default: 4.
- in_chans (int): Number of input image channels. Default: 3.
- embed_dim (int): Number of linear projection output channels. Default: 96.
- norm_layer (nn.Module, optional): Normalization layer. Default: None
- """
-
- def __init__(
- self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None
- ):
- super().__init__()
- img_size = to_2tuple(img_size)
- patch_size = to_2tuple(patch_size)
- patches_resolution = [
- img_size[0] // patch_size[0], # type: ignore
- img_size[1] // patch_size[1], # type: ignore
- ]
- self.img_size = img_size
- self.patch_size = patch_size
- self.patches_resolution = patches_resolution
- self.num_patches = patches_resolution[0] * patches_resolution[1]
-
- self.in_chans = in_chans
- self.embed_dim = embed_dim
-
- if norm_layer is not None:
- self.norm = norm_layer(embed_dim)
- else:
- self.norm = None
-
- def forward(self, x):
- x = x.flatten(2).transpose(1, 2) # b Ph*Pw c
- if self.norm is not None:
- x = self.norm(x)
- return x
-
-
-class PatchUnEmbed(nn.Module):
- r"""Image to Patch Unembedding
- Args:
- img_size (int): Image size. Default: 224.
- patch_size (int): Patch token size. Default: 4.
- in_chans (int): Number of input image channels. Default: 3.
- embed_dim (int): Number of linear projection output channels. Default: 96.
- norm_layer (nn.Module, optional): Normalization layer. Default: None
- """
-
- def __init__(
- self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None
- ):
- super().__init__()
- img_size = to_2tuple(img_size)
- patch_size = to_2tuple(patch_size)
- patches_resolution = [
- img_size[0] // patch_size[0], # type: ignore
- img_size[1] // patch_size[1], # type: ignore
- ]
- self.img_size = img_size
- self.patch_size = patch_size
- self.patches_resolution = patches_resolution
- self.num_patches = patches_resolution[0] * patches_resolution[1]
-
- self.in_chans = in_chans
- self.embed_dim = embed_dim
-
- def forward(self, x, x_size):
- x = (
- x.transpose(1, 2)
- .contiguous()
- .view(x.shape[0], self.embed_dim, x_size[0], x_size[1])
- ) # b Ph*Pw c
- return x
-
-
-class Upsample(nn.Sequential):
- """Upsample module.
- Args:
- scale (int): Scale factor. Supported scales: 2^n and 3.
- num_feat (int): Channel number of intermediate features.
- """
-
- def __init__(self, scale, num_feat):
- m = []
- if (scale & (scale - 1)) == 0: # scale = 2^n
- for _ in range(int(math.log(scale, 2))):
- m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
- m.append(nn.PixelShuffle(2))
- elif scale == 3:
- m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
- m.append(nn.PixelShuffle(3))
- else:
- raise ValueError(
- f"scale {scale} is not supported. " "Supported scales: 2^n and 3."
- )
- super(Upsample, self).__init__(*m)
-
-
-class HAT(nn.Module):
- r"""Hybrid Attention Transformer
- A PyTorch implementation of : `Activating More Pixels in Image Super-Resolution Transformer`.
- Some codes are based on SwinIR.
- Args:
- img_size (int | tuple(int)): Input image size. Default 64
- patch_size (int | tuple(int)): Patch size. Default: 1
- in_chans (int): Number of input image channels. Default: 3
- embed_dim (int): Patch embedding dimension. Default: 96
- depths (tuple(int)): Depth of each Swin Transformer layer.
- num_heads (tuple(int)): Number of attention heads in different layers.
- window_size (int): Window size. Default: 7
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
- qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
- qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
- drop_rate (float): Dropout rate. Default: 0
- attn_drop_rate (float): Attention dropout rate. Default: 0
- drop_path_rate (float): Stochastic depth rate. Default: 0.1
- norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
- ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
- patch_norm (bool): If True, add normalization after patch embedding. Default: True
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
- upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
- img_range: Image range. 1. or 255.
- upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
- resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
- """
-
- def __init__(
- self,
- state_dict,
- **kwargs,
- ):
- super(HAT, self).__init__()
-
- # Defaults
- img_size = 64
- patch_size = 1
- in_chans = 3
- embed_dim = 96
- depths = (6, 6, 6, 6)
- num_heads = (6, 6, 6, 6)
- window_size = 7
- compress_ratio = 3
- squeeze_factor = 30
- conv_scale = 0.01
- overlap_ratio = 0.5
- mlp_ratio = 4.0
- qkv_bias = True
- qk_scale = None
- drop_rate = 0.0
- attn_drop_rate = 0.0
- drop_path_rate = 0.1
- norm_layer = nn.LayerNorm
- ape = False
- patch_norm = True
- use_checkpoint = False
- upscale = 2
- img_range = 1.0
- upsampler = ""
- resi_connection = "1conv"
-
- self.state = state_dict
- self.model_arch = "HAT"
- self.sub_type = "SR"
- self.supports_fp16 = False
- self.support_bf16 = True
- self.min_size_restriction = 16
-
- state_keys = list(state_dict.keys())
-
- num_feat = state_dict["conv_last.weight"].shape[1]
- in_chans = state_dict["conv_first.weight"].shape[1]
- num_out_ch = state_dict["conv_last.weight"].shape[0]
- embed_dim = state_dict["conv_first.weight"].shape[0]
-
- if "conv_before_upsample.0.weight" in state_keys:
- if "conv_up1.weight" in state_keys:
- upsampler = "nearest+conv"
- else:
- upsampler = "pixelshuffle"
- supports_fp16 = False
- elif "upsample.0.weight" in state_keys:
- upsampler = "pixelshuffledirect"
- else:
- upsampler = ""
- upscale = 1
- if upsampler == "nearest+conv":
- upsample_keys = [
- x for x in state_keys if "conv_up" in x and "bias" not in x
- ]
-
- for upsample_key in upsample_keys:
- upscale *= 2
- elif upsampler == "pixelshuffle":
- upsample_keys = [
- x
- for x in state_keys
- if "upsample" in x and "conv" not in x and "bias" not in x
- ]
- for upsample_key in upsample_keys:
- shape = self.state[upsample_key].shape[0]
- upscale *= math.sqrt(shape // num_feat)
- upscale = int(upscale)
- elif upsampler == "pixelshuffledirect":
- upscale = int(
- math.sqrt(self.state["upsample.0.bias"].shape[0] // num_out_ch)
- )
-
- max_layer_num = 0
- max_block_num = 0
- for key in state_keys:
- result = re.match(
- r"layers.(\d*).residual_group.blocks.(\d*).conv_block.cab.0.weight", key
- )
- if result:
- layer_num, block_num = result.groups()
- max_layer_num = max(max_layer_num, int(layer_num))
- max_block_num = max(max_block_num, int(block_num))
-
- depths = [max_block_num + 1 for _ in range(max_layer_num + 1)]
-
- if (
- "layers.0.residual_group.blocks.0.attn.relative_position_bias_table"
- in state_keys
- ):
- num_heads_num = self.state[
- "layers.0.residual_group.blocks.0.attn.relative_position_bias_table"
- ].shape[-1]
- num_heads = [num_heads_num for _ in range(max_layer_num + 1)]
- else:
- num_heads = depths
-
- mlp_ratio = float(
- self.state["layers.0.residual_group.blocks.0.mlp.fc1.bias"].shape[0]
- / embed_dim
- )
-
- # TODO: could actually count the layers, but this should do
- if "layers.0.conv.4.weight" in state_keys:
- resi_connection = "3conv"
- else:
- resi_connection = "1conv"
-
- window_size = int(math.sqrt(self.state["relative_position_index_SA"].shape[0]))
-
- # Not sure if this is needed or used at all anywhere in HAT's config
- if "layers.0.residual_group.blocks.1.attn_mask" in state_keys:
- img_size = int(
- math.sqrt(
- self.state["layers.0.residual_group.blocks.1.attn_mask"].shape[0]
- )
- * window_size
- )
-
- self.window_size = window_size
- self.shift_size = window_size // 2
- self.overlap_ratio = overlap_ratio
-
- self.in_nc = in_chans
- self.out_nc = num_out_ch
- self.num_feat = num_feat
- self.embed_dim = embed_dim
- self.num_heads = num_heads
- self.depths = depths
- self.window_size = window_size
- self.mlp_ratio = mlp_ratio
- self.scale = upscale
- self.upsampler = upsampler
- self.img_size = img_size
- self.img_range = img_range
- self.resi_connection = resi_connection
-
- num_in_ch = in_chans
- # num_out_ch = in_chans
- # num_feat = 64
- self.img_range = img_range
- if in_chans == 3:
- rgb_mean = (0.4488, 0.4371, 0.4040)
- self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
- else:
- self.mean = torch.zeros(1, 1, 1, 1)
- self.upscale = upscale
- self.upsampler = upsampler
-
- # relative position index
- relative_position_index_SA = self.calculate_rpi_sa()
- relative_position_index_OCA = self.calculate_rpi_oca()
- self.register_buffer("relative_position_index_SA", relative_position_index_SA)
- self.register_buffer("relative_position_index_OCA", relative_position_index_OCA)
-
- # ------------------------- 1, shallow feature extraction ------------------------- #
- self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
-
- # ------------------------- 2, deep feature extraction ------------------------- #
- self.num_layers = len(depths)
- self.embed_dim = embed_dim
- self.ape = ape
- self.patch_norm = patch_norm
- self.num_features = embed_dim
- self.mlp_ratio = mlp_ratio
-
- # split image into non-overlapping patches
- self.patch_embed = PatchEmbed(
- img_size=img_size,
- patch_size=patch_size,
- in_chans=embed_dim,
- embed_dim=embed_dim,
- norm_layer=norm_layer if self.patch_norm else None,
- )
- num_patches = self.patch_embed.num_patches
- patches_resolution = self.patch_embed.patches_resolution
- self.patches_resolution = patches_resolution
-
- # merge non-overlapping patches into image
- self.patch_unembed = PatchUnEmbed(
- img_size=img_size,
- patch_size=patch_size,
- in_chans=embed_dim,
- embed_dim=embed_dim,
- norm_layer=norm_layer if self.patch_norm else None,
- )
-
- # absolute position embedding
- if self.ape:
- self.absolute_pos_embed = nn.Parameter( # type: ignore[arg-type]
- torch.zeros(1, num_patches, embed_dim)
- )
- trunc_normal_(self.absolute_pos_embed, std=0.02)
-
- self.pos_drop = nn.Dropout(p=drop_rate)
-
- # stochastic depth
- dpr = [
- x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
- ] # stochastic depth decay rule
-
- # build Residual Hybrid Attention Groups (RHAG)
- self.layers = nn.ModuleList()
- for i_layer in range(self.num_layers):
- layer = RHAG(
- dim=embed_dim,
- input_resolution=(patches_resolution[0], patches_resolution[1]),
- depth=depths[i_layer],
- num_heads=num_heads[i_layer],
- window_size=window_size,
- compress_ratio=compress_ratio,
- squeeze_factor=squeeze_factor,
- conv_scale=conv_scale,
- overlap_ratio=overlap_ratio,
- mlp_ratio=self.mlp_ratio,
- qkv_bias=qkv_bias,
- qk_scale=qk_scale,
- drop=drop_rate,
- attn_drop=attn_drop_rate,
- drop_path=dpr[
- sum(depths[:i_layer]) : sum(depths[: i_layer + 1]) # type: ignore
- ], # no impact on SR results
- norm_layer=norm_layer,
- downsample=None,
- use_checkpoint=use_checkpoint,
- img_size=img_size,
- patch_size=patch_size,
- resi_connection=resi_connection,
- )
- self.layers.append(layer)
- self.norm = norm_layer(self.num_features)
-
- # build the last conv layer in deep feature extraction
- if resi_connection == "1conv":
- self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
- elif resi_connection == "identity":
- self.conv_after_body = nn.Identity()
-
- # ------------------------- 3, high quality image reconstruction ------------------------- #
- if self.upsampler == "pixelshuffle":
- # for classical SR
- self.conv_before_upsample = nn.Sequential(
- nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)
- )
- self.upsample = Upsample(upscale, num_feat)
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
-
- self.apply(self._init_weights)
- self.load_state_dict(self.state, strict=False)
-
- def _init_weights(self, m):
- if isinstance(m, nn.Linear):
- trunc_normal_(m.weight, std=0.02)
- if isinstance(m, nn.Linear) and m.bias is not None:
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.LayerNorm):
- nn.init.constant_(m.bias, 0)
- nn.init.constant_(m.weight, 1.0)
-
- def calculate_rpi_sa(self):
- # calculate relative position index for SA
- coords_h = torch.arange(self.window_size)
- coords_w = torch.arange(self.window_size)
- coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
- coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
- relative_coords = (
- coords_flatten[:, :, None] - coords_flatten[:, None, :]
- ) # 2, Wh*Ww, Wh*Ww
- relative_coords = relative_coords.permute(
- 1, 2, 0
- ).contiguous() # Wh*Ww, Wh*Ww, 2
- relative_coords[:, :, 0] += self.window_size - 1 # shift to start from 0
- relative_coords[:, :, 1] += self.window_size - 1
- relative_coords[:, :, 0] *= 2 * self.window_size - 1
- relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
- return relative_position_index
-
- def calculate_rpi_oca(self):
- # calculate relative position index for OCA
- window_size_ori = self.window_size
- window_size_ext = self.window_size + int(self.overlap_ratio * self.window_size)
-
- coords_h = torch.arange(window_size_ori)
- coords_w = torch.arange(window_size_ori)
- coords_ori = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, ws, ws
- coords_ori_flatten = torch.flatten(coords_ori, 1) # 2, ws*ws
-
- coords_h = torch.arange(window_size_ext)
- coords_w = torch.arange(window_size_ext)
- coords_ext = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, wse, wse
- coords_ext_flatten = torch.flatten(coords_ext, 1) # 2, wse*wse
-
- relative_coords = (
- coords_ext_flatten[:, None, :] - coords_ori_flatten[:, :, None]
- ) # 2, ws*ws, wse*wse
-
- relative_coords = relative_coords.permute(
- 1, 2, 0
- ).contiguous() # ws*ws, wse*wse, 2
- relative_coords[:, :, 0] += (
- window_size_ori - window_size_ext + 1
- ) # shift to start from 0
- relative_coords[:, :, 1] += window_size_ori - window_size_ext + 1
-
- relative_coords[:, :, 0] *= window_size_ori + window_size_ext - 1
- relative_position_index = relative_coords.sum(-1)
- return relative_position_index
-
- def calculate_mask(self, x_size):
- # calculate attention mask for SW-MSA
- h, w = x_size
- img_mask = torch.zeros((1, h, w, 1)) # 1 h w 1
- h_slices = (
- slice(0, -self.window_size),
- slice(-self.window_size, -self.shift_size),
- slice(-self.shift_size, None),
- )
- w_slices = (
- slice(0, -self.window_size),
- slice(-self.window_size, -self.shift_size),
- slice(-self.shift_size, None),
- )
- cnt = 0
- for h in h_slices:
- for w in w_slices:
- img_mask[:, h, w, :] = cnt
- cnt += 1
-
- mask_windows = window_partition(
- img_mask, self.window_size
- ) # nw, window_size, window_size, 1
- mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
- attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
- attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
- attn_mask == 0, float(0.0)
- )
-
- return attn_mask
-
- @torch.jit.ignore # type: ignore
- def no_weight_decay(self):
- return {"absolute_pos_embed"}
-
- @torch.jit.ignore # type: ignore
- def no_weight_decay_keywords(self):
- return {"relative_position_bias_table"}
-
- def check_image_size(self, x):
- _, _, h, w = x.size()
- mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
- mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
- x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "reflect")
- return x
-
- def forward_features(self, x):
- x_size = (x.shape[2], x.shape[3])
-
- # Calculate attention mask and relative position index in advance to speed up inference.
- # The original code is very time-cosuming for large window size.
- attn_mask = self.calculate_mask(x_size).to(x.device)
- params = {
- "attn_mask": attn_mask,
- "rpi_sa": self.relative_position_index_SA,
- "rpi_oca": self.relative_position_index_OCA,
- }
-
- x = self.patch_embed(x)
- if self.ape:
- x = x + self.absolute_pos_embed
- x = self.pos_drop(x)
-
- for layer in self.layers:
- x = layer(x, x_size, params)
-
- x = self.norm(x) # b seq_len c
- x = self.patch_unembed(x, x_size)
-
- return x
-
- def forward(self, x):
- H, W = x.shape[2:]
- self.mean = self.mean.type_as(x)
- x = (x - self.mean) * self.img_range
- x = self.check_image_size(x)
-
- if self.upsampler == "pixelshuffle":
- # for classical SR
- x = self.conv_first(x)
- x = self.conv_after_body(self.forward_features(x)) + x
- x = self.conv_before_upsample(x)
- x = self.conv_last(self.upsample(x))
-
- x = x / self.img_range + self.mean
-
- return x[:, :, : H * self.upscale, : W * self.upscale]
diff --git a/comfy_extras/chainner_models/architecture/LICENSE-DAT b/comfy_extras/chainner_models/architecture/LICENSE-DAT
deleted file mode 100644
index 261eeb9e9f8..00000000000
--- a/comfy_extras/chainner_models/architecture/LICENSE-DAT
+++ /dev/null
@@ -1,201 +0,0 @@
- Apache License
- Version 2.0, January 2004
- http://www.apache.org/licenses/
-
- TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
-
- 1. Definitions.
-
- "License" shall mean the terms and conditions for use, reproduction,
- and distribution as defined by Sections 1 through 9 of this document.
-
- "Licensor" shall mean the copyright owner or entity authorized by
- the copyright owner that is granting the License.
-
- "Legal Entity" shall mean the union of the acting entity and all
- other entities that control, are controlled by, or are under common
- control with that entity. For the purposes of this definition,
- "control" means (i) the power, direct or indirect, to cause the
- direction or management of such entity, whether by contract or
- otherwise, or (ii) ownership of fifty percent (50%) or more of the
- outstanding shares, or (iii) beneficial ownership of such entity.
-
- "You" (or "Your") shall mean an individual or Legal Entity
- exercising permissions granted by this License.
-
- "Source" form shall mean the preferred form for making modifications,
- including but not limited to software source code, documentation
- source, and configuration files.
-
- "Object" form shall mean any form resulting from mechanical
- transformation or translation of a Source form, including but
- not limited to compiled object code, generated documentation,
- and conversions to other media types.
-
- "Work" shall mean the work of authorship, whether in Source or
- Object form, made available under the License, as indicated by a
- copyright notice that is included in or attached to the work
- (an example is provided in the Appendix below).
-
- "Derivative Works" shall mean any work, whether in Source or Object
- form, that is based on (or derived from) the Work and for which the
- editorial revisions, annotations, elaborations, or other modifications
- represent, as a whole, an original work of authorship. For the purposes
- of this License, Derivative Works shall not include works that remain
- separable from, or merely link (or bind by name) to the interfaces of,
- the Work and Derivative Works thereof.
-
- "Contribution" shall mean any work of authorship, including
- the original version of the Work and any modifications or additions
- to that Work or Derivative Works thereof, that is intentionally
- submitted to Licensor for inclusion in the Work by the copyright owner
- or by an individual or Legal Entity authorized to submit on behalf of
- the copyright owner. For the purposes of this definition, "submitted"
- means any form of electronic, verbal, or written communication sent
- to the Licensor or its representatives, including but not limited to
- communication on electronic mailing lists, source code control systems,
- and issue tracking systems that are managed by, or on behalf of, the
- Licensor for the purpose of discussing and improving the Work, but
- excluding communication that is conspicuously marked or otherwise
- designated in writing by the copyright owner as "Not a Contribution."
-
- "Contributor" shall mean Licensor and any individual or Legal Entity
- on behalf of whom a Contribution has been received by Licensor and
- subsequently incorporated within the Work.
-
- 2. Grant of Copyright License. Subject to the terms and conditions of
- this License, each Contributor hereby grants to You a perpetual,
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
- copyright license to reproduce, prepare Derivative Works of,
- publicly display, publicly perform, sublicense, and distribute the
- Work and such Derivative Works in Source or Object form.
-
- 3. Grant of Patent License. Subject to the terms and conditions of
- this License, each Contributor hereby grants to You a perpetual,
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
- (except as stated in this section) patent license to make, have made,
- use, offer to sell, sell, import, and otherwise transfer the Work,
- where such license applies only to those patent claims licensable
- by such Contributor that are necessarily infringed by their
- Contribution(s) alone or by combination of their Contribution(s)
- with the Work to which such Contribution(s) was submitted. If You
- institute patent litigation against any entity (including a
- cross-claim or counterclaim in a lawsuit) alleging that the Work
- or a Contribution incorporated within the Work constitutes direct
- or contributory patent infringement, then any patent licenses
- granted to You under this License for that Work shall terminate
- as of the date such litigation is filed.
-
- 4. Redistribution. You may reproduce and distribute copies of the
- Work or Derivative Works thereof in any medium, with or without
- modifications, and in Source or Object form, provided that You
- meet the following conditions:
-
- (a) You must give any other recipients of the Work or
- Derivative Works a copy of this License; and
-
- (b) You must cause any modified files to carry prominent notices
- stating that You changed the files; and
-
- (c) You must retain, in the Source form of any Derivative Works
- that You distribute, all copyright, patent, trademark, and
- attribution notices from the Source form of the Work,
- excluding those notices that do not pertain to any part of
- the Derivative Works; and
-
- (d) If the Work includes a "NOTICE" text file as part of its
- distribution, then any Derivative Works that You distribute must
- include a readable copy of the attribution notices contained
- within such NOTICE file, excluding those notices that do not
- pertain to any part of the Derivative Works, in at least one
- of the following places: within a NOTICE text file distributed
- as part of the Derivative Works; within the Source form or
- documentation, if provided along with the Derivative Works; or,
- within a display generated by the Derivative Works, if and
- wherever such third-party notices normally appear. The contents
- of the NOTICE file are for informational purposes only and
- do not modify the License. You may add Your own attribution
- notices within Derivative Works that You distribute, alongside
- or as an addendum to the NOTICE text from the Work, provided
- that such additional attribution notices cannot be construed
- as modifying the License.
-
- You may add Your own copyright statement to Your modifications and
- may provide additional or different license terms and conditions
- for use, reproduction, or distribution of Your modifications, or
- for any such Derivative Works as a whole, provided Your use,
- reproduction, and distribution of the Work otherwise complies with
- the conditions stated in this License.
-
- 5. Submission of Contributions. Unless You explicitly state otherwise,
- any Contribution intentionally submitted for inclusion in the Work
- by You to the Licensor shall be under the terms and conditions of
- this License, without any additional terms or conditions.
- Notwithstanding the above, nothing herein shall supersede or modify
- the terms of any separate license agreement you may have executed
- with Licensor regarding such Contributions.
-
- 6. Trademarks. This License does not grant permission to use the trade
- names, trademarks, service marks, or product names of the Licensor,
- except as required for reasonable and customary use in describing the
- origin of the Work and reproducing the content of the NOTICE file.
-
- 7. Disclaimer of Warranty. Unless required by applicable law or
- agreed to in writing, Licensor provides the Work (and each
- Contributor provides its Contributions) on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
- implied, including, without limitation, any warranties or conditions
- of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
- PARTICULAR PURPOSE. You are solely responsible for determining the
- appropriateness of using or redistributing the Work and assume any
- risks associated with Your exercise of permissions under this License.
-
- 8. Limitation of Liability. In no event and under no legal theory,
- whether in tort (including negligence), contract, or otherwise,
- unless required by applicable law (such as deliberate and grossly
- negligent acts) or agreed to in writing, shall any Contributor be
- liable to You for damages, including any direct, indirect, special,
- incidental, or consequential damages of any character arising as a
- result of this License or out of the use or inability to use the
- Work (including but not limited to damages for loss of goodwill,
- work stoppage, computer failure or malfunction, or any and all
- other commercial damages or losses), even if such Contributor
- has been advised of the possibility of such damages.
-
- 9. Accepting Warranty or Additional Liability. While redistributing
- the Work or Derivative Works thereof, You may choose to offer,
- and charge a fee for, acceptance of support, warranty, indemnity,
- or other liability obligations and/or rights consistent with this
- License. However, in accepting such obligations, You may act only
- on Your own behalf and on Your sole responsibility, not on behalf
- of any other Contributor, and only if You agree to indemnify,
- defend, and hold each Contributor harmless for any liability
- incurred by, or claims asserted against, such Contributor by reason
- of your accepting any such warranty or additional liability.
-
- END OF TERMS AND CONDITIONS
-
- APPENDIX: How to apply the Apache License to your work.
-
- To apply the Apache License to your work, attach the following
- boilerplate notice, with the fields enclosed by brackets "[]"
- replaced with your own identifying information. (Don't include
- the brackets!) The text should be enclosed in the appropriate
- comment syntax for the file format. We also recommend that a
- file or class name and description of purpose be included on the
- same "printed page" as the copyright notice for easier
- identification within third-party archives.
-
- Copyright [yyyy] [name of copyright owner]
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
diff --git a/comfy_extras/chainner_models/architecture/LICENSE-ESRGAN b/comfy_extras/chainner_models/architecture/LICENSE-ESRGAN
deleted file mode 100644
index 261eeb9e9f8..00000000000
--- a/comfy_extras/chainner_models/architecture/LICENSE-ESRGAN
+++ /dev/null
@@ -1,201 +0,0 @@
- Apache License
- Version 2.0, January 2004
- http://www.apache.org/licenses/
-
- TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
-
- 1. Definitions.
-
- "License" shall mean the terms and conditions for use, reproduction,
- and distribution as defined by Sections 1 through 9 of this document.
-
- "Licensor" shall mean the copyright owner or entity authorized by
- the copyright owner that is granting the License.
-
- "Legal Entity" shall mean the union of the acting entity and all
- other entities that control, are controlled by, or are under common
- control with that entity. For the purposes of this definition,
- "control" means (i) the power, direct or indirect, to cause the
- direction or management of such entity, whether by contract or
- otherwise, or (ii) ownership of fifty percent (50%) or more of the
- outstanding shares, or (iii) beneficial ownership of such entity.
-
- "You" (or "Your") shall mean an individual or Legal Entity
- exercising permissions granted by this License.
-
- "Source" form shall mean the preferred form for making modifications,
- including but not limited to software source code, documentation
- source, and configuration files.
-
- "Object" form shall mean any form resulting from mechanical
- transformation or translation of a Source form, including but
- not limited to compiled object code, generated documentation,
- and conversions to other media types.
-
- "Work" shall mean the work of authorship, whether in Source or
- Object form, made available under the License, as indicated by a
- copyright notice that is included in or attached to the work
- (an example is provided in the Appendix below).
-
- "Derivative Works" shall mean any work, whether in Source or Object
- form, that is based on (or derived from) the Work and for which the
- editorial revisions, annotations, elaborations, or other modifications
- represent, as a whole, an original work of authorship. For the purposes
- of this License, Derivative Works shall not include works that remain
- separable from, or merely link (or bind by name) to the interfaces of,
- the Work and Derivative Works thereof.
-
- "Contribution" shall mean any work of authorship, including
- the original version of the Work and any modifications or additions
- to that Work or Derivative Works thereof, that is intentionally
- submitted to Licensor for inclusion in the Work by the copyright owner
- or by an individual or Legal Entity authorized to submit on behalf of
- the copyright owner. For the purposes of this definition, "submitted"
- means any form of electronic, verbal, or written communication sent
- to the Licensor or its representatives, including but not limited to
- communication on electronic mailing lists, source code control systems,
- and issue tracking systems that are managed by, or on behalf of, the
- Licensor for the purpose of discussing and improving the Work, but
- excluding communication that is conspicuously marked or otherwise
- designated in writing by the copyright owner as "Not a Contribution."
-
- "Contributor" shall mean Licensor and any individual or Legal Entity
- on behalf of whom a Contribution has been received by Licensor and
- subsequently incorporated within the Work.
-
- 2. Grant of Copyright License. Subject to the terms and conditions of
- this License, each Contributor hereby grants to You a perpetual,
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
- copyright license to reproduce, prepare Derivative Works of,
- publicly display, publicly perform, sublicense, and distribute the
- Work and such Derivative Works in Source or Object form.
-
- 3. Grant of Patent License. Subject to the terms and conditions of
- this License, each Contributor hereby grants to You a perpetual,
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
- (except as stated in this section) patent license to make, have made,
- use, offer to sell, sell, import, and otherwise transfer the Work,
- where such license applies only to those patent claims licensable
- by such Contributor that are necessarily infringed by their
- Contribution(s) alone or by combination of their Contribution(s)
- with the Work to which such Contribution(s) was submitted. If You
- institute patent litigation against any entity (including a
- cross-claim or counterclaim in a lawsuit) alleging that the Work
- or a Contribution incorporated within the Work constitutes direct
- or contributory patent infringement, then any patent licenses
- granted to You under this License for that Work shall terminate
- as of the date such litigation is filed.
-
- 4. Redistribution. You may reproduce and distribute copies of the
- Work or Derivative Works thereof in any medium, with or without
- modifications, and in Source or Object form, provided that You
- meet the following conditions:
-
- (a) You must give any other recipients of the Work or
- Derivative Works a copy of this License; and
-
- (b) You must cause any modified files to carry prominent notices
- stating that You changed the files; and
-
- (c) You must retain, in the Source form of any Derivative Works
- that You distribute, all copyright, patent, trademark, and
- attribution notices from the Source form of the Work,
- excluding those notices that do not pertain to any part of
- the Derivative Works; and
-
- (d) If the Work includes a "NOTICE" text file as part of its
- distribution, then any Derivative Works that You distribute must
- include a readable copy of the attribution notices contained
- within such NOTICE file, excluding those notices that do not
- pertain to any part of the Derivative Works, in at least one
- of the following places: within a NOTICE text file distributed
- as part of the Derivative Works; within the Source form or
- documentation, if provided along with the Derivative Works; or,
- within a display generated by the Derivative Works, if and
- wherever such third-party notices normally appear. The contents
- of the NOTICE file are for informational purposes only and
- do not modify the License. You may add Your own attribution
- notices within Derivative Works that You distribute, alongside
- or as an addendum to the NOTICE text from the Work, provided
- that such additional attribution notices cannot be construed
- as modifying the License.
-
- You may add Your own copyright statement to Your modifications and
- may provide additional or different license terms and conditions
- for use, reproduction, or distribution of Your modifications, or
- for any such Derivative Works as a whole, provided Your use,
- reproduction, and distribution of the Work otherwise complies with
- the conditions stated in this License.
-
- 5. Submission of Contributions. Unless You explicitly state otherwise,
- any Contribution intentionally submitted for inclusion in the Work
- by You to the Licensor shall be under the terms and conditions of
- this License, without any additional terms or conditions.
- Notwithstanding the above, nothing herein shall supersede or modify
- the terms of any separate license agreement you may have executed
- with Licensor regarding such Contributions.
-
- 6. Trademarks. This License does not grant permission to use the trade
- names, trademarks, service marks, or product names of the Licensor,
- except as required for reasonable and customary use in describing the
- origin of the Work and reproducing the content of the NOTICE file.
-
- 7. Disclaimer of Warranty. Unless required by applicable law or
- agreed to in writing, Licensor provides the Work (and each
- Contributor provides its Contributions) on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
- implied, including, without limitation, any warranties or conditions
- of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
- PARTICULAR PURPOSE. You are solely responsible for determining the
- appropriateness of using or redistributing the Work and assume any
- risks associated with Your exercise of permissions under this License.
-
- 8. Limitation of Liability. In no event and under no legal theory,
- whether in tort (including negligence), contract, or otherwise,
- unless required by applicable law (such as deliberate and grossly
- negligent acts) or agreed to in writing, shall any Contributor be
- liable to You for damages, including any direct, indirect, special,
- incidental, or consequential damages of any character arising as a
- result of this License or out of the use or inability to use the
- Work (including but not limited to damages for loss of goodwill,
- work stoppage, computer failure or malfunction, or any and all
- other commercial damages or losses), even if such Contributor
- has been advised of the possibility of such damages.
-
- 9. Accepting Warranty or Additional Liability. While redistributing
- the Work or Derivative Works thereof, You may choose to offer,
- and charge a fee for, acceptance of support, warranty, indemnity,
- or other liability obligations and/or rights consistent with this
- License. However, in accepting such obligations, You may act only
- on Your own behalf and on Your sole responsibility, not on behalf
- of any other Contributor, and only if You agree to indemnify,
- defend, and hold each Contributor harmless for any liability
- incurred by, or claims asserted against, such Contributor by reason
- of your accepting any such warranty or additional liability.
-
- END OF TERMS AND CONDITIONS
-
- APPENDIX: How to apply the Apache License to your work.
-
- To apply the Apache License to your work, attach the following
- boilerplate notice, with the fields enclosed by brackets "[]"
- replaced with your own identifying information. (Don't include
- the brackets!) The text should be enclosed in the appropriate
- comment syntax for the file format. We also recommend that a
- file or class name and description of purpose be included on the
- same "printed page" as the copyright notice for easier
- identification within third-party archives.
-
- Copyright [yyyy] [name of copyright owner]
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
diff --git a/comfy_extras/chainner_models/architecture/LICENSE-HAT b/comfy_extras/chainner_models/architecture/LICENSE-HAT
deleted file mode 100644
index 003e97e96cb..00000000000
--- a/comfy_extras/chainner_models/architecture/LICENSE-HAT
+++ /dev/null
@@ -1,21 +0,0 @@
-MIT License
-
-Copyright (c) 2022 Xiangyu Chen
-
-Permission is hereby granted, free of charge, to any person obtaining a copy
-of this software and associated documentation files (the "Software"), to deal
-in the Software without restriction, including without limitation the rights
-to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-copies of the Software, and to permit persons to whom the Software is
-furnished to do so, subject to the following conditions:
-
-The above copyright notice and this permission notice shall be included in all
-copies or substantial portions of the Software.
-
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-SOFTWARE.
diff --git a/comfy_extras/chainner_models/architecture/LICENSE-RealESRGAN b/comfy_extras/chainner_models/architecture/LICENSE-RealESRGAN
deleted file mode 100644
index 552a1eeaf01..00000000000
--- a/comfy_extras/chainner_models/architecture/LICENSE-RealESRGAN
+++ /dev/null
@@ -1,29 +0,0 @@
-BSD 3-Clause License
-
-Copyright (c) 2021, Xintao Wang
-All rights reserved.
-
-Redistribution and use in source and binary forms, with or without
-modification, are permitted provided that the following conditions are met:
-
-1. Redistributions of source code must retain the above copyright notice, this
- list of conditions and the following disclaimer.
-
-2. Redistributions in binary form must reproduce the above copyright notice,
- this list of conditions and the following disclaimer in the documentation
- and/or other materials provided with the distribution.
-
-3. Neither the name of the copyright holder nor the names of its
- contributors may be used to endorse or promote products derived from
- this software without specific prior written permission.
-
-THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
-AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
-DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
-FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
-DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
-SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
-CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
-OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/comfy_extras/chainner_models/architecture/LICENSE-SCUNet b/comfy_extras/chainner_models/architecture/LICENSE-SCUNet
deleted file mode 100644
index ff75c988f34..00000000000
--- a/comfy_extras/chainner_models/architecture/LICENSE-SCUNet
+++ /dev/null
@@ -1,201 +0,0 @@
- Apache License
- Version 2.0, January 2004
- http://www.apache.org/licenses/
-
- TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
-
- 1. Definitions.
-
- "License" shall mean the terms and conditions for use, reproduction,
- and distribution as defined by Sections 1 through 9 of this document.
-
- "Licensor" shall mean the copyright owner or entity authorized by
- the copyright owner that is granting the License.
-
- "Legal Entity" shall mean the union of the acting entity and all
- other entities that control, are controlled by, or are under common
- control with that entity. For the purposes of this definition,
- "control" means (i) the power, direct or indirect, to cause the
- direction or management of such entity, whether by contract or
- otherwise, or (ii) ownership of fifty percent (50%) or more of the
- outstanding shares, or (iii) beneficial ownership of such entity.
-
- "You" (or "Your") shall mean an individual or Legal Entity
- exercising permissions granted by this License.
-
- "Source" form shall mean the preferred form for making modifications,
- including but not limited to software source code, documentation
- source, and configuration files.
-
- "Object" form shall mean any form resulting from mechanical
- transformation or translation of a Source form, including but
- not limited to compiled object code, generated documentation,
- and conversions to other media types.
-
- "Work" shall mean the work of authorship, whether in Source or
- Object form, made available under the License, as indicated by a
- copyright notice that is included in or attached to the work
- (an example is provided in the Appendix below).
-
- "Derivative Works" shall mean any work, whether in Source or Object
- form, that is based on (or derived from) the Work and for which the
- editorial revisions, annotations, elaborations, or other modifications
- represent, as a whole, an original work of authorship. For the purposes
- of this License, Derivative Works shall not include works that remain
- separable from, or merely link (or bind by name) to the interfaces of,
- the Work and Derivative Works thereof.
-
- "Contribution" shall mean any work of authorship, including
- the original version of the Work and any modifications or additions
- to that Work or Derivative Works thereof, that is intentionally
- submitted to Licensor for inclusion in the Work by the copyright owner
- or by an individual or Legal Entity authorized to submit on behalf of
- the copyright owner. For the purposes of this definition, "submitted"
- means any form of electronic, verbal, or written communication sent
- to the Licensor or its representatives, including but not limited to
- communication on electronic mailing lists, source code control systems,
- and issue tracking systems that are managed by, or on behalf of, the
- Licensor for the purpose of discussing and improving the Work, but
- excluding communication that is conspicuously marked or otherwise
- designated in writing by the copyright owner as "Not a Contribution."
-
- "Contributor" shall mean Licensor and any individual or Legal Entity
- on behalf of whom a Contribution has been received by Licensor and
- subsequently incorporated within the Work.
-
- 2. Grant of Copyright License. Subject to the terms and conditions of
- this License, each Contributor hereby grants to You a perpetual,
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
- copyright license to reproduce, prepare Derivative Works of,
- publicly display, publicly perform, sublicense, and distribute the
- Work and such Derivative Works in Source or Object form.
-
- 3. Grant of Patent License. Subject to the terms and conditions of
- this License, each Contributor hereby grants to You a perpetual,
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
- (except as stated in this section) patent license to make, have made,
- use, offer to sell, sell, import, and otherwise transfer the Work,
- where such license applies only to those patent claims licensable
- by such Contributor that are necessarily infringed by their
- Contribution(s) alone or by combination of their Contribution(s)
- with the Work to which such Contribution(s) was submitted. If You
- institute patent litigation against any entity (including a
- cross-claim or counterclaim in a lawsuit) alleging that the Work
- or a Contribution incorporated within the Work constitutes direct
- or contributory patent infringement, then any patent licenses
- granted to You under this License for that Work shall terminate
- as of the date such litigation is filed.
-
- 4. Redistribution. You may reproduce and distribute copies of the
- Work or Derivative Works thereof in any medium, with or without
- modifications, and in Source or Object form, provided that You
- meet the following conditions:
-
- (a) You must give any other recipients of the Work or
- Derivative Works a copy of this License; and
-
- (b) You must cause any modified files to carry prominent notices
- stating that You changed the files; and
-
- (c) You must retain, in the Source form of any Derivative Works
- that You distribute, all copyright, patent, trademark, and
- attribution notices from the Source form of the Work,
- excluding those notices that do not pertain to any part of
- the Derivative Works; and
-
- (d) If the Work includes a "NOTICE" text file as part of its
- distribution, then any Derivative Works that You distribute must
- include a readable copy of the attribution notices contained
- within such NOTICE file, excluding those notices that do not
- pertain to any part of the Derivative Works, in at least one
- of the following places: within a NOTICE text file distributed
- as part of the Derivative Works; within the Source form or
- documentation, if provided along with the Derivative Works; or,
- within a display generated by the Derivative Works, if and
- wherever such third-party notices normally appear. The contents
- of the NOTICE file are for informational purposes only and
- do not modify the License. You may add Your own attribution
- notices within Derivative Works that You distribute, alongside
- or as an addendum to the NOTICE text from the Work, provided
- that such additional attribution notices cannot be construed
- as modifying the License.
-
- You may add Your own copyright statement to Your modifications and
- may provide additional or different license terms and conditions
- for use, reproduction, or distribution of Your modifications, or
- for any such Derivative Works as a whole, provided Your use,
- reproduction, and distribution of the Work otherwise complies with
- the conditions stated in this License.
-
- 5. Submission of Contributions. Unless You explicitly state otherwise,
- any Contribution intentionally submitted for inclusion in the Work
- by You to the Licensor shall be under the terms and conditions of
- this License, without any additional terms or conditions.
- Notwithstanding the above, nothing herein shall supersede or modify
- the terms of any separate license agreement you may have executed
- with Licensor regarding such Contributions.
-
- 6. Trademarks. This License does not grant permission to use the trade
- names, trademarks, service marks, or product names of the Licensor,
- except as required for reasonable and customary use in describing the
- origin of the Work and reproducing the content of the NOTICE file.
-
- 7. Disclaimer of Warranty. Unless required by applicable law or
- agreed to in writing, Licensor provides the Work (and each
- Contributor provides its Contributions) on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
- implied, including, without limitation, any warranties or conditions
- of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
- PARTICULAR PURPOSE. You are solely responsible for determining the
- appropriateness of using or redistributing the Work and assume any
- risks associated with Your exercise of permissions under this License.
-
- 8. Limitation of Liability. In no event and under no legal theory,
- whether in tort (including negligence), contract, or otherwise,
- unless required by applicable law (such as deliberate and grossly
- negligent acts) or agreed to in writing, shall any Contributor be
- liable to You for damages, including any direct, indirect, special,
- incidental, or consequential damages of any character arising as a
- result of this License or out of the use or inability to use the
- Work (including but not limited to damages for loss of goodwill,
- work stoppage, computer failure or malfunction, or any and all
- other commercial damages or losses), even if such Contributor
- has been advised of the possibility of such damages.
-
- 9. Accepting Warranty or Additional Liability. While redistributing
- the Work or Derivative Works thereof, You may choose to offer,
- and charge a fee for, acceptance of support, warranty, indemnity,
- or other liability obligations and/or rights consistent with this
- License. However, in accepting such obligations, You may act only
- on Your own behalf and on Your sole responsibility, not on behalf
- of any other Contributor, and only if You agree to indemnify,
- defend, and hold each Contributor harmless for any liability
- incurred by, or claims asserted against, such Contributor by reason
- of your accepting any such warranty or additional liability.
-
- END OF TERMS AND CONDITIONS
-
- APPENDIX: How to apply the Apache License to your work.
-
- To apply the Apache License to your work, attach the following
- boilerplate notice, with the fields enclosed by brackets "[]"
- replaced with your own identifying information. (Don't include
- the brackets!) The text should be enclosed in the appropriate
- comment syntax for the file format. We also recommend that a
- file or class name and description of purpose be included on the
- same "printed page" as the copyright notice for easier
- identification within third-party archives.
-
- Copyright 2022 Kai Zhang (cskaizhang@gmail.com, https://cszn.github.io/). All rights reserved.
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
diff --git a/comfy_extras/chainner_models/architecture/LICENSE-SPSR b/comfy_extras/chainner_models/architecture/LICENSE-SPSR
deleted file mode 100644
index 3245f3f9e4f..00000000000
--- a/comfy_extras/chainner_models/architecture/LICENSE-SPSR
+++ /dev/null
@@ -1,201 +0,0 @@
- Apache License
- Version 2.0, January 2004
- http://www.apache.org/licenses/
-
- TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
-
- 1. Definitions.
-
- "License" shall mean the terms and conditions for use, reproduction,
- and distribution as defined by Sections 1 through 9 of this document.
-
- "Licensor" shall mean the copyright owner or entity authorized by
- the copyright owner that is granting the License.
-
- "Legal Entity" shall mean the union of the acting entity and all
- other entities that control, are controlled by, or are under common
- control with that entity. For the purposes of this definition,
- "control" means (i) the power, direct or indirect, to cause the
- direction or management of such entity, whether by contract or
- otherwise, or (ii) ownership of fifty percent (50%) or more of the
- outstanding shares, or (iii) beneficial ownership of such entity.
-
- "You" (or "Your") shall mean an individual or Legal Entity
- exercising permissions granted by this License.
-
- "Source" form shall mean the preferred form for making modifications,
- including but not limited to software source code, documentation
- source, and configuration files.
-
- "Object" form shall mean any form resulting from mechanical
- transformation or translation of a Source form, including but
- not limited to compiled object code, generated documentation,
- and conversions to other media types.
-
- "Work" shall mean the work of authorship, whether in Source or
- Object form, made available under the License, as indicated by a
- copyright notice that is included in or attached to the work
- (an example is provided in the Appendix below).
-
- "Derivative Works" shall mean any work, whether in Source or Object
- form, that is based on (or derived from) the Work and for which the
- editorial revisions, annotations, elaborations, or other modifications
- represent, as a whole, an original work of authorship. For the purposes
- of this License, Derivative Works shall not include works that remain
- separable from, or merely link (or bind by name) to the interfaces of,
- the Work and Derivative Works thereof.
-
- "Contribution" shall mean any work of authorship, including
- the original version of the Work and any modifications or additions
- to that Work or Derivative Works thereof, that is intentionally
- submitted to Licensor for inclusion in the Work by the copyright owner
- or by an individual or Legal Entity authorized to submit on behalf of
- the copyright owner. For the purposes of this definition, "submitted"
- means any form of electronic, verbal, or written communication sent
- to the Licensor or its representatives, including but not limited to
- communication on electronic mailing lists, source code control systems,
- and issue tracking systems that are managed by, or on behalf of, the
- Licensor for the purpose of discussing and improving the Work, but
- excluding communication that is conspicuously marked or otherwise
- designated in writing by the copyright owner as "Not a Contribution."
-
- "Contributor" shall mean Licensor and any individual or Legal Entity
- on behalf of whom a Contribution has been received by Licensor and
- subsequently incorporated within the Work.
-
- 2. Grant of Copyright License. Subject to the terms and conditions of
- this License, each Contributor hereby grants to You a perpetual,
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
- copyright license to reproduce, prepare Derivative Works of,
- publicly display, publicly perform, sublicense, and distribute the
- Work and such Derivative Works in Source or Object form.
-
- 3. Grant of Patent License. Subject to the terms and conditions of
- this License, each Contributor hereby grants to You a perpetual,
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
- (except as stated in this section) patent license to make, have made,
- use, offer to sell, sell, import, and otherwise transfer the Work,
- where such license applies only to those patent claims licensable
- by such Contributor that are necessarily infringed by their
- Contribution(s) alone or by combination of their Contribution(s)
- with the Work to which such Contribution(s) was submitted. If You
- institute patent litigation against any entity (including a
- cross-claim or counterclaim in a lawsuit) alleging that the Work
- or a Contribution incorporated within the Work constitutes direct
- or contributory patent infringement, then any patent licenses
- granted to You under this License for that Work shall terminate
- as of the date such litigation is filed.
-
- 4. Redistribution. You may reproduce and distribute copies of the
- Work or Derivative Works thereof in any medium, with or without
- modifications, and in Source or Object form, provided that You
- meet the following conditions:
-
- (a) You must give any other recipients of the Work or
- Derivative Works a copy of this License; and
-
- (b) You must cause any modified files to carry prominent notices
- stating that You changed the files; and
-
- (c) You must retain, in the Source form of any Derivative Works
- that You distribute, all copyright, patent, trademark, and
- attribution notices from the Source form of the Work,
- excluding those notices that do not pertain to any part of
- the Derivative Works; and
-
- (d) If the Work includes a "NOTICE" text file as part of its
- distribution, then any Derivative Works that You distribute must
- include a readable copy of the attribution notices contained
- within such NOTICE file, excluding those notices that do not
- pertain to any part of the Derivative Works, in at least one
- of the following places: within a NOTICE text file distributed
- as part of the Derivative Works; within the Source form or
- documentation, if provided along with the Derivative Works; or,
- within a display generated by the Derivative Works, if and
- wherever such third-party notices normally appear. The contents
- of the NOTICE file are for informational purposes only and
- do not modify the License. You may add Your own attribution
- notices within Derivative Works that You distribute, alongside
- or as an addendum to the NOTICE text from the Work, provided
- that such additional attribution notices cannot be construed
- as modifying the License.
-
- You may add Your own copyright statement to Your modifications and
- may provide additional or different license terms and conditions
- for use, reproduction, or distribution of Your modifications, or
- for any such Derivative Works as a whole, provided Your use,
- reproduction, and distribution of the Work otherwise complies with
- the conditions stated in this License.
-
- 5. Submission of Contributions. Unless You explicitly state otherwise,
- any Contribution intentionally submitted for inclusion in the Work
- by You to the Licensor shall be under the terms and conditions of
- this License, without any additional terms or conditions.
- Notwithstanding the above, nothing herein shall supersede or modify
- the terms of any separate license agreement you may have executed
- with Licensor regarding such Contributions.
-
- 6. Trademarks. This License does not grant permission to use the trade
- names, trademarks, service marks, or product names of the Licensor,
- except as required for reasonable and customary use in describing the
- origin of the Work and reproducing the content of the NOTICE file.
-
- 7. Disclaimer of Warranty. Unless required by applicable law or
- agreed to in writing, Licensor provides the Work (and each
- Contributor provides its Contributions) on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
- implied, including, without limitation, any warranties or conditions
- of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
- PARTICULAR PURPOSE. You are solely responsible for determining the
- appropriateness of using or redistributing the Work and assume any
- risks associated with Your exercise of permissions under this License.
-
- 8. Limitation of Liability. In no event and under no legal theory,
- whether in tort (including negligence), contract, or otherwise,
- unless required by applicable law (such as deliberate and grossly
- negligent acts) or agreed to in writing, shall any Contributor be
- liable to You for damages, including any direct, indirect, special,
- incidental, or consequential damages of any character arising as a
- result of this License or out of the use or inability to use the
- Work (including but not limited to damages for loss of goodwill,
- work stoppage, computer failure or malfunction, or any and all
- other commercial damages or losses), even if such Contributor
- has been advised of the possibility of such damages.
-
- 9. Accepting Warranty or Additional Liability. While redistributing
- the Work or Derivative Works thereof, You may choose to offer,
- and charge a fee for, acceptance of support, warranty, indemnity,
- or other liability obligations and/or rights consistent with this
- License. However, in accepting such obligations, You may act only
- on Your own behalf and on Your sole responsibility, not on behalf
- of any other Contributor, and only if You agree to indemnify,
- defend, and hold each Contributor harmless for any liability
- incurred by, or claims asserted against, such Contributor by reason
- of your accepting any such warranty or additional liability.
-
- END OF TERMS AND CONDITIONS
-
- APPENDIX: How to apply the Apache License to your work.
-
- To apply the Apache License to your work, attach the following
- boilerplate notice, with the fields enclosed by brackets "[]"
- replaced with your own identifying information. (Don't include
- the brackets!) The text should be enclosed in the appropriate
- comment syntax for the file format. We also recommend that a
- file or class name and description of purpose be included on the
- same "printed page" as the copyright notice for easier
- identification within third-party archives.
-
- Copyright 2018-2022 BasicSR Authors
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
diff --git a/comfy_extras/chainner_models/architecture/LICENSE-SwiftSRGAN b/comfy_extras/chainner_models/architecture/LICENSE-SwiftSRGAN
deleted file mode 100644
index 0e259d42c99..00000000000
--- a/comfy_extras/chainner_models/architecture/LICENSE-SwiftSRGAN
+++ /dev/null
@@ -1,121 +0,0 @@
-Creative Commons Legal Code
-
-CC0 1.0 Universal
-
- CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE
- LEGAL SERVICES. DISTRIBUTION OF THIS DOCUMENT DOES NOT CREATE AN
- ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS
- INFORMATION ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES
- REGARDING THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS
- PROVIDED HEREUNDER, AND DISCLAIMS LIABILITY FOR DAMAGES RESULTING FROM
- THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS PROVIDED
- HEREUNDER.
-
-Statement of Purpose
-
-The laws of most jurisdictions throughout the world automatically confer
-exclusive Copyright and Related Rights (defined below) upon the creator
-and subsequent owner(s) (each and all, an "owner") of an original work of
-authorship and/or a database (each, a "Work").
-
-Certain owners wish to permanently relinquish those rights to a Work for
-the purpose of contributing to a commons of creative, cultural and
-scientific works ("Commons") that the public can reliably and without fear
-of later claims of infringement build upon, modify, incorporate in other
-works, reuse and redistribute as freely as possible in any form whatsoever
-and for any purposes, including without limitation commercial purposes.
-These owners may contribute to the Commons to promote the ideal of a free
-culture and the further production of creative, cultural and scientific
-works, or to gain reputation or greater distribution for their Work in
-part through the use and efforts of others.
-
-For these and/or other purposes and motivations, and without any
-expectation of additional consideration or compensation, the person
-associating CC0 with a Work (the "Affirmer"), to the extent that he or she
-is an owner of Copyright and Related Rights in the Work, voluntarily
-elects to apply CC0 to the Work and publicly distribute the Work under its
-terms, with knowledge of his or her Copyright and Related Rights in the
-Work and the meaning and intended legal effect of CC0 on those rights.
-
-1. Copyright and Related Rights. A Work made available under CC0 may be
-protected by copyright and related or neighboring rights ("Copyright and
-Related Rights"). Copyright and Related Rights include, but are not
-limited to, the following:
-
- i. the right to reproduce, adapt, distribute, perform, display,
- communicate, and translate a Work;
- ii. moral rights retained by the original author(s) and/or performer(s);
-iii. publicity and privacy rights pertaining to a person's image or
- likeness depicted in a Work;
- iv. rights protecting against unfair competition in regards to a Work,
- subject to the limitations in paragraph 4(a), below;
- v. rights protecting the extraction, dissemination, use and reuse of data
- in a Work;
- vi. database rights (such as those arising under Directive 96/9/EC of the
- European Parliament and of the Council of 11 March 1996 on the legal
- protection of databases, and under any national implementation
- thereof, including any amended or successor version of such
- directive); and
-vii. other similar, equivalent or corresponding rights throughout the
- world based on applicable law or treaty, and any national
- implementations thereof.
-
-2. Waiver. To the greatest extent permitted by, but not in contravention
-of, applicable law, Affirmer hereby overtly, fully, permanently,
-irrevocably and unconditionally waives, abandons, and surrenders all of
-Affirmer's Copyright and Related Rights and associated claims and causes
-of action, whether now known or unknown (including existing as well as
-future claims and causes of action), in the Work (i) in all territories
-worldwide, (ii) for the maximum duration provided by applicable law or
-treaty (including future time extensions), (iii) in any current or future
-medium and for any number of copies, and (iv) for any purpose whatsoever,
-including without limitation commercial, advertising or promotional
-purposes (the "Waiver"). Affirmer makes the Waiver for the benefit of each
-member of the public at large and to the detriment of Affirmer's heirs and
-successors, fully intending that such Waiver shall not be subject to
-revocation, rescission, cancellation, termination, or any other legal or
-equitable action to disrupt the quiet enjoyment of the Work by the public
-as contemplated by Affirmer's express Statement of Purpose.
-
-3. Public License Fallback. Should any part of the Waiver for any reason
-be judged legally invalid or ineffective under applicable law, then the
-Waiver shall be preserved to the maximum extent permitted taking into
-account Affirmer's express Statement of Purpose. In addition, to the
-extent the Waiver is so judged Affirmer hereby grants to each affected
-person a royalty-free, non transferable, non sublicensable, non exclusive,
-irrevocable and unconditional license to exercise Affirmer's Copyright and
-Related Rights in the Work (i) in all territories worldwide, (ii) for the
-maximum duration provided by applicable law or treaty (including future
-time extensions), (iii) in any current or future medium and for any number
-of copies, and (iv) for any purpose whatsoever, including without
-limitation commercial, advertising or promotional purposes (the
-"License"). The License shall be deemed effective as of the date CC0 was
-applied by Affirmer to the Work. Should any part of the License for any
-reason be judged legally invalid or ineffective under applicable law, such
-partial invalidity or ineffectiveness shall not invalidate the remainder
-of the License, and in such case Affirmer hereby affirms that he or she
-will not (i) exercise any of his or her remaining Copyright and Related
-Rights in the Work or (ii) assert any associated claims and causes of
-action with respect to the Work, in either case contrary to Affirmer's
-express Statement of Purpose.
-
-4. Limitations and Disclaimers.
-
- a. No trademark or patent rights held by Affirmer are waived, abandoned,
- surrendered, licensed or otherwise affected by this document.
- b. Affirmer offers the Work as-is and makes no representations or
- warranties of any kind concerning the Work, express, implied,
- statutory or otherwise, including without limitation warranties of
- title, merchantability, fitness for a particular purpose, non
- infringement, or the absence of latent or other defects, accuracy, or
- the present or absence of errors, whether or not discoverable, all to
- the greatest extent permissible under applicable law.
- c. Affirmer disclaims responsibility for clearing rights of other persons
- that may apply to the Work or any use thereof, including without
- limitation any person's Copyright and Related Rights in the Work.
- Further, Affirmer disclaims responsibility for obtaining any necessary
- consents, permissions or other rights required for any use of the
- Work.
- d. Affirmer understands and acknowledges that Creative Commons is not a
- party to this document and has no duty or obligation with respect to
- this CC0 or use of the Work.
diff --git a/comfy_extras/chainner_models/architecture/LICENSE-Swin2SR b/comfy_extras/chainner_models/architecture/LICENSE-Swin2SR
deleted file mode 100644
index e5e4ee061a3..00000000000
--- a/comfy_extras/chainner_models/architecture/LICENSE-Swin2SR
+++ /dev/null
@@ -1,201 +0,0 @@
- Apache License
- Version 2.0, January 2004
- http://www.apache.org/licenses/
-
- TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
-
- 1. Definitions.
-
- "License" shall mean the terms and conditions for use, reproduction,
- and distribution as defined by Sections 1 through 9 of this document.
-
- "Licensor" shall mean the copyright owner or entity authorized by
- the copyright owner that is granting the License.
-
- "Legal Entity" shall mean the union of the acting entity and all
- other entities that control, are controlled by, or are under common
- control with that entity. For the purposes of this definition,
- "control" means (i) the power, direct or indirect, to cause the
- direction or management of such entity, whether by contract or
- otherwise, or (ii) ownership of fifty percent (50%) or more of the
- outstanding shares, or (iii) beneficial ownership of such entity.
-
- "You" (or "Your") shall mean an individual or Legal Entity
- exercising permissions granted by this License.
-
- "Source" form shall mean the preferred form for making modifications,
- including but not limited to software source code, documentation
- source, and configuration files.
-
- "Object" form shall mean any form resulting from mechanical
- transformation or translation of a Source form, including but
- not limited to compiled object code, generated documentation,
- and conversions to other media types.
-
- "Work" shall mean the work of authorship, whether in Source or
- Object form, made available under the License, as indicated by a
- copyright notice that is included in or attached to the work
- (an example is provided in the Appendix below).
-
- "Derivative Works" shall mean any work, whether in Source or Object
- form, that is based on (or derived from) the Work and for which the
- editorial revisions, annotations, elaborations, or other modifications
- represent, as a whole, an original work of authorship. For the purposes
- of this License, Derivative Works shall not include works that remain
- separable from, or merely link (or bind by name) to the interfaces of,
- the Work and Derivative Works thereof.
-
- "Contribution" shall mean any work of authorship, including
- the original version of the Work and any modifications or additions
- to that Work or Derivative Works thereof, that is intentionally
- submitted to Licensor for inclusion in the Work by the copyright owner
- or by an individual or Legal Entity authorized to submit on behalf of
- the copyright owner. For the purposes of this definition, "submitted"
- means any form of electronic, verbal, or written communication sent
- to the Licensor or its representatives, including but not limited to
- communication on electronic mailing lists, source code control systems,
- and issue tracking systems that are managed by, or on behalf of, the
- Licensor for the purpose of discussing and improving the Work, but
- excluding communication that is conspicuously marked or otherwise
- designated in writing by the copyright owner as "Not a Contribution."
-
- "Contributor" shall mean Licensor and any individual or Legal Entity
- on behalf of whom a Contribution has been received by Licensor and
- subsequently incorporated within the Work.
-
- 2. Grant of Copyright License. Subject to the terms and conditions of
- this License, each Contributor hereby grants to You a perpetual,
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
- copyright license to reproduce, prepare Derivative Works of,
- publicly display, publicly perform, sublicense, and distribute the
- Work and such Derivative Works in Source or Object form.
-
- 3. Grant of Patent License. Subject to the terms and conditions of
- this License, each Contributor hereby grants to You a perpetual,
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
- (except as stated in this section) patent license to make, have made,
- use, offer to sell, sell, import, and otherwise transfer the Work,
- where such license applies only to those patent claims licensable
- by such Contributor that are necessarily infringed by their
- Contribution(s) alone or by combination of their Contribution(s)
- with the Work to which such Contribution(s) was submitted. If You
- institute patent litigation against any entity (including a
- cross-claim or counterclaim in a lawsuit) alleging that the Work
- or a Contribution incorporated within the Work constitutes direct
- or contributory patent infringement, then any patent licenses
- granted to You under this License for that Work shall terminate
- as of the date such litigation is filed.
-
- 4. Redistribution. You may reproduce and distribute copies of the
- Work or Derivative Works thereof in any medium, with or without
- modifications, and in Source or Object form, provided that You
- meet the following conditions:
-
- (a) You must give any other recipients of the Work or
- Derivative Works a copy of this License; and
-
- (b) You must cause any modified files to carry prominent notices
- stating that You changed the files; and
-
- (c) You must retain, in the Source form of any Derivative Works
- that You distribute, all copyright, patent, trademark, and
- attribution notices from the Source form of the Work,
- excluding those notices that do not pertain to any part of
- the Derivative Works; and
-
- (d) If the Work includes a "NOTICE" text file as part of its
- distribution, then any Derivative Works that You distribute must
- include a readable copy of the attribution notices contained
- within such NOTICE file, excluding those notices that do not
- pertain to any part of the Derivative Works, in at least one
- of the following places: within a NOTICE text file distributed
- as part of the Derivative Works; within the Source form or
- documentation, if provided along with the Derivative Works; or,
- within a display generated by the Derivative Works, if and
- wherever such third-party notices normally appear. The contents
- of the NOTICE file are for informational purposes only and
- do not modify the License. You may add Your own attribution
- notices within Derivative Works that You distribute, alongside
- or as an addendum to the NOTICE text from the Work, provided
- that such additional attribution notices cannot be construed
- as modifying the License.
-
- You may add Your own copyright statement to Your modifications and
- may provide additional or different license terms and conditions
- for use, reproduction, or distribution of Your modifications, or
- for any such Derivative Works as a whole, provided Your use,
- reproduction, and distribution of the Work otherwise complies with
- the conditions stated in this License.
-
- 5. Submission of Contributions. Unless You explicitly state otherwise,
- any Contribution intentionally submitted for inclusion in the Work
- by You to the Licensor shall be under the terms and conditions of
- this License, without any additional terms or conditions.
- Notwithstanding the above, nothing herein shall supersede or modify
- the terms of any separate license agreement you may have executed
- with Licensor regarding such Contributions.
-
- 6. Trademarks. This License does not grant permission to use the trade
- names, trademarks, service marks, or product names of the Licensor,
- except as required for reasonable and customary use in describing the
- origin of the Work and reproducing the content of the NOTICE file.
-
- 7. Disclaimer of Warranty. Unless required by applicable law or
- agreed to in writing, Licensor provides the Work (and each
- Contributor provides its Contributions) on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
- implied, including, without limitation, any warranties or conditions
- of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
- PARTICULAR PURPOSE. You are solely responsible for determining the
- appropriateness of using or redistributing the Work and assume any
- risks associated with Your exercise of permissions under this License.
-
- 8. Limitation of Liability. In no event and under no legal theory,
- whether in tort (including negligence), contract, or otherwise,
- unless required by applicable law (such as deliberate and grossly
- negligent acts) or agreed to in writing, shall any Contributor be
- liable to You for damages, including any direct, indirect, special,
- incidental, or consequential damages of any character arising as a
- result of this License or out of the use or inability to use the
- Work (including but not limited to damages for loss of goodwill,
- work stoppage, computer failure or malfunction, or any and all
- other commercial damages or losses), even if such Contributor
- has been advised of the possibility of such damages.
-
- 9. Accepting Warranty or Additional Liability. While redistributing
- the Work or Derivative Works thereof, You may choose to offer,
- and charge a fee for, acceptance of support, warranty, indemnity,
- or other liability obligations and/or rights consistent with this
- License. However, in accepting such obligations, You may act only
- on Your own behalf and on Your sole responsibility, not on behalf
- of any other Contributor, and only if You agree to indemnify,
- defend, and hold each Contributor harmless for any liability
- incurred by, or claims asserted against, such Contributor by reason
- of your accepting any such warranty or additional liability.
-
- END OF TERMS AND CONDITIONS
-
- APPENDIX: How to apply the Apache License to your work.
-
- To apply the Apache License to your work, attach the following
- boilerplate notice, with the fields enclosed by brackets "[]"
- replaced with your own identifying information. (Don't include
- the brackets!) The text should be enclosed in the appropriate
- comment syntax for the file format. We also recommend that a
- file or class name and description of purpose be included on the
- same "printed page" as the copyright notice for easier
- identification within third-party archives.
-
- Copyright [2021] [SwinIR Authors]
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
diff --git a/comfy_extras/chainner_models/architecture/LICENSE-SwinIR b/comfy_extras/chainner_models/architecture/LICENSE-SwinIR
deleted file mode 100644
index e5e4ee061a3..00000000000
--- a/comfy_extras/chainner_models/architecture/LICENSE-SwinIR
+++ /dev/null
@@ -1,201 +0,0 @@
- Apache License
- Version 2.0, January 2004
- http://www.apache.org/licenses/
-
- TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
-
- 1. Definitions.
-
- "License" shall mean the terms and conditions for use, reproduction,
- and distribution as defined by Sections 1 through 9 of this document.
-
- "Licensor" shall mean the copyright owner or entity authorized by
- the copyright owner that is granting the License.
-
- "Legal Entity" shall mean the union of the acting entity and all
- other entities that control, are controlled by, or are under common
- control with that entity. For the purposes of this definition,
- "control" means (i) the power, direct or indirect, to cause the
- direction or management of such entity, whether by contract or
- otherwise, or (ii) ownership of fifty percent (50%) or more of the
- outstanding shares, or (iii) beneficial ownership of such entity.
-
- "You" (or "Your") shall mean an individual or Legal Entity
- exercising permissions granted by this License.
-
- "Source" form shall mean the preferred form for making modifications,
- including but not limited to software source code, documentation
- source, and configuration files.
-
- "Object" form shall mean any form resulting from mechanical
- transformation or translation of a Source form, including but
- not limited to compiled object code, generated documentation,
- and conversions to other media types.
-
- "Work" shall mean the work of authorship, whether in Source or
- Object form, made available under the License, as indicated by a
- copyright notice that is included in or attached to the work
- (an example is provided in the Appendix below).
-
- "Derivative Works" shall mean any work, whether in Source or Object
- form, that is based on (or derived from) the Work and for which the
- editorial revisions, annotations, elaborations, or other modifications
- represent, as a whole, an original work of authorship. For the purposes
- of this License, Derivative Works shall not include works that remain
- separable from, or merely link (or bind by name) to the interfaces of,
- the Work and Derivative Works thereof.
-
- "Contribution" shall mean any work of authorship, including
- the original version of the Work and any modifications or additions
- to that Work or Derivative Works thereof, that is intentionally
- submitted to Licensor for inclusion in the Work by the copyright owner
- or by an individual or Legal Entity authorized to submit on behalf of
- the copyright owner. For the purposes of this definition, "submitted"
- means any form of electronic, verbal, or written communication sent
- to the Licensor or its representatives, including but not limited to
- communication on electronic mailing lists, source code control systems,
- and issue tracking systems that are managed by, or on behalf of, the
- Licensor for the purpose of discussing and improving the Work, but
- excluding communication that is conspicuously marked or otherwise
- designated in writing by the copyright owner as "Not a Contribution."
-
- "Contributor" shall mean Licensor and any individual or Legal Entity
- on behalf of whom a Contribution has been received by Licensor and
- subsequently incorporated within the Work.
-
- 2. Grant of Copyright License. Subject to the terms and conditions of
- this License, each Contributor hereby grants to You a perpetual,
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
- copyright license to reproduce, prepare Derivative Works of,
- publicly display, publicly perform, sublicense, and distribute the
- Work and such Derivative Works in Source or Object form.
-
- 3. Grant of Patent License. Subject to the terms and conditions of
- this License, each Contributor hereby grants to You a perpetual,
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
- (except as stated in this section) patent license to make, have made,
- use, offer to sell, sell, import, and otherwise transfer the Work,
- where such license applies only to those patent claims licensable
- by such Contributor that are necessarily infringed by their
- Contribution(s) alone or by combination of their Contribution(s)
- with the Work to which such Contribution(s) was submitted. If You
- institute patent litigation against any entity (including a
- cross-claim or counterclaim in a lawsuit) alleging that the Work
- or a Contribution incorporated within the Work constitutes direct
- or contributory patent infringement, then any patent licenses
- granted to You under this License for that Work shall terminate
- as of the date such litigation is filed.
-
- 4. Redistribution. You may reproduce and distribute copies of the
- Work or Derivative Works thereof in any medium, with or without
- modifications, and in Source or Object form, provided that You
- meet the following conditions:
-
- (a) You must give any other recipients of the Work or
- Derivative Works a copy of this License; and
-
- (b) You must cause any modified files to carry prominent notices
- stating that You changed the files; and
-
- (c) You must retain, in the Source form of any Derivative Works
- that You distribute, all copyright, patent, trademark, and
- attribution notices from the Source form of the Work,
- excluding those notices that do not pertain to any part of
- the Derivative Works; and
-
- (d) If the Work includes a "NOTICE" text file as part of its
- distribution, then any Derivative Works that You distribute must
- include a readable copy of the attribution notices contained
- within such NOTICE file, excluding those notices that do not
- pertain to any part of the Derivative Works, in at least one
- of the following places: within a NOTICE text file distributed
- as part of the Derivative Works; within the Source form or
- documentation, if provided along with the Derivative Works; or,
- within a display generated by the Derivative Works, if and
- wherever such third-party notices normally appear. The contents
- of the NOTICE file are for informational purposes only and
- do not modify the License. You may add Your own attribution
- notices within Derivative Works that You distribute, alongside
- or as an addendum to the NOTICE text from the Work, provided
- that such additional attribution notices cannot be construed
- as modifying the License.
-
- You may add Your own copyright statement to Your modifications and
- may provide additional or different license terms and conditions
- for use, reproduction, or distribution of Your modifications, or
- for any such Derivative Works as a whole, provided Your use,
- reproduction, and distribution of the Work otherwise complies with
- the conditions stated in this License.
-
- 5. Submission of Contributions. Unless You explicitly state otherwise,
- any Contribution intentionally submitted for inclusion in the Work
- by You to the Licensor shall be under the terms and conditions of
- this License, without any additional terms or conditions.
- Notwithstanding the above, nothing herein shall supersede or modify
- the terms of any separate license agreement you may have executed
- with Licensor regarding such Contributions.
-
- 6. Trademarks. This License does not grant permission to use the trade
- names, trademarks, service marks, or product names of the Licensor,
- except as required for reasonable and customary use in describing the
- origin of the Work and reproducing the content of the NOTICE file.
-
- 7. Disclaimer of Warranty. Unless required by applicable law or
- agreed to in writing, Licensor provides the Work (and each
- Contributor provides its Contributions) on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
- implied, including, without limitation, any warranties or conditions
- of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
- PARTICULAR PURPOSE. You are solely responsible for determining the
- appropriateness of using or redistributing the Work and assume any
- risks associated with Your exercise of permissions under this License.
-
- 8. Limitation of Liability. In no event and under no legal theory,
- whether in tort (including negligence), contract, or otherwise,
- unless required by applicable law (such as deliberate and grossly
- negligent acts) or agreed to in writing, shall any Contributor be
- liable to You for damages, including any direct, indirect, special,
- incidental, or consequential damages of any character arising as a
- result of this License or out of the use or inability to use the
- Work (including but not limited to damages for loss of goodwill,
- work stoppage, computer failure or malfunction, or any and all
- other commercial damages or losses), even if such Contributor
- has been advised of the possibility of such damages.
-
- 9. Accepting Warranty or Additional Liability. While redistributing
- the Work or Derivative Works thereof, You may choose to offer,
- and charge a fee for, acceptance of support, warranty, indemnity,
- or other liability obligations and/or rights consistent with this
- License. However, in accepting such obligations, You may act only
- on Your own behalf and on Your sole responsibility, not on behalf
- of any other Contributor, and only if You agree to indemnify,
- defend, and hold each Contributor harmless for any liability
- incurred by, or claims asserted against, such Contributor by reason
- of your accepting any such warranty or additional liability.
-
- END OF TERMS AND CONDITIONS
-
- APPENDIX: How to apply the Apache License to your work.
-
- To apply the Apache License to your work, attach the following
- boilerplate notice, with the fields enclosed by brackets "[]"
- replaced with your own identifying information. (Don't include
- the brackets!) The text should be enclosed in the appropriate
- comment syntax for the file format. We also recommend that a
- file or class name and description of purpose be included on the
- same "printed page" as the copyright notice for easier
- identification within third-party archives.
-
- Copyright [2021] [SwinIR Authors]
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
diff --git a/comfy_extras/chainner_models/architecture/LICENSE-lama b/comfy_extras/chainner_models/architecture/LICENSE-lama
deleted file mode 100644
index ca822bb5f62..00000000000
--- a/comfy_extras/chainner_models/architecture/LICENSE-lama
+++ /dev/null
@@ -1,201 +0,0 @@
- Apache License
- Version 2.0, January 2004
- http://www.apache.org/licenses/
-
- TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
-
- 1. Definitions.
-
- "License" shall mean the terms and conditions for use, reproduction,
- and distribution as defined by Sections 1 through 9 of this document.
-
- "Licensor" shall mean the copyright owner or entity authorized by
- the copyright owner that is granting the License.
-
- "Legal Entity" shall mean the union of the acting entity and all
- other entities that control, are controlled by, or are under common
- control with that entity. For the purposes of this definition,
- "control" means (i) the power, direct or indirect, to cause the
- direction or management of such entity, whether by contract or
- otherwise, or (ii) ownership of fifty percent (50%) or more of the
- outstanding shares, or (iii) beneficial ownership of such entity.
-
- "You" (or "Your") shall mean an individual or Legal Entity
- exercising permissions granted by this License.
-
- "Source" form shall mean the preferred form for making modifications,
- including but not limited to software source code, documentation
- source, and configuration files.
-
- "Object" form shall mean any form resulting from mechanical
- transformation or translation of a Source form, including but
- not limited to compiled object code, generated documentation,
- and conversions to other media types.
-
- "Work" shall mean the work of authorship, whether in Source or
- Object form, made available under the License, as indicated by a
- copyright notice that is included in or attached to the work
- (an example is provided in the Appendix below).
-
- "Derivative Works" shall mean any work, whether in Source or Object
- form, that is based on (or derived from) the Work and for which the
- editorial revisions, annotations, elaborations, or other modifications
- represent, as a whole, an original work of authorship. For the purposes
- of this License, Derivative Works shall not include works that remain
- separable from, or merely link (or bind by name) to the interfaces of,
- the Work and Derivative Works thereof.
-
- "Contribution" shall mean any work of authorship, including
- the original version of the Work and any modifications or additions
- to that Work or Derivative Works thereof, that is intentionally
- submitted to Licensor for inclusion in the Work by the copyright owner
- or by an individual or Legal Entity authorized to submit on behalf of
- the copyright owner. For the purposes of this definition, "submitted"
- means any form of electronic, verbal, or written communication sent
- to the Licensor or its representatives, including but not limited to
- communication on electronic mailing lists, source code control systems,
- and issue tracking systems that are managed by, or on behalf of, the
- Licensor for the purpose of discussing and improving the Work, but
- excluding communication that is conspicuously marked or otherwise
- designated in writing by the copyright owner as "Not a Contribution."
-
- "Contributor" shall mean Licensor and any individual or Legal Entity
- on behalf of whom a Contribution has been received by Licensor and
- subsequently incorporated within the Work.
-
- 2. Grant of Copyright License. Subject to the terms and conditions of
- this License, each Contributor hereby grants to You a perpetual,
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
- copyright license to reproduce, prepare Derivative Works of,
- publicly display, publicly perform, sublicense, and distribute the
- Work and such Derivative Works in Source or Object form.
-
- 3. Grant of Patent License. Subject to the terms and conditions of
- this License, each Contributor hereby grants to You a perpetual,
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
- (except as stated in this section) patent license to make, have made,
- use, offer to sell, sell, import, and otherwise transfer the Work,
- where such license applies only to those patent claims licensable
- by such Contributor that are necessarily infringed by their
- Contribution(s) alone or by combination of their Contribution(s)
- with the Work to which such Contribution(s) was submitted. If You
- institute patent litigation against any entity (including a
- cross-claim or counterclaim in a lawsuit) alleging that the Work
- or a Contribution incorporated within the Work constitutes direct
- or contributory patent infringement, then any patent licenses
- granted to You under this License for that Work shall terminate
- as of the date such litigation is filed.
-
- 4. Redistribution. You may reproduce and distribute copies of the
- Work or Derivative Works thereof in any medium, with or without
- modifications, and in Source or Object form, provided that You
- meet the following conditions:
-
- (a) You must give any other recipients of the Work or
- Derivative Works a copy of this License; and
-
- (b) You must cause any modified files to carry prominent notices
- stating that You changed the files; and
-
- (c) You must retain, in the Source form of any Derivative Works
- that You distribute, all copyright, patent, trademark, and
- attribution notices from the Source form of the Work,
- excluding those notices that do not pertain to any part of
- the Derivative Works; and
-
- (d) If the Work includes a "NOTICE" text file as part of its
- distribution, then any Derivative Works that You distribute must
- include a readable copy of the attribution notices contained
- within such NOTICE file, excluding those notices that do not
- pertain to any part of the Derivative Works, in at least one
- of the following places: within a NOTICE text file distributed
- as part of the Derivative Works; within the Source form or
- documentation, if provided along with the Derivative Works; or,
- within a display generated by the Derivative Works, if and
- wherever such third-party notices normally appear. The contents
- of the NOTICE file are for informational purposes only and
- do not modify the License. You may add Your own attribution
- notices within Derivative Works that You distribute, alongside
- or as an addendum to the NOTICE text from the Work, provided
- that such additional attribution notices cannot be construed
- as modifying the License.
-
- You may add Your own copyright statement to Your modifications and
- may provide additional or different license terms and conditions
- for use, reproduction, or distribution of Your modifications, or
- for any such Derivative Works as a whole, provided Your use,
- reproduction, and distribution of the Work otherwise complies with
- the conditions stated in this License.
-
- 5. Submission of Contributions. Unless You explicitly state otherwise,
- any Contribution intentionally submitted for inclusion in the Work
- by You to the Licensor shall be under the terms and conditions of
- this License, without any additional terms or conditions.
- Notwithstanding the above, nothing herein shall supersede or modify
- the terms of any separate license agreement you may have executed
- with Licensor regarding such Contributions.
-
- 6. Trademarks. This License does not grant permission to use the trade
- names, trademarks, service marks, or product names of the Licensor,
- except as required for reasonable and customary use in describing the
- origin of the Work and reproducing the content of the NOTICE file.
-
- 7. Disclaimer of Warranty. Unless required by applicable law or
- agreed to in writing, Licensor provides the Work (and each
- Contributor provides its Contributions) on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
- implied, including, without limitation, any warranties or conditions
- of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
- PARTICULAR PURPOSE. You are solely responsible for determining the
- appropriateness of using or redistributing the Work and assume any
- risks associated with Your exercise of permissions under this License.
-
- 8. Limitation of Liability. In no event and under no legal theory,
- whether in tort (including negligence), contract, or otherwise,
- unless required by applicable law (such as deliberate and grossly
- negligent acts) or agreed to in writing, shall any Contributor be
- liable to You for damages, including any direct, indirect, special,
- incidental, or consequential damages of any character arising as a
- result of this License or out of the use or inability to use the
- Work (including but not limited to damages for loss of goodwill,
- work stoppage, computer failure or malfunction, or any and all
- other commercial damages or losses), even if such Contributor
- has been advised of the possibility of such damages.
-
- 9. Accepting Warranty or Additional Liability. While redistributing
- the Work or Derivative Works thereof, You may choose to offer,
- and charge a fee for, acceptance of support, warranty, indemnity,
- or other liability obligations and/or rights consistent with this
- License. However, in accepting such obligations, You may act only
- on Your own behalf and on Your sole responsibility, not on behalf
- of any other Contributor, and only if You agree to indemnify,
- defend, and hold each Contributor harmless for any liability
- incurred by, or claims asserted against, such Contributor by reason
- of your accepting any such warranty or additional liability.
-
- END OF TERMS AND CONDITIONS
-
- APPENDIX: How to apply the Apache License to your work.
-
- To apply the Apache License to your work, attach the following
- boilerplate notice, with the fields enclosed by brackets "[]"
- replaced with your own identifying information. (Don't include
- the brackets!) The text should be enclosed in the appropriate
- comment syntax for the file format. We also recommend that a
- file or class name and description of purpose be included on the
- same "printed page" as the copyright notice for easier
- identification within third-party archives.
-
- Copyright [2021] Samsung Research
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
diff --git a/comfy_extras/chainner_models/architecture/LaMa.py b/comfy_extras/chainner_models/architecture/LaMa.py
deleted file mode 100644
index a781f3e4dda..00000000000
--- a/comfy_extras/chainner_models/architecture/LaMa.py
+++ /dev/null
@@ -1,694 +0,0 @@
-# pylint: skip-file
-"""
-Model adapted from advimman's lama project: https://github.com/advimman/lama
-"""
-
-# Fast Fourier Convolution NeurIPS 2020
-# original implementation https://github.com/pkumivision/FFC/blob/main/model_zoo/ffc.py
-# paper https://proceedings.neurips.cc/paper/2020/file/2fd5d41ec6cfab47e32164d5624269b1-Paper.pdf
-
-from typing import List
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from torchvision.transforms.functional import InterpolationMode, rotate
-
-
-class LearnableSpatialTransformWrapper(nn.Module):
- def __init__(self, impl, pad_coef=0.5, angle_init_range=80, train_angle=True):
- super().__init__()
- self.impl = impl
- self.angle = torch.rand(1) * angle_init_range
- if train_angle:
- self.angle = nn.Parameter(self.angle, requires_grad=True)
- self.pad_coef = pad_coef
-
- def forward(self, x):
- if torch.is_tensor(x):
- return self.inverse_transform(self.impl(self.transform(x)), x)
- elif isinstance(x, tuple):
- x_trans = tuple(self.transform(elem) for elem in x)
- y_trans = self.impl(x_trans)
- return tuple(
- self.inverse_transform(elem, orig_x) for elem, orig_x in zip(y_trans, x)
- )
- else:
- raise ValueError(f"Unexpected input type {type(x)}")
-
- def transform(self, x):
- height, width = x.shape[2:]
- pad_h, pad_w = int(height * self.pad_coef), int(width * self.pad_coef)
- x_padded = F.pad(x, [pad_w, pad_w, pad_h, pad_h], mode="reflect")
- x_padded_rotated = rotate(
- x_padded, self.angle.to(x_padded), InterpolationMode.BILINEAR, fill=0
- )
-
- return x_padded_rotated
-
- def inverse_transform(self, y_padded_rotated, orig_x):
- height, width = orig_x.shape[2:]
- pad_h, pad_w = int(height * self.pad_coef), int(width * self.pad_coef)
-
- y_padded = rotate(
- y_padded_rotated,
- -self.angle.to(y_padded_rotated),
- InterpolationMode.BILINEAR,
- fill=0,
- )
- y_height, y_width = y_padded.shape[2:]
- y = y_padded[:, :, pad_h : y_height - pad_h, pad_w : y_width - pad_w]
- return y
-
-
-class SELayer(nn.Module):
- def __init__(self, channel, reduction=16):
- super(SELayer, self).__init__()
- self.avg_pool = nn.AdaptiveAvgPool2d(1)
- self.fc = nn.Sequential(
- nn.Linear(channel, channel // reduction, bias=False),
- nn.ReLU(inplace=True),
- nn.Linear(channel // reduction, channel, bias=False),
- nn.Sigmoid(),
- )
-
- def forward(self, x):
- b, c, _, _ = x.size()
- y = self.avg_pool(x).view(b, c)
- y = self.fc(y).view(b, c, 1, 1)
- res = x * y.expand_as(x)
- return res
-
-
-class FourierUnit(nn.Module):
- def __init__(
- self,
- in_channels,
- out_channels,
- groups=1,
- spatial_scale_factor=None,
- spatial_scale_mode="bilinear",
- spectral_pos_encoding=False,
- use_se=False,
- se_kwargs=None,
- ffc3d=False,
- fft_norm="ortho",
- ):
- # bn_layer not used
- super(FourierUnit, self).__init__()
- self.groups = groups
-
- self.conv_layer = torch.nn.Conv2d(
- in_channels=in_channels * 2 + (2 if spectral_pos_encoding else 0),
- out_channels=out_channels * 2,
- kernel_size=1,
- stride=1,
- padding=0,
- groups=self.groups,
- bias=False,
- )
- self.bn = torch.nn.BatchNorm2d(out_channels * 2)
- self.relu = torch.nn.ReLU(inplace=True)
-
- # squeeze and excitation block
- self.use_se = use_se
- if use_se:
- if se_kwargs is None:
- se_kwargs = {}
- self.se = SELayer(self.conv_layer.in_channels, **se_kwargs)
-
- self.spatial_scale_factor = spatial_scale_factor
- self.spatial_scale_mode = spatial_scale_mode
- self.spectral_pos_encoding = spectral_pos_encoding
- self.ffc3d = ffc3d
- self.fft_norm = fft_norm
-
- def forward(self, x):
- half_check = False
- if x.type() == "torch.cuda.HalfTensor":
- # half only works on gpu anyway
- half_check = True
-
- batch = x.shape[0]
-
- if self.spatial_scale_factor is not None:
- orig_size = x.shape[-2:]
- x = F.interpolate(
- x,
- scale_factor=self.spatial_scale_factor,
- mode=self.spatial_scale_mode,
- align_corners=False,
- )
-
- # (batch, c, h, w/2+1, 2)
- fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1)
- if half_check == True:
- ffted = torch.fft.rfftn(
- x.float(), dim=fft_dim, norm=self.fft_norm
- ) # .type(torch.cuda.HalfTensor)
- else:
- ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm)
-
- ffted = torch.stack((ffted.real, ffted.imag), dim=-1)
- ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1)
- ffted = ffted.view(
- (
- batch,
- -1,
- )
- + ffted.size()[3:]
- )
-
- if self.spectral_pos_encoding:
- height, width = ffted.shape[-2:]
- coords_vert = (
- torch.linspace(0, 1, height)[None, None, :, None]
- .expand(batch, 1, height, width)
- .to(ffted)
- )
- coords_hor = (
- torch.linspace(0, 1, width)[None, None, None, :]
- .expand(batch, 1, height, width)
- .to(ffted)
- )
- ffted = torch.cat((coords_vert, coords_hor, ffted), dim=1)
-
- if self.use_se:
- ffted = self.se(ffted)
-
- if half_check == True:
- ffted = self.conv_layer(ffted.half()) # (batch, c*2, h, w/2+1)
- else:
- ffted = self.conv_layer(
- ffted
- ) # .type(torch.cuda.FloatTensor) # (batch, c*2, h, w/2+1)
-
- ffted = self.relu(self.bn(ffted))
- # forcing to be always float
- ffted = ffted.float()
-
- ffted = (
- ffted.view(
- (
- batch,
- -1,
- 2,
- )
- + ffted.size()[2:]
- )
- .permute(0, 1, 3, 4, 2)
- .contiguous()
- ) # (batch,c, t, h, w/2+1, 2)
-
- ffted = torch.complex(ffted[..., 0], ffted[..., 1])
-
- ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:]
- output = torch.fft.irfftn(
- ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm
- )
-
- if half_check == True:
- output = output.half()
-
- if self.spatial_scale_factor is not None:
- output = F.interpolate(
- output,
- size=orig_size,
- mode=self.spatial_scale_mode,
- align_corners=False,
- )
-
- return output
-
-
-class SpectralTransform(nn.Module):
- def __init__(
- self,
- in_channels,
- out_channels,
- stride=1,
- groups=1,
- enable_lfu=True,
- separable_fu=False,
- **fu_kwargs,
- ):
- # bn_layer not used
- super(SpectralTransform, self).__init__()
- self.enable_lfu = enable_lfu
- if stride == 2:
- self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
- else:
- self.downsample = nn.Identity()
-
- self.stride = stride
- self.conv1 = nn.Sequential(
- nn.Conv2d(
- in_channels, out_channels // 2, kernel_size=1, groups=groups, bias=False
- ),
- nn.BatchNorm2d(out_channels // 2),
- nn.ReLU(inplace=True),
- )
- fu_class = FourierUnit
- self.fu = fu_class(out_channels // 2, out_channels // 2, groups, **fu_kwargs)
- if self.enable_lfu:
- self.lfu = fu_class(out_channels // 2, out_channels // 2, groups)
- self.conv2 = torch.nn.Conv2d(
- out_channels // 2, out_channels, kernel_size=1, groups=groups, bias=False
- )
-
- def forward(self, x):
- x = self.downsample(x)
- x = self.conv1(x)
- output = self.fu(x)
-
- if self.enable_lfu:
- _, c, h, _ = x.shape
- split_no = 2
- split_s = h // split_no
- xs = torch.cat(
- torch.split(x[:, : c // 4], split_s, dim=-2), dim=1
- ).contiguous()
- xs = torch.cat(torch.split(xs, split_s, dim=-1), dim=1).contiguous()
- xs = self.lfu(xs)
- xs = xs.repeat(1, 1, split_no, split_no).contiguous()
- else:
- xs = 0
-
- output = self.conv2(x + output + xs)
-
- return output
-
-
-class FFC(nn.Module):
- def __init__(
- self,
- in_channels,
- out_channels,
- kernel_size,
- ratio_gin,
- ratio_gout,
- stride=1,
- padding=0,
- dilation=1,
- groups=1,
- bias=False,
- enable_lfu=True,
- padding_type="reflect",
- gated=False,
- **spectral_kwargs,
- ):
- super(FFC, self).__init__()
-
- assert stride == 1 or stride == 2, "Stride should be 1 or 2."
- self.stride = stride
-
- in_cg = int(in_channels * ratio_gin)
- in_cl = in_channels - in_cg
- out_cg = int(out_channels * ratio_gout)
- out_cl = out_channels - out_cg
- # groups_g = 1 if groups == 1 else int(groups * ratio_gout)
- # groups_l = 1 if groups == 1 else groups - groups_g
-
- self.ratio_gin = ratio_gin
- self.ratio_gout = ratio_gout
- self.global_in_num = in_cg
-
- module = nn.Identity if in_cl == 0 or out_cl == 0 else nn.Conv2d
- self.convl2l = module(
- in_cl,
- out_cl,
- kernel_size,
- stride,
- padding,
- dilation,
- groups,
- bias,
- padding_mode=padding_type,
- )
- module = nn.Identity if in_cl == 0 or out_cg == 0 else nn.Conv2d
- self.convl2g = module(
- in_cl,
- out_cg,
- kernel_size,
- stride,
- padding,
- dilation,
- groups,
- bias,
- padding_mode=padding_type,
- )
- module = nn.Identity if in_cg == 0 or out_cl == 0 else nn.Conv2d
- self.convg2l = module(
- in_cg,
- out_cl,
- kernel_size,
- stride,
- padding,
- dilation,
- groups,
- bias,
- padding_mode=padding_type,
- )
- module = nn.Identity if in_cg == 0 or out_cg == 0 else SpectralTransform
- self.convg2g = module(
- in_cg,
- out_cg,
- stride,
- 1 if groups == 1 else groups // 2,
- enable_lfu,
- **spectral_kwargs,
- )
-
- self.gated = gated
- module = (
- nn.Identity if in_cg == 0 or out_cl == 0 or not self.gated else nn.Conv2d
- )
- self.gate = module(in_channels, 2, 1)
-
- def forward(self, x):
- x_l, x_g = x if type(x) is tuple else (x, 0)
- out_xl, out_xg = 0, 0
-
- if self.gated:
- total_input_parts = [x_l]
- if torch.is_tensor(x_g):
- total_input_parts.append(x_g)
- total_input = torch.cat(total_input_parts, dim=1)
-
- gates = torch.sigmoid(self.gate(total_input))
- g2l_gate, l2g_gate = gates.chunk(2, dim=1)
- else:
- g2l_gate, l2g_gate = 1, 1
-
- if self.ratio_gout != 1:
- out_xl = self.convl2l(x_l) + self.convg2l(x_g) * g2l_gate
- if self.ratio_gout != 0:
- out_xg = self.convl2g(x_l) * l2g_gate + self.convg2g(x_g)
-
- return out_xl, out_xg
-
-
-class FFC_BN_ACT(nn.Module):
- def __init__(
- self,
- in_channels,
- out_channels,
- kernel_size,
- ratio_gin,
- ratio_gout,
- stride=1,
- padding=0,
- dilation=1,
- groups=1,
- bias=False,
- norm_layer=nn.BatchNorm2d,
- activation_layer=nn.Identity,
- padding_type="reflect",
- enable_lfu=True,
- **kwargs,
- ):
- super(FFC_BN_ACT, self).__init__()
- self.ffc = FFC(
- in_channels,
- out_channels,
- kernel_size,
- ratio_gin,
- ratio_gout,
- stride,
- padding,
- dilation,
- groups,
- bias,
- enable_lfu,
- padding_type=padding_type,
- **kwargs,
- )
- lnorm = nn.Identity if ratio_gout == 1 else norm_layer
- gnorm = nn.Identity if ratio_gout == 0 else norm_layer
- global_channels = int(out_channels * ratio_gout)
- self.bn_l = lnorm(out_channels - global_channels)
- self.bn_g = gnorm(global_channels)
-
- lact = nn.Identity if ratio_gout == 1 else activation_layer
- gact = nn.Identity if ratio_gout == 0 else activation_layer
- self.act_l = lact(inplace=True)
- self.act_g = gact(inplace=True)
-
- def forward(self, x):
- x_l, x_g = self.ffc(x)
- x_l = self.act_l(self.bn_l(x_l))
- x_g = self.act_g(self.bn_g(x_g))
- return x_l, x_g
-
-
-class FFCResnetBlock(nn.Module):
- def __init__(
- self,
- dim,
- padding_type,
- norm_layer,
- activation_layer=nn.ReLU,
- dilation=1,
- spatial_transform_kwargs=None,
- inline=False,
- **conv_kwargs,
- ):
- super().__init__()
- self.conv1 = FFC_BN_ACT(
- dim,
- dim,
- kernel_size=3,
- padding=dilation,
- dilation=dilation,
- norm_layer=norm_layer,
- activation_layer=activation_layer,
- padding_type=padding_type,
- **conv_kwargs,
- )
- self.conv2 = FFC_BN_ACT(
- dim,
- dim,
- kernel_size=3,
- padding=dilation,
- dilation=dilation,
- norm_layer=norm_layer,
- activation_layer=activation_layer,
- padding_type=padding_type,
- **conv_kwargs,
- )
- if spatial_transform_kwargs is not None:
- self.conv1 = LearnableSpatialTransformWrapper(
- self.conv1, **spatial_transform_kwargs
- )
- self.conv2 = LearnableSpatialTransformWrapper(
- self.conv2, **spatial_transform_kwargs
- )
- self.inline = inline
-
- def forward(self, x):
- if self.inline:
- x_l, x_g = (
- x[:, : -self.conv1.ffc.global_in_num],
- x[:, -self.conv1.ffc.global_in_num :],
- )
- else:
- x_l, x_g = x if type(x) is tuple else (x, 0)
-
- id_l, id_g = x_l, x_g
-
- x_l, x_g = self.conv1((x_l, x_g))
- x_l, x_g = self.conv2((x_l, x_g))
-
- x_l, x_g = id_l + x_l, id_g + x_g
- out = x_l, x_g
- if self.inline:
- out = torch.cat(out, dim=1)
- return out
-
-
-class ConcatTupleLayer(nn.Module):
- def forward(self, x):
- assert isinstance(x, tuple)
- x_l, x_g = x
- assert torch.is_tensor(x_l) or torch.is_tensor(x_g)
- if not torch.is_tensor(x_g):
- return x_l
- return torch.cat(x, dim=1)
-
-
-class FFCResNetGenerator(nn.Module):
- def __init__(
- self,
- input_nc,
- output_nc,
- ngf=64,
- n_downsampling=3,
- n_blocks=18,
- norm_layer=nn.BatchNorm2d,
- padding_type="reflect",
- activation_layer=nn.ReLU,
- up_norm_layer=nn.BatchNorm2d,
- up_activation=nn.ReLU(True),
- init_conv_kwargs={},
- downsample_conv_kwargs={},
- resnet_conv_kwargs={},
- spatial_transform_layers=None,
- spatial_transform_kwargs={},
- max_features=1024,
- out_ffc=False,
- out_ffc_kwargs={},
- ):
- assert n_blocks >= 0
- super().__init__()
- """
- init_conv_kwargs = {'ratio_gin': 0, 'ratio_gout': 0, 'enable_lfu': False}
- downsample_conv_kwargs = {'ratio_gin': '${generator.init_conv_kwargs.ratio_gout}', 'ratio_gout': '${generator.downsample_conv_kwargs.ratio_gin}', 'enable_lfu': False}
- resnet_conv_kwargs = {'ratio_gin': 0.75, 'ratio_gout': '${generator.resnet_conv_kwargs.ratio_gin}', 'enable_lfu': False}
- spatial_transform_kwargs = {}
- out_ffc_kwargs = {}
- """
- """
- print(input_nc, output_nc, ngf, n_downsampling, n_blocks, norm_layer,
- padding_type, activation_layer,
- up_norm_layer, up_activation,
- spatial_transform_layers,
- add_out_act, max_features, out_ffc, file=sys.stderr)
-
- 4 3 64 3 18
- reflect
-
- ReLU(inplace=True)
- None sigmoid 1024 False
- """
- init_conv_kwargs = {"ratio_gin": 0, "ratio_gout": 0, "enable_lfu": False}
- downsample_conv_kwargs = {"ratio_gin": 0, "ratio_gout": 0, "enable_lfu": False}
- resnet_conv_kwargs = {
- "ratio_gin": 0.75,
- "ratio_gout": 0.75,
- "enable_lfu": False,
- }
- spatial_transform_kwargs = {}
- out_ffc_kwargs = {}
-
- model = [
- nn.ReflectionPad2d(3),
- FFC_BN_ACT(
- input_nc,
- ngf,
- kernel_size=7,
- padding=0,
- norm_layer=norm_layer,
- activation_layer=activation_layer,
- **init_conv_kwargs,
- ),
- ]
-
- ### downsample
- for i in range(n_downsampling):
- mult = 2**i
- if i == n_downsampling - 1:
- cur_conv_kwargs = dict(downsample_conv_kwargs)
- cur_conv_kwargs["ratio_gout"] = resnet_conv_kwargs.get("ratio_gin", 0)
- else:
- cur_conv_kwargs = downsample_conv_kwargs
- model += [
- FFC_BN_ACT(
- min(max_features, ngf * mult),
- min(max_features, ngf * mult * 2),
- kernel_size=3,
- stride=2,
- padding=1,
- norm_layer=norm_layer,
- activation_layer=activation_layer,
- **cur_conv_kwargs,
- )
- ]
-
- mult = 2**n_downsampling
- feats_num_bottleneck = min(max_features, ngf * mult)
-
- ### resnet blocks
- for i in range(n_blocks):
- cur_resblock = FFCResnetBlock(
- feats_num_bottleneck,
- padding_type=padding_type,
- activation_layer=activation_layer,
- norm_layer=norm_layer,
- **resnet_conv_kwargs,
- )
- if spatial_transform_layers is not None and i in spatial_transform_layers:
- cur_resblock = LearnableSpatialTransformWrapper(
- cur_resblock, **spatial_transform_kwargs
- )
- model += [cur_resblock]
-
- model += [ConcatTupleLayer()]
-
- ### upsample
- for i in range(n_downsampling):
- mult = 2 ** (n_downsampling - i)
- model += [
- nn.ConvTranspose2d(
- min(max_features, ngf * mult),
- min(max_features, int(ngf * mult / 2)),
- kernel_size=3,
- stride=2,
- padding=1,
- output_padding=1,
- ),
- up_norm_layer(min(max_features, int(ngf * mult / 2))),
- up_activation,
- ]
-
- if out_ffc:
- model += [
- FFCResnetBlock(
- ngf,
- padding_type=padding_type,
- activation_layer=activation_layer,
- norm_layer=norm_layer,
- inline=True,
- **out_ffc_kwargs,
- )
- ]
-
- model += [
- nn.ReflectionPad2d(3),
- nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0),
- ]
- model.append(nn.Sigmoid())
- self.model = nn.Sequential(*model)
-
- def forward(self, image, mask):
- return self.model(torch.cat([image, mask], dim=1))
-
-
-class LaMa(nn.Module):
- def __init__(self, state_dict) -> None:
- super(LaMa, self).__init__()
- self.model_arch = "LaMa"
- self.sub_type = "Inpaint"
- self.in_nc = 4
- self.out_nc = 3
- self.scale = 1
-
- self.min_size = None
- self.pad_mod = 8
- self.pad_to_square = False
-
- self.model = FFCResNetGenerator(self.in_nc, self.out_nc)
- self.state = {
- k.replace("generator.model", "model.model"): v
- for k, v in state_dict.items()
- }
-
- self.supports_fp16 = False
- self.support_bf16 = True
-
- self.load_state_dict(self.state, strict=False)
-
- def forward(self, img, mask):
- masked_img = img * (1 - mask)
- inpainted_mask = mask * self.model.forward(masked_img, mask)
- result = inpainted_mask + (1 - mask) * img
- return result
diff --git a/comfy_extras/chainner_models/architecture/OmniSR/ChannelAttention.py b/comfy_extras/chainner_models/architecture/OmniSR/ChannelAttention.py
deleted file mode 100644
index f4d52aa1e06..00000000000
--- a/comfy_extras/chainner_models/architecture/OmniSR/ChannelAttention.py
+++ /dev/null
@@ -1,110 +0,0 @@
-import math
-
-import torch.nn as nn
-
-
-class CA_layer(nn.Module):
- def __init__(self, channel, reduction=16):
- super(CA_layer, self).__init__()
- # global average pooling
- self.gap = nn.AdaptiveAvgPool2d(1)
- self.fc = nn.Sequential(
- nn.Conv2d(channel, channel // reduction, kernel_size=(1, 1), bias=False),
- nn.GELU(),
- nn.Conv2d(channel // reduction, channel, kernel_size=(1, 1), bias=False),
- # nn.Sigmoid()
- )
-
- def forward(self, x):
- y = self.fc(self.gap(x))
- return x * y.expand_as(x)
-
-
-class Simple_CA_layer(nn.Module):
- def __init__(self, channel):
- super(Simple_CA_layer, self).__init__()
- self.gap = nn.AdaptiveAvgPool2d(1)
- self.fc = nn.Conv2d(
- in_channels=channel,
- out_channels=channel,
- kernel_size=1,
- padding=0,
- stride=1,
- groups=1,
- bias=True,
- )
-
- def forward(self, x):
- return x * self.fc(self.gap(x))
-
-
-class ECA_layer(nn.Module):
- """Constructs a ECA module.
- Args:
- channel: Number of channels of the input feature map
- k_size: Adaptive selection of kernel size
- """
-
- def __init__(self, channel):
- super(ECA_layer, self).__init__()
-
- b = 1
- gamma = 2
- k_size = int(abs(math.log(channel, 2) + b) / gamma)
- k_size = k_size if k_size % 2 else k_size + 1
- self.avg_pool = nn.AdaptiveAvgPool2d(1)
- self.conv = nn.Conv1d(
- 1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False
- )
- # self.sigmoid = nn.Sigmoid()
-
- def forward(self, x):
- # x: input features with shape [b, c, h, w]
- # b, c, h, w = x.size()
-
- # feature descriptor on the global spatial information
- y = self.avg_pool(x)
-
- # Two different branches of ECA module
- y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
-
- # Multi-scale information fusion
- # y = self.sigmoid(y)
-
- return x * y.expand_as(x)
-
-
-class ECA_MaxPool_layer(nn.Module):
- """Constructs a ECA module.
- Args:
- channel: Number of channels of the input feature map
- k_size: Adaptive selection of kernel size
- """
-
- def __init__(self, channel):
- super(ECA_MaxPool_layer, self).__init__()
-
- b = 1
- gamma = 2
- k_size = int(abs(math.log(channel, 2) + b) / gamma)
- k_size = k_size if k_size % 2 else k_size + 1
- self.max_pool = nn.AdaptiveMaxPool2d(1)
- self.conv = nn.Conv1d(
- 1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False
- )
- # self.sigmoid = nn.Sigmoid()
-
- def forward(self, x):
- # x: input features with shape [b, c, h, w]
- # b, c, h, w = x.size()
-
- # feature descriptor on the global spatial information
- y = self.max_pool(x)
-
- # Two different branches of ECA module
- y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
-
- # Multi-scale information fusion
- # y = self.sigmoid(y)
-
- return x * y.expand_as(x)
diff --git a/comfy_extras/chainner_models/architecture/OmniSR/LICENSE b/comfy_extras/chainner_models/architecture/OmniSR/LICENSE
deleted file mode 100644
index 261eeb9e9f8..00000000000
--- a/comfy_extras/chainner_models/architecture/OmniSR/LICENSE
+++ /dev/null
@@ -1,201 +0,0 @@
- Apache License
- Version 2.0, January 2004
- http://www.apache.org/licenses/
-
- TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
-
- 1. Definitions.
-
- "License" shall mean the terms and conditions for use, reproduction,
- and distribution as defined by Sections 1 through 9 of this document.
-
- "Licensor" shall mean the copyright owner or entity authorized by
- the copyright owner that is granting the License.
-
- "Legal Entity" shall mean the union of the acting entity and all
- other entities that control, are controlled by, or are under common
- control with that entity. For the purposes of this definition,
- "control" means (i) the power, direct or indirect, to cause the
- direction or management of such entity, whether by contract or
- otherwise, or (ii) ownership of fifty percent (50%) or more of the
- outstanding shares, or (iii) beneficial ownership of such entity.
-
- "You" (or "Your") shall mean an individual or Legal Entity
- exercising permissions granted by this License.
-
- "Source" form shall mean the preferred form for making modifications,
- including but not limited to software source code, documentation
- source, and configuration files.
-
- "Object" form shall mean any form resulting from mechanical
- transformation or translation of a Source form, including but
- not limited to compiled object code, generated documentation,
- and conversions to other media types.
-
- "Work" shall mean the work of authorship, whether in Source or
- Object form, made available under the License, as indicated by a
- copyright notice that is included in or attached to the work
- (an example is provided in the Appendix below).
-
- "Derivative Works" shall mean any work, whether in Source or Object
- form, that is based on (or derived from) the Work and for which the
- editorial revisions, annotations, elaborations, or other modifications
- represent, as a whole, an original work of authorship. For the purposes
- of this License, Derivative Works shall not include works that remain
- separable from, or merely link (or bind by name) to the interfaces of,
- the Work and Derivative Works thereof.
-
- "Contribution" shall mean any work of authorship, including
- the original version of the Work and any modifications or additions
- to that Work or Derivative Works thereof, that is intentionally
- submitted to Licensor for inclusion in the Work by the copyright owner
- or by an individual or Legal Entity authorized to submit on behalf of
- the copyright owner. For the purposes of this definition, "submitted"
- means any form of electronic, verbal, or written communication sent
- to the Licensor or its representatives, including but not limited to
- communication on electronic mailing lists, source code control systems,
- and issue tracking systems that are managed by, or on behalf of, the
- Licensor for the purpose of discussing and improving the Work, but
- excluding communication that is conspicuously marked or otherwise
- designated in writing by the copyright owner as "Not a Contribution."
-
- "Contributor" shall mean Licensor and any individual or Legal Entity
- on behalf of whom a Contribution has been received by Licensor and
- subsequently incorporated within the Work.
-
- 2. Grant of Copyright License. Subject to the terms and conditions of
- this License, each Contributor hereby grants to You a perpetual,
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
- copyright license to reproduce, prepare Derivative Works of,
- publicly display, publicly perform, sublicense, and distribute the
- Work and such Derivative Works in Source or Object form.
-
- 3. Grant of Patent License. Subject to the terms and conditions of
- this License, each Contributor hereby grants to You a perpetual,
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
- (except as stated in this section) patent license to make, have made,
- use, offer to sell, sell, import, and otherwise transfer the Work,
- where such license applies only to those patent claims licensable
- by such Contributor that are necessarily infringed by their
- Contribution(s) alone or by combination of their Contribution(s)
- with the Work to which such Contribution(s) was submitted. If You
- institute patent litigation against any entity (including a
- cross-claim or counterclaim in a lawsuit) alleging that the Work
- or a Contribution incorporated within the Work constitutes direct
- or contributory patent infringement, then any patent licenses
- granted to You under this License for that Work shall terminate
- as of the date such litigation is filed.
-
- 4. Redistribution. You may reproduce and distribute copies of the
- Work or Derivative Works thereof in any medium, with or without
- modifications, and in Source or Object form, provided that You
- meet the following conditions:
-
- (a) You must give any other recipients of the Work or
- Derivative Works a copy of this License; and
-
- (b) You must cause any modified files to carry prominent notices
- stating that You changed the files; and
-
- (c) You must retain, in the Source form of any Derivative Works
- that You distribute, all copyright, patent, trademark, and
- attribution notices from the Source form of the Work,
- excluding those notices that do not pertain to any part of
- the Derivative Works; and
-
- (d) If the Work includes a "NOTICE" text file as part of its
- distribution, then any Derivative Works that You distribute must
- include a readable copy of the attribution notices contained
- within such NOTICE file, excluding those notices that do not
- pertain to any part of the Derivative Works, in at least one
- of the following places: within a NOTICE text file distributed
- as part of the Derivative Works; within the Source form or
- documentation, if provided along with the Derivative Works; or,
- within a display generated by the Derivative Works, if and
- wherever such third-party notices normally appear. The contents
- of the NOTICE file are for informational purposes only and
- do not modify the License. You may add Your own attribution
- notices within Derivative Works that You distribute, alongside
- or as an addendum to the NOTICE text from the Work, provided
- that such additional attribution notices cannot be construed
- as modifying the License.
-
- You may add Your own copyright statement to Your modifications and
- may provide additional or different license terms and conditions
- for use, reproduction, or distribution of Your modifications, or
- for any such Derivative Works as a whole, provided Your use,
- reproduction, and distribution of the Work otherwise complies with
- the conditions stated in this License.
-
- 5. Submission of Contributions. Unless You explicitly state otherwise,
- any Contribution intentionally submitted for inclusion in the Work
- by You to the Licensor shall be under the terms and conditions of
- this License, without any additional terms or conditions.
- Notwithstanding the above, nothing herein shall supersede or modify
- the terms of any separate license agreement you may have executed
- with Licensor regarding such Contributions.
-
- 6. Trademarks. This License does not grant permission to use the trade
- names, trademarks, service marks, or product names of the Licensor,
- except as required for reasonable and customary use in describing the
- origin of the Work and reproducing the content of the NOTICE file.
-
- 7. Disclaimer of Warranty. Unless required by applicable law or
- agreed to in writing, Licensor provides the Work (and each
- Contributor provides its Contributions) on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
- implied, including, without limitation, any warranties or conditions
- of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
- PARTICULAR PURPOSE. You are solely responsible for determining the
- appropriateness of using or redistributing the Work and assume any
- risks associated with Your exercise of permissions under this License.
-
- 8. Limitation of Liability. In no event and under no legal theory,
- whether in tort (including negligence), contract, or otherwise,
- unless required by applicable law (such as deliberate and grossly
- negligent acts) or agreed to in writing, shall any Contributor be
- liable to You for damages, including any direct, indirect, special,
- incidental, or consequential damages of any character arising as a
- result of this License or out of the use or inability to use the
- Work (including but not limited to damages for loss of goodwill,
- work stoppage, computer failure or malfunction, or any and all
- other commercial damages or losses), even if such Contributor
- has been advised of the possibility of such damages.
-
- 9. Accepting Warranty or Additional Liability. While redistributing
- the Work or Derivative Works thereof, You may choose to offer,
- and charge a fee for, acceptance of support, warranty, indemnity,
- or other liability obligations and/or rights consistent with this
- License. However, in accepting such obligations, You may act only
- on Your own behalf and on Your sole responsibility, not on behalf
- of any other Contributor, and only if You agree to indemnify,
- defend, and hold each Contributor harmless for any liability
- incurred by, or claims asserted against, such Contributor by reason
- of your accepting any such warranty or additional liability.
-
- END OF TERMS AND CONDITIONS
-
- APPENDIX: How to apply the Apache License to your work.
-
- To apply the Apache License to your work, attach the following
- boilerplate notice, with the fields enclosed by brackets "[]"
- replaced with your own identifying information. (Don't include
- the brackets!) The text should be enclosed in the appropriate
- comment syntax for the file format. We also recommend that a
- file or class name and description of purpose be included on the
- same "printed page" as the copyright notice for easier
- identification within third-party archives.
-
- Copyright [yyyy] [name of copyright owner]
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
diff --git a/comfy_extras/chainner_models/architecture/OmniSR/OSA.py b/comfy_extras/chainner_models/architecture/OmniSR/OSA.py
deleted file mode 100644
index d7a129696b2..00000000000
--- a/comfy_extras/chainner_models/architecture/OmniSR/OSA.py
+++ /dev/null
@@ -1,577 +0,0 @@
-#!/usr/bin/env python3
-# -*- coding:utf-8 -*-
-#############################################################
-# File: OSA.py
-# Created Date: Tuesday April 28th 2022
-# Author: Chen Xuanhong
-# Email: chenxuanhongzju@outlook.com
-# Last Modified: Sunday, 23rd April 2023 3:07:42 pm
-# Modified By: Chen Xuanhong
-# Copyright (c) 2020 Shanghai Jiao Tong University
-#############################################################
-
-import torch
-import torch.nn.functional as F
-from einops import rearrange, repeat
-from einops.layers.torch import Rearrange, Reduce
-from torch import einsum, nn
-
-from .layernorm import LayerNorm2d
-
-# helpers
-
-
-def exists(val):
- return val is not None
-
-
-def default(val, d):
- return val if exists(val) else d
-
-
-def cast_tuple(val, length=1):
- return val if isinstance(val, tuple) else ((val,) * length)
-
-
-# helper classes
-
-
-class PreNormResidual(nn.Module):
- def __init__(self, dim, fn):
- super().__init__()
- self.norm = nn.LayerNorm(dim)
- self.fn = fn
-
- def forward(self, x):
- return self.fn(self.norm(x)) + x
-
-
-class Conv_PreNormResidual(nn.Module):
- def __init__(self, dim, fn):
- super().__init__()
- self.norm = LayerNorm2d(dim)
- self.fn = fn
-
- def forward(self, x):
- return self.fn(self.norm(x)) + x
-
-
-class FeedForward(nn.Module):
- def __init__(self, dim, mult=2, dropout=0.0):
- super().__init__()
- inner_dim = int(dim * mult)
- self.net = nn.Sequential(
- nn.Linear(dim, inner_dim),
- nn.GELU(),
- nn.Dropout(dropout),
- nn.Linear(inner_dim, dim),
- nn.Dropout(dropout),
- )
-
- def forward(self, x):
- return self.net(x)
-
-
-class Conv_FeedForward(nn.Module):
- def __init__(self, dim, mult=2, dropout=0.0):
- super().__init__()
- inner_dim = int(dim * mult)
- self.net = nn.Sequential(
- nn.Conv2d(dim, inner_dim, 1, 1, 0),
- nn.GELU(),
- nn.Dropout(dropout),
- nn.Conv2d(inner_dim, dim, 1, 1, 0),
- nn.Dropout(dropout),
- )
-
- def forward(self, x):
- return self.net(x)
-
-
-class Gated_Conv_FeedForward(nn.Module):
- def __init__(self, dim, mult=1, bias=False, dropout=0.0):
- super().__init__()
-
- hidden_features = int(dim * mult)
-
- self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias)
-
- self.dwconv = nn.Conv2d(
- hidden_features * 2,
- hidden_features * 2,
- kernel_size=3,
- stride=1,
- padding=1,
- groups=hidden_features * 2,
- bias=bias,
- )
-
- self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
-
- def forward(self, x):
- x = self.project_in(x)
- x1, x2 = self.dwconv(x).chunk(2, dim=1)
- x = F.gelu(x1) * x2
- x = self.project_out(x)
- return x
-
-
-# MBConv
-
-
-class SqueezeExcitation(nn.Module):
- def __init__(self, dim, shrinkage_rate=0.25):
- super().__init__()
- hidden_dim = int(dim * shrinkage_rate)
-
- self.gate = nn.Sequential(
- Reduce("b c h w -> b c", "mean"),
- nn.Linear(dim, hidden_dim, bias=False),
- nn.SiLU(),
- nn.Linear(hidden_dim, dim, bias=False),
- nn.Sigmoid(),
- Rearrange("b c -> b c 1 1"),
- )
-
- def forward(self, x):
- return x * self.gate(x)
-
-
-class MBConvResidual(nn.Module):
- def __init__(self, fn, dropout=0.0):
- super().__init__()
- self.fn = fn
- self.dropsample = Dropsample(dropout)
-
- def forward(self, x):
- out = self.fn(x)
- out = self.dropsample(out)
- return out + x
-
-
-class Dropsample(nn.Module):
- def __init__(self, prob=0):
- super().__init__()
- self.prob = prob
-
- def forward(self, x):
- device = x.device
-
- if self.prob == 0.0 or (not self.training):
- return x
-
- keep_mask = (
- torch.FloatTensor((x.shape[0], 1, 1, 1), device=device).uniform_()
- > self.prob
- )
- return x * keep_mask / (1 - self.prob)
-
-
-def MBConv(
- dim_in, dim_out, *, downsample, expansion_rate=4, shrinkage_rate=0.25, dropout=0.0
-):
- hidden_dim = int(expansion_rate * dim_out)
- stride = 2 if downsample else 1
-
- net = nn.Sequential(
- nn.Conv2d(dim_in, hidden_dim, 1),
- # nn.BatchNorm2d(hidden_dim),
- nn.GELU(),
- nn.Conv2d(
- hidden_dim, hidden_dim, 3, stride=stride, padding=1, groups=hidden_dim
- ),
- # nn.BatchNorm2d(hidden_dim),
- nn.GELU(),
- SqueezeExcitation(hidden_dim, shrinkage_rate=shrinkage_rate),
- nn.Conv2d(hidden_dim, dim_out, 1),
- # nn.BatchNorm2d(dim_out)
- )
-
- if dim_in == dim_out and not downsample:
- net = MBConvResidual(net, dropout=dropout)
-
- return net
-
-
-# attention related classes
-class Attention(nn.Module):
- def __init__(
- self,
- dim,
- dim_head=32,
- dropout=0.0,
- window_size=7,
- with_pe=True,
- ):
- super().__init__()
- assert (
- dim % dim_head
- ) == 0, "dimension should be divisible by dimension per head"
-
- self.heads = dim // dim_head
- self.scale = dim_head**-0.5
- self.with_pe = with_pe
-
- self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
-
- self.attend = nn.Sequential(nn.Softmax(dim=-1), nn.Dropout(dropout))
-
- self.to_out = nn.Sequential(
- nn.Linear(dim, dim, bias=False), nn.Dropout(dropout)
- )
-
- # relative positional bias
- if self.with_pe:
- self.rel_pos_bias = nn.Embedding((2 * window_size - 1) ** 2, self.heads)
-
- pos = torch.arange(window_size)
- grid = torch.stack(torch.meshgrid(pos, pos))
- grid = rearrange(grid, "c i j -> (i j) c")
- rel_pos = rearrange(grid, "i ... -> i 1 ...") - rearrange(
- grid, "j ... -> 1 j ..."
- )
- rel_pos += window_size - 1
- rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(
- dim=-1
- )
-
- self.register_buffer("rel_pos_indices", rel_pos_indices, persistent=False)
-
- def forward(self, x):
- batch, height, width, window_height, window_width, _, device, h = (
- *x.shape,
- x.device,
- self.heads,
- )
-
- # flatten
-
- x = rearrange(x, "b x y w1 w2 d -> (b x y) (w1 w2) d")
-
- # project for queries, keys, values
-
- q, k, v = self.to_qkv(x).chunk(3, dim=-1)
-
- # split heads
-
- q, k, v = map(lambda t: rearrange(t, "b n (h d ) -> b h n d", h=h), (q, k, v))
-
- # scale
-
- q = q * self.scale
-
- # sim
-
- sim = einsum("b h i d, b h j d -> b h i j", q, k)
-
- # add positional bias
- if self.with_pe:
- bias = self.rel_pos_bias(self.rel_pos_indices)
- sim = sim + rearrange(bias, "i j h -> h i j")
-
- # attention
-
- attn = self.attend(sim)
-
- # aggregate
-
- out = einsum("b h i j, b h j d -> b h i d", attn, v)
-
- # merge heads
-
- out = rearrange(
- out, "b h (w1 w2) d -> b w1 w2 (h d)", w1=window_height, w2=window_width
- )
-
- # combine heads out
-
- out = self.to_out(out)
- return rearrange(out, "(b x y) ... -> b x y ...", x=height, y=width)
-
-
-class Block_Attention(nn.Module):
- def __init__(
- self,
- dim,
- dim_head=32,
- bias=False,
- dropout=0.0,
- window_size=7,
- with_pe=True,
- ):
- super().__init__()
- assert (
- dim % dim_head
- ) == 0, "dimension should be divisible by dimension per head"
-
- self.heads = dim // dim_head
- self.ps = window_size
- self.scale = dim_head**-0.5
- self.with_pe = with_pe
-
- self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
- self.qkv_dwconv = nn.Conv2d(
- dim * 3,
- dim * 3,
- kernel_size=3,
- stride=1,
- padding=1,
- groups=dim * 3,
- bias=bias,
- )
-
- self.attend = nn.Sequential(nn.Softmax(dim=-1), nn.Dropout(dropout))
-
- self.to_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
-
- def forward(self, x):
- # project for queries, keys, values
- b, c, h, w = x.shape
-
- qkv = self.qkv_dwconv(self.qkv(x))
- q, k, v = qkv.chunk(3, dim=1)
-
- # split heads
-
- q, k, v = map(
- lambda t: rearrange(
- t,
- "b (h d) (x w1) (y w2) -> (b x y) h (w1 w2) d",
- h=self.heads,
- w1=self.ps,
- w2=self.ps,
- ),
- (q, k, v),
- )
-
- # scale
-
- q = q * self.scale
-
- # sim
-
- sim = einsum("b h i d, b h j d -> b h i j", q, k)
-
- # attention
- attn = self.attend(sim)
-
- # aggregate
-
- out = einsum("b h i j, b h j d -> b h i d", attn, v)
-
- # merge heads
- out = rearrange(
- out,
- "(b x y) head (w1 w2) d -> b (head d) (x w1) (y w2)",
- x=h // self.ps,
- y=w // self.ps,
- head=self.heads,
- w1=self.ps,
- w2=self.ps,
- )
-
- out = self.to_out(out)
- return out
-
-
-class Channel_Attention(nn.Module):
- def __init__(self, dim, heads, bias=False, dropout=0.0, window_size=7):
- super(Channel_Attention, self).__init__()
- self.heads = heads
-
- self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
-
- self.ps = window_size
-
- self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
- self.qkv_dwconv = nn.Conv2d(
- dim * 3,
- dim * 3,
- kernel_size=3,
- stride=1,
- padding=1,
- groups=dim * 3,
- bias=bias,
- )
- self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
-
- def forward(self, x):
- b, c, h, w = x.shape
-
- qkv = self.qkv_dwconv(self.qkv(x))
- qkv = qkv.chunk(3, dim=1)
-
- q, k, v = map(
- lambda t: rearrange(
- t,
- "b (head d) (h ph) (w pw) -> b (h w) head d (ph pw)",
- ph=self.ps,
- pw=self.ps,
- head=self.heads,
- ),
- qkv,
- )
-
- q = F.normalize(q, dim=-1)
- k = F.normalize(k, dim=-1)
-
- attn = (q @ k.transpose(-2, -1)) * self.temperature
- attn = attn.softmax(dim=-1)
- out = attn @ v
-
- out = rearrange(
- out,
- "b (h w) head d (ph pw) -> b (head d) (h ph) (w pw)",
- h=h // self.ps,
- w=w // self.ps,
- ph=self.ps,
- pw=self.ps,
- head=self.heads,
- )
-
- out = self.project_out(out)
-
- return out
-
-
-class Channel_Attention_grid(nn.Module):
- def __init__(self, dim, heads, bias=False, dropout=0.0, window_size=7):
- super(Channel_Attention_grid, self).__init__()
- self.heads = heads
-
- self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
-
- self.ps = window_size
-
- self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
- self.qkv_dwconv = nn.Conv2d(
- dim * 3,
- dim * 3,
- kernel_size=3,
- stride=1,
- padding=1,
- groups=dim * 3,
- bias=bias,
- )
- self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
-
- def forward(self, x):
- b, c, h, w = x.shape
-
- qkv = self.qkv_dwconv(self.qkv(x))
- qkv = qkv.chunk(3, dim=1)
-
- q, k, v = map(
- lambda t: rearrange(
- t,
- "b (head d) (h ph) (w pw) -> b (ph pw) head d (h w)",
- ph=self.ps,
- pw=self.ps,
- head=self.heads,
- ),
- qkv,
- )
-
- q = F.normalize(q, dim=-1)
- k = F.normalize(k, dim=-1)
-
- attn = (q @ k.transpose(-2, -1)) * self.temperature
- attn = attn.softmax(dim=-1)
- out = attn @ v
-
- out = rearrange(
- out,
- "b (ph pw) head d (h w) -> b (head d) (h ph) (w pw)",
- h=h // self.ps,
- w=w // self.ps,
- ph=self.ps,
- pw=self.ps,
- head=self.heads,
- )
-
- out = self.project_out(out)
-
- return out
-
-
-class OSA_Block(nn.Module):
- def __init__(
- self,
- channel_num=64,
- bias=True,
- ffn_bias=True,
- window_size=8,
- with_pe=False,
- dropout=0.0,
- ):
- super(OSA_Block, self).__init__()
-
- w = window_size
-
- self.layer = nn.Sequential(
- MBConv(
- channel_num,
- channel_num,
- downsample=False,
- expansion_rate=1,
- shrinkage_rate=0.25,
- ),
- Rearrange(
- "b d (x w1) (y w2) -> b x y w1 w2 d", w1=w, w2=w
- ), # block-like attention
- PreNormResidual(
- channel_num,
- Attention(
- dim=channel_num,
- dim_head=channel_num // 4,
- dropout=dropout,
- window_size=window_size,
- with_pe=with_pe,
- ),
- ),
- Rearrange("b x y w1 w2 d -> b d (x w1) (y w2)"),
- Conv_PreNormResidual(
- channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout)
- ),
- # channel-like attention
- Conv_PreNormResidual(
- channel_num,
- Channel_Attention(
- dim=channel_num, heads=4, dropout=dropout, window_size=window_size
- ),
- ),
- Conv_PreNormResidual(
- channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout)
- ),
- Rearrange(
- "b d (w1 x) (w2 y) -> b x y w1 w2 d", w1=w, w2=w
- ), # grid-like attention
- PreNormResidual(
- channel_num,
- Attention(
- dim=channel_num,
- dim_head=channel_num // 4,
- dropout=dropout,
- window_size=window_size,
- with_pe=with_pe,
- ),
- ),
- Rearrange("b x y w1 w2 d -> b d (w1 x) (w2 y)"),
- Conv_PreNormResidual(
- channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout)
- ),
- # channel-like attention
- Conv_PreNormResidual(
- channel_num,
- Channel_Attention_grid(
- dim=channel_num, heads=4, dropout=dropout, window_size=window_size
- ),
- ),
- Conv_PreNormResidual(
- channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout)
- ),
- )
-
- def forward(self, x):
- out = self.layer(x)
- return out
diff --git a/comfy_extras/chainner_models/architecture/OmniSR/OSAG.py b/comfy_extras/chainner_models/architecture/OmniSR/OSAG.py
deleted file mode 100644
index 477e81f9da4..00000000000
--- a/comfy_extras/chainner_models/architecture/OmniSR/OSAG.py
+++ /dev/null
@@ -1,60 +0,0 @@
-#!/usr/bin/env python3
-# -*- coding:utf-8 -*-
-#############################################################
-# File: OSAG.py
-# Created Date: Tuesday April 28th 2022
-# Author: Chen Xuanhong
-# Email: chenxuanhongzju@outlook.com
-# Last Modified: Sunday, 23rd April 2023 3:08:49 pm
-# Modified By: Chen Xuanhong
-# Copyright (c) 2020 Shanghai Jiao Tong University
-#############################################################
-
-
-import torch.nn as nn
-
-from .esa import ESA
-from .OSA import OSA_Block
-
-
-class OSAG(nn.Module):
- def __init__(
- self,
- channel_num=64,
- bias=True,
- block_num=4,
- ffn_bias=False,
- window_size=0,
- pe=False,
- ):
- super(OSAG, self).__init__()
-
- # print("window_size: %d" % (window_size))
- # print("with_pe", pe)
- # print("ffn_bias: %d" % (ffn_bias))
-
- # block_script_name = kwargs.get("block_script_name", "OSA")
- # block_class_name = kwargs.get("block_class_name", "OSA_Block")
-
- # script_name = "." + block_script_name
- # package = __import__(script_name, fromlist=True)
- block_class = OSA_Block # getattr(package, block_class_name)
- group_list = []
- for _ in range(block_num):
- temp_res = block_class(
- channel_num,
- bias,
- ffn_bias=ffn_bias,
- window_size=window_size,
- with_pe=pe,
- )
- group_list.append(temp_res)
- group_list.append(nn.Conv2d(channel_num, channel_num, 1, 1, 0, bias=bias))
- self.residual_layer = nn.Sequential(*group_list)
- esa_channel = max(channel_num // 4, 16)
- self.esa = ESA(esa_channel, channel_num)
-
- def forward(self, x):
- out = self.residual_layer(x)
- out = out + x
- return self.esa(out)
diff --git a/comfy_extras/chainner_models/architecture/OmniSR/OmniSR.py b/comfy_extras/chainner_models/architecture/OmniSR/OmniSR.py
deleted file mode 100644
index 1e1c3f35e65..00000000000
--- a/comfy_extras/chainner_models/architecture/OmniSR/OmniSR.py
+++ /dev/null
@@ -1,143 +0,0 @@
-#!/usr/bin/env python3
-# -*- coding:utf-8 -*-
-#############################################################
-# File: OmniSR.py
-# Created Date: Tuesday April 28th 2022
-# Author: Chen Xuanhong
-# Email: chenxuanhongzju@outlook.com
-# Last Modified: Sunday, 23rd April 2023 3:06:36 pm
-# Modified By: Chen Xuanhong
-# Copyright (c) 2020 Shanghai Jiao Tong University
-#############################################################
-
-import math
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-from .OSAG import OSAG
-from .pixelshuffle import pixelshuffle_block
-
-
-class OmniSR(nn.Module):
- def __init__(
- self,
- state_dict,
- **kwargs,
- ):
- super(OmniSR, self).__init__()
- self.state = state_dict
-
- bias = True # Fine to assume this for now
- block_num = 1 # Fine to assume this for now
- ffn_bias = True
- pe = True
-
- num_feat = state_dict["input.weight"].shape[0] or 64
- num_in_ch = state_dict["input.weight"].shape[1] or 3
- num_out_ch = num_in_ch # we can just assume this for now. pixelshuffle smh
-
- pixelshuffle_shape = state_dict["up.0.weight"].shape[0]
- up_scale = math.sqrt(pixelshuffle_shape / num_out_ch)
- if up_scale - int(up_scale) > 0:
- print(
- "out_nc is probably different than in_nc, scale calculation might be wrong"
- )
- up_scale = int(up_scale)
- res_num = 0
- for key in state_dict.keys():
- if "residual_layer" in key:
- temp_res_num = int(key.split(".")[1])
- if temp_res_num > res_num:
- res_num = temp_res_num
- res_num = res_num + 1 # zero-indexed
-
- residual_layer = []
- self.res_num = res_num
-
- if (
- "residual_layer.0.residual_layer.0.layer.2.fn.rel_pos_bias.weight"
- in state_dict.keys()
- ):
- rel_pos_bias_weight = state_dict[
- "residual_layer.0.residual_layer.0.layer.2.fn.rel_pos_bias.weight"
- ].shape[0]
- self.window_size = int((math.sqrt(rel_pos_bias_weight) + 1) / 2)
- else:
- self.window_size = 8
-
- self.up_scale = up_scale
-
- for _ in range(res_num):
- temp_res = OSAG(
- channel_num=num_feat,
- bias=bias,
- block_num=block_num,
- ffn_bias=ffn_bias,
- window_size=self.window_size,
- pe=pe,
- )
- residual_layer.append(temp_res)
- self.residual_layer = nn.Sequential(*residual_layer)
- self.input = nn.Conv2d(
- in_channels=num_in_ch,
- out_channels=num_feat,
- kernel_size=3,
- stride=1,
- padding=1,
- bias=bias,
- )
- self.output = nn.Conv2d(
- in_channels=num_feat,
- out_channels=num_feat,
- kernel_size=3,
- stride=1,
- padding=1,
- bias=bias,
- )
- self.up = pixelshuffle_block(num_feat, num_out_ch, up_scale, bias=bias)
-
- # self.tail = pixelshuffle_block(num_feat,num_out_ch,up_scale,bias=bias)
-
- # for m in self.modules():
- # if isinstance(m, nn.Conv2d):
- # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
- # m.weight.data.normal_(0, sqrt(2. / n))
-
- # chaiNNer specific stuff
- self.model_arch = "OmniSR"
- self.sub_type = "SR"
- self.in_nc = num_in_ch
- self.out_nc = num_out_ch
- self.num_feat = num_feat
- self.scale = up_scale
-
- self.supports_fp16 = True # TODO: Test this
- self.supports_bfp16 = True
- self.min_size_restriction = 16
-
- self.load_state_dict(state_dict, strict=False)
-
- def check_image_size(self, x):
- _, _, h, w = x.size()
- # import pdb; pdb.set_trace()
- mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
- mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
- # x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
- x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "constant", 0)
- return x
-
- def forward(self, x):
- H, W = x.shape[2:]
- x = self.check_image_size(x)
-
- residual = self.input(x)
- out = self.residual_layer(residual)
-
- # origin
- out = torch.add(self.output(out), residual)
- out = self.up(out)
-
- out = out[:, :, : H * self.up_scale, : W * self.up_scale]
- return out
diff --git a/comfy_extras/chainner_models/architecture/OmniSR/esa.py b/comfy_extras/chainner_models/architecture/OmniSR/esa.py
deleted file mode 100644
index f9ce7f7a60b..00000000000
--- a/comfy_extras/chainner_models/architecture/OmniSR/esa.py
+++ /dev/null
@@ -1,294 +0,0 @@
-#!/usr/bin/env python3
-# -*- coding:utf-8 -*-
-#############################################################
-# File: esa.py
-# Created Date: Tuesday April 28th 2022
-# Author: Chen Xuanhong
-# Email: chenxuanhongzju@outlook.com
-# Last Modified: Thursday, 20th April 2023 9:28:06 am
-# Modified By: Chen Xuanhong
-# Copyright (c) 2020 Shanghai Jiao Tong University
-#############################################################
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-from .layernorm import LayerNorm2d
-
-
-def moment(x, dim=(2, 3), k=2):
- assert len(x.size()) == 4
- mean = torch.mean(x, dim=dim).unsqueeze(-1).unsqueeze(-1)
- mk = (1 / (x.size(2) * x.size(3))) * torch.sum(torch.pow(x - mean, k), dim=dim)
- return mk
-
-
-class ESA(nn.Module):
- """
- Modification of Enhanced Spatial Attention (ESA), which is proposed by
- `Residual Feature Aggregation Network for Image Super-Resolution`
- Note: `conv_max` and `conv3_` are NOT used here, so the corresponding codes
- are deleted.
- """
-
- def __init__(self, esa_channels, n_feats, conv=nn.Conv2d):
- super(ESA, self).__init__()
- f = esa_channels
- self.conv1 = conv(n_feats, f, kernel_size=1)
- self.conv_f = conv(f, f, kernel_size=1)
- self.conv2 = conv(f, f, kernel_size=3, stride=2, padding=0)
- self.conv3 = conv(f, f, kernel_size=3, padding=1)
- self.conv4 = conv(f, n_feats, kernel_size=1)
- self.sigmoid = nn.Sigmoid()
- self.relu = nn.ReLU(inplace=True)
-
- def forward(self, x):
- c1_ = self.conv1(x)
- c1 = self.conv2(c1_)
- v_max = F.max_pool2d(c1, kernel_size=7, stride=3)
- c3 = self.conv3(v_max)
- c3 = F.interpolate(
- c3, (x.size(2), x.size(3)), mode="bilinear", align_corners=False
- )
- cf = self.conv_f(c1_)
- c4 = self.conv4(c3 + cf)
- m = self.sigmoid(c4)
- return x * m
-
-
-class LK_ESA(nn.Module):
- def __init__(
- self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True
- ):
- super(LK_ESA, self).__init__()
- f = esa_channels
- self.conv1 = conv(n_feats, f, kernel_size=1)
- self.conv_f = conv(f, f, kernel_size=1)
-
- kernel_size = 17
- kernel_expand = kernel_expand
- padding = kernel_size // 2
-
- self.vec_conv = nn.Conv2d(
- in_channels=f * kernel_expand,
- out_channels=f * kernel_expand,
- kernel_size=(1, kernel_size),
- padding=(0, padding),
- groups=2,
- bias=bias,
- )
- self.vec_conv3x1 = nn.Conv2d(
- in_channels=f * kernel_expand,
- out_channels=f * kernel_expand,
- kernel_size=(1, 3),
- padding=(0, 1),
- groups=2,
- bias=bias,
- )
-
- self.hor_conv = nn.Conv2d(
- in_channels=f * kernel_expand,
- out_channels=f * kernel_expand,
- kernel_size=(kernel_size, 1),
- padding=(padding, 0),
- groups=2,
- bias=bias,
- )
- self.hor_conv1x3 = nn.Conv2d(
- in_channels=f * kernel_expand,
- out_channels=f * kernel_expand,
- kernel_size=(3, 1),
- padding=(1, 0),
- groups=2,
- bias=bias,
- )
-
- self.conv4 = conv(f, n_feats, kernel_size=1)
- self.sigmoid = nn.Sigmoid()
- self.relu = nn.ReLU(inplace=True)
-
- def forward(self, x):
- c1_ = self.conv1(x)
-
- res = self.vec_conv(c1_) + self.vec_conv3x1(c1_)
- res = self.hor_conv(res) + self.hor_conv1x3(res)
-
- cf = self.conv_f(c1_)
- c4 = self.conv4(res + cf)
- m = self.sigmoid(c4)
- return x * m
-
-
-class LK_ESA_LN(nn.Module):
- def __init__(
- self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True
- ):
- super(LK_ESA_LN, self).__init__()
- f = esa_channels
- self.conv1 = conv(n_feats, f, kernel_size=1)
- self.conv_f = conv(f, f, kernel_size=1)
-
- kernel_size = 17
- kernel_expand = kernel_expand
- padding = kernel_size // 2
-
- self.norm = LayerNorm2d(n_feats)
-
- self.vec_conv = nn.Conv2d(
- in_channels=f * kernel_expand,
- out_channels=f * kernel_expand,
- kernel_size=(1, kernel_size),
- padding=(0, padding),
- groups=2,
- bias=bias,
- )
- self.vec_conv3x1 = nn.Conv2d(
- in_channels=f * kernel_expand,
- out_channels=f * kernel_expand,
- kernel_size=(1, 3),
- padding=(0, 1),
- groups=2,
- bias=bias,
- )
-
- self.hor_conv = nn.Conv2d(
- in_channels=f * kernel_expand,
- out_channels=f * kernel_expand,
- kernel_size=(kernel_size, 1),
- padding=(padding, 0),
- groups=2,
- bias=bias,
- )
- self.hor_conv1x3 = nn.Conv2d(
- in_channels=f * kernel_expand,
- out_channels=f * kernel_expand,
- kernel_size=(3, 1),
- padding=(1, 0),
- groups=2,
- bias=bias,
- )
-
- self.conv4 = conv(f, n_feats, kernel_size=1)
- self.sigmoid = nn.Sigmoid()
- self.relu = nn.ReLU(inplace=True)
-
- def forward(self, x):
- c1_ = self.norm(x)
- c1_ = self.conv1(c1_)
-
- res = self.vec_conv(c1_) + self.vec_conv3x1(c1_)
- res = self.hor_conv(res) + self.hor_conv1x3(res)
-
- cf = self.conv_f(c1_)
- c4 = self.conv4(res + cf)
- m = self.sigmoid(c4)
- return x * m
-
-
-class AdaGuidedFilter(nn.Module):
- def __init__(
- self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True
- ):
- super(AdaGuidedFilter, self).__init__()
-
- self.gap = nn.AdaptiveAvgPool2d(1)
- self.fc = nn.Conv2d(
- in_channels=n_feats,
- out_channels=1,
- kernel_size=1,
- padding=0,
- stride=1,
- groups=1,
- bias=True,
- )
-
- self.r = 5
-
- def box_filter(self, x, r):
- channel = x.shape[1]
- kernel_size = 2 * r + 1
- weight = 1.0 / (kernel_size**2)
- box_kernel = weight * torch.ones(
- (channel, 1, kernel_size, kernel_size), dtype=torch.float32, device=x.device
- )
- output = F.conv2d(x, weight=box_kernel, stride=1, padding=r, groups=channel)
- return output
-
- def forward(self, x):
- _, _, H, W = x.shape
- N = self.box_filter(
- torch.ones((1, 1, H, W), dtype=x.dtype, device=x.device), self.r
- )
-
- # epsilon = self.fc(self.gap(x))
- # epsilon = torch.pow(epsilon, 2)
- epsilon = 1e-2
-
- mean_x = self.box_filter(x, self.r) / N
- var_x = self.box_filter(x * x, self.r) / N - mean_x * mean_x
-
- A = var_x / (var_x + epsilon)
- b = (1 - A) * mean_x
- m = A * x + b
-
- # mean_A = self.box_filter(A, self.r) / N
- # mean_b = self.box_filter(b, self.r) / N
- # m = mean_A * x + mean_b
- return x * m
-
-
-class AdaConvGuidedFilter(nn.Module):
- def __init__(
- self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True
- ):
- super(AdaConvGuidedFilter, self).__init__()
- f = esa_channels
-
- self.conv_f = conv(f, f, kernel_size=1)
-
- kernel_size = 17
- kernel_expand = kernel_expand
- padding = kernel_size // 2
-
- self.vec_conv = nn.Conv2d(
- in_channels=f,
- out_channels=f,
- kernel_size=(1, kernel_size),
- padding=(0, padding),
- groups=f,
- bias=bias,
- )
-
- self.hor_conv = nn.Conv2d(
- in_channels=f,
- out_channels=f,
- kernel_size=(kernel_size, 1),
- padding=(padding, 0),
- groups=f,
- bias=bias,
- )
-
- self.gap = nn.AdaptiveAvgPool2d(1)
- self.fc = nn.Conv2d(
- in_channels=f,
- out_channels=f,
- kernel_size=1,
- padding=0,
- stride=1,
- groups=1,
- bias=True,
- )
-
- def forward(self, x):
- y = self.vec_conv(x)
- y = self.hor_conv(y)
-
- sigma = torch.pow(y, 2)
- epsilon = self.fc(self.gap(y))
-
- weight = sigma / (sigma + epsilon)
-
- m = weight * x + (1 - weight)
-
- return x * m
diff --git a/comfy_extras/chainner_models/architecture/OmniSR/layernorm.py b/comfy_extras/chainner_models/architecture/OmniSR/layernorm.py
deleted file mode 100644
index 731a25f7542..00000000000
--- a/comfy_extras/chainner_models/architecture/OmniSR/layernorm.py
+++ /dev/null
@@ -1,70 +0,0 @@
-#!/usr/bin/env python3
-# -*- coding:utf-8 -*-
-#############################################################
-# File: layernorm.py
-# Created Date: Tuesday April 28th 2022
-# Author: Chen Xuanhong
-# Email: chenxuanhongzju@outlook.com
-# Last Modified: Thursday, 20th April 2023 9:28:20 am
-# Modified By: Chen Xuanhong
-# Copyright (c) 2020 Shanghai Jiao Tong University
-#############################################################
-
-import torch
-import torch.nn as nn
-
-
-class LayerNormFunction(torch.autograd.Function):
- @staticmethod
- def forward(ctx, x, weight, bias, eps):
- ctx.eps = eps
- N, C, H, W = x.size()
- mu = x.mean(1, keepdim=True)
- var = (x - mu).pow(2).mean(1, keepdim=True)
- y = (x - mu) / (var + eps).sqrt()
- ctx.save_for_backward(y, var, weight)
- y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1)
- return y
-
- @staticmethod
- def backward(ctx, grad_output):
- eps = ctx.eps
-
- N, C, H, W = grad_output.size()
- y, var, weight = ctx.saved_variables
- g = grad_output * weight.view(1, C, 1, 1)
- mean_g = g.mean(dim=1, keepdim=True)
-
- mean_gy = (g * y).mean(dim=1, keepdim=True)
- gx = 1.0 / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g)
- return (
- gx,
- (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0),
- grad_output.sum(dim=3).sum(dim=2).sum(dim=0),
- None,
- )
-
-
-class LayerNorm2d(nn.Module):
- def __init__(self, channels, eps=1e-6):
- super(LayerNorm2d, self).__init__()
- self.register_parameter("weight", nn.Parameter(torch.ones(channels)))
- self.register_parameter("bias", nn.Parameter(torch.zeros(channels)))
- self.eps = eps
-
- def forward(self, x):
- return LayerNormFunction.apply(x, self.weight, self.bias, self.eps)
-
-
-class GRN(nn.Module):
- """GRN (Global Response Normalization) layer"""
-
- def __init__(self, dim):
- super().__init__()
- self.gamma = nn.Parameter(torch.zeros(1, dim, 1, 1))
- self.beta = nn.Parameter(torch.zeros(1, dim, 1, 1))
-
- def forward(self, x):
- Gx = torch.norm(x, p=2, dim=(2, 3), keepdim=True)
- Nx = Gx / (Gx.mean(dim=1, keepdim=True) + 1e-6)
- return self.gamma * (x * Nx) + self.beta + x
diff --git a/comfy_extras/chainner_models/architecture/OmniSR/pixelshuffle.py b/comfy_extras/chainner_models/architecture/OmniSR/pixelshuffle.py
deleted file mode 100644
index 4260fb7c9d8..00000000000
--- a/comfy_extras/chainner_models/architecture/OmniSR/pixelshuffle.py
+++ /dev/null
@@ -1,31 +0,0 @@
-#!/usr/bin/env python3
-# -*- coding:utf-8 -*-
-#############################################################
-# File: pixelshuffle.py
-# Created Date: Friday July 1st 2022
-# Author: Chen Xuanhong
-# Email: chenxuanhongzju@outlook.com
-# Last Modified: Friday, 1st July 2022 10:18:39 am
-# Modified By: Chen Xuanhong
-# Copyright (c) 2022 Shanghai Jiao Tong University
-#############################################################
-
-import torch.nn as nn
-
-
-def pixelshuffle_block(
- in_channels, out_channels, upscale_factor=2, kernel_size=3, bias=False
-):
- """
- Upsample features according to `upscale_factor`.
- """
- padding = kernel_size // 2
- conv = nn.Conv2d(
- in_channels,
- out_channels * (upscale_factor**2),
- kernel_size,
- padding=1,
- bias=bias,
- )
- pixel_shuffle = nn.PixelShuffle(upscale_factor)
- return nn.Sequential(*[conv, pixel_shuffle])
diff --git a/comfy_extras/chainner_models/architecture/RRDB.py b/comfy_extras/chainner_models/architecture/RRDB.py
deleted file mode 100644
index b50db7c24a8..00000000000
--- a/comfy_extras/chainner_models/architecture/RRDB.py
+++ /dev/null
@@ -1,296 +0,0 @@
-#!/usr/bin/env python3
-# -*- coding: utf-8 -*-
-
-import functools
-import math
-import re
-from collections import OrderedDict
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-from . import block as B
-
-
-# Borrowed from https://github.com/rlaphoenix/VSGAN/blob/master/vsgan/archs/ESRGAN.py
-# Which enhanced stuff that was already here
-class RRDBNet(nn.Module):
- def __init__(
- self,
- state_dict,
- norm=None,
- act: str = "leakyrelu",
- upsampler: str = "upconv",
- mode: B.ConvMode = "CNA",
- ) -> None:
- """
- ESRGAN - Enhanced Super-Resolution Generative Adversarial Networks.
- By Xintao Wang, Ke Yu, Shixiang Wu, Jinjin Gu, Yihao Liu, Chao Dong, Yu Qiao,
- and Chen Change Loy.
- This is old-arch Residual in Residual Dense Block Network and is not
- the newest revision that's available at github.com/xinntao/ESRGAN.
- This is on purpose, the newest Network has severely limited the
- potential use of the Network with no benefits.
- This network supports model files from both new and old-arch.
- Args:
- norm: Normalization layer
- act: Activation layer
- upsampler: Upsample layer. upconv, pixel_shuffle
- mode: Convolution mode
- """
- super(RRDBNet, self).__init__()
- self.model_arch = "ESRGAN"
- self.sub_type = "SR"
-
- self.state = state_dict
- self.norm = norm
- self.act = act
- self.upsampler = upsampler
- self.mode = mode
-
- self.state_map = {
- # currently supports old, new, and newer RRDBNet arch models
- # ESRGAN, BSRGAN/RealSR, Real-ESRGAN
- "model.0.weight": ("conv_first.weight",),
- "model.0.bias": ("conv_first.bias",),
- "model.1.sub./NB/.weight": ("trunk_conv.weight", "conv_body.weight"),
- "model.1.sub./NB/.bias": ("trunk_conv.bias", "conv_body.bias"),
- r"model.1.sub.\1.RDB\2.conv\3.0.\4": (
- r"RRDB_trunk\.(\d+)\.RDB(\d)\.conv(\d+)\.(weight|bias)",
- r"body\.(\d+)\.rdb(\d)\.conv(\d+)\.(weight|bias)",
- ),
- }
- if "params_ema" in self.state:
- self.state = self.state["params_ema"]
- # self.model_arch = "RealESRGAN"
- self.num_blocks = self.get_num_blocks()
- self.plus = any("conv1x1" in k for k in self.state.keys())
- if self.plus:
- self.model_arch = "ESRGAN+"
-
- self.state = self.new_to_old_arch(self.state)
-
- self.key_arr = list(self.state.keys())
-
- self.in_nc: int = self.state[self.key_arr[0]].shape[1]
- self.out_nc: int = self.state[self.key_arr[-1]].shape[0]
-
- self.scale: int = self.get_scale()
- self.num_filters: int = self.state[self.key_arr[0]].shape[0]
-
- c2x2 = False
- if self.state["model.0.weight"].shape[-2] == 2:
- c2x2 = True
- self.scale = round(math.sqrt(self.scale / 4))
- self.model_arch = "ESRGAN-2c2"
-
- self.supports_fp16 = True
- self.supports_bfp16 = True
- self.min_size_restriction = None
-
- # Detect if pixelunshuffle was used (Real-ESRGAN)
- if self.in_nc in (self.out_nc * 4, self.out_nc * 16) and self.out_nc in (
- self.in_nc / 4,
- self.in_nc / 16,
- ):
- self.shuffle_factor = int(math.sqrt(self.in_nc / self.out_nc))
- else:
- self.shuffle_factor = None
-
- upsample_block = {
- "upconv": B.upconv_block,
- "pixel_shuffle": B.pixelshuffle_block,
- }.get(self.upsampler)
- if upsample_block is None:
- raise NotImplementedError(f"Upsample mode [{self.upsampler}] is not found")
-
- if self.scale == 3:
- upsample_blocks = upsample_block(
- in_nc=self.num_filters,
- out_nc=self.num_filters,
- upscale_factor=3,
- act_type=self.act,
- c2x2=c2x2,
- )
- else:
- upsample_blocks = [
- upsample_block(
- in_nc=self.num_filters,
- out_nc=self.num_filters,
- act_type=self.act,
- c2x2=c2x2,
- )
- for _ in range(int(math.log(self.scale, 2)))
- ]
-
- self.model = B.sequential(
- # fea conv
- B.conv_block(
- in_nc=self.in_nc,
- out_nc=self.num_filters,
- kernel_size=3,
- norm_type=None,
- act_type=None,
- c2x2=c2x2,
- ),
- B.ShortcutBlock(
- B.sequential(
- # rrdb blocks
- *[
- B.RRDB(
- nf=self.num_filters,
- kernel_size=3,
- gc=32,
- stride=1,
- bias=True,
- pad_type="zero",
- norm_type=self.norm,
- act_type=self.act,
- mode="CNA",
- plus=self.plus,
- c2x2=c2x2,
- )
- for _ in range(self.num_blocks)
- ],
- # lr conv
- B.conv_block(
- in_nc=self.num_filters,
- out_nc=self.num_filters,
- kernel_size=3,
- norm_type=self.norm,
- act_type=None,
- mode=self.mode,
- c2x2=c2x2,
- ),
- )
- ),
- *upsample_blocks,
- # hr_conv0
- B.conv_block(
- in_nc=self.num_filters,
- out_nc=self.num_filters,
- kernel_size=3,
- norm_type=None,
- act_type=self.act,
- c2x2=c2x2,
- ),
- # hr_conv1
- B.conv_block(
- in_nc=self.num_filters,
- out_nc=self.out_nc,
- kernel_size=3,
- norm_type=None,
- act_type=None,
- c2x2=c2x2,
- ),
- )
-
- # Adjust these properties for calculations outside of the model
- if self.shuffle_factor:
- self.in_nc //= self.shuffle_factor**2
- self.scale //= self.shuffle_factor
-
- self.load_state_dict(self.state, strict=False)
-
- def new_to_old_arch(self, state):
- """Convert a new-arch model state dictionary to an old-arch dictionary."""
- if "params_ema" in state:
- state = state["params_ema"]
-
- if "conv_first.weight" not in state:
- # model is already old arch, this is a loose check, but should be sufficient
- return state
-
- # add nb to state keys
- for kind in ("weight", "bias"):
- self.state_map[f"model.1.sub.{self.num_blocks}.{kind}"] = self.state_map[
- f"model.1.sub./NB/.{kind}"
- ]
- del self.state_map[f"model.1.sub./NB/.{kind}"]
-
- old_state = OrderedDict()
- for old_key, new_keys in self.state_map.items():
- for new_key in new_keys:
- if r"\1" in old_key:
- for k, v in state.items():
- sub = re.sub(new_key, old_key, k)
- if sub != k:
- old_state[sub] = v
- else:
- if new_key in state:
- old_state[old_key] = state[new_key]
-
- # upconv layers
- max_upconv = 0
- for key in state.keys():
- match = re.match(r"(upconv|conv_up)(\d)\.(weight|bias)", key)
- if match is not None:
- _, key_num, key_type = match.groups()
- old_state[f"model.{int(key_num) * 3}.{key_type}"] = state[key]
- max_upconv = max(max_upconv, int(key_num) * 3)
-
- # final layers
- for key in state.keys():
- if key in ("HRconv.weight", "conv_hr.weight"):
- old_state[f"model.{max_upconv + 2}.weight"] = state[key]
- elif key in ("HRconv.bias", "conv_hr.bias"):
- old_state[f"model.{max_upconv + 2}.bias"] = state[key]
- elif key in ("conv_last.weight",):
- old_state[f"model.{max_upconv + 4}.weight"] = state[key]
- elif key in ("conv_last.bias",):
- old_state[f"model.{max_upconv + 4}.bias"] = state[key]
-
- # Sort by first numeric value of each layer
- def compare(item1, item2):
- parts1 = item1.split(".")
- parts2 = item2.split(".")
- int1 = int(parts1[1])
- int2 = int(parts2[1])
- return int1 - int2
-
- sorted_keys = sorted(old_state.keys(), key=functools.cmp_to_key(compare))
-
- # Rebuild the output dict in the right order
- out_dict = OrderedDict((k, old_state[k]) for k in sorted_keys)
-
- return out_dict
-
- def get_scale(self, min_part: int = 6) -> int:
- n = 0
- for part in list(self.state):
- parts = part.split(".")[1:]
- if len(parts) == 2:
- part_num = int(parts[0])
- if part_num > min_part and parts[1] == "weight":
- n += 1
- return 2**n
-
- def get_num_blocks(self) -> int:
- nbs = []
- state_keys = self.state_map[r"model.1.sub.\1.RDB\2.conv\3.0.\4"] + (
- r"model\.\d+\.sub\.(\d+)\.RDB(\d+)\.conv(\d+)\.0\.(weight|bias)",
- )
- for state_key in state_keys:
- for k in self.state:
- m = re.search(state_key, k)
- if m:
- nbs.append(int(m.group(1)))
- if nbs:
- break
- return max(*nbs) + 1
-
- def forward(self, x):
- if self.shuffle_factor:
- _, _, h, w = x.size()
- mod_pad_h = (
- self.shuffle_factor - h % self.shuffle_factor
- ) % self.shuffle_factor
- mod_pad_w = (
- self.shuffle_factor - w % self.shuffle_factor
- ) % self.shuffle_factor
- x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "reflect")
- x = torch.pixel_unshuffle(x, downscale_factor=self.shuffle_factor)
- x = self.model(x)
- return x[:, :, : h * self.scale, : w * self.scale]
- return self.model(x)
diff --git a/comfy_extras/chainner_models/architecture/SCUNet.py b/comfy_extras/chainner_models/architecture/SCUNet.py
deleted file mode 100644
index b8354a87308..00000000000
--- a/comfy_extras/chainner_models/architecture/SCUNet.py
+++ /dev/null
@@ -1,455 +0,0 @@
-# pylint: skip-file
-# -----------------------------------------------------------------------------------
-# SCUNet: Practical Blind Denoising via Swin-Conv-UNet and Data Synthesis, https://arxiv.org/abs/2203.13278
-# Zhang, Kai and Li, Yawei and Liang, Jingyun and Cao, Jiezhang and Zhang, Yulun and Tang, Hao and Timofte, Radu and Van Gool, Luc
-# -----------------------------------------------------------------------------------
-
-import numpy as np
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from einops import rearrange
-from einops.layers.torch import Rearrange
-
-from .timm.drop import DropPath
-from .timm.weight_init import trunc_normal_
-
-
-# Borrowed from https://github.com/cszn/SCUNet/blob/main/models/network_scunet.py
-class WMSA(nn.Module):
- """Self-attention module in Swin Transformer"""
-
- def __init__(self, input_dim, output_dim, head_dim, window_size, type):
- super(WMSA, self).__init__()
- self.input_dim = input_dim
- self.output_dim = output_dim
- self.head_dim = head_dim
- self.scale = self.head_dim**-0.5
- self.n_heads = input_dim // head_dim
- self.window_size = window_size
- self.type = type
- self.embedding_layer = nn.Linear(self.input_dim, 3 * self.input_dim, bias=True)
-
- self.relative_position_params = nn.Parameter(
- torch.zeros((2 * window_size - 1) * (2 * window_size - 1), self.n_heads)
- )
- # TODO recover
- # self.relative_position_params = nn.Parameter(torch.zeros(self.n_heads, 2 * window_size - 1, 2 * window_size -1))
- self.relative_position_params = nn.Parameter(
- torch.zeros((2 * window_size - 1) * (2 * window_size - 1), self.n_heads)
- )
-
- self.linear = nn.Linear(self.input_dim, self.output_dim)
-
- trunc_normal_(self.relative_position_params, std=0.02)
- self.relative_position_params = torch.nn.Parameter(
- self.relative_position_params.view(
- 2 * window_size - 1, 2 * window_size - 1, self.n_heads
- )
- .transpose(1, 2)
- .transpose(0, 1)
- )
-
- def generate_mask(self, h, w, p, shift):
- """generating the mask of SW-MSA
- Args:
- shift: shift parameters in CyclicShift.
- Returns:
- attn_mask: should be (1 1 w p p),
- """
- # supporting square.
- attn_mask = torch.zeros(
- h,
- w,
- p,
- p,
- p,
- p,
- dtype=torch.bool,
- device=self.relative_position_params.device,
- )
- if self.type == "W":
- return attn_mask
-
- s = p - shift
- attn_mask[-1, :, :s, :, s:, :] = True
- attn_mask[-1, :, s:, :, :s, :] = True
- attn_mask[:, -1, :, :s, :, s:] = True
- attn_mask[:, -1, :, s:, :, :s] = True
- attn_mask = rearrange(
- attn_mask, "w1 w2 p1 p2 p3 p4 -> 1 1 (w1 w2) (p1 p2) (p3 p4)"
- )
- return attn_mask
-
- def forward(self, x):
- """Forward pass of Window Multi-head Self-attention module.
- Args:
- x: input tensor with shape of [b h w c];
- attn_mask: attention mask, fill -inf where the value is True;
- Returns:
- output: tensor shape [b h w c]
- """
- if self.type != "W":
- x = torch.roll(
- x,
- shifts=(-(self.window_size // 2), -(self.window_size // 2)),
- dims=(1, 2),
- )
-
- x = rearrange(
- x,
- "b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c",
- p1=self.window_size,
- p2=self.window_size,
- )
- h_windows = x.size(1)
- w_windows = x.size(2)
- # square validation
- # assert h_windows == w_windows
-
- x = rearrange(
- x,
- "b w1 w2 p1 p2 c -> b (w1 w2) (p1 p2) c",
- p1=self.window_size,
- p2=self.window_size,
- )
- qkv = self.embedding_layer(x)
- q, k, v = rearrange(
- qkv, "b nw np (threeh c) -> threeh b nw np c", c=self.head_dim
- ).chunk(3, dim=0)
- sim = torch.einsum("hbwpc,hbwqc->hbwpq", q, k) * self.scale
- # Adding learnable relative embedding
- sim = sim + rearrange(self.relative_embedding(), "h p q -> h 1 1 p q")
- # Using Attn Mask to distinguish different subwindows.
- if self.type != "W":
- attn_mask = self.generate_mask(
- h_windows, w_windows, self.window_size, shift=self.window_size // 2
- )
- sim = sim.masked_fill_(attn_mask, float("-inf"))
-
- probs = nn.functional.softmax(sim, dim=-1)
- output = torch.einsum("hbwij,hbwjc->hbwic", probs, v)
- output = rearrange(output, "h b w p c -> b w p (h c)")
- output = self.linear(output)
- output = rearrange(
- output,
- "b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c",
- w1=h_windows,
- p1=self.window_size,
- )
-
- if self.type != "W":
- output = torch.roll(
- output,
- shifts=(self.window_size // 2, self.window_size // 2),
- dims=(1, 2),
- )
-
- return output
-
- def relative_embedding(self):
- cord = torch.tensor(
- np.array(
- [
- [i, j]
- for i in range(self.window_size)
- for j in range(self.window_size)
- ]
- )
- )
- relation = cord[:, None, :] - cord[None, :, :] + self.window_size - 1
- # negative is allowed
- return self.relative_position_params[
- :, relation[:, :, 0].long(), relation[:, :, 1].long()
- ]
-
-
-class Block(nn.Module):
- def __init__(
- self,
- input_dim,
- output_dim,
- head_dim,
- window_size,
- drop_path,
- type="W",
- input_resolution=None,
- ):
- """SwinTransformer Block"""
- super(Block, self).__init__()
- self.input_dim = input_dim
- self.output_dim = output_dim
- assert type in ["W", "SW"]
- self.type = type
- if input_resolution <= window_size:
- self.type = "W"
-
- self.ln1 = nn.LayerNorm(input_dim)
- self.msa = WMSA(input_dim, input_dim, head_dim, window_size, self.type)
- self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
- self.ln2 = nn.LayerNorm(input_dim)
- self.mlp = nn.Sequential(
- nn.Linear(input_dim, 4 * input_dim),
- nn.GELU(),
- nn.Linear(4 * input_dim, output_dim),
- )
-
- def forward(self, x):
- x = x + self.drop_path(self.msa(self.ln1(x)))
- x = x + self.drop_path(self.mlp(self.ln2(x)))
- return x
-
-
-class ConvTransBlock(nn.Module):
- def __init__(
- self,
- conv_dim,
- trans_dim,
- head_dim,
- window_size,
- drop_path,
- type="W",
- input_resolution=None,
- ):
- """SwinTransformer and Conv Block"""
- super(ConvTransBlock, self).__init__()
- self.conv_dim = conv_dim
- self.trans_dim = trans_dim
- self.head_dim = head_dim
- self.window_size = window_size
- self.drop_path = drop_path
- self.type = type
- self.input_resolution = input_resolution
-
- assert self.type in ["W", "SW"]
- if self.input_resolution <= self.window_size:
- self.type = "W"
-
- self.trans_block = Block(
- self.trans_dim,
- self.trans_dim,
- self.head_dim,
- self.window_size,
- self.drop_path,
- self.type,
- self.input_resolution,
- )
- self.conv1_1 = nn.Conv2d(
- self.conv_dim + self.trans_dim,
- self.conv_dim + self.trans_dim,
- 1,
- 1,
- 0,
- bias=True,
- )
- self.conv1_2 = nn.Conv2d(
- self.conv_dim + self.trans_dim,
- self.conv_dim + self.trans_dim,
- 1,
- 1,
- 0,
- bias=True,
- )
-
- self.conv_block = nn.Sequential(
- nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False),
- nn.ReLU(True),
- nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False),
- )
-
- def forward(self, x):
- conv_x, trans_x = torch.split(
- self.conv1_1(x), (self.conv_dim, self.trans_dim), dim=1
- )
- conv_x = self.conv_block(conv_x) + conv_x
- trans_x = Rearrange("b c h w -> b h w c")(trans_x)
- trans_x = self.trans_block(trans_x)
- trans_x = Rearrange("b h w c -> b c h w")(trans_x)
- res = self.conv1_2(torch.cat((conv_x, trans_x), dim=1))
- x = x + res
-
- return x
-
-
-class SCUNet(nn.Module):
- def __init__(
- self,
- state_dict,
- in_nc=3,
- config=[4, 4, 4, 4, 4, 4, 4],
- dim=64,
- drop_path_rate=0.0,
- input_resolution=256,
- ):
- super(SCUNet, self).__init__()
- self.model_arch = "SCUNet"
- self.sub_type = "SR"
-
- self.num_filters: int = 0
-
- self.state = state_dict
- self.config = config
- self.dim = dim
- self.head_dim = 32
- self.window_size = 8
-
- self.in_nc = in_nc
- self.out_nc = self.in_nc
- self.scale = 1
- self.supports_fp16 = True
-
- # drop path rate for each layer
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(config))]
-
- self.m_head = [nn.Conv2d(in_nc, dim, 3, 1, 1, bias=False)]
-
- begin = 0
- self.m_down1 = [
- ConvTransBlock(
- dim // 2,
- dim // 2,
- self.head_dim,
- self.window_size,
- dpr[i + begin],
- "W" if not i % 2 else "SW",
- input_resolution,
- )
- for i in range(config[0])
- ] + [nn.Conv2d(dim, 2 * dim, 2, 2, 0, bias=False)]
-
- begin += config[0]
- self.m_down2 = [
- ConvTransBlock(
- dim,
- dim,
- self.head_dim,
- self.window_size,
- dpr[i + begin],
- "W" if not i % 2 else "SW",
- input_resolution // 2,
- )
- for i in range(config[1])
- ] + [nn.Conv2d(2 * dim, 4 * dim, 2, 2, 0, bias=False)]
-
- begin += config[1]
- self.m_down3 = [
- ConvTransBlock(
- 2 * dim,
- 2 * dim,
- self.head_dim,
- self.window_size,
- dpr[i + begin],
- "W" if not i % 2 else "SW",
- input_resolution // 4,
- )
- for i in range(config[2])
- ] + [nn.Conv2d(4 * dim, 8 * dim, 2, 2, 0, bias=False)]
-
- begin += config[2]
- self.m_body = [
- ConvTransBlock(
- 4 * dim,
- 4 * dim,
- self.head_dim,
- self.window_size,
- dpr[i + begin],
- "W" if not i % 2 else "SW",
- input_resolution // 8,
- )
- for i in range(config[3])
- ]
-
- begin += config[3]
- self.m_up3 = [
- nn.ConvTranspose2d(8 * dim, 4 * dim, 2, 2, 0, bias=False),
- ] + [
- ConvTransBlock(
- 2 * dim,
- 2 * dim,
- self.head_dim,
- self.window_size,
- dpr[i + begin],
- "W" if not i % 2 else "SW",
- input_resolution // 4,
- )
- for i in range(config[4])
- ]
-
- begin += config[4]
- self.m_up2 = [
- nn.ConvTranspose2d(4 * dim, 2 * dim, 2, 2, 0, bias=False),
- ] + [
- ConvTransBlock(
- dim,
- dim,
- self.head_dim,
- self.window_size,
- dpr[i + begin],
- "W" if not i % 2 else "SW",
- input_resolution // 2,
- )
- for i in range(config[5])
- ]
-
- begin += config[5]
- self.m_up1 = [
- nn.ConvTranspose2d(2 * dim, dim, 2, 2, 0, bias=False),
- ] + [
- ConvTransBlock(
- dim // 2,
- dim // 2,
- self.head_dim,
- self.window_size,
- dpr[i + begin],
- "W" if not i % 2 else "SW",
- input_resolution,
- )
- for i in range(config[6])
- ]
-
- self.m_tail = [nn.Conv2d(dim, in_nc, 3, 1, 1, bias=False)]
-
- self.m_head = nn.Sequential(*self.m_head)
- self.m_down1 = nn.Sequential(*self.m_down1)
- self.m_down2 = nn.Sequential(*self.m_down2)
- self.m_down3 = nn.Sequential(*self.m_down3)
- self.m_body = nn.Sequential(*self.m_body)
- self.m_up3 = nn.Sequential(*self.m_up3)
- self.m_up2 = nn.Sequential(*self.m_up2)
- self.m_up1 = nn.Sequential(*self.m_up1)
- self.m_tail = nn.Sequential(*self.m_tail)
- # self.apply(self._init_weights)
- self.load_state_dict(state_dict, strict=True)
-
- def check_image_size(self, x):
- _, _, h, w = x.size()
- mod_pad_h = (64 - h % 64) % 64
- mod_pad_w = (64 - w % 64) % 64
- x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "reflect")
- return x
-
- def forward(self, x0):
- h, w = x0.size()[-2:]
- x0 = self.check_image_size(x0)
-
- x1 = self.m_head(x0)
- x2 = self.m_down1(x1)
- x3 = self.m_down2(x2)
- x4 = self.m_down3(x3)
- x = self.m_body(x4)
- x = self.m_up3(x + x4)
- x = self.m_up2(x + x3)
- x = self.m_up1(x + x2)
- x = self.m_tail(x + x1)
-
- x = x[:, :, :h, :w]
- return x
-
- def _init_weights(self, m):
- if isinstance(m, nn.Linear):
- trunc_normal_(m.weight, std=0.02)
- if m.bias is not None:
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.LayerNorm):
- nn.init.constant_(m.bias, 0)
- nn.init.constant_(m.weight, 1.0)
diff --git a/comfy_extras/chainner_models/architecture/SPSR.py b/comfy_extras/chainner_models/architecture/SPSR.py
deleted file mode 100644
index c3cefff1902..00000000000
--- a/comfy_extras/chainner_models/architecture/SPSR.py
+++ /dev/null
@@ -1,383 +0,0 @@
-#!/usr/bin/env python3
-# -*- coding: utf-8 -*-
-
-import math
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-from . import block as B
-
-
-class Get_gradient_nopadding(nn.Module):
- def __init__(self):
- super(Get_gradient_nopadding, self).__init__()
- kernel_v = [[0, -1, 0], [0, 0, 0], [0, 1, 0]]
- kernel_h = [[0, 0, 0], [-1, 0, 1], [0, 0, 0]]
- kernel_h = torch.FloatTensor(kernel_h).unsqueeze(0).unsqueeze(0)
- kernel_v = torch.FloatTensor(kernel_v).unsqueeze(0).unsqueeze(0)
- self.weight_h = nn.Parameter(data=kernel_h, requires_grad=False) # type: ignore
-
- self.weight_v = nn.Parameter(data=kernel_v, requires_grad=False) # type: ignore
-
- def forward(self, x):
- x_list = []
- for i in range(x.shape[1]):
- x_i = x[:, i]
- x_i_v = F.conv2d(x_i.unsqueeze(1), self.weight_v, padding=1)
- x_i_h = F.conv2d(x_i.unsqueeze(1), self.weight_h, padding=1)
- x_i = torch.sqrt(torch.pow(x_i_v, 2) + torch.pow(x_i_h, 2) + 1e-6)
- x_list.append(x_i)
-
- x = torch.cat(x_list, dim=1)
-
- return x
-
-
-class SPSRNet(nn.Module):
- def __init__(
- self,
- state_dict,
- norm=None,
- act: str = "leakyrelu",
- upsampler: str = "upconv",
- mode: B.ConvMode = "CNA",
- ):
- super(SPSRNet, self).__init__()
- self.model_arch = "SPSR"
- self.sub_type = "SR"
-
- self.state = state_dict
- self.norm = norm
- self.act = act
- self.upsampler = upsampler
- self.mode = mode
-
- self.num_blocks = self.get_num_blocks()
-
- self.in_nc: int = self.state["model.0.weight"].shape[1]
- self.out_nc: int = self.state["f_HR_conv1.0.bias"].shape[0]
-
- self.scale = self.get_scale(4)
- self.num_filters: int = self.state["model.0.weight"].shape[0]
-
- self.supports_fp16 = True
- self.supports_bfp16 = True
- self.min_size_restriction = None
-
- n_upscale = int(math.log(self.scale, 2))
- if self.scale == 3:
- n_upscale = 1
-
- fea_conv = B.conv_block(
- self.in_nc, self.num_filters, kernel_size=3, norm_type=None, act_type=None
- )
- rb_blocks = [
- B.RRDB(
- self.num_filters,
- kernel_size=3,
- gc=32,
- stride=1,
- bias=True,
- pad_type="zero",
- norm_type=norm,
- act_type=act,
- mode="CNA",
- )
- for _ in range(self.num_blocks)
- ]
- LR_conv = B.conv_block(
- self.num_filters,
- self.num_filters,
- kernel_size=3,
- norm_type=norm,
- act_type=None,
- mode=mode,
- )
-
- if upsampler == "upconv":
- upsample_block = B.upconv_block
- elif upsampler == "pixelshuffle":
- upsample_block = B.pixelshuffle_block
- else:
- raise NotImplementedError(f"upsample mode [{upsampler}] is not found")
- if self.scale == 3:
- a_upsampler = upsample_block(
- self.num_filters, self.num_filters, 3, act_type=act
- )
- else:
- a_upsampler = [
- upsample_block(self.num_filters, self.num_filters, act_type=act)
- for _ in range(n_upscale)
- ]
- self.HR_conv0_new = B.conv_block(
- self.num_filters,
- self.num_filters,
- kernel_size=3,
- norm_type=None,
- act_type=act,
- )
- self.HR_conv1_new = B.conv_block(
- self.num_filters,
- self.num_filters,
- kernel_size=3,
- norm_type=None,
- act_type=None,
- )
-
- self.model = B.sequential(
- fea_conv,
- B.ShortcutBlockSPSR(B.sequential(*rb_blocks, LR_conv)),
- *a_upsampler,
- self.HR_conv0_new,
- )
-
- self.get_g_nopadding = Get_gradient_nopadding()
-
- self.b_fea_conv = B.conv_block(
- self.in_nc, self.num_filters, kernel_size=3, norm_type=None, act_type=None
- )
-
- self.b_concat_1 = B.conv_block(
- 2 * self.num_filters,
- self.num_filters,
- kernel_size=3,
- norm_type=None,
- act_type=None,
- )
- self.b_block_1 = B.RRDB(
- self.num_filters * 2,
- kernel_size=3,
- gc=32,
- stride=1,
- bias=True,
- pad_type="zero",
- norm_type=norm,
- act_type=act,
- mode="CNA",
- )
-
- self.b_concat_2 = B.conv_block(
- 2 * self.num_filters,
- self.num_filters,
- kernel_size=3,
- norm_type=None,
- act_type=None,
- )
- self.b_block_2 = B.RRDB(
- self.num_filters * 2,
- kernel_size=3,
- gc=32,
- stride=1,
- bias=True,
- pad_type="zero",
- norm_type=norm,
- act_type=act,
- mode="CNA",
- )
-
- self.b_concat_3 = B.conv_block(
- 2 * self.num_filters,
- self.num_filters,
- kernel_size=3,
- norm_type=None,
- act_type=None,
- )
- self.b_block_3 = B.RRDB(
- self.num_filters * 2,
- kernel_size=3,
- gc=32,
- stride=1,
- bias=True,
- pad_type="zero",
- norm_type=norm,
- act_type=act,
- mode="CNA",
- )
-
- self.b_concat_4 = B.conv_block(
- 2 * self.num_filters,
- self.num_filters,
- kernel_size=3,
- norm_type=None,
- act_type=None,
- )
- self.b_block_4 = B.RRDB(
- self.num_filters * 2,
- kernel_size=3,
- gc=32,
- stride=1,
- bias=True,
- pad_type="zero",
- norm_type=norm,
- act_type=act,
- mode="CNA",
- )
-
- self.b_LR_conv = B.conv_block(
- self.num_filters,
- self.num_filters,
- kernel_size=3,
- norm_type=norm,
- act_type=None,
- mode=mode,
- )
-
- if upsampler == "upconv":
- upsample_block = B.upconv_block
- elif upsampler == "pixelshuffle":
- upsample_block = B.pixelshuffle_block
- else:
- raise NotImplementedError(f"upsample mode [{upsampler}] is not found")
- if self.scale == 3:
- b_upsampler = upsample_block(
- self.num_filters, self.num_filters, 3, act_type=act
- )
- else:
- b_upsampler = [
- upsample_block(self.num_filters, self.num_filters, act_type=act)
- for _ in range(n_upscale)
- ]
-
- b_HR_conv0 = B.conv_block(
- self.num_filters,
- self.num_filters,
- kernel_size=3,
- norm_type=None,
- act_type=act,
- )
- b_HR_conv1 = B.conv_block(
- self.num_filters,
- self.num_filters,
- kernel_size=3,
- norm_type=None,
- act_type=None,
- )
-
- self.b_module = B.sequential(*b_upsampler, b_HR_conv0, b_HR_conv1)
-
- self.conv_w = B.conv_block(
- self.num_filters, self.out_nc, kernel_size=1, norm_type=None, act_type=None
- )
-
- self.f_concat = B.conv_block(
- self.num_filters * 2,
- self.num_filters,
- kernel_size=3,
- norm_type=None,
- act_type=None,
- )
-
- self.f_block = B.RRDB(
- self.num_filters * 2,
- kernel_size=3,
- gc=32,
- stride=1,
- bias=True,
- pad_type="zero",
- norm_type=norm,
- act_type=act,
- mode="CNA",
- )
-
- self.f_HR_conv0 = B.conv_block(
- self.num_filters,
- self.num_filters,
- kernel_size=3,
- norm_type=None,
- act_type=act,
- )
- self.f_HR_conv1 = B.conv_block(
- self.num_filters, self.out_nc, kernel_size=3, norm_type=None, act_type=None
- )
-
- self.load_state_dict(self.state, strict=False)
-
- def get_scale(self, min_part: int = 4) -> int:
- n = 0
- for part in list(self.state):
- parts = part.split(".")
- if len(parts) == 3:
- part_num = int(parts[1])
- if part_num > min_part and parts[0] == "model" and parts[2] == "weight":
- n += 1
- return 2**n
-
- def get_num_blocks(self) -> int:
- nb = 0
- for part in list(self.state):
- parts = part.split(".")
- n_parts = len(parts)
- if n_parts == 5 and parts[2] == "sub":
- nb = int(parts[3])
- return nb
-
- def forward(self, x):
- x_grad = self.get_g_nopadding(x)
- x = self.model[0](x)
-
- x, block_list = self.model[1](x)
-
- x_ori = x
- for i in range(5):
- x = block_list[i](x)
- x_fea1 = x
-
- for i in range(5):
- x = block_list[i + 5](x)
- x_fea2 = x
-
- for i in range(5):
- x = block_list[i + 10](x)
- x_fea3 = x
-
- for i in range(5):
- x = block_list[i + 15](x)
- x_fea4 = x
-
- x = block_list[20:](x)
- # short cut
- x = x_ori + x
- x = self.model[2:](x)
- x = self.HR_conv1_new(x)
-
- x_b_fea = self.b_fea_conv(x_grad)
- x_cat_1 = torch.cat([x_b_fea, x_fea1], dim=1)
-
- x_cat_1 = self.b_block_1(x_cat_1)
- x_cat_1 = self.b_concat_1(x_cat_1)
-
- x_cat_2 = torch.cat([x_cat_1, x_fea2], dim=1)
-
- x_cat_2 = self.b_block_2(x_cat_2)
- x_cat_2 = self.b_concat_2(x_cat_2)
-
- x_cat_3 = torch.cat([x_cat_2, x_fea3], dim=1)
-
- x_cat_3 = self.b_block_3(x_cat_3)
- x_cat_3 = self.b_concat_3(x_cat_3)
-
- x_cat_4 = torch.cat([x_cat_3, x_fea4], dim=1)
-
- x_cat_4 = self.b_block_4(x_cat_4)
- x_cat_4 = self.b_concat_4(x_cat_4)
-
- x_cat_4 = self.b_LR_conv(x_cat_4)
-
- # short cut
- x_cat_4 = x_cat_4 + x_b_fea
- x_branch = self.b_module(x_cat_4)
-
- # x_out_branch = self.conv_w(x_branch)
- ########
- x_branch_d = x_branch
- x_f_cat = torch.cat([x_branch_d, x], dim=1)
- x_f_cat = self.f_block(x_f_cat)
- x_out = self.f_concat(x_f_cat)
- x_out = self.f_HR_conv0(x_out)
- x_out = self.f_HR_conv1(x_out)
-
- #########
- # return x_out_branch, x_out, x_grad
- return x_out
diff --git a/comfy_extras/chainner_models/architecture/SRVGG.py b/comfy_extras/chainner_models/architecture/SRVGG.py
deleted file mode 100644
index 7a8ec37ae5d..00000000000
--- a/comfy_extras/chainner_models/architecture/SRVGG.py
+++ /dev/null
@@ -1,114 +0,0 @@
-#!/usr/bin/env python3
-# -*- coding: utf-8 -*-
-
-import math
-
-import torch.nn as nn
-import torch.nn.functional as F
-
-
-class SRVGGNetCompact(nn.Module):
- """A compact VGG-style network structure for super-resolution.
- It is a compact network structure, which performs upsampling in the last layer and no convolution is
- conducted on the HR feature space.
- Args:
- num_in_ch (int): Channel number of inputs. Default: 3.
- num_out_ch (int): Channel number of outputs. Default: 3.
- num_feat (int): Channel number of intermediate features. Default: 64.
- num_conv (int): Number of convolution layers in the body network. Default: 16.
- upscale (int): Upsampling factor. Default: 4.
- act_type (str): Activation type, options: 'relu', 'prelu', 'leakyrelu'. Default: prelu.
- """
-
- def __init__(
- self,
- state_dict,
- act_type: str = "prelu",
- ):
- super(SRVGGNetCompact, self).__init__()
- self.model_arch = "SRVGG (RealESRGAN)"
- self.sub_type = "SR"
-
- self.act_type = act_type
-
- self.state = state_dict
-
- if "params" in self.state:
- self.state = self.state["params"]
-
- self.key_arr = list(self.state.keys())
-
- self.in_nc = self.get_in_nc()
- self.num_feat = self.get_num_feats()
- self.num_conv = self.get_num_conv()
- self.out_nc = self.in_nc # :(
- self.pixelshuffle_shape = None # Defined in get_scale()
- self.scale = self.get_scale()
-
- self.supports_fp16 = True
- self.supports_bfp16 = True
- self.min_size_restriction = None
-
- self.body = nn.ModuleList()
- # the first conv
- self.body.append(nn.Conv2d(self.in_nc, self.num_feat, 3, 1, 1))
- # the first activation
- if act_type == "relu":
- activation = nn.ReLU(inplace=True)
- elif act_type == "prelu":
- activation = nn.PReLU(num_parameters=self.num_feat)
- elif act_type == "leakyrelu":
- activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
- self.body.append(activation) # type: ignore
-
- # the body structure
- for _ in range(self.num_conv):
- self.body.append(nn.Conv2d(self.num_feat, self.num_feat, 3, 1, 1))
- # activation
- if act_type == "relu":
- activation = nn.ReLU(inplace=True)
- elif act_type == "prelu":
- activation = nn.PReLU(num_parameters=self.num_feat)
- elif act_type == "leakyrelu":
- activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
- self.body.append(activation) # type: ignore
-
- # the last conv
- self.body.append(nn.Conv2d(self.num_feat, self.pixelshuffle_shape, 3, 1, 1)) # type: ignore
- # upsample
- self.upsampler = nn.PixelShuffle(self.scale)
-
- self.load_state_dict(self.state, strict=False)
-
- def get_num_conv(self) -> int:
- return (int(self.key_arr[-1].split(".")[1]) - 2) // 2
-
- def get_num_feats(self) -> int:
- return self.state[self.key_arr[0]].shape[0]
-
- def get_in_nc(self) -> int:
- return self.state[self.key_arr[0]].shape[1]
-
- def get_scale(self) -> int:
- self.pixelshuffle_shape = self.state[self.key_arr[-1]].shape[0]
- # Assume out_nc is the same as in_nc
- # I cant think of a better way to do that
- self.out_nc = self.in_nc
- scale = math.sqrt(self.pixelshuffle_shape / self.out_nc)
- if scale - int(scale) > 0:
- print(
- "out_nc is probably different than in_nc, scale calculation might be wrong"
- )
- scale = int(scale)
- return scale
-
- def forward(self, x):
- out = x
- for i in range(0, len(self.body)):
- out = self.body[i](out)
-
- out = self.upsampler(out)
- # add the nearest upsampled image, so that the network learns the residual
- base = F.interpolate(x, scale_factor=self.scale, mode="nearest")
- out += base
- return out
diff --git a/comfy_extras/chainner_models/architecture/SwiftSRGAN.py b/comfy_extras/chainner_models/architecture/SwiftSRGAN.py
deleted file mode 100644
index dbb7725b08d..00000000000
--- a/comfy_extras/chainner_models/architecture/SwiftSRGAN.py
+++ /dev/null
@@ -1,161 +0,0 @@
-# From https://github.com/Koushik0901/Swift-SRGAN/blob/master/swift-srgan/models.py
-
-import torch
-from torch import nn
-
-
-class SeperableConv2d(nn.Module):
- def __init__(
- self, in_channels, out_channels, kernel_size, stride=1, padding=1, bias=True
- ):
- super(SeperableConv2d, self).__init__()
- self.depthwise = nn.Conv2d(
- in_channels,
- in_channels,
- kernel_size=kernel_size,
- stride=stride,
- groups=in_channels,
- bias=bias,
- padding=padding,
- )
- self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=bias)
-
- def forward(self, x):
- return self.pointwise(self.depthwise(x))
-
-
-class ConvBlock(nn.Module):
- def __init__(
- self,
- in_channels,
- out_channels,
- use_act=True,
- use_bn=True,
- discriminator=False,
- **kwargs,
- ):
- super(ConvBlock, self).__init__()
-
- self.use_act = use_act
- self.cnn = SeperableConv2d(in_channels, out_channels, **kwargs, bias=not use_bn)
- self.bn = nn.BatchNorm2d(out_channels) if use_bn else nn.Identity()
- self.act = (
- nn.LeakyReLU(0.2, inplace=True)
- if discriminator
- else nn.PReLU(num_parameters=out_channels)
- )
-
- def forward(self, x):
- return self.act(self.bn(self.cnn(x))) if self.use_act else self.bn(self.cnn(x))
-
-
-class UpsampleBlock(nn.Module):
- def __init__(self, in_channels, scale_factor):
- super(UpsampleBlock, self).__init__()
-
- self.conv = SeperableConv2d(
- in_channels,
- in_channels * scale_factor**2,
- kernel_size=3,
- stride=1,
- padding=1,
- )
- self.ps = nn.PixelShuffle(
- scale_factor
- ) # (in_channels * 4, H, W) -> (in_channels, H*2, W*2)
- self.act = nn.PReLU(num_parameters=in_channels)
-
- def forward(self, x):
- return self.act(self.ps(self.conv(x)))
-
-
-class ResidualBlock(nn.Module):
- def __init__(self, in_channels):
- super(ResidualBlock, self).__init__()
-
- self.block1 = ConvBlock(
- in_channels, in_channels, kernel_size=3, stride=1, padding=1
- )
- self.block2 = ConvBlock(
- in_channels, in_channels, kernel_size=3, stride=1, padding=1, use_act=False
- )
-
- def forward(self, x):
- out = self.block1(x)
- out = self.block2(out)
- return out + x
-
-
-class Generator(nn.Module):
- """Swift-SRGAN Generator
- Args:
- in_channels (int): number of input image channels.
- num_channels (int): number of hidden channels.
- num_blocks (int): number of residual blocks.
- upscale_factor (int): factor to upscale the image [2x, 4x, 8x].
- Returns:
- torch.Tensor: super resolution image
- """
-
- def __init__(
- self,
- state_dict,
- ):
- super(Generator, self).__init__()
- self.model_arch = "Swift-SRGAN"
- self.sub_type = "SR"
- self.state = state_dict
- if "model" in self.state:
- self.state = self.state["model"]
-
- self.in_nc: int = self.state["initial.cnn.depthwise.weight"].shape[0]
- self.out_nc: int = self.state["final_conv.pointwise.weight"].shape[0]
- self.num_filters: int = self.state["initial.cnn.pointwise.weight"].shape[0]
- self.num_blocks = len(
- set([x.split(".")[1] for x in self.state.keys() if "residual" in x])
- )
- self.scale: int = 2 ** len(
- set([x.split(".")[1] for x in self.state.keys() if "upsampler" in x])
- )
-
- in_channels = self.in_nc
- num_channels = self.num_filters
- num_blocks = self.num_blocks
- upscale_factor = self.scale
-
- self.supports_fp16 = True
- self.supports_bfp16 = True
- self.min_size_restriction = None
-
- self.initial = ConvBlock(
- in_channels, num_channels, kernel_size=9, stride=1, padding=4, use_bn=False
- )
- self.residual = nn.Sequential(
- *[ResidualBlock(num_channels) for _ in range(num_blocks)]
- )
- self.convblock = ConvBlock(
- num_channels,
- num_channels,
- kernel_size=3,
- stride=1,
- padding=1,
- use_act=False,
- )
- self.upsampler = nn.Sequential(
- *[
- UpsampleBlock(num_channels, scale_factor=2)
- for _ in range(upscale_factor // 2)
- ]
- )
- self.final_conv = SeperableConv2d(
- num_channels, in_channels, kernel_size=9, stride=1, padding=4
- )
-
- self.load_state_dict(self.state, strict=False)
-
- def forward(self, x):
- initial = self.initial(x)
- x = self.residual(initial)
- x = self.convblock(x) + initial
- x = self.upsampler(x)
- return (torch.tanh(self.final_conv(x)) + 1) / 2
diff --git a/comfy_extras/chainner_models/architecture/Swin2SR.py b/comfy_extras/chainner_models/architecture/Swin2SR.py
deleted file mode 100644
index cb57ecfc4ad..00000000000
--- a/comfy_extras/chainner_models/architecture/Swin2SR.py
+++ /dev/null
@@ -1,1377 +0,0 @@
-# pylint: skip-file
-# -----------------------------------------------------------------------------------
-# Swin2SR: Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration, https://arxiv.org/abs/2209.11345
-# Written by Conde and Choi et al.
-# From: https://raw.githubusercontent.com/mv-lab/swin2sr/main/models/network_swin2sr.py
-# -----------------------------------------------------------------------------------
-
-import math
-import re
-
-import numpy as np
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import torch.utils.checkpoint as checkpoint
-
-# Originally from the timm package
-from .timm.drop import DropPath
-from .timm.helpers import to_2tuple
-from .timm.weight_init import trunc_normal_
-
-
-class Mlp(nn.Module):
- def __init__(
- self,
- in_features,
- hidden_features=None,
- out_features=None,
- act_layer=nn.GELU,
- drop=0.0,
- ):
- super().__init__()
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- self.fc1 = nn.Linear(in_features, hidden_features)
- self.act = act_layer()
- self.fc2 = nn.Linear(hidden_features, out_features)
- self.drop = nn.Dropout(drop)
-
- def forward(self, x):
- x = self.fc1(x)
- x = self.act(x)
- x = self.drop(x)
- x = self.fc2(x)
- x = self.drop(x)
- return x
-
-
-def window_partition(x, window_size):
- """
- Args:
- x: (B, H, W, C)
- window_size (int): window size
- Returns:
- windows: (num_windows*B, window_size, window_size, C)
- """
- B, H, W, C = x.shape
- x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
- windows = (
- x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
- )
- return windows
-
-
-def window_reverse(windows, window_size, H, W):
- """
- Args:
- windows: (num_windows*B, window_size, window_size, C)
- window_size (int): Window size
- H (int): Height of image
- W (int): Width of image
- Returns:
- x: (B, H, W, C)
- """
- B = int(windows.shape[0] / (H * W / window_size / window_size))
- x = windows.view(
- B, H // window_size, W // window_size, window_size, window_size, -1
- )
- x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
- return x
-
-
-class WindowAttention(nn.Module):
- r"""Window based multi-head self attention (W-MSA) module with relative position bias.
- It supports both of shifted and non-shifted window.
- Args:
- dim (int): Number of input channels.
- window_size (tuple[int]): The height and width of the window.
- num_heads (int): Number of attention heads.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
- proj_drop (float, optional): Dropout ratio of output. Default: 0.0
- pretrained_window_size (tuple[int]): The height and width of the window in pre-training.
- """
-
- def __init__(
- self,
- dim,
- window_size,
- num_heads,
- qkv_bias=True,
- attn_drop=0.0,
- proj_drop=0.0,
- pretrained_window_size=[0, 0],
- ):
- super().__init__()
- self.dim = dim
- self.window_size = window_size # Wh, Ww
- self.pretrained_window_size = pretrained_window_size
- self.num_heads = num_heads
-
- self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True) # type: ignore
-
- # mlp to generate continuous relative position bias
- self.cpb_mlp = nn.Sequential(
- nn.Linear(2, 512, bias=True),
- nn.ReLU(inplace=True),
- nn.Linear(512, num_heads, bias=False),
- )
-
- # get relative_coords_table
- relative_coords_h = torch.arange(
- -(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32
- )
- relative_coords_w = torch.arange(
- -(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32
- )
- relative_coords_table = (
- torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w]))
- .permute(1, 2, 0)
- .contiguous()
- .unsqueeze(0)
- ) # 1, 2*Wh-1, 2*Ww-1, 2
- if pretrained_window_size[0] > 0:
- relative_coords_table[:, :, :, 0] /= pretrained_window_size[0] - 1
- relative_coords_table[:, :, :, 1] /= pretrained_window_size[1] - 1
- else:
- relative_coords_table[:, :, :, 0] /= self.window_size[0] - 1
- relative_coords_table[:, :, :, 1] /= self.window_size[1] - 1
- relative_coords_table *= 8 # normalize to -8, 8
- relative_coords_table = (
- torch.sign(relative_coords_table)
- * torch.log2(torch.abs(relative_coords_table) + 1.0)
- / np.log2(8)
- )
-
- self.register_buffer("relative_coords_table", relative_coords_table)
-
- # get pair-wise relative position index for each token inside the window
- coords_h = torch.arange(self.window_size[0])
- coords_w = torch.arange(self.window_size[1])
- coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
- coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
- relative_coords = (
- coords_flatten[:, :, None] - coords_flatten[:, None, :]
- ) # 2, Wh*Ww, Wh*Ww
- relative_coords = relative_coords.permute(
- 1, 2, 0
- ).contiguous() # Wh*Ww, Wh*Ww, 2
- relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
- relative_coords[:, :, 1] += self.window_size[1] - 1
- relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
- relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
- self.register_buffer("relative_position_index", relative_position_index)
-
- self.qkv = nn.Linear(dim, dim * 3, bias=False)
- if qkv_bias:
- self.q_bias = nn.Parameter(torch.zeros(dim)) # type: ignore
- self.v_bias = nn.Parameter(torch.zeros(dim)) # type: ignore
- else:
- self.q_bias = None
- self.v_bias = None
- self.attn_drop = nn.Dropout(attn_drop)
- self.proj = nn.Linear(dim, dim)
- self.proj_drop = nn.Dropout(proj_drop)
- self.softmax = nn.Softmax(dim=-1)
-
- def forward(self, x, mask=None):
- """
- Args:
- x: input features with shape of (num_windows*B, N, C)
- mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
- """
- B_, N, C = x.shape
- qkv_bias = None
- if self.q_bias is not None:
- qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) # type: ignore
- qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
- qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
- q, k, v = (
- qkv[0],
- qkv[1],
- qkv[2],
- ) # make torchscript happy (cannot use tensor as tuple)
-
- # cosine attention
- attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)
- logit_scale = torch.clamp(
- self.logit_scale,
- max=torch.log(torch.tensor(1.0 / 0.01)).to(self.logit_scale.device),
- ).exp()
- attn = attn * logit_scale
-
- relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(
- -1, self.num_heads
- )
- relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view( # type: ignore
- self.window_size[0] * self.window_size[1],
- self.window_size[0] * self.window_size[1],
- -1,
- ) # Wh*Ww,Wh*Ww,nH
- relative_position_bias = relative_position_bias.permute(
- 2, 0, 1
- ).contiguous() # nH, Wh*Ww, Wh*Ww
- relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
- attn = attn + relative_position_bias.unsqueeze(0)
-
- if mask is not None:
- nW = mask.shape[0]
- attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(
- 1
- ).unsqueeze(0)
- attn = attn.view(-1, self.num_heads, N, N)
- attn = self.softmax(attn)
- else:
- attn = self.softmax(attn)
-
- attn = self.attn_drop(attn)
-
- x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
-
- def extra_repr(self) -> str:
- return (
- f"dim={self.dim}, window_size={self.window_size}, "
- f"pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}"
- )
-
- def flops(self, N):
- # calculate flops for 1 window with token length of N
- flops = 0
- # qkv = self.qkv(x)
- flops += N * self.dim * 3 * self.dim
- # attn = (q @ k.transpose(-2, -1))
- flops += self.num_heads * N * (self.dim // self.num_heads) * N
- # x = (attn @ v)
- flops += self.num_heads * N * N * (self.dim // self.num_heads)
- # x = self.proj(x)
- flops += N * self.dim * self.dim
- return flops
-
-
-class SwinTransformerBlock(nn.Module):
- r"""Swin Transformer Block.
- Args:
- dim (int): Number of input channels.
- input_resolution (tuple[int]): Input resulotion.
- num_heads (int): Number of attention heads.
- window_size (int): Window size.
- shift_size (int): Shift size for SW-MSA.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- drop (float, optional): Dropout rate. Default: 0.0
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
- drop_path (float, optional): Stochastic depth rate. Default: 0.0
- act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- pretrained_window_size (int): Window size in pre-training.
- """
-
- def __init__(
- self,
- dim,
- input_resolution,
- num_heads,
- window_size=7,
- shift_size=0,
- mlp_ratio=4.0,
- qkv_bias=True,
- drop=0.0,
- attn_drop=0.0,
- drop_path=0.0,
- act_layer=nn.GELU,
- norm_layer=nn.LayerNorm,
- pretrained_window_size=0,
- ):
- super().__init__()
- self.dim = dim
- self.input_resolution = input_resolution
- self.num_heads = num_heads
- self.window_size = window_size
- self.shift_size = shift_size
- self.mlp_ratio = mlp_ratio
- if min(self.input_resolution) <= self.window_size:
- # if window size is larger than input resolution, we don't partition windows
- self.shift_size = 0
- self.window_size = min(self.input_resolution)
- assert (
- 0 <= self.shift_size < self.window_size
- ), "shift_size must in 0-window_size"
-
- self.norm1 = norm_layer(dim)
- self.attn = WindowAttention(
- dim,
- window_size=to_2tuple(self.window_size),
- num_heads=num_heads,
- qkv_bias=qkv_bias,
- attn_drop=attn_drop,
- proj_drop=drop,
- pretrained_window_size=to_2tuple(pretrained_window_size),
- )
-
- self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
- self.norm2 = norm_layer(dim)
- mlp_hidden_dim = int(dim * mlp_ratio)
- self.mlp = Mlp(
- in_features=dim,
- hidden_features=mlp_hidden_dim,
- act_layer=act_layer,
- drop=drop,
- )
-
- if self.shift_size > 0:
- attn_mask = self.calculate_mask(self.input_resolution)
- else:
- attn_mask = None
-
- self.register_buffer("attn_mask", attn_mask)
-
- def calculate_mask(self, x_size):
- # calculate attention mask for SW-MSA
- H, W = x_size
- img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
- h_slices = (
- slice(0, -self.window_size),
- slice(-self.window_size, -self.shift_size),
- slice(-self.shift_size, None),
- )
- w_slices = (
- slice(0, -self.window_size),
- slice(-self.window_size, -self.shift_size),
- slice(-self.shift_size, None),
- )
- cnt = 0
- for h in h_slices:
- for w in w_slices:
- img_mask[:, h, w, :] = cnt
- cnt += 1
-
- mask_windows = window_partition(
- img_mask, self.window_size
- ) # nW, window_size, window_size, 1
- mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
- attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
- attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
- attn_mask == 0, float(0.0)
- )
-
- return attn_mask
-
- def forward(self, x, x_size):
- H, W = x_size
- B, L, C = x.shape
- # assert L == H * W, "input feature has wrong size"
-
- shortcut = x
- x = x.view(B, H, W, C)
-
- # cyclic shift
- if self.shift_size > 0:
- shifted_x = torch.roll(
- x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
- )
- else:
- shifted_x = x
-
- # partition windows
- x_windows = window_partition(
- shifted_x, self.window_size
- ) # nW*B, window_size, window_size, C
- x_windows = x_windows.view(
- -1, self.window_size * self.window_size, C
- ) # nW*B, window_size*window_size, C
-
- # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
- if self.input_resolution == x_size:
- attn_windows = self.attn(
- x_windows, mask=self.attn_mask
- ) # nW*B, window_size*window_size, C
- else:
- attn_windows = self.attn(
- x_windows, mask=self.calculate_mask(x_size).to(x.device)
- )
-
- # merge windows
- attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
- shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
-
- # reverse cyclic shift
- if self.shift_size > 0:
- x = torch.roll(
- shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)
- )
- else:
- x = shifted_x
- x = x.view(B, H * W, C)
- x = shortcut + self.drop_path(self.norm1(x))
-
- # FFN
- x = x + self.drop_path(self.norm2(self.mlp(x)))
-
- return x
-
- def extra_repr(self) -> str:
- return (
- f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, "
- f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
- )
-
- def flops(self):
- flops = 0
- H, W = self.input_resolution
- # norm1
- flops += self.dim * H * W
- # W-MSA/SW-MSA
- nW = H * W / self.window_size / self.window_size
- flops += nW * self.attn.flops(self.window_size * self.window_size)
- # mlp
- flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
- # norm2
- flops += self.dim * H * W
- return flops
-
-
-class PatchMerging(nn.Module):
- r"""Patch Merging Layer.
- Args:
- input_resolution (tuple[int]): Resolution of input feature.
- dim (int): Number of input channels.
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- """
-
- def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
- super().__init__()
- self.input_resolution = input_resolution
- self.dim = dim
- self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
- self.norm = norm_layer(2 * dim)
-
- def forward(self, x):
- """
- x: B, H*W, C
- """
- H, W = self.input_resolution
- B, L, C = x.shape
- assert L == H * W, "input feature has wrong size"
- assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
-
- x = x.view(B, H, W, C)
-
- x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
- x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
- x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
- x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
- x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
- x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
-
- x = self.reduction(x)
- x = self.norm(x)
-
- return x
-
- def extra_repr(self) -> str:
- return f"input_resolution={self.input_resolution}, dim={self.dim}"
-
- def flops(self):
- H, W = self.input_resolution
- flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
- flops += H * W * self.dim // 2
- return flops
-
-
-class BasicLayer(nn.Module):
- """A basic Swin Transformer layer for one stage.
- Args:
- dim (int): Number of input channels.
- input_resolution (tuple[int]): Input resolution.
- depth (int): Number of blocks.
- num_heads (int): Number of attention heads.
- window_size (int): Local window size.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- drop (float, optional): Dropout rate. Default: 0.0
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
- drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
- pretrained_window_size (int): Local window size in pre-training.
- """
-
- def __init__(
- self,
- dim,
- input_resolution,
- depth,
- num_heads,
- window_size,
- mlp_ratio=4.0,
- qkv_bias=True,
- drop=0.0,
- attn_drop=0.0,
- drop_path=0.0,
- norm_layer=nn.LayerNorm,
- downsample=None,
- use_checkpoint=False,
- pretrained_window_size=0,
- ):
- super().__init__()
- self.dim = dim
- self.input_resolution = input_resolution
- self.depth = depth
- self.use_checkpoint = use_checkpoint
-
- # build blocks
- self.blocks = nn.ModuleList(
- [
- SwinTransformerBlock(
- dim=dim,
- input_resolution=input_resolution,
- num_heads=num_heads,
- window_size=window_size,
- shift_size=0 if (i % 2 == 0) else window_size // 2,
- mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias,
- drop=drop,
- attn_drop=attn_drop,
- drop_path=drop_path[i]
- if isinstance(drop_path, list)
- else drop_path,
- norm_layer=norm_layer,
- pretrained_window_size=pretrained_window_size,
- )
- for i in range(depth)
- ]
- )
-
- # patch merging layer
- if downsample is not None:
- self.downsample = downsample(
- input_resolution, dim=dim, norm_layer=norm_layer
- )
- else:
- self.downsample = None
-
- def forward(self, x, x_size):
- for blk in self.blocks:
- if self.use_checkpoint:
- x = checkpoint.checkpoint(blk, x, x_size)
- else:
- x = blk(x, x_size)
- if self.downsample is not None:
- x = self.downsample(x)
- return x
-
- def extra_repr(self) -> str:
- return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
-
- def flops(self):
- flops = 0
- for blk in self.blocks:
- flops += blk.flops() # type: ignore
- if self.downsample is not None:
- flops += self.downsample.flops()
- return flops
-
- def _init_respostnorm(self):
- for blk in self.blocks:
- nn.init.constant_(blk.norm1.bias, 0) # type: ignore
- nn.init.constant_(blk.norm1.weight, 0) # type: ignore
- nn.init.constant_(blk.norm2.bias, 0) # type: ignore
- nn.init.constant_(blk.norm2.weight, 0) # type: ignore
-
-
-class PatchEmbed(nn.Module):
- r"""Image to Patch Embedding
- Args:
- img_size (int): Image size. Default: 224.
- patch_size (int): Patch token size. Default: 4.
- in_chans (int): Number of input image channels. Default: 3.
- embed_dim (int): Number of linear projection output channels. Default: 96.
- norm_layer (nn.Module, optional): Normalization layer. Default: None
- """
-
- def __init__(
- self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None
- ):
- super().__init__()
- img_size = to_2tuple(img_size)
- patch_size = to_2tuple(patch_size)
- patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] # type: ignore
- self.img_size = img_size
- self.patch_size = patch_size
- self.patches_resolution = patches_resolution
- self.num_patches = patches_resolution[0] * patches_resolution[1]
-
- self.in_chans = in_chans
- self.embed_dim = embed_dim
-
- self.proj = nn.Conv2d(
- in_chans, embed_dim, kernel_size=patch_size, stride=patch_size # type: ignore
- )
- if norm_layer is not None:
- self.norm = norm_layer(embed_dim)
- else:
- self.norm = None
-
- def forward(self, x):
- B, C, H, W = x.shape
- # FIXME look at relaxing size constraints
- # assert H == self.img_size[0] and W == self.img_size[1],
- # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
- x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
- if self.norm is not None:
- x = self.norm(x)
- return x
-
- def flops(self):
- Ho, Wo = self.patches_resolution
- flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) # type: ignore
- if self.norm is not None:
- flops += Ho * Wo * self.embed_dim
- return flops
-
-
-class RSTB(nn.Module):
- """Residual Swin Transformer Block (RSTB).
-
- Args:
- dim (int): Number of input channels.
- input_resolution (tuple[int]): Input resolution.
- depth (int): Number of blocks.
- num_heads (int): Number of attention heads.
- window_size (int): Local window size.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- drop (float, optional): Dropout rate. Default: 0.0
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
- drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
- img_size: Input image size.
- patch_size: Patch size.
- resi_connection: The convolutional block before residual connection.
- """
-
- def __init__(
- self,
- dim,
- input_resolution,
- depth,
- num_heads,
- window_size,
- mlp_ratio=4.0,
- qkv_bias=True,
- drop=0.0,
- attn_drop=0.0,
- drop_path=0.0,
- norm_layer=nn.LayerNorm,
- downsample=None,
- use_checkpoint=False,
- img_size=224,
- patch_size=4,
- resi_connection="1conv",
- ):
- super(RSTB, self).__init__()
-
- self.dim = dim
- self.input_resolution = input_resolution
-
- self.residual_group = BasicLayer(
- dim=dim,
- input_resolution=input_resolution,
- depth=depth,
- num_heads=num_heads,
- window_size=window_size,
- mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias,
- drop=drop,
- attn_drop=attn_drop,
- drop_path=drop_path,
- norm_layer=norm_layer,
- downsample=downsample,
- use_checkpoint=use_checkpoint,
- )
-
- if resi_connection == "1conv":
- self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
- elif resi_connection == "3conv":
- # to save parameters and memory
- self.conv = nn.Sequential(
- nn.Conv2d(dim, dim // 4, 3, 1, 1),
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
- nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
- nn.Conv2d(dim // 4, dim, 3, 1, 1),
- )
-
- self.patch_embed = PatchEmbed(
- img_size=img_size,
- patch_size=patch_size,
- in_chans=dim,
- embed_dim=dim,
- norm_layer=None,
- )
-
- self.patch_unembed = PatchUnEmbed(
- img_size=img_size,
- patch_size=patch_size,
- in_chans=dim,
- embed_dim=dim,
- norm_layer=None,
- )
-
- def forward(self, x, x_size):
- return (
- self.patch_embed(
- self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))
- )
- + x
- )
-
- def flops(self):
- flops = 0
- flops += self.residual_group.flops()
- H, W = self.input_resolution
- flops += H * W * self.dim * self.dim * 9
- flops += self.patch_embed.flops()
- flops += self.patch_unembed.flops()
-
- return flops
-
-
-class PatchUnEmbed(nn.Module):
- r"""Image to Patch Unembedding
-
- Args:
- img_size (int): Image size. Default: 224.
- patch_size (int): Patch token size. Default: 4.
- in_chans (int): Number of input image channels. Default: 3.
- embed_dim (int): Number of linear projection output channels. Default: 96.
- norm_layer (nn.Module, optional): Normalization layer. Default: None
- """
-
- def __init__(
- self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None
- ):
- super().__init__()
- img_size = to_2tuple(img_size)
- patch_size = to_2tuple(patch_size)
- patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] # type: ignore
- self.img_size = img_size
- self.patch_size = patch_size
- self.patches_resolution = patches_resolution
- self.num_patches = patches_resolution[0] * patches_resolution[1]
-
- self.in_chans = in_chans
- self.embed_dim = embed_dim
-
- def forward(self, x, x_size):
- B, HW, C = x.shape
- x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
- return x
-
- def flops(self):
- flops = 0
- return flops
-
-
-class Upsample(nn.Sequential):
- """Upsample module.
-
- Args:
- scale (int): Scale factor. Supported scales: 2^n and 3.
- num_feat (int): Channel number of intermediate features.
- """
-
- def __init__(self, scale, num_feat):
- m = []
- if (scale & (scale - 1)) == 0: # scale = 2^n
- for _ in range(int(math.log(scale, 2))):
- m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
- m.append(nn.PixelShuffle(2))
- elif scale == 3:
- m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
- m.append(nn.PixelShuffle(3))
- else:
- raise ValueError(
- f"scale {scale} is not supported. " "Supported scales: 2^n and 3."
- )
- super(Upsample, self).__init__(*m)
-
-
-class Upsample_hf(nn.Sequential):
- """Upsample module.
-
- Args:
- scale (int): Scale factor. Supported scales: 2^n and 3.
- num_feat (int): Channel number of intermediate features.
- """
-
- def __init__(self, scale, num_feat):
- m = []
- if (scale & (scale - 1)) == 0: # scale = 2^n
- for _ in range(int(math.log(scale, 2))):
- m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
- m.append(nn.PixelShuffle(2))
- elif scale == 3:
- m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
- m.append(nn.PixelShuffle(3))
- else:
- raise ValueError(
- f"scale {scale} is not supported. " "Supported scales: 2^n and 3."
- )
- super(Upsample_hf, self).__init__(*m)
-
-
-class UpsampleOneStep(nn.Sequential):
- """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
- Used in lightweight SR to save parameters.
-
- Args:
- scale (int): Scale factor. Supported scales: 2^n and 3.
- num_feat (int): Channel number of intermediate features.
-
- """
-
- def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
- self.num_feat = num_feat
- self.input_resolution = input_resolution
- m = []
- m.append(nn.Conv2d(num_feat, (scale**2) * num_out_ch, 3, 1, 1))
- m.append(nn.PixelShuffle(scale))
- super(UpsampleOneStep, self).__init__(*m)
-
- def flops(self):
- H, W = self.input_resolution # type: ignore
- flops = H * W * self.num_feat * 3 * 9
- return flops
-
-
-class Swin2SR(nn.Module):
- r"""Swin2SR
- A PyTorch impl of : `Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration`.
-
- Args:
- img_size (int | tuple(int)): Input image size. Default 64
- patch_size (int | tuple(int)): Patch size. Default: 1
- in_chans (int): Number of input image channels. Default: 3
- embed_dim (int): Patch embedding dimension. Default: 96
- depths (tuple(int)): Depth of each Swin Transformer layer.
- num_heads (tuple(int)): Number of attention heads in different layers.
- window_size (int): Window size. Default: 7
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
- qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
- drop_rate (float): Dropout rate. Default: 0
- attn_drop_rate (float): Attention dropout rate. Default: 0
- drop_path_rate (float): Stochastic depth rate. Default: 0.1
- norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
- ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
- patch_norm (bool): If True, add normalization after patch embedding. Default: True
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
- upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
- img_range: Image range. 1. or 255.
- upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
- resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
- """
-
- def __init__(
- self,
- state_dict,
- **kwargs,
- ):
- super(Swin2SR, self).__init__()
-
- # Defaults
- img_size = 128
- patch_size = 1
- in_chans = 3
- embed_dim = 96
- depths = [6, 6, 6, 6]
- num_heads = [6, 6, 6, 6]
- window_size = 7
- mlp_ratio = 4.0
- qkv_bias = True
- drop_rate = 0.0
- attn_drop_rate = 0.0
- drop_path_rate = 0.1
- norm_layer = nn.LayerNorm
- ape = False
- patch_norm = True
- use_checkpoint = False
- upscale = 2
- img_range = 1.0
- upsampler = ""
- resi_connection = "1conv"
- num_in_ch = in_chans
- num_out_ch = in_chans
- num_feat = 64
-
- self.model_arch = "Swin2SR"
- self.sub_type = "SR"
- self.state = state_dict
- if "params_ema" in self.state:
- self.state = self.state["params_ema"]
- elif "params" in self.state:
- self.state = self.state["params"]
-
- state_keys = self.state.keys()
-
- if "conv_before_upsample.0.weight" in state_keys:
- if "conv_aux.weight" in state_keys:
- upsampler = "pixelshuffle_aux"
- elif "conv_up1.weight" in state_keys:
- upsampler = "nearest+conv"
- else:
- upsampler = "pixelshuffle"
- supports_fp16 = False
- elif "upsample.0.weight" in state_keys:
- upsampler = "pixelshuffledirect"
- else:
- upsampler = ""
-
- num_feat = (
- self.state.get("conv_before_upsample.0.weight", None).shape[1]
- if self.state.get("conv_before_upsample.weight", None)
- else 64
- )
-
- num_in_ch = self.state["conv_first.weight"].shape[1]
- in_chans = num_in_ch
- if "conv_last.weight" in state_keys:
- num_out_ch = self.state["conv_last.weight"].shape[0]
- else:
- num_out_ch = num_in_ch
-
- upscale = 1
- if upsampler == "nearest+conv":
- upsample_keys = [
- x for x in state_keys if "conv_up" in x and "bias" not in x
- ]
-
- for upsample_key in upsample_keys:
- upscale *= 2
- elif upsampler == "pixelshuffle" or upsampler == "pixelshuffle_aux":
- upsample_keys = [
- x
- for x in state_keys
- if "upsample" in x and "conv" not in x and "bias" not in x
- ]
- for upsample_key in upsample_keys:
- shape = self.state[upsample_key].shape[0]
- upscale *= math.sqrt(shape // num_feat)
- upscale = int(upscale)
- elif upsampler == "pixelshuffledirect":
- upscale = int(
- math.sqrt(self.state["upsample.0.bias"].shape[0] // num_out_ch)
- )
-
- max_layer_num = 0
- max_block_num = 0
- for key in state_keys:
- result = re.match(
- r"layers.(\d*).residual_group.blocks.(\d*).norm1.weight", key
- )
- if result:
- layer_num, block_num = result.groups()
- max_layer_num = max(max_layer_num, int(layer_num))
- max_block_num = max(max_block_num, int(block_num))
-
- depths = [max_block_num + 1 for _ in range(max_layer_num + 1)]
-
- if (
- "layers.0.residual_group.blocks.0.attn.relative_position_bias_table"
- in state_keys
- ):
- num_heads_num = self.state[
- "layers.0.residual_group.blocks.0.attn.relative_position_bias_table"
- ].shape[-1]
- num_heads = [num_heads_num for _ in range(max_layer_num + 1)]
- else:
- num_heads = depths
-
- embed_dim = self.state["conv_first.weight"].shape[0]
-
- mlp_ratio = float(
- self.state["layers.0.residual_group.blocks.0.mlp.fc1.bias"].shape[0]
- / embed_dim
- )
-
- # TODO: could actually count the layers, but this should do
- if "layers.0.conv.4.weight" in state_keys:
- resi_connection = "3conv"
- else:
- resi_connection = "1conv"
-
- window_size = int(
- math.sqrt(
- self.state[
- "layers.0.residual_group.blocks.0.attn.relative_position_index"
- ].shape[0]
- )
- )
-
- if "layers.0.residual_group.blocks.1.attn_mask" in state_keys:
- img_size = int(
- math.sqrt(
- self.state["layers.0.residual_group.blocks.1.attn_mask"].shape[0]
- )
- * window_size
- )
-
- # The JPEG models are the only ones with window-size 7, and they also use this range
- img_range = 255.0 if window_size == 7 else 1.0
-
- self.in_nc = num_in_ch
- self.out_nc = num_out_ch
- self.num_feat = num_feat
- self.embed_dim = embed_dim
- self.num_heads = num_heads
- self.depths = depths
- self.window_size = window_size
- self.mlp_ratio = mlp_ratio
- self.scale = upscale
- self.upsampler = upsampler
- self.img_size = img_size
- self.img_range = img_range
- self.resi_connection = resi_connection
-
- self.supports_fp16 = False # Too much weirdness to support this at the moment
- self.supports_bfp16 = True
- self.min_size_restriction = 16
-
- ## END AUTO DETECTION
-
- if in_chans == 3:
- rgb_mean = (0.4488, 0.4371, 0.4040)
- self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
- else:
- self.mean = torch.zeros(1, 1, 1, 1)
- self.upscale = upscale
- self.upsampler = upsampler
- self.window_size = window_size
-
- #####################################################################################################
- ################################### 1, shallow feature extraction ###################################
- self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
-
- #####################################################################################################
- ################################### 2, deep feature extraction ######################################
- self.num_layers = len(depths)
- self.embed_dim = embed_dim
- self.ape = ape
- self.patch_norm = patch_norm
- self.num_features = embed_dim
- self.mlp_ratio = mlp_ratio
-
- # split image into non-overlapping patches
- self.patch_embed = PatchEmbed(
- img_size=img_size,
- patch_size=patch_size,
- in_chans=embed_dim,
- embed_dim=embed_dim,
- norm_layer=norm_layer if self.patch_norm else None,
- )
- num_patches = self.patch_embed.num_patches
- patches_resolution = self.patch_embed.patches_resolution
- self.patches_resolution = patches_resolution
-
- # merge non-overlapping patches into image
- self.patch_unembed = PatchUnEmbed(
- img_size=img_size,
- patch_size=patch_size,
- in_chans=embed_dim,
- embed_dim=embed_dim,
- norm_layer=norm_layer if self.patch_norm else None,
- )
-
- # absolute position embedding
- if self.ape:
- self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) # type: ignore
- trunc_normal_(self.absolute_pos_embed, std=0.02)
-
- self.pos_drop = nn.Dropout(p=drop_rate)
-
- # stochastic depth
- dpr = [
- x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
- ] # stochastic depth decay rule
-
- # build Residual Swin Transformer blocks (RSTB)
- self.layers = nn.ModuleList()
- for i_layer in range(self.num_layers):
- layer = RSTB(
- dim=embed_dim,
- input_resolution=(patches_resolution[0], patches_resolution[1]),
- depth=depths[i_layer],
- num_heads=num_heads[i_layer],
- window_size=window_size,
- mlp_ratio=self.mlp_ratio,
- qkv_bias=qkv_bias,
- drop=drop_rate,
- attn_drop=attn_drop_rate,
- drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])], # type: ignore # no impact on SR results
- norm_layer=norm_layer,
- downsample=None,
- use_checkpoint=use_checkpoint,
- img_size=img_size,
- patch_size=patch_size,
- resi_connection=resi_connection,
- )
- self.layers.append(layer)
-
- if self.upsampler == "pixelshuffle_hf":
- self.layers_hf = nn.ModuleList()
- for i_layer in range(self.num_layers):
- layer = RSTB(
- dim=embed_dim,
- input_resolution=(patches_resolution[0], patches_resolution[1]),
- depth=depths[i_layer],
- num_heads=num_heads[i_layer],
- window_size=window_size,
- mlp_ratio=self.mlp_ratio,
- qkv_bias=qkv_bias,
- drop=drop_rate,
- attn_drop=attn_drop_rate,
- drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])], # type: ignore # no impact on SR results # type: ignore
- norm_layer=norm_layer,
- downsample=None,
- use_checkpoint=use_checkpoint,
- img_size=img_size,
- patch_size=patch_size,
- resi_connection=resi_connection,
- )
- self.layers_hf.append(layer)
-
- self.norm = norm_layer(self.num_features)
-
- # build the last conv layer in deep feature extraction
- if resi_connection == "1conv":
- self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
- elif resi_connection == "3conv":
- # to save parameters and memory
- self.conv_after_body = nn.Sequential(
- nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
- nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
- nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1),
- )
-
- #####################################################################################################
- ################################ 3, high quality image reconstruction ################################
- if self.upsampler == "pixelshuffle":
- # for classical SR
- self.conv_before_upsample = nn.Sequential(
- nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)
- )
- self.upsample = Upsample(upscale, num_feat)
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
- elif self.upsampler == "pixelshuffle_aux":
- self.conv_bicubic = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
- self.conv_before_upsample = nn.Sequential(
- nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)
- )
- self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
- self.conv_after_aux = nn.Sequential(
- nn.Conv2d(3, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)
- )
- self.upsample = Upsample(upscale, num_feat)
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
-
- elif self.upsampler == "pixelshuffle_hf":
- self.conv_before_upsample = nn.Sequential(
- nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)
- )
- self.upsample = Upsample(upscale, num_feat)
- self.upsample_hf = Upsample_hf(upscale, num_feat)
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
- self.conv_first_hf = nn.Sequential(
- nn.Conv2d(num_feat, embed_dim, 3, 1, 1), nn.LeakyReLU(inplace=True)
- )
- self.conv_after_body_hf = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
- self.conv_before_upsample_hf = nn.Sequential(
- nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)
- )
- self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
-
- elif self.upsampler == "pixelshuffledirect":
- # for lightweight SR (to save parameters)
- self.upsample = UpsampleOneStep(
- upscale,
- embed_dim,
- num_out_ch,
- (patches_resolution[0], patches_resolution[1]),
- )
- elif self.upsampler == "nearest+conv":
- # for real-world SR (less artifacts)
- assert self.upscale == 4, "only support x4 now."
- self.conv_before_upsample = nn.Sequential(
- nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)
- )
- self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
- self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
- self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
- self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
- else:
- # for image denoising and JPEG compression artifact reduction
- self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
-
- self.apply(self._init_weights)
-
- self.load_state_dict(state_dict)
-
- def _init_weights(self, m):
- if isinstance(m, nn.Linear):
- trunc_normal_(m.weight, std=0.02)
- if isinstance(m, nn.Linear) and m.bias is not None:
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.LayerNorm):
- nn.init.constant_(m.bias, 0)
- nn.init.constant_(m.weight, 1.0)
-
- @torch.jit.ignore # type: ignore
- def no_weight_decay(self):
- return {"absolute_pos_embed"}
-
- @torch.jit.ignore # type: ignore
- def no_weight_decay_keywords(self):
- return {"relative_position_bias_table"}
-
- def check_image_size(self, x):
- _, _, h, w = x.size()
- mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
- mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
- x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "reflect")
- return x
-
- def forward_features(self, x):
- x_size = (x.shape[2], x.shape[3])
- x = self.patch_embed(x)
- if self.ape:
- x = x + self.absolute_pos_embed
- x = self.pos_drop(x)
-
- for layer in self.layers:
- x = layer(x, x_size)
-
- x = self.norm(x) # B L C
- x = self.patch_unembed(x, x_size)
-
- return x
-
- def forward_features_hf(self, x):
- x_size = (x.shape[2], x.shape[3])
- x = self.patch_embed(x)
- if self.ape:
- x = x + self.absolute_pos_embed
- x = self.pos_drop(x)
-
- for layer in self.layers_hf:
- x = layer(x, x_size)
-
- x = self.norm(x) # B L C
- x = self.patch_unembed(x, x_size)
-
- return x
-
- def forward(self, x):
- H, W = x.shape[2:]
- x = self.check_image_size(x)
-
- self.mean = self.mean.type_as(x)
- x = (x - self.mean) * self.img_range
-
- if self.upsampler == "pixelshuffle":
- # for classical SR
- x = self.conv_first(x)
- x = self.conv_after_body(self.forward_features(x)) + x
- x = self.conv_before_upsample(x)
- x = self.conv_last(self.upsample(x))
- elif self.upsampler == "pixelshuffle_aux":
- bicubic = F.interpolate(
- x,
- size=(H * self.upscale, W * self.upscale),
- mode="bicubic",
- align_corners=False,
- )
- bicubic = self.conv_bicubic(bicubic)
- x = self.conv_first(x)
- x = self.conv_after_body(self.forward_features(x)) + x
- x = self.conv_before_upsample(x)
- aux = self.conv_aux(x) # b, 3, LR_H, LR_W
- x = self.conv_after_aux(aux)
- x = (
- self.upsample(x)[:, :, : H * self.upscale, : W * self.upscale]
- + bicubic[:, :, : H * self.upscale, : W * self.upscale]
- )
- x = self.conv_last(x)
- aux = aux / self.img_range + self.mean
- elif self.upsampler == "pixelshuffle_hf":
- # for classical SR with HF
- x = self.conv_first(x)
- x = self.conv_after_body(self.forward_features(x)) + x
- x_before = self.conv_before_upsample(x)
- x_out = self.conv_last(self.upsample(x_before))
-
- x_hf = self.conv_first_hf(x_before)
- x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf
- x_hf = self.conv_before_upsample_hf(x_hf)
- x_hf = self.conv_last_hf(self.upsample_hf(x_hf))
- x = x_out + x_hf
- x_hf = x_hf / self.img_range + self.mean
-
- elif self.upsampler == "pixelshuffledirect":
- # for lightweight SR
- x = self.conv_first(x)
- x = self.conv_after_body(self.forward_features(x)) + x
- x = self.upsample(x)
- elif self.upsampler == "nearest+conv":
- # for real-world SR
- x = self.conv_first(x)
- x = self.conv_after_body(self.forward_features(x)) + x
- x = self.conv_before_upsample(x)
- x = self.lrelu(
- self.conv_up1(
- torch.nn.functional.interpolate(x, scale_factor=2, mode="nearest")
- )
- )
- x = self.lrelu(
- self.conv_up2(
- torch.nn.functional.interpolate(x, scale_factor=2, mode="nearest")
- )
- )
- x = self.conv_last(self.lrelu(self.conv_hr(x)))
- else:
- # for image denoising and JPEG compression artifact reduction
- x_first = self.conv_first(x)
- res = self.conv_after_body(self.forward_features(x_first)) + x_first
- x = x + self.conv_last(res)
-
- x = x / self.img_range + self.mean
- if self.upsampler == "pixelshuffle_aux":
- # NOTE: I removed an "aux" output here. not sure what that was for
- return x[:, :, : H * self.upscale, : W * self.upscale] # type: ignore
-
- elif self.upsampler == "pixelshuffle_hf":
- x_out = x_out / self.img_range + self.mean # type: ignore
- return x_out[:, :, : H * self.upscale, : W * self.upscale], x[:, :, : H * self.upscale, : W * self.upscale], x_hf[:, :, : H * self.upscale, : W * self.upscale] # type: ignore
-
- else:
- return x[:, :, : H * self.upscale, : W * self.upscale]
-
- def flops(self):
- flops = 0
- H, W = self.patches_resolution
- flops += H * W * 3 * self.embed_dim * 9
- flops += self.patch_embed.flops()
- for i, layer in enumerate(self.layers):
- flops += layer.flops() # type: ignore
- flops += H * W * 3 * self.embed_dim * self.embed_dim
- flops += self.upsample.flops() # type: ignore
- return flops
diff --git a/comfy_extras/chainner_models/architecture/SwinIR.py b/comfy_extras/chainner_models/architecture/SwinIR.py
deleted file mode 100644
index 439dcbcb2b1..00000000000
--- a/comfy_extras/chainner_models/architecture/SwinIR.py
+++ /dev/null
@@ -1,1224 +0,0 @@
-# pylint: skip-file
-# -----------------------------------------------------------------------------------
-# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257
-# Originally Written by Ze Liu, Modified by Jingyun Liang.
-# -----------------------------------------------------------------------------------
-
-import math
-import re
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import torch.utils.checkpoint as checkpoint
-
-# Originally from the timm package
-from .timm.drop import DropPath
-from .timm.helpers import to_2tuple
-from .timm.weight_init import trunc_normal_
-
-
-class Mlp(nn.Module):
- def __init__(
- self,
- in_features,
- hidden_features=None,
- out_features=None,
- act_layer=nn.GELU,
- drop=0.0,
- ):
- super().__init__()
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- self.fc1 = nn.Linear(in_features, hidden_features)
- self.act = act_layer()
- self.fc2 = nn.Linear(hidden_features, out_features)
- self.drop = nn.Dropout(drop)
-
- def forward(self, x):
- x = self.fc1(x)
- x = self.act(x)
- x = self.drop(x)
- x = self.fc2(x)
- x = self.drop(x)
- return x
-
-
-def window_partition(x, window_size):
- """
- Args:
- x: (B, H, W, C)
- window_size (int): window size
-
- Returns:
- windows: (num_windows*B, window_size, window_size, C)
- """
- B, H, W, C = x.shape
- x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
- windows = (
- x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
- )
- return windows
-
-
-def window_reverse(windows, window_size, H, W):
- """
- Args:
- windows: (num_windows*B, window_size, window_size, C)
- window_size (int): Window size
- H (int): Height of image
- W (int): Width of image
-
- Returns:
- x: (B, H, W, C)
- """
- B = int(windows.shape[0] / (H * W / window_size / window_size))
- x = windows.view(
- B, H // window_size, W // window_size, window_size, window_size, -1
- )
- x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
- return x
-
-
-class WindowAttention(nn.Module):
- r"""Window based multi-head self attention (W-MSA) module with relative position bias.
- It supports both of shifted and non-shifted window.
-
- Args:
- dim (int): Number of input channels.
- window_size (tuple[int]): The height and width of the window.
- num_heads (int): Number of attention heads.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
- attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
- proj_drop (float, optional): Dropout ratio of output. Default: 0.0
- """
-
- def __init__(
- self,
- dim,
- window_size,
- num_heads,
- qkv_bias=True,
- qk_scale=None,
- attn_drop=0.0,
- proj_drop=0.0,
- ):
- super().__init__()
- self.dim = dim
- self.window_size = window_size # Wh, Ww
- self.num_heads = num_heads
- head_dim = dim // num_heads
- self.scale = qk_scale or head_dim**-0.5
-
- # define a parameter table of relative position bias
- self.relative_position_bias_table = nn.Parameter( # type: ignore
- torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
- ) # 2*Wh-1 * 2*Ww-1, nH
-
- # get pair-wise relative position index for each token inside the window
- coords_h = torch.arange(self.window_size[0])
- coords_w = torch.arange(self.window_size[1])
- coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
- coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
- relative_coords = (
- coords_flatten[:, :, None] - coords_flatten[:, None, :]
- ) # 2, Wh*Ww, Wh*Ww
- relative_coords = relative_coords.permute(
- 1, 2, 0
- ).contiguous() # Wh*Ww, Wh*Ww, 2
- relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
- relative_coords[:, :, 1] += self.window_size[1] - 1
- relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
- relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
- self.register_buffer("relative_position_index", relative_position_index)
-
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
- self.attn_drop = nn.Dropout(attn_drop)
- self.proj = nn.Linear(dim, dim)
-
- self.proj_drop = nn.Dropout(proj_drop)
-
- trunc_normal_(self.relative_position_bias_table, std=0.02)
- self.softmax = nn.Softmax(dim=-1)
-
- def forward(self, x, mask=None):
- """
- Args:
- x: input features with shape of (num_windows*B, N, C)
- mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
- """
- B_, N, C = x.shape
- qkv = (
- self.qkv(x)
- .reshape(B_, N, 3, self.num_heads, C // self.num_heads)
- .permute(2, 0, 3, 1, 4)
- )
- q, k, v = (
- qkv[0],
- qkv[1],
- qkv[2],
- ) # make torchscript happy (cannot use tensor as tuple)
-
- q = q * self.scale
- attn = q @ k.transpose(-2, -1)
-
- relative_position_bias = self.relative_position_bias_table[
- self.relative_position_index.view(-1) # type: ignore
- ].view(
- self.window_size[0] * self.window_size[1],
- self.window_size[0] * self.window_size[1],
- -1,
- ) # Wh*Ww,Wh*Ww,nH
- relative_position_bias = relative_position_bias.permute(
- 2, 0, 1
- ).contiguous() # nH, Wh*Ww, Wh*Ww
- attn = attn + relative_position_bias.unsqueeze(0)
-
- if mask is not None:
- nW = mask.shape[0]
- attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(
- 1
- ).unsqueeze(0)
- attn = attn.view(-1, self.num_heads, N, N)
- attn = self.softmax(attn)
- else:
- attn = self.softmax(attn)
-
- attn = self.attn_drop(attn)
-
- x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
-
- def extra_repr(self) -> str:
- return f"dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}"
-
- def flops(self, N):
- # calculate flops for 1 window with token length of N
- flops = 0
- # qkv = self.qkv(x)
- flops += N * self.dim * 3 * self.dim
- # attn = (q @ k.transpose(-2, -1))
- flops += self.num_heads * N * (self.dim // self.num_heads) * N
- # x = (attn @ v)
- flops += self.num_heads * N * N * (self.dim // self.num_heads)
- # x = self.proj(x)
- flops += N * self.dim * self.dim
- return flops
-
-
-class SwinTransformerBlock(nn.Module):
- r"""Swin Transformer Block.
-
- Args:
- dim (int): Number of input channels.
- input_resolution (tuple[int]): Input resulotion.
- num_heads (int): Number of attention heads.
- window_size (int): Window size.
- shift_size (int): Shift size for SW-MSA.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
- drop (float, optional): Dropout rate. Default: 0.0
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
- drop_path (float, optional): Stochastic depth rate. Default: 0.0
- act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- """
-
- def __init__(
- self,
- dim,
- input_resolution,
- num_heads,
- window_size=7,
- shift_size=0,
- mlp_ratio=4.0,
- qkv_bias=True,
- qk_scale=None,
- drop=0.0,
- attn_drop=0.0,
- drop_path=0.0,
- act_layer=nn.GELU,
- norm_layer=nn.LayerNorm,
- ):
- super().__init__()
- self.dim = dim
- self.input_resolution = input_resolution
- self.num_heads = num_heads
- self.window_size = window_size
- self.shift_size = shift_size
- self.mlp_ratio = mlp_ratio
- if min(self.input_resolution) <= self.window_size:
- # if window size is larger than input resolution, we don't partition windows
- self.shift_size = 0
- self.window_size = min(self.input_resolution)
- assert (
- 0 <= self.shift_size < self.window_size
- ), "shift_size must in 0-window_size"
-
- self.norm1 = norm_layer(dim)
- self.attn = WindowAttention(
- dim,
- window_size=to_2tuple(self.window_size),
- num_heads=num_heads,
- qkv_bias=qkv_bias,
- qk_scale=qk_scale,
- attn_drop=attn_drop,
- proj_drop=drop,
- )
-
- self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
- self.norm2 = norm_layer(dim)
- mlp_hidden_dim = int(dim * mlp_ratio)
- self.mlp = Mlp(
- in_features=dim,
- hidden_features=mlp_hidden_dim,
- act_layer=act_layer,
- drop=drop,
- )
-
- if self.shift_size > 0:
- attn_mask = self.calculate_mask(self.input_resolution)
- else:
- attn_mask = None
-
- self.register_buffer("attn_mask", attn_mask)
-
- def calculate_mask(self, x_size):
- # calculate attention mask for SW-MSA
- H, W = x_size
- img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
- h_slices = (
- slice(0, -self.window_size),
- slice(-self.window_size, -self.shift_size),
- slice(-self.shift_size, None),
- )
- w_slices = (
- slice(0, -self.window_size),
- slice(-self.window_size, -self.shift_size),
- slice(-self.shift_size, None),
- )
- cnt = 0
- for h in h_slices:
- for w in w_slices:
- img_mask[:, h, w, :] = cnt
- cnt += 1
-
- mask_windows = window_partition(
- img_mask, self.window_size
- ) # nW, window_size, window_size, 1
- mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
- attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
- attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
- attn_mask == 0, float(0.0)
- )
-
- return attn_mask
-
- def forward(self, x, x_size):
- H, W = x_size
- B, L, C = x.shape
- # assert L == H * W, "input feature has wrong size"
-
- shortcut = x
- x = self.norm1(x)
- x = x.view(B, H, W, C)
-
- # cyclic shift
- if self.shift_size > 0:
- shifted_x = torch.roll(
- x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
- )
- else:
- shifted_x = x
-
- # partition windows
- x_windows = window_partition(
- shifted_x, self.window_size
- ) # nW*B, window_size, window_size, C
- x_windows = x_windows.view(
- -1, self.window_size * self.window_size, C
- ) # nW*B, window_size*window_size, C
-
- # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
- if self.input_resolution == x_size:
- attn_windows = self.attn(
- x_windows, mask=self.attn_mask
- ) # nW*B, window_size*window_size, C
- else:
- attn_windows = self.attn(
- x_windows, mask=self.calculate_mask(x_size).to(x.device)
- )
-
- # merge windows
- attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
- shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
-
- # reverse cyclic shift
- if self.shift_size > 0:
- x = torch.roll(
- shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)
- )
- else:
- x = shifted_x
- x = x.view(B, H * W, C)
-
- # FFN
- x = shortcut + self.drop_path(x)
- x = x + self.drop_path(self.mlp(self.norm2(x)))
-
- return x
-
- def extra_repr(self) -> str:
- return (
- f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, "
- f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
- )
-
- def flops(self):
- flops = 0
- H, W = self.input_resolution
- # norm1
- flops += self.dim * H * W
- # W-MSA/SW-MSA
- nW = H * W / self.window_size / self.window_size
- flops += nW * self.attn.flops(self.window_size * self.window_size)
- # mlp
- flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
- # norm2
- flops += self.dim * H * W
- return flops
-
-
-class PatchMerging(nn.Module):
- r"""Patch Merging Layer.
-
- Args:
- input_resolution (tuple[int]): Resolution of input feature.
- dim (int): Number of input channels.
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- """
-
- def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
- super().__init__()
- self.input_resolution = input_resolution
- self.dim = dim
- self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
- self.norm = norm_layer(4 * dim)
-
- def forward(self, x):
- """
- x: B, H*W, C
- """
- H, W = self.input_resolution
- B, L, C = x.shape
- assert L == H * W, "input feature has wrong size"
- assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
-
- x = x.view(B, H, W, C)
-
- x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
- x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
- x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
- x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
- x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
- x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
-
- x = self.norm(x)
- x = self.reduction(x)
-
- return x
-
- def extra_repr(self) -> str:
- return f"input_resolution={self.input_resolution}, dim={self.dim}"
-
- def flops(self):
- H, W = self.input_resolution
- flops = H * W * self.dim
- flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
- return flops
-
-
-class BasicLayer(nn.Module):
- """A basic Swin Transformer layer for one stage.
-
- Args:
- dim (int): Number of input channels.
- input_resolution (tuple[int]): Input resolution.
- depth (int): Number of blocks.
- num_heads (int): Number of attention heads.
- window_size (int): Local window size.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
- drop (float, optional): Dropout rate. Default: 0.0
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
- drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
- """
-
- def __init__(
- self,
- dim,
- input_resolution,
- depth,
- num_heads,
- window_size,
- mlp_ratio=4.0,
- qkv_bias=True,
- qk_scale=None,
- drop=0.0,
- attn_drop=0.0,
- drop_path=0.0,
- norm_layer=nn.LayerNorm,
- downsample=None,
- use_checkpoint=False,
- ):
- super().__init__()
- self.dim = dim
- self.input_resolution = input_resolution
- self.depth = depth
- self.use_checkpoint = use_checkpoint
-
- # build blocks
- self.blocks = nn.ModuleList(
- [
- SwinTransformerBlock(
- dim=dim,
- input_resolution=input_resolution,
- num_heads=num_heads,
- window_size=window_size,
- shift_size=0 if (i % 2 == 0) else window_size // 2,
- mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias,
- qk_scale=qk_scale,
- drop=drop,
- attn_drop=attn_drop,
- drop_path=drop_path[i]
- if isinstance(drop_path, list)
- else drop_path,
- norm_layer=norm_layer,
- )
- for i in range(depth)
- ]
- )
-
- # patch merging layer
- if downsample is not None:
- self.downsample = downsample(
- input_resolution, dim=dim, norm_layer=norm_layer
- )
- else:
- self.downsample = None
-
- def forward(self, x, x_size):
- for blk in self.blocks:
- if self.use_checkpoint:
- x = checkpoint.checkpoint(blk, x, x_size)
- else:
- x = blk(x, x_size)
- if self.downsample is not None:
- x = self.downsample(x)
- return x
-
- def extra_repr(self) -> str:
- return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
-
- def flops(self):
- flops = 0
- for blk in self.blocks:
- flops += blk.flops() # type: ignore
- if self.downsample is not None:
- flops += self.downsample.flops()
- return flops
-
-
-class RSTB(nn.Module):
- """Residual Swin Transformer Block (RSTB).
-
- Args:
- dim (int): Number of input channels.
- input_resolution (tuple[int]): Input resolution.
- depth (int): Number of blocks.
- num_heads (int): Number of attention heads.
- window_size (int): Local window size.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
- drop (float, optional): Dropout rate. Default: 0.0
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
- drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
- img_size: Input image size.
- patch_size: Patch size.
- resi_connection: The convolutional block before residual connection.
- """
-
- def __init__(
- self,
- dim,
- input_resolution,
- depth,
- num_heads,
- window_size,
- mlp_ratio=4.0,
- qkv_bias=True,
- qk_scale=None,
- drop=0.0,
- attn_drop=0.0,
- drop_path=0.0,
- norm_layer=nn.LayerNorm,
- downsample=None,
- use_checkpoint=False,
- img_size=224,
- patch_size=4,
- resi_connection="1conv",
- ):
- super(RSTB, self).__init__()
-
- self.dim = dim
- self.input_resolution = input_resolution
-
- self.residual_group = BasicLayer(
- dim=dim,
- input_resolution=input_resolution,
- depth=depth,
- num_heads=num_heads,
- window_size=window_size,
- mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias,
- qk_scale=qk_scale,
- drop=drop,
- attn_drop=attn_drop,
- drop_path=drop_path,
- norm_layer=norm_layer,
- downsample=downsample,
- use_checkpoint=use_checkpoint,
- )
-
- if resi_connection == "1conv":
- self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
- elif resi_connection == "3conv":
- # to save parameters and memory
- self.conv = nn.Sequential(
- nn.Conv2d(dim, dim // 4, 3, 1, 1),
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
- nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
- nn.Conv2d(dim // 4, dim, 3, 1, 1),
- )
-
- self.patch_embed = PatchEmbed(
- img_size=img_size,
- patch_size=patch_size,
- in_chans=0,
- embed_dim=dim,
- norm_layer=None,
- )
-
- self.patch_unembed = PatchUnEmbed(
- img_size=img_size,
- patch_size=patch_size,
- in_chans=0,
- embed_dim=dim,
- norm_layer=None,
- )
-
- def forward(self, x, x_size):
- return (
- self.patch_embed(
- self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))
- )
- + x
- )
-
- def flops(self):
- flops = 0
- flops += self.residual_group.flops()
- H, W = self.input_resolution
- flops += H * W * self.dim * self.dim * 9
- flops += self.patch_embed.flops()
- flops += self.patch_unembed.flops()
-
- return flops
-
-
-class PatchEmbed(nn.Module):
- r"""Image to Patch Embedding
-
- Args:
- img_size (int): Image size. Default: 224.
- patch_size (int): Patch token size. Default: 4.
- in_chans (int): Number of input image channels. Default: 3.
- embed_dim (int): Number of linear projection output channels. Default: 96.
- norm_layer (nn.Module, optional): Normalization layer. Default: None
- """
-
- def __init__(
- self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None
- ):
- super().__init__()
- img_size = to_2tuple(img_size)
- patch_size = to_2tuple(patch_size)
- patches_resolution = [
- img_size[0] // patch_size[0], # type: ignore
- img_size[1] // patch_size[1], # type: ignore
- ]
- self.img_size = img_size
- self.patch_size = patch_size
- self.patches_resolution = patches_resolution
- self.num_patches = patches_resolution[0] * patches_resolution[1]
-
- self.in_chans = in_chans
- self.embed_dim = embed_dim
-
- if norm_layer is not None:
- self.norm = norm_layer(embed_dim)
- else:
- self.norm = None
-
- def forward(self, x):
- x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
- if self.norm is not None:
- x = self.norm(x)
- return x
-
- def flops(self):
- flops = 0
- H, W = self.img_size
- if self.norm is not None:
- flops += H * W * self.embed_dim # type: ignore
- return flops
-
-
-class PatchUnEmbed(nn.Module):
- r"""Image to Patch Unembedding
-
- Args:
- img_size (int): Image size. Default: 224.
- patch_size (int): Patch token size. Default: 4.
- in_chans (int): Number of input image channels. Default: 3.
- embed_dim (int): Number of linear projection output channels. Default: 96.
- norm_layer (nn.Module, optional): Normalization layer. Default: None
- """
-
- def __init__(
- self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None
- ):
- super().__init__()
- img_size = to_2tuple(img_size)
- patch_size = to_2tuple(patch_size)
- patches_resolution = [
- img_size[0] // patch_size[0], # type: ignore
- img_size[1] // patch_size[1], # type: ignore
- ]
- self.img_size = img_size
- self.patch_size = patch_size
- self.patches_resolution = patches_resolution
- self.num_patches = patches_resolution[0] * patches_resolution[1]
-
- self.in_chans = in_chans
- self.embed_dim = embed_dim
-
- def forward(self, x, x_size):
- B, HW, C = x.shape
- x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
- return x
-
- def flops(self):
- flops = 0
- return flops
-
-
-class Upsample(nn.Sequential):
- """Upsample module.
-
- Args:
- scale (int): Scale factor. Supported scales: 2^n and 3.
- num_feat (int): Channel number of intermediate features.
- """
-
- def __init__(self, scale, num_feat):
- m = []
- if (scale & (scale - 1)) == 0: # scale = 2^n
- for _ in range(int(math.log(scale, 2))):
- m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
- m.append(nn.PixelShuffle(2))
- elif scale == 3:
- m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
- m.append(nn.PixelShuffle(3))
- else:
- raise ValueError(
- f"scale {scale} is not supported. " "Supported scales: 2^n and 3."
- )
- super(Upsample, self).__init__(*m)
-
-
-class UpsampleOneStep(nn.Sequential):
- """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
- Used in lightweight SR to save parameters.
-
- Args:
- scale (int): Scale factor. Supported scales: 2^n and 3.
- num_feat (int): Channel number of intermediate features.
-
- """
-
- def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
- self.num_feat = num_feat
- self.input_resolution = input_resolution
- m = []
- m.append(nn.Conv2d(num_feat, (scale**2) * num_out_ch, 3, 1, 1))
- m.append(nn.PixelShuffle(scale))
- super(UpsampleOneStep, self).__init__(*m)
-
- def flops(self):
- H, W = self.input_resolution # type: ignore
- flops = H * W * self.num_feat * 3 * 9
- return flops
-
-
-class SwinIR(nn.Module):
- r"""SwinIR
- A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer.
-
- Args:
- img_size (int | tuple(int)): Input image size. Default 64
- patch_size (int | tuple(int)): Patch size. Default: 1
- in_chans (int): Number of input image channels. Default: 3
- embed_dim (int): Patch embedding dimension. Default: 96
- depths (tuple(int)): Depth of each Swin Transformer layer.
- num_heads (tuple(int)): Number of attention heads in different layers.
- window_size (int): Window size. Default: 7
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
- qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
- qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
- drop_rate (float): Dropout rate. Default: 0
- attn_drop_rate (float): Attention dropout rate. Default: 0
- drop_path_rate (float): Stochastic depth rate. Default: 0.1
- norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
- ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
- patch_norm (bool): If True, add normalization after patch embedding. Default: True
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
- upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
- img_range: Image range. 1. or 255.
- upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
- resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
- """
-
- def __init__(
- self,
- state_dict,
- **kwargs,
- ):
- super(SwinIR, self).__init__()
-
- # Defaults
- img_size = 64
- patch_size = 1
- in_chans = 3
- embed_dim = 96
- depths = [6, 6, 6, 6]
- num_heads = [6, 6, 6, 6]
- window_size = 7
- mlp_ratio = 4.0
- qkv_bias = True
- qk_scale = None
- drop_rate = 0.0
- attn_drop_rate = 0.0
- drop_path_rate = 0.1
- norm_layer = nn.LayerNorm
- ape = False
- patch_norm = True
- use_checkpoint = False
- upscale = 2
- img_range = 1.0
- upsampler = ""
- resi_connection = "1conv"
- num_feat = 64
- num_in_ch = in_chans
- num_out_ch = in_chans
- supports_fp16 = True
- self.start_unshuffle = 1
-
- self.model_arch = "SwinIR"
- self.sub_type = "SR"
- self.state = state_dict
- if "params_ema" in self.state:
- self.state = self.state["params_ema"]
- elif "params" in self.state:
- self.state = self.state["params"]
-
- state_keys = self.state.keys()
-
- if "conv_before_upsample.0.weight" in state_keys:
- if "conv_up1.weight" in state_keys:
- upsampler = "nearest+conv"
- else:
- upsampler = "pixelshuffle"
- supports_fp16 = False
- elif "upsample.0.weight" in state_keys:
- upsampler = "pixelshuffledirect"
- else:
- upsampler = ""
-
- num_feat = (
- self.state.get("conv_before_upsample.0.weight", None).shape[1]
- if self.state.get("conv_before_upsample.weight", None)
- else 64
- )
-
- if "conv_first.1.weight" in self.state:
- self.state["conv_first.weight"] = self.state.pop("conv_first.1.weight")
- self.state["conv_first.bias"] = self.state.pop("conv_first.1.bias")
- self.start_unshuffle = round(math.sqrt(self.state["conv_first.weight"].shape[1] // 3))
-
- num_in_ch = self.state["conv_first.weight"].shape[1]
- in_chans = num_in_ch
- if "conv_last.weight" in state_keys:
- num_out_ch = self.state["conv_last.weight"].shape[0]
- else:
- num_out_ch = num_in_ch
-
- upscale = 1
- if upsampler == "nearest+conv":
- upsample_keys = [
- x for x in state_keys if "conv_up" in x and "bias" not in x
- ]
-
- for upsample_key in upsample_keys:
- upscale *= 2
- elif upsampler == "pixelshuffle":
- upsample_keys = [
- x
- for x in state_keys
- if "upsample" in x and "conv" not in x and "bias" not in x
- ]
- for upsample_key in upsample_keys:
- shape = self.state[upsample_key].shape[0]
- upscale *= math.sqrt(shape // num_feat)
- upscale = int(upscale)
- elif upsampler == "pixelshuffledirect":
- upscale = int(
- math.sqrt(self.state["upsample.0.bias"].shape[0] // num_out_ch)
- )
-
- max_layer_num = 0
- max_block_num = 0
- for key in state_keys:
- result = re.match(
- r"layers.(\d*).residual_group.blocks.(\d*).norm1.weight", key
- )
- if result:
- layer_num, block_num = result.groups()
- max_layer_num = max(max_layer_num, int(layer_num))
- max_block_num = max(max_block_num, int(block_num))
-
- depths = [max_block_num + 1 for _ in range(max_layer_num + 1)]
-
- if (
- "layers.0.residual_group.blocks.0.attn.relative_position_bias_table"
- in state_keys
- ):
- num_heads_num = self.state[
- "layers.0.residual_group.blocks.0.attn.relative_position_bias_table"
- ].shape[-1]
- num_heads = [num_heads_num for _ in range(max_layer_num + 1)]
- else:
- num_heads = depths
-
- embed_dim = self.state["conv_first.weight"].shape[0]
-
- mlp_ratio = float(
- self.state["layers.0.residual_group.blocks.0.mlp.fc1.bias"].shape[0]
- / embed_dim
- )
-
- # TODO: could actually count the layers, but this should do
- if "layers.0.conv.4.weight" in state_keys:
- resi_connection = "3conv"
- else:
- resi_connection = "1conv"
-
- window_size = int(
- math.sqrt(
- self.state[
- "layers.0.residual_group.blocks.0.attn.relative_position_index"
- ].shape[0]
- )
- )
-
- if "layers.0.residual_group.blocks.1.attn_mask" in state_keys:
- img_size = int(
- math.sqrt(
- self.state["layers.0.residual_group.blocks.1.attn_mask"].shape[0]
- )
- * window_size
- )
-
- # The JPEG models are the only ones with window-size 7, and they also use this range
- img_range = 255.0 if window_size == 7 else 1.0
-
- self.in_nc = num_in_ch
- self.out_nc = num_out_ch
- self.num_feat = num_feat
- self.embed_dim = embed_dim
- self.num_heads = num_heads
- self.depths = depths
- self.window_size = window_size
- self.mlp_ratio = mlp_ratio
- self.scale = upscale / self.start_unshuffle
- self.upsampler = upsampler
- self.img_size = img_size
- self.img_range = img_range
- self.resi_connection = resi_connection
-
- self.supports_fp16 = False # Too much weirdness to support this at the moment
- self.supports_bfp16 = True
- self.min_size_restriction = 16
-
- self.img_range = img_range
- if in_chans == 3:
- rgb_mean = (0.4488, 0.4371, 0.4040)
- self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
- else:
- self.mean = torch.zeros(1, 1, 1, 1)
- self.upscale = upscale
- self.upsampler = upsampler
- self.window_size = window_size
-
- #####################################################################################################
- ################################### 1, shallow feature extraction ###################################
- self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
-
- #####################################################################################################
- ################################### 2, deep feature extraction ######################################
- self.num_layers = len(depths)
- self.embed_dim = embed_dim
- self.ape = ape
- self.patch_norm = patch_norm
- self.num_features = embed_dim
- self.mlp_ratio = mlp_ratio
-
- # split image into non-overlapping patches
- self.patch_embed = PatchEmbed(
- img_size=img_size,
- patch_size=patch_size,
- in_chans=embed_dim,
- embed_dim=embed_dim,
- norm_layer=norm_layer if self.patch_norm else None,
- )
- num_patches = self.patch_embed.num_patches
- patches_resolution = self.patch_embed.patches_resolution
- self.patches_resolution = patches_resolution
-
- # merge non-overlapping patches into image
- self.patch_unembed = PatchUnEmbed(
- img_size=img_size,
- patch_size=patch_size,
- in_chans=embed_dim,
- embed_dim=embed_dim,
- norm_layer=norm_layer if self.patch_norm else None,
- )
-
- # absolute position embedding
- if self.ape:
- self.absolute_pos_embed = nn.Parameter( # type: ignore
- torch.zeros(1, num_patches, embed_dim)
- )
- trunc_normal_(self.absolute_pos_embed, std=0.02)
-
- self.pos_drop = nn.Dropout(p=drop_rate)
-
- # stochastic depth
- dpr = [
- x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
- ] # stochastic depth decay rule
-
- # build Residual Swin Transformer blocks (RSTB)
- self.layers = nn.ModuleList()
- for i_layer in range(self.num_layers):
- layer = RSTB(
- dim=embed_dim,
- input_resolution=(patches_resolution[0], patches_resolution[1]),
- depth=depths[i_layer],
- num_heads=num_heads[i_layer],
- window_size=window_size,
- mlp_ratio=self.mlp_ratio,
- qkv_bias=qkv_bias,
- qk_scale=qk_scale,
- drop=drop_rate,
- attn_drop=attn_drop_rate,
- drop_path=dpr[
- sum(depths[:i_layer]) : sum(depths[: i_layer + 1]) # type: ignore
- ], # no impact on SR results
- norm_layer=norm_layer,
- downsample=None,
- use_checkpoint=use_checkpoint,
- img_size=img_size,
- patch_size=patch_size,
- resi_connection=resi_connection,
- )
- self.layers.append(layer)
- self.norm = norm_layer(self.num_features)
-
- # build the last conv layer in deep feature extraction
- if resi_connection == "1conv":
- self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
- elif resi_connection == "3conv":
- # to save parameters and memory
- self.conv_after_body = nn.Sequential(
- nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
- nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
- nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1),
- )
-
- #####################################################################################################
- ################################ 3, high quality image reconstruction ################################
- if self.upsampler == "pixelshuffle":
- # for classical SR
- self.conv_before_upsample = nn.Sequential(
- nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)
- )
- self.upsample = Upsample(upscale, num_feat)
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
- elif self.upsampler == "pixelshuffledirect":
- # for lightweight SR (to save parameters)
- self.upsample = UpsampleOneStep(
- upscale,
- embed_dim,
- num_out_ch,
- (patches_resolution[0], patches_resolution[1]),
- )
- elif self.upsampler == "nearest+conv":
- # for real-world SR (less artifacts)
- self.conv_before_upsample = nn.Sequential(
- nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)
- )
- self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
- if self.upscale == 4:
- self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
- elif self.upscale == 8:
- self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
- self.conv_up3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
- self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
- self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
- else:
- # for image denoising and JPEG compression artifact reduction
- self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
-
- self.apply(self._init_weights)
- self.load_state_dict(self.state, strict=False)
-
- def _init_weights(self, m):
- if isinstance(m, nn.Linear):
- trunc_normal_(m.weight, std=0.02)
- if isinstance(m, nn.Linear) and m.bias is not None:
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.LayerNorm):
- nn.init.constant_(m.bias, 0)
- nn.init.constant_(m.weight, 1.0)
-
- @torch.jit.ignore # type: ignore
- def no_weight_decay(self):
- return {"absolute_pos_embed"}
-
- @torch.jit.ignore # type: ignore
- def no_weight_decay_keywords(self):
- return {"relative_position_bias_table"}
-
- def check_image_size(self, x):
- _, _, h, w = x.size()
- mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
- mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
- x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "reflect")
- return x
-
- def forward_features(self, x):
- x_size = (x.shape[2], x.shape[3])
- x = self.patch_embed(x)
- if self.ape:
- x = x + self.absolute_pos_embed
- x = self.pos_drop(x)
-
- for layer in self.layers:
- x = layer(x, x_size)
-
- x = self.norm(x) # B L C
- x = self.patch_unembed(x, x_size)
-
- return x
-
- def forward(self, x):
- H, W = x.shape[2:]
- x = self.check_image_size(x)
-
- self.mean = self.mean.type_as(x)
- x = (x - self.mean) * self.img_range
-
- if self.start_unshuffle > 1:
- x = torch.nn.functional.pixel_unshuffle(x, self.start_unshuffle)
-
- if self.upsampler == "pixelshuffle":
- # for classical SR
- x = self.conv_first(x)
- x = self.conv_after_body(self.forward_features(x)) + x
- x = self.conv_before_upsample(x)
- x = self.conv_last(self.upsample(x))
- elif self.upsampler == "pixelshuffledirect":
- # for lightweight SR
- x = self.conv_first(x)
- x = self.conv_after_body(self.forward_features(x)) + x
- x = self.upsample(x)
- elif self.upsampler == "nearest+conv":
- # for real-world SR
- x = self.conv_first(x)
- x = self.conv_after_body(self.forward_features(x)) + x
- x = self.conv_before_upsample(x)
- x = self.lrelu(
- self.conv_up1(
- torch.nn.functional.interpolate(x, scale_factor=2, mode="nearest") # type: ignore
- )
- )
- if self.upscale == 4:
- x = self.lrelu(
- self.conv_up2(
- torch.nn.functional.interpolate( # type: ignore
- x, scale_factor=2, mode="nearest"
- )
- )
- )
- elif self.upscale == 8:
- x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
- x = self.lrelu(self.conv_up3(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
- x = self.conv_last(self.lrelu(self.conv_hr(x)))
- else:
- # for image denoising and JPEG compression artifact reduction
- x_first = self.conv_first(x)
- res = self.conv_after_body(self.forward_features(x_first)) + x_first
- x = x + self.conv_last(res)
-
- x = x / self.img_range + self.mean
-
- return x[:, :, : H * self.upscale, : W * self.upscale]
-
- def flops(self):
- flops = 0
- H, W = self.patches_resolution
- flops += H * W * 3 * self.embed_dim * 9
- flops += self.patch_embed.flops()
- for i, layer in enumerate(self.layers):
- flops += layer.flops() # type: ignore
- flops += H * W * 3 * self.embed_dim * self.embed_dim
- flops += self.upsample.flops() # type: ignore
- return flops
diff --git a/comfy_extras/chainner_models/architecture/__init__.py b/comfy_extras/chainner_models/architecture/__init__.py
deleted file mode 100644
index e69de29bb2d..00000000000
diff --git a/comfy_extras/chainner_models/architecture/block.py b/comfy_extras/chainner_models/architecture/block.py
deleted file mode 100644
index d7bc5d22700..00000000000
--- a/comfy_extras/chainner_models/architecture/block.py
+++ /dev/null
@@ -1,546 +0,0 @@
-#!/usr/bin/env python3
-# -*- coding: utf-8 -*-
-
-from __future__ import annotations
-
-from collections import OrderedDict
-try:
- from typing import Literal
-except ImportError:
- from typing_extensions import Literal
-
-import torch
-import torch.nn as nn
-
-####################
-# Basic blocks
-####################
-
-
-def act(act_type: str, inplace=True, neg_slope=0.2, n_prelu=1):
- # helper selecting activation
- # neg_slope: for leakyrelu and init of prelu
- # n_prelu: for p_relu num_parameters
- act_type = act_type.lower()
- if act_type == "relu":
- layer = nn.ReLU(inplace)
- elif act_type == "leakyrelu":
- layer = nn.LeakyReLU(neg_slope, inplace)
- elif act_type == "prelu":
- layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
- else:
- raise NotImplementedError(
- "activation layer [{:s}] is not found".format(act_type)
- )
- return layer
-
-
-def norm(norm_type: str, nc: int):
- # helper selecting normalization layer
- norm_type = norm_type.lower()
- if norm_type == "batch":
- layer = nn.BatchNorm2d(nc, affine=True)
- elif norm_type == "instance":
- layer = nn.InstanceNorm2d(nc, affine=False)
- else:
- raise NotImplementedError(
- "normalization layer [{:s}] is not found".format(norm_type)
- )
- return layer
-
-
-def pad(pad_type: str, padding):
- # helper selecting padding layer
- # if padding is 'zero', do by conv layers
- pad_type = pad_type.lower()
- if padding == 0:
- return None
- if pad_type == "reflect":
- layer = nn.ReflectionPad2d(padding)
- elif pad_type == "replicate":
- layer = nn.ReplicationPad2d(padding)
- else:
- raise NotImplementedError(
- "padding layer [{:s}] is not implemented".format(pad_type)
- )
- return layer
-
-
-def get_valid_padding(kernel_size, dilation):
- kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
- padding = (kernel_size - 1) // 2
- return padding
-
-
-class ConcatBlock(nn.Module):
- # Concat the output of a submodule to its input
- def __init__(self, submodule):
- super(ConcatBlock, self).__init__()
- self.sub = submodule
-
- def forward(self, x):
- output = torch.cat((x, self.sub(x)), dim=1)
- return output
-
- def __repr__(self):
- tmpstr = "Identity .. \n|"
- modstr = self.sub.__repr__().replace("\n", "\n|")
- tmpstr = tmpstr + modstr
- return tmpstr
-
-
-class ShortcutBlock(nn.Module):
- # Elementwise sum the output of a submodule to its input
- def __init__(self, submodule):
- super(ShortcutBlock, self).__init__()
- self.sub = submodule
-
- def forward(self, x):
- output = x + self.sub(x)
- return output
-
- def __repr__(self):
- tmpstr = "Identity + \n|"
- modstr = self.sub.__repr__().replace("\n", "\n|")
- tmpstr = tmpstr + modstr
- return tmpstr
-
-
-class ShortcutBlockSPSR(nn.Module):
- # Elementwise sum the output of a submodule to its input
- def __init__(self, submodule):
- super(ShortcutBlockSPSR, self).__init__()
- self.sub = submodule
-
- def forward(self, x):
- return x, self.sub
-
- def __repr__(self):
- tmpstr = "Identity + \n|"
- modstr = self.sub.__repr__().replace("\n", "\n|")
- tmpstr = tmpstr + modstr
- return tmpstr
-
-
-def sequential(*args):
- # Flatten Sequential. It unwraps nn.Sequential.
- if len(args) == 1:
- if isinstance(args[0], OrderedDict):
- raise NotImplementedError("sequential does not support OrderedDict input.")
- return args[0] # No sequential is needed.
- modules = []
- for module in args:
- if isinstance(module, nn.Sequential):
- for submodule in module.children():
- modules.append(submodule)
- elif isinstance(module, nn.Module):
- modules.append(module)
- return nn.Sequential(*modules)
-
-
-ConvMode = Literal["CNA", "NAC", "CNAC"]
-
-
-# 2x2x2 Conv Block
-def conv_block_2c2(
- in_nc,
- out_nc,
- act_type="relu",
-):
- return sequential(
- nn.Conv2d(in_nc, out_nc, kernel_size=2, padding=1),
- nn.Conv2d(out_nc, out_nc, kernel_size=2, padding=0),
- act(act_type) if act_type else None,
- )
-
-
-def conv_block(
- in_nc: int,
- out_nc: int,
- kernel_size,
- stride=1,
- dilation=1,
- groups=1,
- bias=True,
- pad_type="zero",
- norm_type: str | None = None,
- act_type: str | None = "relu",
- mode: ConvMode = "CNA",
- c2x2=False,
-):
- """
- Conv layer with padding, normalization, activation
- mode: CNA --> Conv -> Norm -> Act
- NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16)
- """
-
- if c2x2:
- return conv_block_2c2(in_nc, out_nc, act_type=act_type)
-
- assert mode in ("CNA", "NAC", "CNAC"), "Wrong conv mode [{:s}]".format(mode)
- padding = get_valid_padding(kernel_size, dilation)
- p = pad(pad_type, padding) if pad_type and pad_type != "zero" else None
- padding = padding if pad_type == "zero" else 0
-
- c = nn.Conv2d(
- in_nc,
- out_nc,
- kernel_size=kernel_size,
- stride=stride,
- padding=padding,
- dilation=dilation,
- bias=bias,
- groups=groups,
- )
- a = act(act_type) if act_type else None
- if mode in ("CNA", "CNAC"):
- n = norm(norm_type, out_nc) if norm_type else None
- return sequential(p, c, n, a)
- elif mode == "NAC":
- if norm_type is None and act_type is not None:
- a = act(act_type, inplace=False)
- # Important!
- # input----ReLU(inplace)----Conv--+----output
- # |________________________|
- # inplace ReLU will modify the input, therefore wrong output
- n = norm(norm_type, in_nc) if norm_type else None
- return sequential(n, a, p, c)
- else:
- assert False, f"Invalid conv mode {mode}"
-
-
-####################
-# Useful blocks
-####################
-
-
-class ResNetBlock(nn.Module):
- """
- ResNet Block, 3-3 style
- with extra residual scaling used in EDSR
- (Enhanced Deep Residual Networks for Single Image Super-Resolution, CVPRW 17)
- """
-
- def __init__(
- self,
- in_nc,
- mid_nc,
- out_nc,
- kernel_size=3,
- stride=1,
- dilation=1,
- groups=1,
- bias=True,
- pad_type="zero",
- norm_type=None,
- act_type="relu",
- mode: ConvMode = "CNA",
- res_scale=1,
- ):
- super(ResNetBlock, self).__init__()
- conv0 = conv_block(
- in_nc,
- mid_nc,
- kernel_size,
- stride,
- dilation,
- groups,
- bias,
- pad_type,
- norm_type,
- act_type,
- mode,
- )
- if mode == "CNA":
- act_type = None
- if mode == "CNAC": # Residual path: |-CNAC-|
- act_type = None
- norm_type = None
- conv1 = conv_block(
- mid_nc,
- out_nc,
- kernel_size,
- stride,
- dilation,
- groups,
- bias,
- pad_type,
- norm_type,
- act_type,
- mode,
- )
- # if in_nc != out_nc:
- # self.project = conv_block(in_nc, out_nc, 1, stride, dilation, 1, bias, pad_type, \
- # None, None)
- # print('Need a projecter in ResNetBlock.')
- # else:
- # self.project = lambda x:x
- self.res = sequential(conv0, conv1)
- self.res_scale = res_scale
-
- def forward(self, x):
- res = self.res(x).mul(self.res_scale)
- return x + res
-
-
-class RRDB(nn.Module):
- """
- Residual in Residual Dense Block
- (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
- """
-
- def __init__(
- self,
- nf,
- kernel_size=3,
- gc=32,
- stride=1,
- bias: bool = True,
- pad_type="zero",
- norm_type=None,
- act_type="leakyrelu",
- mode: ConvMode = "CNA",
- _convtype="Conv2D",
- _spectral_norm=False,
- plus=False,
- c2x2=False,
- ):
- super(RRDB, self).__init__()
- self.RDB1 = ResidualDenseBlock_5C(
- nf,
- kernel_size,
- gc,
- stride,
- bias,
- pad_type,
- norm_type,
- act_type,
- mode,
- plus=plus,
- c2x2=c2x2,
- )
- self.RDB2 = ResidualDenseBlock_5C(
- nf,
- kernel_size,
- gc,
- stride,
- bias,
- pad_type,
- norm_type,
- act_type,
- mode,
- plus=plus,
- c2x2=c2x2,
- )
- self.RDB3 = ResidualDenseBlock_5C(
- nf,
- kernel_size,
- gc,
- stride,
- bias,
- pad_type,
- norm_type,
- act_type,
- mode,
- plus=plus,
- c2x2=c2x2,
- )
-
- def forward(self, x):
- out = self.RDB1(x)
- out = self.RDB2(out)
- out = self.RDB3(out)
- return out * 0.2 + x
-
-
-class ResidualDenseBlock_5C(nn.Module):
- """
- Residual Dense Block
- style: 5 convs
- The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
- Modified options that can be used:
- - "Partial Convolution based Padding" arXiv:1811.11718
- - "Spectral normalization" arXiv:1802.05957
- - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
- {Rakotonirina} and A. {Rasoanaivo}
-
- Args:
- nf (int): Channel number of intermediate features (num_feat).
- gc (int): Channels for each growth (num_grow_ch: growth channel,
- i.e. intermediate channels).
- convtype (str): the type of convolution to use. Default: 'Conv2D'
- gaussian_noise (bool): enable the ESRGAN+ gaussian noise (no new
- trainable parameters)
- plus (bool): enable the additional residual paths from ESRGAN+
- (adds trainable parameters)
- """
-
- def __init__(
- self,
- nf=64,
- kernel_size=3,
- gc=32,
- stride=1,
- bias: bool = True,
- pad_type="zero",
- norm_type=None,
- act_type="leakyrelu",
- mode: ConvMode = "CNA",
- plus=False,
- c2x2=False,
- ):
- super(ResidualDenseBlock_5C, self).__init__()
-
- ## +
- self.conv1x1 = conv1x1(nf, gc) if plus else None
- ## +
-
- self.conv1 = conv_block(
- nf,
- gc,
- kernel_size,
- stride,
- bias=bias,
- pad_type=pad_type,
- norm_type=norm_type,
- act_type=act_type,
- mode=mode,
- c2x2=c2x2,
- )
- self.conv2 = conv_block(
- nf + gc,
- gc,
- kernel_size,
- stride,
- bias=bias,
- pad_type=pad_type,
- norm_type=norm_type,
- act_type=act_type,
- mode=mode,
- c2x2=c2x2,
- )
- self.conv3 = conv_block(
- nf + 2 * gc,
- gc,
- kernel_size,
- stride,
- bias=bias,
- pad_type=pad_type,
- norm_type=norm_type,
- act_type=act_type,
- mode=mode,
- c2x2=c2x2,
- )
- self.conv4 = conv_block(
- nf + 3 * gc,
- gc,
- kernel_size,
- stride,
- bias=bias,
- pad_type=pad_type,
- norm_type=norm_type,
- act_type=act_type,
- mode=mode,
- c2x2=c2x2,
- )
- if mode == "CNA":
- last_act = None
- else:
- last_act = act_type
- self.conv5 = conv_block(
- nf + 4 * gc,
- nf,
- 3,
- stride,
- bias=bias,
- pad_type=pad_type,
- norm_type=norm_type,
- act_type=last_act,
- mode=mode,
- c2x2=c2x2,
- )
-
- def forward(self, x):
- x1 = self.conv1(x)
- x2 = self.conv2(torch.cat((x, x1), 1))
- if self.conv1x1:
- # pylint: disable=not-callable
- x2 = x2 + self.conv1x1(x) # +
- x3 = self.conv3(torch.cat((x, x1, x2), 1))
- x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
- if self.conv1x1:
- x4 = x4 + x2 # +
- x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
- return x5 * 0.2 + x
-
-
-def conv1x1(in_planes, out_planes, stride=1):
- return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
-
-
-####################
-# Upsampler
-####################
-
-
-def pixelshuffle_block(
- in_nc: int,
- out_nc: int,
- upscale_factor=2,
- kernel_size=3,
- stride=1,
- bias=True,
- pad_type="zero",
- norm_type: str | None = None,
- act_type="relu",
-):
- """
- Pixel shuffle layer
- (Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
- Neural Network, CVPR17)
- """
- conv = conv_block(
- in_nc,
- out_nc * (upscale_factor**2),
- kernel_size,
- stride,
- bias=bias,
- pad_type=pad_type,
- norm_type=None,
- act_type=None,
- )
- pixel_shuffle = nn.PixelShuffle(upscale_factor)
-
- n = norm(norm_type, out_nc) if norm_type else None
- a = act(act_type) if act_type else None
- return sequential(conv, pixel_shuffle, n, a)
-
-
-def upconv_block(
- in_nc: int,
- out_nc: int,
- upscale_factor=2,
- kernel_size=3,
- stride=1,
- bias=True,
- pad_type="zero",
- norm_type: str | None = None,
- act_type="relu",
- mode="nearest",
- c2x2=False,
-):
- # Up conv
- # described in https://distill.pub/2016/deconv-checkerboard/
- upsample = nn.Upsample(scale_factor=upscale_factor, mode=mode)
- conv = conv_block(
- in_nc,
- out_nc,
- kernel_size,
- stride,
- bias=bias,
- pad_type=pad_type,
- norm_type=norm_type,
- act_type=act_type,
- c2x2=c2x2,
- )
- return sequential(upsample, conv)
diff --git a/comfy_extras/chainner_models/architecture/face/LICENSE-GFPGAN b/comfy_extras/chainner_models/architecture/face/LICENSE-GFPGAN
deleted file mode 100644
index 5ac273fd509..00000000000
--- a/comfy_extras/chainner_models/architecture/face/LICENSE-GFPGAN
+++ /dev/null
@@ -1,351 +0,0 @@
-Tencent is pleased to support the open source community by making GFPGAN available.
-
-Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
-
-GFPGAN is licensed under the Apache License Version 2.0 except for the third-party components listed below.
-
-
-Terms of the Apache License Version 2.0:
----------------------------------------------
-Apache License
-
-Version 2.0, January 2004
-
-http://www.apache.org/licenses/
-
-TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
-1. Definitions.
-
-“License” shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
-
-“Licensor” shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
-
-“Legal Entity” shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, “control” means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
-
-“You” (or “Your”) shall mean an individual or Legal Entity exercising permissions granted by this License.
-
-“Source” form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
-
-“Object” form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
-
-“Work” shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
-
-“Derivative Works” shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
-
-“Contribution” shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, “submitted” means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as “Not a Contribution.”
-
-“Contributor” shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
-
-2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
-
-3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
-
-4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
-
-You must give any other recipients of the Work or Derivative Works a copy of this License; and
-
-You must cause any modified files to carry prominent notices stating that You changed the files; and
-
-You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
-
-If the Work includes a “NOTICE” text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
-
-You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
-
-5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
-
-6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
-
-7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
-
-8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
-
-9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
-
-END OF TERMS AND CONDITIONS
-
-
-
-Other dependencies and licenses:
-
-
-Open Source Software licensed under the Apache 2.0 license and Other Licenses of the Third-Party Components therein:
----------------------------------------------
-1. basicsr
-Copyright 2018-2020 BasicSR Authors
-
-
-This BasicSR project is released under the Apache 2.0 license.
-
-A copy of Apache 2.0 is included in this file.
-
-StyleGAN2
-The codes are modified from the repository stylegan2-pytorch. Many thanks to the author - Kim Seonghyeon 😊 for translating from the official TensorFlow codes to PyTorch ones. Here is the license of stylegan2-pytorch.
-The official repository is https://github.com/NVlabs/stylegan2, and here is the NVIDIA license.
-DFDNet
-The codes are largely modified from the repository DFDNet. Their license is Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License.
-
-Terms of the Nvidia License:
----------------------------------------------
-
-1. Definitions
-
-"Licensor" means any person or entity that distributes its Work.
-
-"Software" means the original work of authorship made available under
-this License.
-
-"Work" means the Software and any additions to or derivative works of
-the Software that are made available under this License.
-
-"Nvidia Processors" means any central processing unit (CPU), graphics
-processing unit (GPU), field-programmable gate array (FPGA),
-application-specific integrated circuit (ASIC) or any combination
-thereof designed, made, sold, or provided by Nvidia or its affiliates.
-
-The terms "reproduce," "reproduction," "derivative works," and
-"distribution" have the meaning as provided under U.S. copyright law;
-provided, however, that for the purposes of this License, derivative
-works shall not include works that remain separable from, or merely
-link (or bind by name) to the interfaces of, the Work.
-
-Works, including the Software, are "made available" under this License
-by including in or with the Work either (a) a copyright notice
-referencing the applicability of this License to the Work, or (b) a
-copy of this License.
-
-2. License Grants
-
- 2.1 Copyright Grant. Subject to the terms and conditions of this
- License, each Licensor grants to you a perpetual, worldwide,
- non-exclusive, royalty-free, copyright license to reproduce,
- prepare derivative works of, publicly display, publicly perform,
- sublicense and distribute its Work and any resulting derivative
- works in any form.
-
-3. Limitations
-
- 3.1 Redistribution. You may reproduce or distribute the Work only
- if (a) you do so under this License, (b) you include a complete
- copy of this License with your distribution, and (c) you retain
- without modification any copyright, patent, trademark, or
- attribution notices that are present in the Work.
-
- 3.2 Derivative Works. You may specify that additional or different
- terms apply to the use, reproduction, and distribution of your
- derivative works of the Work ("Your Terms") only if (a) Your Terms
- provide that the use limitation in Section 3.3 applies to your
- derivative works, and (b) you identify the specific derivative
- works that are subject to Your Terms. Notwithstanding Your Terms,
- this License (including the redistribution requirements in Section
- 3.1) will continue to apply to the Work itself.
-
- 3.3 Use Limitation. The Work and any derivative works thereof only
- may be used or intended for use non-commercially. The Work or
- derivative works thereof may be used or intended for use by Nvidia
- or its affiliates commercially or non-commercially. As used herein,
- "non-commercially" means for research or evaluation purposes only.
-
- 3.4 Patent Claims. If you bring or threaten to bring a patent claim
- against any Licensor (including any claim, cross-claim or
- counterclaim in a lawsuit) to enforce any patents that you allege
- are infringed by any Work, then your rights under this License from
- such Licensor (including the grants in Sections 2.1 and 2.2) will
- terminate immediately.
-
- 3.5 Trademarks. This License does not grant any rights to use any
- Licensor's or its affiliates' names, logos, or trademarks, except
- as necessary to reproduce the notices described in this License.
-
- 3.6 Termination. If you violate any term of this License, then your
- rights under this License (including the grants in Sections 2.1 and
- 2.2) will terminate immediately.
-
-4. Disclaimer of Warranty.
-
-THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
-KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
-MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
-NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
-THIS LICENSE.
-
-5. Limitation of Liability.
-
-EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
-THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
-SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
-INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
-OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
-(INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
-LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
-COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
-THE POSSIBILITY OF SUCH DAMAGES.
-
-MIT License
-
-Copyright (c) 2019 Kim Seonghyeon
-
-Permission is hereby granted, free of charge, to any person obtaining a copy
-of this software and associated documentation files (the "Software"), to deal
-in the Software without restriction, including without limitation the rights
-to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-copies of the Software, and to permit persons to whom the Software is
-furnished to do so, subject to the following conditions:
-
-The above copyright notice and this permission notice shall be included in all
-copies or substantial portions of the Software.
-
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-SOFTWARE.
-
-
-
-Open Source Software licensed under the BSD 3-Clause license:
----------------------------------------------
-1. torchvision
-Copyright (c) Soumith Chintala 2016,
-All rights reserved.
-
-2. torch
-Copyright (c) 2016- Facebook, Inc (Adam Paszke)
-Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
-Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
-Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
-Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
-Copyright (c) 2011-2013 NYU (Clement Farabet)
-Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
-Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
-Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
-
-
-Terms of the BSD 3-Clause License:
----------------------------------------------
-Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
-
-1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
-
-2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
-
-3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
-
-THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-
-
-Open Source Software licensed under the BSD 3-Clause License and Other Licenses of the Third-Party Components therein:
----------------------------------------------
-1. numpy
-Copyright (c) 2005-2020, NumPy Developers.
-All rights reserved.
-
-A copy of BSD 3-Clause License is included in this file.
-
-The NumPy repository and source distributions bundle several libraries that are
-compatibly licensed. We list these here.
-
-Name: Numpydoc
-Files: doc/sphinxext/numpydoc/*
-License: BSD-2-Clause
- For details, see doc/sphinxext/LICENSE.txt
-
-Name: scipy-sphinx-theme
-Files: doc/scipy-sphinx-theme/*
-License: BSD-3-Clause AND PSF-2.0 AND Apache-2.0
- For details, see doc/scipy-sphinx-theme/LICENSE.txt
-
-Name: lapack-lite
-Files: numpy/linalg/lapack_lite/*
-License: BSD-3-Clause
- For details, see numpy/linalg/lapack_lite/LICENSE.txt
-
-Name: tempita
-Files: tools/npy_tempita/*
-License: MIT
- For details, see tools/npy_tempita/license.txt
-
-Name: dragon4
-Files: numpy/core/src/multiarray/dragon4.c
-License: MIT
- For license text, see numpy/core/src/multiarray/dragon4.c
-
-
-
-Open Source Software licensed under the MIT license:
----------------------------------------------
-1. facexlib
-Copyright (c) 2020 Xintao Wang
-
-2. opencv-python
-Copyright (c) Olli-Pekka Heinisuo
-Please note that only files in cv2 package are used.
-
-
-Terms of the MIT License:
----------------------------------------------
-Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
-
-The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
-
-THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
-
-
-
-Open Source Software licensed under the MIT license and Other Licenses of the Third-Party Components therein:
----------------------------------------------
-1. tqdm
-Copyright (c) 2013 noamraph
-
-`tqdm` is a product of collaborative work.
-Unless otherwise stated, all authors (see commit logs) retain copyright
-for their respective work, and release the work under the MIT licence
-(text below).
-
-Exceptions or notable authors are listed below
-in reverse chronological order:
-
-* files: *
- MPLv2.0 2015-2020 (c) Casper da Costa-Luis
- [casperdcl](https://github.com/casperdcl).
-* files: tqdm/_tqdm.py
- MIT 2016 (c) [PR #96] on behalf of Google Inc.
-* files: tqdm/_tqdm.py setup.py README.rst MANIFEST.in .gitignore
- MIT 2013 (c) Noam Yorav-Raphael, original author.
-
-[PR #96]: https://github.com/tqdm/tqdm/pull/96
-
-
-Mozilla Public Licence (MPL) v. 2.0 - Exhibit A
------------------------------------------------
-
-This Source Code Form is subject to the terms of the
-Mozilla Public License, v. 2.0.
-If a copy of the MPL was not distributed with this file,
-You can obtain one at https://mozilla.org/MPL/2.0/.
-
-
-MIT License (MIT)
------------------
-
-Copyright (c) 2013 noamraph
-
-Permission is hereby granted, free of charge, to any person obtaining a copy of
-this software and associated documentation files (the "Software"), to deal in
-the Software without restriction, including without limitation the rights to
-use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
-the Software, and to permit persons to whom the Software is furnished to do so,
-subject to the following conditions:
-
-The above copyright notice and this permission notice shall be included in all
-copies or substantial portions of the Software.
-
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
-FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
-COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
-IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
-CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
\ No newline at end of file
diff --git a/comfy_extras/chainner_models/architecture/face/LICENSE-RestoreFormer b/comfy_extras/chainner_models/architecture/face/LICENSE-RestoreFormer
deleted file mode 100644
index 5ac273fd509..00000000000
--- a/comfy_extras/chainner_models/architecture/face/LICENSE-RestoreFormer
+++ /dev/null
@@ -1,351 +0,0 @@
-Tencent is pleased to support the open source community by making GFPGAN available.
-
-Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
-
-GFPGAN is licensed under the Apache License Version 2.0 except for the third-party components listed below.
-
-
-Terms of the Apache License Version 2.0:
----------------------------------------------
-Apache License
-
-Version 2.0, January 2004
-
-http://www.apache.org/licenses/
-
-TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
-1. Definitions.
-
-“License” shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
-
-“Licensor” shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
-
-“Legal Entity” shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, “control” means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
-
-“You” (or “Your”) shall mean an individual or Legal Entity exercising permissions granted by this License.
-
-“Source” form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
-
-“Object” form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
-
-“Work” shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
-
-“Derivative Works” shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
-
-“Contribution” shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, “submitted” means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as “Not a Contribution.”
-
-“Contributor” shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
-
-2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
-
-3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
-
-4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
-
-You must give any other recipients of the Work or Derivative Works a copy of this License; and
-
-You must cause any modified files to carry prominent notices stating that You changed the files; and
-
-You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
-
-If the Work includes a “NOTICE” text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
-
-You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
-
-5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
-
-6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
-
-7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
-
-8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
-
-9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
-
-END OF TERMS AND CONDITIONS
-
-
-
-Other dependencies and licenses:
-
-
-Open Source Software licensed under the Apache 2.0 license and Other Licenses of the Third-Party Components therein:
----------------------------------------------
-1. basicsr
-Copyright 2018-2020 BasicSR Authors
-
-
-This BasicSR project is released under the Apache 2.0 license.
-
-A copy of Apache 2.0 is included in this file.
-
-StyleGAN2
-The codes are modified from the repository stylegan2-pytorch. Many thanks to the author - Kim Seonghyeon 😊 for translating from the official TensorFlow codes to PyTorch ones. Here is the license of stylegan2-pytorch.
-The official repository is https://github.com/NVlabs/stylegan2, and here is the NVIDIA license.
-DFDNet
-The codes are largely modified from the repository DFDNet. Their license is Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License.
-
-Terms of the Nvidia License:
----------------------------------------------
-
-1. Definitions
-
-"Licensor" means any person or entity that distributes its Work.
-
-"Software" means the original work of authorship made available under
-this License.
-
-"Work" means the Software and any additions to or derivative works of
-the Software that are made available under this License.
-
-"Nvidia Processors" means any central processing unit (CPU), graphics
-processing unit (GPU), field-programmable gate array (FPGA),
-application-specific integrated circuit (ASIC) or any combination
-thereof designed, made, sold, or provided by Nvidia or its affiliates.
-
-The terms "reproduce," "reproduction," "derivative works," and
-"distribution" have the meaning as provided under U.S. copyright law;
-provided, however, that for the purposes of this License, derivative
-works shall not include works that remain separable from, or merely
-link (or bind by name) to the interfaces of, the Work.
-
-Works, including the Software, are "made available" under this License
-by including in or with the Work either (a) a copyright notice
-referencing the applicability of this License to the Work, or (b) a
-copy of this License.
-
-2. License Grants
-
- 2.1 Copyright Grant. Subject to the terms and conditions of this
- License, each Licensor grants to you a perpetual, worldwide,
- non-exclusive, royalty-free, copyright license to reproduce,
- prepare derivative works of, publicly display, publicly perform,
- sublicense and distribute its Work and any resulting derivative
- works in any form.
-
-3. Limitations
-
- 3.1 Redistribution. You may reproduce or distribute the Work only
- if (a) you do so under this License, (b) you include a complete
- copy of this License with your distribution, and (c) you retain
- without modification any copyright, patent, trademark, or
- attribution notices that are present in the Work.
-
- 3.2 Derivative Works. You may specify that additional or different
- terms apply to the use, reproduction, and distribution of your
- derivative works of the Work ("Your Terms") only if (a) Your Terms
- provide that the use limitation in Section 3.3 applies to your
- derivative works, and (b) you identify the specific derivative
- works that are subject to Your Terms. Notwithstanding Your Terms,
- this License (including the redistribution requirements in Section
- 3.1) will continue to apply to the Work itself.
-
- 3.3 Use Limitation. The Work and any derivative works thereof only
- may be used or intended for use non-commercially. The Work or
- derivative works thereof may be used or intended for use by Nvidia
- or its affiliates commercially or non-commercially. As used herein,
- "non-commercially" means for research or evaluation purposes only.
-
- 3.4 Patent Claims. If you bring or threaten to bring a patent claim
- against any Licensor (including any claim, cross-claim or
- counterclaim in a lawsuit) to enforce any patents that you allege
- are infringed by any Work, then your rights under this License from
- such Licensor (including the grants in Sections 2.1 and 2.2) will
- terminate immediately.
-
- 3.5 Trademarks. This License does not grant any rights to use any
- Licensor's or its affiliates' names, logos, or trademarks, except
- as necessary to reproduce the notices described in this License.
-
- 3.6 Termination. If you violate any term of this License, then your
- rights under this License (including the grants in Sections 2.1 and
- 2.2) will terminate immediately.
-
-4. Disclaimer of Warranty.
-
-THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
-KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
-MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
-NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
-THIS LICENSE.
-
-5. Limitation of Liability.
-
-EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
-THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
-SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
-INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
-OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
-(INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
-LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
-COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
-THE POSSIBILITY OF SUCH DAMAGES.
-
-MIT License
-
-Copyright (c) 2019 Kim Seonghyeon
-
-Permission is hereby granted, free of charge, to any person obtaining a copy
-of this software and associated documentation files (the "Software"), to deal
-in the Software without restriction, including without limitation the rights
-to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-copies of the Software, and to permit persons to whom the Software is
-furnished to do so, subject to the following conditions:
-
-The above copyright notice and this permission notice shall be included in all
-copies or substantial portions of the Software.
-
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-SOFTWARE.
-
-
-
-Open Source Software licensed under the BSD 3-Clause license:
----------------------------------------------
-1. torchvision
-Copyright (c) Soumith Chintala 2016,
-All rights reserved.
-
-2. torch
-Copyright (c) 2016- Facebook, Inc (Adam Paszke)
-Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
-Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
-Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
-Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
-Copyright (c) 2011-2013 NYU (Clement Farabet)
-Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
-Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
-Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
-
-
-Terms of the BSD 3-Clause License:
----------------------------------------------
-Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
-
-1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
-
-2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
-
-3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
-
-THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-
-
-Open Source Software licensed under the BSD 3-Clause License and Other Licenses of the Third-Party Components therein:
----------------------------------------------
-1. numpy
-Copyright (c) 2005-2020, NumPy Developers.
-All rights reserved.
-
-A copy of BSD 3-Clause License is included in this file.
-
-The NumPy repository and source distributions bundle several libraries that are
-compatibly licensed. We list these here.
-
-Name: Numpydoc
-Files: doc/sphinxext/numpydoc/*
-License: BSD-2-Clause
- For details, see doc/sphinxext/LICENSE.txt
-
-Name: scipy-sphinx-theme
-Files: doc/scipy-sphinx-theme/*
-License: BSD-3-Clause AND PSF-2.0 AND Apache-2.0
- For details, see doc/scipy-sphinx-theme/LICENSE.txt
-
-Name: lapack-lite
-Files: numpy/linalg/lapack_lite/*
-License: BSD-3-Clause
- For details, see numpy/linalg/lapack_lite/LICENSE.txt
-
-Name: tempita
-Files: tools/npy_tempita/*
-License: MIT
- For details, see tools/npy_tempita/license.txt
-
-Name: dragon4
-Files: numpy/core/src/multiarray/dragon4.c
-License: MIT
- For license text, see numpy/core/src/multiarray/dragon4.c
-
-
-
-Open Source Software licensed under the MIT license:
----------------------------------------------
-1. facexlib
-Copyright (c) 2020 Xintao Wang
-
-2. opencv-python
-Copyright (c) Olli-Pekka Heinisuo
-Please note that only files in cv2 package are used.
-
-
-Terms of the MIT License:
----------------------------------------------
-Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
-
-The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
-
-THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
-
-
-
-Open Source Software licensed under the MIT license and Other Licenses of the Third-Party Components therein:
----------------------------------------------
-1. tqdm
-Copyright (c) 2013 noamraph
-
-`tqdm` is a product of collaborative work.
-Unless otherwise stated, all authors (see commit logs) retain copyright
-for their respective work, and release the work under the MIT licence
-(text below).
-
-Exceptions or notable authors are listed below
-in reverse chronological order:
-
-* files: *
- MPLv2.0 2015-2020 (c) Casper da Costa-Luis
- [casperdcl](https://github.com/casperdcl).
-* files: tqdm/_tqdm.py
- MIT 2016 (c) [PR #96] on behalf of Google Inc.
-* files: tqdm/_tqdm.py setup.py README.rst MANIFEST.in .gitignore
- MIT 2013 (c) Noam Yorav-Raphael, original author.
-
-[PR #96]: https://github.com/tqdm/tqdm/pull/96
-
-
-Mozilla Public Licence (MPL) v. 2.0 - Exhibit A
------------------------------------------------
-
-This Source Code Form is subject to the terms of the
-Mozilla Public License, v. 2.0.
-If a copy of the MPL was not distributed with this file,
-You can obtain one at https://mozilla.org/MPL/2.0/.
-
-
-MIT License (MIT)
------------------
-
-Copyright (c) 2013 noamraph
-
-Permission is hereby granted, free of charge, to any person obtaining a copy of
-this software and associated documentation files (the "Software"), to deal in
-the Software without restriction, including without limitation the rights to
-use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
-the Software, and to permit persons to whom the Software is furnished to do so,
-subject to the following conditions:
-
-The above copyright notice and this permission notice shall be included in all
-copies or substantial portions of the Software.
-
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
-FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
-COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
-IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
-CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
\ No newline at end of file
diff --git a/comfy_extras/chainner_models/architecture/face/LICENSE-codeformer b/comfy_extras/chainner_models/architecture/face/LICENSE-codeformer
deleted file mode 100644
index be6c4ed8048..00000000000
--- a/comfy_extras/chainner_models/architecture/face/LICENSE-codeformer
+++ /dev/null
@@ -1,35 +0,0 @@
-S-Lab License 1.0
-
-Copyright 2022 S-Lab
-
-Redistribution and use for non-commercial purpose in source and
-binary forms, with or without modification, are permitted provided
-that the following conditions are met:
-
-1. Redistributions of source code must retain the above copyright
- notice, this list of conditions and the following disclaimer.
-
-2. Redistributions in binary form must reproduce the above copyright
- notice, this list of conditions and the following disclaimer in
- the documentation and/or other materials provided with the
- distribution.
-
-3. Neither the name of the copyright holder nor the names of its
- contributors may be used to endorse or promote products derived
- from this software without specific prior written permission.
-
-THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-In the event that redistribution and/or use for commercial purpose in
-source or binary forms, with or without modification is required,
-please contact the contributor(s) of the work.
diff --git a/comfy_extras/chainner_models/architecture/face/arcface_arch.py b/comfy_extras/chainner_models/architecture/face/arcface_arch.py
deleted file mode 100644
index b548af059a7..00000000000
--- a/comfy_extras/chainner_models/architecture/face/arcface_arch.py
+++ /dev/null
@@ -1,265 +0,0 @@
-import torch.nn as nn
-
-
-def conv3x3(inplanes, outplanes, stride=1):
- """A simple wrapper for 3x3 convolution with padding.
-
- Args:
- inplanes (int): Channel number of inputs.
- outplanes (int): Channel number of outputs.
- stride (int): Stride in convolution. Default: 1.
- """
- return nn.Conv2d(
- inplanes, outplanes, kernel_size=3, stride=stride, padding=1, bias=False
- )
-
-
-class BasicBlock(nn.Module):
- """Basic residual block used in the ResNetArcFace architecture.
-
- Args:
- inplanes (int): Channel number of inputs.
- planes (int): Channel number of outputs.
- stride (int): Stride in convolution. Default: 1.
- downsample (nn.Module): The downsample module. Default: None.
- """
-
- expansion = 1 # output channel expansion ratio
-
- def __init__(self, inplanes, planes, stride=1, downsample=None):
- super(BasicBlock, self).__init__()
- self.conv1 = conv3x3(inplanes, planes, stride)
- self.bn1 = nn.BatchNorm2d(planes)
- self.relu = nn.ReLU(inplace=True)
- self.conv2 = conv3x3(planes, planes)
- self.bn2 = nn.BatchNorm2d(planes)
- self.downsample = downsample
- self.stride = stride
-
- def forward(self, x):
- residual = x
-
- out = self.conv1(x)
- out = self.bn1(out)
- out = self.relu(out)
-
- out = self.conv2(out)
- out = self.bn2(out)
-
- if self.downsample is not None:
- residual = self.downsample(x)
-
- out += residual
- out = self.relu(out)
-
- return out
-
-
-class IRBlock(nn.Module):
- """Improved residual block (IR Block) used in the ResNetArcFace architecture.
-
- Args:
- inplanes (int): Channel number of inputs.
- planes (int): Channel number of outputs.
- stride (int): Stride in convolution. Default: 1.
- downsample (nn.Module): The downsample module. Default: None.
- use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
- """
-
- expansion = 1 # output channel expansion ratio
-
- def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True):
- super(IRBlock, self).__init__()
- self.bn0 = nn.BatchNorm2d(inplanes)
- self.conv1 = conv3x3(inplanes, inplanes)
- self.bn1 = nn.BatchNorm2d(inplanes)
- self.prelu = nn.PReLU()
- self.conv2 = conv3x3(inplanes, planes, stride)
- self.bn2 = nn.BatchNorm2d(planes)
- self.downsample = downsample
- self.stride = stride
- self.use_se = use_se
- if self.use_se:
- self.se = SEBlock(planes)
-
- def forward(self, x):
- residual = x
- out = self.bn0(x)
- out = self.conv1(out)
- out = self.bn1(out)
- out = self.prelu(out)
-
- out = self.conv2(out)
- out = self.bn2(out)
- if self.use_se:
- out = self.se(out)
-
- if self.downsample is not None:
- residual = self.downsample(x)
-
- out += residual
- out = self.prelu(out)
-
- return out
-
-
-class Bottleneck(nn.Module):
- """Bottleneck block used in the ResNetArcFace architecture.
-
- Args:
- inplanes (int): Channel number of inputs.
- planes (int): Channel number of outputs.
- stride (int): Stride in convolution. Default: 1.
- downsample (nn.Module): The downsample module. Default: None.
- """
-
- expansion = 4 # output channel expansion ratio
-
- def __init__(self, inplanes, planes, stride=1, downsample=None):
- super(Bottleneck, self).__init__()
- self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
- self.bn1 = nn.BatchNorm2d(planes)
- self.conv2 = nn.Conv2d(
- planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
- )
- self.bn2 = nn.BatchNorm2d(planes)
- self.conv3 = nn.Conv2d(
- planes, planes * self.expansion, kernel_size=1, bias=False
- )
- self.bn3 = nn.BatchNorm2d(planes * self.expansion)
- self.relu = nn.ReLU(inplace=True)
- self.downsample = downsample
- self.stride = stride
-
- def forward(self, x):
- residual = x
-
- out = self.conv1(x)
- out = self.bn1(out)
- out = self.relu(out)
-
- out = self.conv2(out)
- out = self.bn2(out)
- out = self.relu(out)
-
- out = self.conv3(out)
- out = self.bn3(out)
-
- if self.downsample is not None:
- residual = self.downsample(x)
-
- out += residual
- out = self.relu(out)
-
- return out
-
-
-class SEBlock(nn.Module):
- """The squeeze-and-excitation block (SEBlock) used in the IRBlock.
-
- Args:
- channel (int): Channel number of inputs.
- reduction (int): Channel reduction ration. Default: 16.
- """
-
- def __init__(self, channel, reduction=16):
- super(SEBlock, self).__init__()
- self.avg_pool = nn.AdaptiveAvgPool2d(
- 1
- ) # pool to 1x1 without spatial information
- self.fc = nn.Sequential(
- nn.Linear(channel, channel // reduction),
- nn.PReLU(),
- nn.Linear(channel // reduction, channel),
- nn.Sigmoid(),
- )
-
- def forward(self, x):
- b, c, _, _ = x.size()
- y = self.avg_pool(x).view(b, c)
- y = self.fc(y).view(b, c, 1, 1)
- return x * y
-
-
-class ResNetArcFace(nn.Module):
- """ArcFace with ResNet architectures.
-
- Ref: ArcFace: Additive Angular Margin Loss for Deep Face Recognition.
-
- Args:
- block (str): Block used in the ArcFace architecture.
- layers (tuple(int)): Block numbers in each layer.
- use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
- """
-
- def __init__(self, block, layers, use_se=True):
- if block == "IRBlock":
- block = IRBlock
- self.inplanes = 64
- self.use_se = use_se
- super(ResNetArcFace, self).__init__()
-
- self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)
- self.bn1 = nn.BatchNorm2d(64)
- self.prelu = nn.PReLU()
- self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
- self.layer1 = self._make_layer(block, 64, layers[0])
- self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
- self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
- self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
- self.bn4 = nn.BatchNorm2d(512)
- self.dropout = nn.Dropout()
- self.fc5 = nn.Linear(512 * 8 * 8, 512)
- self.bn5 = nn.BatchNorm1d(512)
-
- # initialization
- for m in self.modules():
- if isinstance(m, nn.Conv2d):
- nn.init.xavier_normal_(m.weight)
- elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
- nn.init.constant_(m.weight, 1)
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.Linear):
- nn.init.xavier_normal_(m.weight)
- nn.init.constant_(m.bias, 0)
-
- def _make_layer(self, block, planes, num_blocks, stride=1):
- downsample = None
- if stride != 1 or self.inplanes != planes * block.expansion:
- downsample = nn.Sequential(
- nn.Conv2d(
- self.inplanes,
- planes * block.expansion,
- kernel_size=1,
- stride=stride,
- bias=False,
- ),
- nn.BatchNorm2d(planes * block.expansion),
- )
- layers = []
- layers.append(
- block(self.inplanes, planes, stride, downsample, use_se=self.use_se)
- )
- self.inplanes = planes
- for _ in range(1, num_blocks):
- layers.append(block(self.inplanes, planes, use_se=self.use_se))
-
- return nn.Sequential(*layers)
-
- def forward(self, x):
- x = self.conv1(x)
- x = self.bn1(x)
- x = self.prelu(x)
- x = self.maxpool(x)
-
- x = self.layer1(x)
- x = self.layer2(x)
- x = self.layer3(x)
- x = self.layer4(x)
- x = self.bn4(x)
- x = self.dropout(x)
- x = x.view(x.size(0), -1)
- x = self.fc5(x)
- x = self.bn5(x)
-
- return x
diff --git a/comfy_extras/chainner_models/architecture/face/codeformer.py b/comfy_extras/chainner_models/architecture/face/codeformer.py
deleted file mode 100644
index 06614007864..00000000000
--- a/comfy_extras/chainner_models/architecture/face/codeformer.py
+++ /dev/null
@@ -1,790 +0,0 @@
-"""
-Modified from https://github.com/sczhou/CodeFormer
-VQGAN code, adapted from the original created by the Unleashing Transformers authors:
-https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
-This verison of the arch specifically was gathered from an old version of GFPGAN. If this is a problem, please contact me.
-"""
-import math
-from typing import Optional
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import logging as logger
-from torch import Tensor
-
-
-class VectorQuantizer(nn.Module):
- def __init__(self, codebook_size, emb_dim, beta):
- super(VectorQuantizer, self).__init__()
- self.codebook_size = codebook_size # number of embeddings
- self.emb_dim = emb_dim # dimension of embedding
- self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
- self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
- self.embedding.weight.data.uniform_(
- -1.0 / self.codebook_size, 1.0 / self.codebook_size
- )
-
- def forward(self, z):
- # reshape z -> (batch, height, width, channel) and flatten
- z = z.permute(0, 2, 3, 1).contiguous()
- z_flattened = z.view(-1, self.emb_dim)
-
- # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
- d = (
- (z_flattened**2).sum(dim=1, keepdim=True)
- + (self.embedding.weight**2).sum(1)
- - 2 * torch.matmul(z_flattened, self.embedding.weight.t())
- )
-
- mean_distance = torch.mean(d)
- # find closest encodings
- # min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
- min_encoding_scores, min_encoding_indices = torch.topk(
- d, 1, dim=1, largest=False
- )
- # [0-1], higher score, higher confidence
- min_encoding_scores = torch.exp(-min_encoding_scores / 10)
-
- min_encodings = torch.zeros(
- min_encoding_indices.shape[0], self.codebook_size
- ).to(z)
- min_encodings.scatter_(1, min_encoding_indices, 1)
-
- # get quantized latent vectors
- z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
- # compute loss for embedding
- loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean(
- (z_q - z.detach()) ** 2
- )
- # preserve gradients
- z_q = z + (z_q - z).detach()
-
- # perplexity
- e_mean = torch.mean(min_encodings, dim=0)
- perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
- # reshape back to match original input shape
- z_q = z_q.permute(0, 3, 1, 2).contiguous()
-
- return (
- z_q,
- loss,
- {
- "perplexity": perplexity,
- "min_encodings": min_encodings,
- "min_encoding_indices": min_encoding_indices,
- "min_encoding_scores": min_encoding_scores,
- "mean_distance": mean_distance,
- },
- )
-
- def get_codebook_feat(self, indices, shape):
- # input indices: batch*token_num -> (batch*token_num)*1
- # shape: batch, height, width, channel
- indices = indices.view(-1, 1)
- min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
- min_encodings.scatter_(1, indices, 1)
- # get quantized latent vectors
- z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
-
- if shape is not None: # reshape back to match original input shape
- z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
-
- return z_q
-
-
-class GumbelQuantizer(nn.Module):
- def __init__(
- self,
- codebook_size,
- emb_dim,
- num_hiddens,
- straight_through=False,
- kl_weight=5e-4,
- temp_init=1.0,
- ):
- super().__init__()
- self.codebook_size = codebook_size # number of embeddings
- self.emb_dim = emb_dim # dimension of embedding
- self.straight_through = straight_through
- self.temperature = temp_init
- self.kl_weight = kl_weight
- self.proj = nn.Conv2d(
- num_hiddens, codebook_size, 1
- ) # projects last encoder layer to quantized logits
- self.embed = nn.Embedding(codebook_size, emb_dim)
-
- def forward(self, z):
- hard = self.straight_through if self.training else True
-
- logits = self.proj(z)
-
- soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard)
-
- z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
-
- # + kl divergence to the prior loss
- qy = F.softmax(logits, dim=1)
- diff = (
- self.kl_weight
- * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
- )
- min_encoding_indices = soft_one_hot.argmax(dim=1)
-
- return z_q, diff, {"min_encoding_indices": min_encoding_indices}
-
-
-class Downsample(nn.Module):
- def __init__(self, in_channels):
- super().__init__()
- self.conv = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=3, stride=2, padding=0
- )
-
- def forward(self, x):
- pad = (0, 1, 0, 1)
- x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
- x = self.conv(x)
- return x
-
-
-class Upsample(nn.Module):
- def __init__(self, in_channels):
- super().__init__()
- self.conv = nn.Conv2d(
- in_channels, in_channels, kernel_size=3, stride=1, padding=1
- )
-
- def forward(self, x):
- x = F.interpolate(x, scale_factor=2.0, mode="nearest")
- x = self.conv(x)
-
- return x
-
-
-class AttnBlock(nn.Module):
- def __init__(self, in_channels):
- super().__init__()
- self.in_channels = in_channels
-
- self.norm = normalize(in_channels)
- self.q = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
- )
- self.k = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
- )
- self.v = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
- )
- self.proj_out = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
- )
-
- def forward(self, x):
- h_ = x
- h_ = self.norm(h_)
- q = self.q(h_)
- k = self.k(h_)
- v = self.v(h_)
-
- # compute attention
- b, c, h, w = q.shape
- q = q.reshape(b, c, h * w)
- q = q.permute(0, 2, 1)
- k = k.reshape(b, c, h * w)
- w_ = torch.bmm(q, k)
- w_ = w_ * (int(c) ** (-0.5))
- w_ = F.softmax(w_, dim=2)
-
- # attend to values
- v = v.reshape(b, c, h * w)
- w_ = w_.permute(0, 2, 1)
- h_ = torch.bmm(v, w_)
- h_ = h_.reshape(b, c, h, w)
-
- h_ = self.proj_out(h_)
-
- return x + h_
-
-
-class Encoder(nn.Module):
- def __init__(
- self,
- in_channels,
- nf,
- out_channels,
- ch_mult,
- num_res_blocks,
- resolution,
- attn_resolutions,
- ):
- super().__init__()
- self.nf = nf
- self.num_resolutions = len(ch_mult)
- self.num_res_blocks = num_res_blocks
- self.resolution = resolution
- self.attn_resolutions = attn_resolutions
-
- curr_res = self.resolution
- in_ch_mult = (1,) + tuple(ch_mult)
-
- blocks = []
- # initial convultion
- blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1))
-
- # residual and downsampling blocks, with attention on smaller res (16x16)
- for i in range(self.num_resolutions):
- block_in_ch = nf * in_ch_mult[i]
- block_out_ch = nf * ch_mult[i]
- for _ in range(self.num_res_blocks):
- blocks.append(ResBlock(block_in_ch, block_out_ch))
- block_in_ch = block_out_ch
- if curr_res in attn_resolutions:
- blocks.append(AttnBlock(block_in_ch))
-
- if i != self.num_resolutions - 1:
- blocks.append(Downsample(block_in_ch))
- curr_res = curr_res // 2
-
- # non-local attention block
- blocks.append(ResBlock(block_in_ch, block_in_ch)) # type: ignore
- blocks.append(AttnBlock(block_in_ch)) # type: ignore
- blocks.append(ResBlock(block_in_ch, block_in_ch)) # type: ignore
-
- # normalise and convert to latent size
- blocks.append(normalize(block_in_ch)) # type: ignore
- blocks.append(
- nn.Conv2d(block_in_ch, out_channels, kernel_size=3, stride=1, padding=1) # type: ignore
- )
- self.blocks = nn.ModuleList(blocks)
-
- def forward(self, x):
- for block in self.blocks:
- x = block(x)
-
- return x
-
-
-class Generator(nn.Module):
- def __init__(self, nf, ch_mult, res_blocks, img_size, attn_resolutions, emb_dim):
- super().__init__()
- self.nf = nf
- self.ch_mult = ch_mult
- self.num_resolutions = len(self.ch_mult)
- self.num_res_blocks = res_blocks
- self.resolution = img_size
- self.attn_resolutions = attn_resolutions
- self.in_channels = emb_dim
- self.out_channels = 3
- block_in_ch = self.nf * self.ch_mult[-1]
- curr_res = self.resolution // 2 ** (self.num_resolutions - 1)
-
- blocks = []
- # initial conv
- blocks.append(
- nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1)
- )
-
- # non-local attention block
- blocks.append(ResBlock(block_in_ch, block_in_ch))
- blocks.append(AttnBlock(block_in_ch))
- blocks.append(ResBlock(block_in_ch, block_in_ch))
-
- for i in reversed(range(self.num_resolutions)):
- block_out_ch = self.nf * self.ch_mult[i]
-
- for _ in range(self.num_res_blocks):
- blocks.append(ResBlock(block_in_ch, block_out_ch))
- block_in_ch = block_out_ch
-
- if curr_res in self.attn_resolutions:
- blocks.append(AttnBlock(block_in_ch))
-
- if i != 0:
- blocks.append(Upsample(block_in_ch))
- curr_res = curr_res * 2
-
- blocks.append(normalize(block_in_ch))
- blocks.append(
- nn.Conv2d(
- block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1
- )
- )
-
- self.blocks = nn.ModuleList(blocks)
-
- def forward(self, x):
- for block in self.blocks:
- x = block(x)
-
- return x
-
-
-class VQAutoEncoder(nn.Module):
- def __init__(
- self,
- img_size,
- nf,
- ch_mult,
- quantizer="nearest",
- res_blocks=2,
- attn_resolutions=[16],
- codebook_size=1024,
- emb_dim=256,
- beta=0.25,
- gumbel_straight_through=False,
- gumbel_kl_weight=1e-8,
- model_path=None,
- ):
- super().__init__()
- self.in_channels = 3
- self.nf = nf
- self.n_blocks = res_blocks
- self.codebook_size = codebook_size
- self.embed_dim = emb_dim
- self.ch_mult = ch_mult
- self.resolution = img_size
- self.attn_resolutions = attn_resolutions
- self.quantizer_type = quantizer
- self.encoder = Encoder(
- self.in_channels,
- self.nf,
- self.embed_dim,
- self.ch_mult,
- self.n_blocks,
- self.resolution,
- self.attn_resolutions,
- )
- if self.quantizer_type == "nearest":
- self.beta = beta # 0.25
- self.quantize = VectorQuantizer(
- self.codebook_size, self.embed_dim, self.beta
- )
- elif self.quantizer_type == "gumbel":
- self.gumbel_num_hiddens = emb_dim
- self.straight_through = gumbel_straight_through
- self.kl_weight = gumbel_kl_weight
- self.quantize = GumbelQuantizer(
- self.codebook_size,
- self.embed_dim,
- self.gumbel_num_hiddens,
- self.straight_through,
- self.kl_weight,
- )
- self.generator = Generator(
- nf, ch_mult, res_blocks, img_size, attn_resolutions, emb_dim
- )
-
- if model_path is not None:
- chkpt = torch.load(model_path, map_location="cpu")
- if "params_ema" in chkpt:
- self.load_state_dict(
- torch.load(model_path, map_location="cpu")["params_ema"]
- )
- logger.info(f"vqgan is loaded from: {model_path} [params_ema]")
- elif "params" in chkpt:
- self.load_state_dict(
- torch.load(model_path, map_location="cpu")["params"]
- )
- logger.info(f"vqgan is loaded from: {model_path} [params]")
- else:
- raise ValueError("Wrong params!")
-
- def forward(self, x):
- x = self.encoder(x)
- quant, codebook_loss, quant_stats = self.quantize(x)
- x = self.generator(quant)
- return x, codebook_loss, quant_stats
-
-
-def calc_mean_std(feat, eps=1e-5):
- """Calculate mean and std for adaptive_instance_normalization.
- Args:
- feat (Tensor): 4D tensor.
- eps (float): A small value added to the variance to avoid
- divide-by-zero. Default: 1e-5.
- """
- size = feat.size()
- assert len(size) == 4, "The input feature should be 4D tensor."
- b, c = size[:2]
- feat_var = feat.view(b, c, -1).var(dim=2) + eps
- feat_std = feat_var.sqrt().view(b, c, 1, 1)
- feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
- return feat_mean, feat_std
-
-
-def adaptive_instance_normalization(content_feat, style_feat):
- """Adaptive instance normalization.
- Adjust the reference features to have the similar color and illuminations
- as those in the degradate features.
- Args:
- content_feat (Tensor): The reference feature.
- style_feat (Tensor): The degradate features.
- """
- size = content_feat.size()
- style_mean, style_std = calc_mean_std(style_feat)
- content_mean, content_std = calc_mean_std(content_feat)
- normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(
- size
- )
- return normalized_feat * style_std.expand(size) + style_mean.expand(size)
-
-
-class PositionEmbeddingSine(nn.Module):
- """
- This is a more standard version of the position embedding, very similar to the one
- used by the Attention is all you need paper, generalized to work on images.
- """
-
- def __init__(
- self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
- ):
- super().__init__()
- self.num_pos_feats = num_pos_feats
- self.temperature = temperature
- self.normalize = normalize
- if scale is not None and normalize is False:
- raise ValueError("normalize should be True if scale is passed")
- if scale is None:
- scale = 2 * math.pi
- self.scale = scale
-
- def forward(self, x, mask=None):
- if mask is None:
- mask = torch.zeros(
- (x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool
- )
- not_mask = ~mask # pylint: disable=invalid-unary-operand-type
- y_embed = not_mask.cumsum(1, dtype=torch.float32)
- x_embed = not_mask.cumsum(2, dtype=torch.float32)
- if self.normalize:
- eps = 1e-6
- y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
- x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
-
- dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
- dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
-
- pos_x = x_embed[:, :, :, None] / dim_t
- pos_y = y_embed[:, :, :, None] / dim_t
- pos_x = torch.stack(
- (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
- ).flatten(3)
- pos_y = torch.stack(
- (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
- ).flatten(3)
- pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
- return pos
-
-
-def _get_activation_fn(activation):
- """Return an activation function given a string"""
- if activation == "relu":
- return F.relu
- if activation == "gelu":
- return F.gelu
- if activation == "glu":
- return F.glu
- raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
-
-
-class TransformerSALayer(nn.Module):
- def __init__(
- self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"
- ):
- super().__init__()
- self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
- # Implementation of Feedforward model - MLP
- self.linear1 = nn.Linear(embed_dim, dim_mlp)
- self.dropout = nn.Dropout(dropout)
- self.linear2 = nn.Linear(dim_mlp, embed_dim)
-
- self.norm1 = nn.LayerNorm(embed_dim)
- self.norm2 = nn.LayerNorm(embed_dim)
- self.dropout1 = nn.Dropout(dropout)
- self.dropout2 = nn.Dropout(dropout)
-
- self.activation = _get_activation_fn(activation)
-
- def with_pos_embed(self, tensor, pos: Optional[Tensor]):
- return tensor if pos is None else tensor + pos
-
- def forward(
- self,
- tgt,
- tgt_mask: Optional[Tensor] = None,
- tgt_key_padding_mask: Optional[Tensor] = None,
- query_pos: Optional[Tensor] = None,
- ):
- # self attention
- tgt2 = self.norm1(tgt)
- q = k = self.with_pos_embed(tgt2, query_pos)
- tgt2 = self.self_attn(
- q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
- )[0]
- tgt = tgt + self.dropout1(tgt2)
-
- # ffn
- tgt2 = self.norm2(tgt)
- tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
- tgt = tgt + self.dropout2(tgt2)
- return tgt
-
-
-def normalize(in_channels):
- return torch.nn.GroupNorm(
- num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
- )
-
-
-@torch.jit.script # type: ignore
-def swish(x):
- return x * torch.sigmoid(x)
-
-
-class ResBlock(nn.Module):
- def __init__(self, in_channels, out_channels=None):
- super(ResBlock, self).__init__()
- self.in_channels = in_channels
- self.out_channels = in_channels if out_channels is None else out_channels
- self.norm1 = normalize(in_channels)
- self.conv1 = nn.Conv2d(
- in_channels, out_channels, kernel_size=3, stride=1, padding=1 # type: ignore
- )
- self.norm2 = normalize(out_channels)
- self.conv2 = nn.Conv2d(
- out_channels, out_channels, kernel_size=3, stride=1, padding=1 # type: ignore
- )
- if self.in_channels != self.out_channels:
- self.conv_out = nn.Conv2d(
- in_channels, out_channels, kernel_size=1, stride=1, padding=0 # type: ignore
- )
-
- def forward(self, x_in):
- x = x_in
- x = self.norm1(x)
- x = swish(x)
- x = self.conv1(x)
- x = self.norm2(x)
- x = swish(x)
- x = self.conv2(x)
- if self.in_channels != self.out_channels:
- x_in = self.conv_out(x_in)
-
- return x + x_in
-
-
-class Fuse_sft_block(nn.Module):
- def __init__(self, in_ch, out_ch):
- super().__init__()
- self.encode_enc = ResBlock(2 * in_ch, out_ch)
-
- self.scale = nn.Sequential(
- nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
- nn.LeakyReLU(0.2, True),
- nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
- )
-
- self.shift = nn.Sequential(
- nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
- nn.LeakyReLU(0.2, True),
- nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
- )
-
- def forward(self, enc_feat, dec_feat, w=1):
- enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
- scale = self.scale(enc_feat)
- shift = self.shift(enc_feat)
- residual = w * (dec_feat * scale + shift)
- out = dec_feat + residual
- return out
-
-
-class CodeFormer(VQAutoEncoder):
- def __init__(self, state_dict):
- dim_embd = 512
- n_head = 8
- n_layers = 9
- codebook_size = 1024
- latent_size = 256
- connect_list = ["32", "64", "128", "256"]
- fix_modules = ["quantize", "generator"]
-
- # This is just a guess as I only have one model to look at
- position_emb = state_dict["position_emb"]
- dim_embd = position_emb.shape[1]
- latent_size = position_emb.shape[0]
-
- try:
- n_layers = len(
- set([x.split(".")[1] for x in state_dict.keys() if "ft_layers" in x])
- )
- except:
- pass
-
- codebook_size = state_dict["quantize.embedding.weight"].shape[0]
-
- # This is also just another guess
- n_head_exp = (
- state_dict["ft_layers.0.self_attn.in_proj_weight"].shape[0] // dim_embd
- )
- n_head = 2**n_head_exp
-
- in_nc = state_dict["encoder.blocks.0.weight"].shape[1]
-
- self.model_arch = "CodeFormer"
- self.sub_type = "Face SR"
- self.scale = 8
- self.in_nc = in_nc
- self.out_nc = in_nc
-
- self.state = state_dict
-
- self.supports_fp16 = False
- self.supports_bf16 = True
- self.min_size_restriction = 16
-
- super(CodeFormer, self).__init__(
- 512, 64, [1, 2, 2, 4, 4, 8], "nearest", 2, [16], codebook_size
- )
-
- if fix_modules is not None:
- for module in fix_modules:
- for param in getattr(self, module).parameters():
- param.requires_grad = False
-
- self.connect_list = connect_list
- self.n_layers = n_layers
- self.dim_embd = dim_embd
- self.dim_mlp = dim_embd * 2
-
- self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd)) # type: ignore
- self.feat_emb = nn.Linear(256, self.dim_embd)
-
- # transformer
- self.ft_layers = nn.Sequential(
- *[
- TransformerSALayer(
- embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0
- )
- for _ in range(self.n_layers)
- ]
- )
-
- # logits_predict head
- self.idx_pred_layer = nn.Sequential(
- nn.LayerNorm(dim_embd), nn.Linear(dim_embd, codebook_size, bias=False)
- )
-
- self.channels = {
- "16": 512,
- "32": 256,
- "64": 256,
- "128": 128,
- "256": 128,
- "512": 64,
- }
-
- # after second residual block for > 16, before attn layer for ==16
- self.fuse_encoder_block = {
- "512": 2,
- "256": 5,
- "128": 8,
- "64": 11,
- "32": 14,
- "16": 18,
- }
- # after first residual block for > 16, before attn layer for ==16
- self.fuse_generator_block = {
- "16": 6,
- "32": 9,
- "64": 12,
- "128": 15,
- "256": 18,
- "512": 21,
- }
-
- # fuse_convs_dict
- self.fuse_convs_dict = nn.ModuleDict()
- for f_size in self.connect_list:
- in_ch = self.channels[f_size]
- self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
-
- self.load_state_dict(state_dict)
-
- def _init_weights(self, module):
- if isinstance(module, (nn.Linear, nn.Embedding)):
- module.weight.data.normal_(mean=0.0, std=0.02)
- if isinstance(module, nn.Linear) and module.bias is not None:
- module.bias.data.zero_()
- elif isinstance(module, nn.LayerNorm):
- module.bias.data.zero_()
- module.weight.data.fill_(1.0)
-
- def forward(self, x, weight=0.5, **kwargs):
- detach_16 = True
- code_only = False
- adain = True
- # ################### Encoder #####################
- enc_feat_dict = {}
- out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
- for i, block in enumerate(self.encoder.blocks):
- x = block(x)
- if i in out_list:
- enc_feat_dict[str(x.shape[-1])] = x.clone()
-
- lq_feat = x
- # ################# Transformer ###################
- # quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
- pos_emb = self.position_emb.unsqueeze(1).repeat(1, x.shape[0], 1)
- # BCHW -> BC(HW) -> (HW)BC
- feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2, 0, 1))
- query_emb = feat_emb
- # Transformer encoder
- for layer in self.ft_layers:
- query_emb = layer(query_emb, query_pos=pos_emb)
-
- # output logits
- logits = self.idx_pred_layer(query_emb) # (hw)bn
- logits = logits.permute(1, 0, 2) # (hw)bn -> b(hw)n
-
- if code_only: # for training stage II
- # logits doesn't need softmax before cross_entropy loss
- return logits, lq_feat
-
- # ################# Quantization ###################
- # if self.training:
- # quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
- # # b(hw)c -> bc(hw) -> bchw
- # quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
- # ------------
- soft_one_hot = F.softmax(logits, dim=2)
- _, top_idx = torch.topk(soft_one_hot, 1, dim=2)
- quant_feat = self.quantize.get_codebook_feat(
- top_idx, shape=[x.shape[0], 16, 16, 256] # type: ignore
- )
- # preserve gradients
- # quant_feat = lq_feat + (quant_feat - lq_feat).detach()
-
- if detach_16:
- quant_feat = quant_feat.detach() # for training stage III
- if adain:
- quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
-
- # ################## Generator ####################
- x = quant_feat
- fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
-
- for i, block in enumerate(self.generator.blocks):
- x = block(x)
- if i in fuse_list: # fuse after i-th block
- f_size = str(x.shape[-1])
- if weight > 0:
- x = self.fuse_convs_dict[f_size](
- enc_feat_dict[f_size].detach(), x, weight
- )
- out = x
- # logits doesn't need softmax before cross_entropy loss
- # return out, logits, lq_feat
- return out, logits
diff --git a/comfy_extras/chainner_models/architecture/face/fused_act.py b/comfy_extras/chainner_models/architecture/face/fused_act.py
deleted file mode 100644
index 7ed526547b4..00000000000
--- a/comfy_extras/chainner_models/architecture/face/fused_act.py
+++ /dev/null
@@ -1,81 +0,0 @@
-# pylint: skip-file
-# type: ignore
-# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py # noqa:E501
-
-import torch
-from torch import nn
-from torch.autograd import Function
-
-fused_act_ext = None
-
-
-class FusedLeakyReLUFunctionBackward(Function):
- @staticmethod
- def forward(ctx, grad_output, out, negative_slope, scale):
- ctx.save_for_backward(out)
- ctx.negative_slope = negative_slope
- ctx.scale = scale
-
- empty = grad_output.new_empty(0)
-
- grad_input = fused_act_ext.fused_bias_act(
- grad_output, empty, out, 3, 1, negative_slope, scale
- )
-
- dim = [0]
-
- if grad_input.ndim > 2:
- dim += list(range(2, grad_input.ndim))
-
- grad_bias = grad_input.sum(dim).detach()
-
- return grad_input, grad_bias
-
- @staticmethod
- def backward(ctx, gradgrad_input, gradgrad_bias):
- (out,) = ctx.saved_tensors
- gradgrad_out = fused_act_ext.fused_bias_act(
- gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
- )
-
- return gradgrad_out, None, None, None
-
-
-class FusedLeakyReLUFunction(Function):
- @staticmethod
- def forward(ctx, input, bias, negative_slope, scale):
- empty = input.new_empty(0)
- out = fused_act_ext.fused_bias_act(
- input, bias, empty, 3, 0, negative_slope, scale
- )
- ctx.save_for_backward(out)
- ctx.negative_slope = negative_slope
- ctx.scale = scale
-
- return out
-
- @staticmethod
- def backward(ctx, grad_output):
- (out,) = ctx.saved_tensors
-
- grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
- grad_output, out, ctx.negative_slope, ctx.scale
- )
-
- return grad_input, grad_bias, None, None
-
-
-class FusedLeakyReLU(nn.Module):
- def __init__(self, channel, negative_slope=0.2, scale=2**0.5):
- super().__init__()
-
- self.bias = nn.Parameter(torch.zeros(channel))
- self.negative_slope = negative_slope
- self.scale = scale
-
- def forward(self, input):
- return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
-
-
-def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5):
- return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
diff --git a/comfy_extras/chainner_models/architecture/face/gfpgan_bilinear_arch.py b/comfy_extras/chainner_models/architecture/face/gfpgan_bilinear_arch.py
deleted file mode 100644
index b6e820e006f..00000000000
--- a/comfy_extras/chainner_models/architecture/face/gfpgan_bilinear_arch.py
+++ /dev/null
@@ -1,389 +0,0 @@
-# pylint: skip-file
-# type: ignore
-import math
-import random
-
-import torch
-from torch import nn
-
-from .gfpganv1_arch import ResUpBlock
-from .stylegan2_bilinear_arch import (
- ConvLayer,
- EqualConv2d,
- EqualLinear,
- ResBlock,
- ScaledLeakyReLU,
- StyleGAN2GeneratorBilinear,
-)
-
-
-class StyleGAN2GeneratorBilinearSFT(StyleGAN2GeneratorBilinear):
- """StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
- It is the bilinear version. It does not use the complicated UpFirDnSmooth function that is not friendly for
- deployment. It can be easily converted to the clean version: StyleGAN2GeneratorCSFT.
- Args:
- out_size (int): The spatial size of outputs.
- num_style_feat (int): Channel number of style features. Default: 512.
- num_mlp (int): Layer number of MLP style layers. Default: 8.
- channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
- lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
- narrow (float): The narrow ratio for channels. Default: 1.
- sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
- """
-
- def __init__(
- self,
- out_size,
- num_style_feat=512,
- num_mlp=8,
- channel_multiplier=2,
- lr_mlp=0.01,
- narrow=1,
- sft_half=False,
- ):
- super(StyleGAN2GeneratorBilinearSFT, self).__init__(
- out_size,
- num_style_feat=num_style_feat,
- num_mlp=num_mlp,
- channel_multiplier=channel_multiplier,
- lr_mlp=lr_mlp,
- narrow=narrow,
- )
- self.sft_half = sft_half
-
- def forward(
- self,
- styles,
- conditions,
- input_is_latent=False,
- noise=None,
- randomize_noise=True,
- truncation=1,
- truncation_latent=None,
- inject_index=None,
- return_latents=False,
- ):
- """Forward function for StyleGAN2GeneratorBilinearSFT.
- Args:
- styles (list[Tensor]): Sample codes of styles.
- conditions (list[Tensor]): SFT conditions to generators.
- input_is_latent (bool): Whether input is latent style. Default: False.
- noise (Tensor | None): Input noise or None. Default: None.
- randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
- truncation (float): The truncation ratio. Default: 1.
- truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
- inject_index (int | None): The injection index for mixing noise. Default: None.
- return_latents (bool): Whether to return style latents. Default: False.
- """
- # style codes -> latents with Style MLP layer
- if not input_is_latent:
- styles = [self.style_mlp(s) for s in styles]
- # noises
- if noise is None:
- if randomize_noise:
- noise = [None] * self.num_layers # for each style conv layer
- else: # use the stored noise
- noise = [
- getattr(self.noises, f"noise{i}") for i in range(self.num_layers)
- ]
- # style truncation
- if truncation < 1:
- style_truncation = []
- for style in styles:
- style_truncation.append(
- truncation_latent + truncation * (style - truncation_latent)
- )
- styles = style_truncation
- # get style latents with injection
- if len(styles) == 1:
- inject_index = self.num_latent
-
- if styles[0].ndim < 3:
- # repeat latent code for all the layers
- latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
- else: # used for encoder with different latent code for each layer
- latent = styles[0]
- elif len(styles) == 2: # mixing noises
- if inject_index is None:
- inject_index = random.randint(1, self.num_latent - 1)
- latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
- latent2 = (
- styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
- )
- latent = torch.cat([latent1, latent2], 1)
-
- # main generation
- out = self.constant_input(latent.shape[0])
- out = self.style_conv1(out, latent[:, 0], noise=noise[0])
- skip = self.to_rgb1(out, latent[:, 1])
-
- i = 1
- for conv1, conv2, noise1, noise2, to_rgb in zip(
- self.style_convs[::2],
- self.style_convs[1::2],
- noise[1::2],
- noise[2::2],
- self.to_rgbs,
- ):
- out = conv1(out, latent[:, i], noise=noise1)
-
- # the conditions may have fewer levels
- if i < len(conditions):
- # SFT part to combine the conditions
- if self.sft_half: # only apply SFT to half of the channels
- out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
- out_sft = out_sft * conditions[i - 1] + conditions[i]
- out = torch.cat([out_same, out_sft], dim=1)
- else: # apply SFT to all the channels
- out = out * conditions[i - 1] + conditions[i]
-
- out = conv2(out, latent[:, i + 1], noise=noise2)
- skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
- i += 2
-
- image = skip
-
- if return_latents:
- return image, latent
- else:
- return image, None
-
-
-class GFPGANBilinear(nn.Module):
- """The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
- It is the bilinear version and it does not use the complicated UpFirDnSmooth function that is not friendly for
- deployment. It can be easily converted to the clean version: GFPGANv1Clean.
- Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
- Args:
- out_size (int): The spatial size of outputs.
- num_style_feat (int): Channel number of style features. Default: 512.
- channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
- decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
- fix_decoder (bool): Whether to fix the decoder. Default: True.
- num_mlp (int): Layer number of MLP style layers. Default: 8.
- lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
- input_is_latent (bool): Whether input is latent style. Default: False.
- different_w (bool): Whether to use different latent w for different layers. Default: False.
- narrow (float): The narrow ratio for channels. Default: 1.
- sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
- """
-
- def __init__(
- self,
- out_size,
- num_style_feat=512,
- channel_multiplier=1,
- decoder_load_path=None,
- fix_decoder=True,
- # for stylegan decoder
- num_mlp=8,
- lr_mlp=0.01,
- input_is_latent=False,
- different_w=False,
- narrow=1,
- sft_half=False,
- ):
- super(GFPGANBilinear, self).__init__()
- self.input_is_latent = input_is_latent
- self.different_w = different_w
- self.num_style_feat = num_style_feat
- self.min_size_restriction = 512
-
- unet_narrow = narrow * 0.5 # by default, use a half of input channels
- channels = {
- "4": int(512 * unet_narrow),
- "8": int(512 * unet_narrow),
- "16": int(512 * unet_narrow),
- "32": int(512 * unet_narrow),
- "64": int(256 * channel_multiplier * unet_narrow),
- "128": int(128 * channel_multiplier * unet_narrow),
- "256": int(64 * channel_multiplier * unet_narrow),
- "512": int(32 * channel_multiplier * unet_narrow),
- "1024": int(16 * channel_multiplier * unet_narrow),
- }
-
- self.log_size = int(math.log(out_size, 2))
- first_out_size = 2 ** (int(math.log(out_size, 2)))
-
- self.conv_body_first = ConvLayer(
- 3, channels[f"{first_out_size}"], 1, bias=True, activate=True
- )
-
- # downsample
- in_channels = channels[f"{first_out_size}"]
- self.conv_body_down = nn.ModuleList()
- for i in range(self.log_size, 2, -1):
- out_channels = channels[f"{2**(i - 1)}"]
- self.conv_body_down.append(ResBlock(in_channels, out_channels))
- in_channels = out_channels
-
- self.final_conv = ConvLayer(
- in_channels, channels["4"], 3, bias=True, activate=True
- )
-
- # upsample
- in_channels = channels["4"]
- self.conv_body_up = nn.ModuleList()
- for i in range(3, self.log_size + 1):
- out_channels = channels[f"{2**i}"]
- self.conv_body_up.append(ResUpBlock(in_channels, out_channels))
- in_channels = out_channels
-
- # to RGB
- self.toRGB = nn.ModuleList()
- for i in range(3, self.log_size + 1):
- self.toRGB.append(
- EqualConv2d(
- channels[f"{2**i}"],
- 3,
- 1,
- stride=1,
- padding=0,
- bias=True,
- bias_init_val=0,
- )
- )
-
- if different_w:
- linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat
- else:
- linear_out_channel = num_style_feat
-
- self.final_linear = EqualLinear(
- channels["4"] * 4 * 4,
- linear_out_channel,
- bias=True,
- bias_init_val=0,
- lr_mul=1,
- activation=None,
- )
-
- # the decoder: stylegan2 generator with SFT modulations
- self.stylegan_decoder = StyleGAN2GeneratorBilinearSFT(
- out_size=out_size,
- num_style_feat=num_style_feat,
- num_mlp=num_mlp,
- channel_multiplier=channel_multiplier,
- lr_mlp=lr_mlp,
- narrow=narrow,
- sft_half=sft_half,
- )
-
- # load pre-trained stylegan2 model if necessary
- if decoder_load_path:
- self.stylegan_decoder.load_state_dict(
- torch.load(
- decoder_load_path, map_location=lambda storage, loc: storage
- )["params_ema"]
- )
- # fix decoder without updating params
- if fix_decoder:
- for _, param in self.stylegan_decoder.named_parameters():
- param.requires_grad = False
-
- # for SFT modulations (scale and shift)
- self.condition_scale = nn.ModuleList()
- self.condition_shift = nn.ModuleList()
- for i in range(3, self.log_size + 1):
- out_channels = channels[f"{2**i}"]
- if sft_half:
- sft_out_channels = out_channels
- else:
- sft_out_channels = out_channels * 2
- self.condition_scale.append(
- nn.Sequential(
- EqualConv2d(
- out_channels,
- out_channels,
- 3,
- stride=1,
- padding=1,
- bias=True,
- bias_init_val=0,
- ),
- ScaledLeakyReLU(0.2),
- EqualConv2d(
- out_channels,
- sft_out_channels,
- 3,
- stride=1,
- padding=1,
- bias=True,
- bias_init_val=1,
- ),
- )
- )
- self.condition_shift.append(
- nn.Sequential(
- EqualConv2d(
- out_channels,
- out_channels,
- 3,
- stride=1,
- padding=1,
- bias=True,
- bias_init_val=0,
- ),
- ScaledLeakyReLU(0.2),
- EqualConv2d(
- out_channels,
- sft_out_channels,
- 3,
- stride=1,
- padding=1,
- bias=True,
- bias_init_val=0,
- ),
- )
- )
-
- def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True):
- """Forward function for GFPGANBilinear.
- Args:
- x (Tensor): Input images.
- return_latents (bool): Whether to return style latents. Default: False.
- return_rgb (bool): Whether return intermediate rgb images. Default: True.
- randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
- """
- conditions = []
- unet_skips = []
- out_rgbs = []
-
- # encoder
- feat = self.conv_body_first(x)
- for i in range(self.log_size - 2):
- feat = self.conv_body_down[i](feat)
- unet_skips.insert(0, feat)
-
- feat = self.final_conv(feat)
-
- # style code
- style_code = self.final_linear(feat.view(feat.size(0), -1))
- if self.different_w:
- style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
-
- # decode
- for i in range(self.log_size - 2):
- # add unet skip
- feat = feat + unet_skips[i]
- # ResUpLayer
- feat = self.conv_body_up[i](feat)
- # generate scale and shift for SFT layers
- scale = self.condition_scale[i](feat)
- conditions.append(scale.clone())
- shift = self.condition_shift[i](feat)
- conditions.append(shift.clone())
- # generate rgb images
- if return_rgb:
- out_rgbs.append(self.toRGB[i](feat))
-
- # decoder
- image, _ = self.stylegan_decoder(
- [style_code],
- conditions,
- return_latents=return_latents,
- input_is_latent=self.input_is_latent,
- randomize_noise=randomize_noise,
- )
-
- return image, out_rgbs
diff --git a/comfy_extras/chainner_models/architecture/face/gfpganv1_arch.py b/comfy_extras/chainner_models/architecture/face/gfpganv1_arch.py
deleted file mode 100644
index 72d72fc865e..00000000000
--- a/comfy_extras/chainner_models/architecture/face/gfpganv1_arch.py
+++ /dev/null
@@ -1,566 +0,0 @@
-# pylint: skip-file
-# type: ignore
-import math
-import random
-
-import torch
-from torch import nn
-from torch.nn import functional as F
-
-from .fused_act import FusedLeakyReLU
-from .stylegan2_arch import (
- ConvLayer,
- EqualConv2d,
- EqualLinear,
- ResBlock,
- ScaledLeakyReLU,
- StyleGAN2Generator,
-)
-
-
-class StyleGAN2GeneratorSFT(StyleGAN2Generator):
- """StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
- Args:
- out_size (int): The spatial size of outputs.
- num_style_feat (int): Channel number of style features. Default: 512.
- num_mlp (int): Layer number of MLP style layers. Default: 8.
- channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
- resample_kernel (list[int]): A list indicating the 1D resample kernel magnitude. A cross production will be
- applied to extent 1D resample kernel to 2D resample kernel. Default: (1, 3, 3, 1).
- lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
- narrow (float): The narrow ratio for channels. Default: 1.
- sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
- """
-
- def __init__(
- self,
- out_size,
- num_style_feat=512,
- num_mlp=8,
- channel_multiplier=2,
- resample_kernel=(1, 3, 3, 1),
- lr_mlp=0.01,
- narrow=1,
- sft_half=False,
- ):
- super(StyleGAN2GeneratorSFT, self).__init__(
- out_size,
- num_style_feat=num_style_feat,
- num_mlp=num_mlp,
- channel_multiplier=channel_multiplier,
- resample_kernel=resample_kernel,
- lr_mlp=lr_mlp,
- narrow=narrow,
- )
- self.sft_half = sft_half
-
- def forward(
- self,
- styles,
- conditions,
- input_is_latent=False,
- noise=None,
- randomize_noise=True,
- truncation=1,
- truncation_latent=None,
- inject_index=None,
- return_latents=False,
- ):
- """Forward function for StyleGAN2GeneratorSFT.
- Args:
- styles (list[Tensor]): Sample codes of styles.
- conditions (list[Tensor]): SFT conditions to generators.
- input_is_latent (bool): Whether input is latent style. Default: False.
- noise (Tensor | None): Input noise or None. Default: None.
- randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
- truncation (float): The truncation ratio. Default: 1.
- truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
- inject_index (int | None): The injection index for mixing noise. Default: None.
- return_latents (bool): Whether to return style latents. Default: False.
- """
- # style codes -> latents with Style MLP layer
- if not input_is_latent:
- styles = [self.style_mlp(s) for s in styles]
- # noises
- if noise is None:
- if randomize_noise:
- noise = [None] * self.num_layers # for each style conv layer
- else: # use the stored noise
- noise = [
- getattr(self.noises, f"noise{i}") for i in range(self.num_layers)
- ]
- # style truncation
- if truncation < 1:
- style_truncation = []
- for style in styles:
- style_truncation.append(
- truncation_latent + truncation * (style - truncation_latent)
- )
- styles = style_truncation
- # get style latents with injection
- if len(styles) == 1:
- inject_index = self.num_latent
-
- if styles[0].ndim < 3:
- # repeat latent code for all the layers
- latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
- else: # used for encoder with different latent code for each layer
- latent = styles[0]
- elif len(styles) == 2: # mixing noises
- if inject_index is None:
- inject_index = random.randint(1, self.num_latent - 1)
- latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
- latent2 = (
- styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
- )
- latent = torch.cat([latent1, latent2], 1)
-
- # main generation
- out = self.constant_input(latent.shape[0])
- out = self.style_conv1(out, latent[:, 0], noise=noise[0])
- skip = self.to_rgb1(out, latent[:, 1])
-
- i = 1
- for conv1, conv2, noise1, noise2, to_rgb in zip(
- self.style_convs[::2],
- self.style_convs[1::2],
- noise[1::2],
- noise[2::2],
- self.to_rgbs,
- ):
- out = conv1(out, latent[:, i], noise=noise1)
-
- # the conditions may have fewer levels
- if i < len(conditions):
- # SFT part to combine the conditions
- if self.sft_half: # only apply SFT to half of the channels
- out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
- out_sft = out_sft * conditions[i - 1] + conditions[i]
- out = torch.cat([out_same, out_sft], dim=1)
- else: # apply SFT to all the channels
- out = out * conditions[i - 1] + conditions[i]
-
- out = conv2(out, latent[:, i + 1], noise=noise2)
- skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
- i += 2
-
- image = skip
-
- if return_latents:
- return image, latent
- else:
- return image, None
-
-
-class ConvUpLayer(nn.Module):
- """Convolutional upsampling layer. It uses bilinear upsampler + Conv.
- Args:
- in_channels (int): Channel number of the input.
- out_channels (int): Channel number of the output.
- kernel_size (int): Size of the convolving kernel.
- stride (int): Stride of the convolution. Default: 1
- padding (int): Zero-padding added to both sides of the input. Default: 0.
- bias (bool): If ``True``, adds a learnable bias to the output. Default: ``True``.
- bias_init_val (float): Bias initialized value. Default: 0.
- activate (bool): Whether use activateion. Default: True.
- """
-
- def __init__(
- self,
- in_channels,
- out_channels,
- kernel_size,
- stride=1,
- padding=0,
- bias=True,
- bias_init_val=0,
- activate=True,
- ):
- super(ConvUpLayer, self).__init__()
- self.in_channels = in_channels
- self.out_channels = out_channels
- self.kernel_size = kernel_size
- self.stride = stride
- self.padding = padding
- # self.scale is used to scale the convolution weights, which is related to the common initializations.
- self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
-
- self.weight = nn.Parameter(
- torch.randn(out_channels, in_channels, kernel_size, kernel_size)
- )
-
- if bias and not activate:
- self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
- else:
- self.register_parameter("bias", None)
-
- # activation
- if activate:
- if bias:
- self.activation = FusedLeakyReLU(out_channels)
- else:
- self.activation = ScaledLeakyReLU(0.2)
- else:
- self.activation = None
-
- def forward(self, x):
- # bilinear upsample
- out = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=False)
- # conv
- out = F.conv2d(
- out,
- self.weight * self.scale,
- bias=self.bias,
- stride=self.stride,
- padding=self.padding,
- )
- # activation
- if self.activation is not None:
- out = self.activation(out)
- return out
-
-
-class ResUpBlock(nn.Module):
- """Residual block with upsampling.
- Args:
- in_channels (int): Channel number of the input.
- out_channels (int): Channel number of the output.
- """
-
- def __init__(self, in_channels, out_channels):
- super(ResUpBlock, self).__init__()
-
- self.conv1 = ConvLayer(in_channels, in_channels, 3, bias=True, activate=True)
- self.conv2 = ConvUpLayer(
- in_channels, out_channels, 3, stride=1, padding=1, bias=True, activate=True
- )
- self.skip = ConvUpLayer(
- in_channels, out_channels, 1, bias=False, activate=False
- )
-
- def forward(self, x):
- out = self.conv1(x)
- out = self.conv2(out)
- skip = self.skip(x)
- out = (out + skip) / math.sqrt(2)
- return out
-
-
-class GFPGANv1(nn.Module):
- """The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
- Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
- Args:
- out_size (int): The spatial size of outputs.
- num_style_feat (int): Channel number of style features. Default: 512.
- channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
- resample_kernel (list[int]): A list indicating the 1D resample kernel magnitude. A cross production will be
- applied to extent 1D resample kernel to 2D resample kernel. Default: (1, 3, 3, 1).
- decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
- fix_decoder (bool): Whether to fix the decoder. Default: True.
- num_mlp (int): Layer number of MLP style layers. Default: 8.
- lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
- input_is_latent (bool): Whether input is latent style. Default: False.
- different_w (bool): Whether to use different latent w for different layers. Default: False.
- narrow (float): The narrow ratio for channels. Default: 1.
- sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
- """
-
- def __init__(
- self,
- out_size,
- num_style_feat=512,
- channel_multiplier=1,
- resample_kernel=(1, 3, 3, 1),
- decoder_load_path=None,
- fix_decoder=True,
- # for stylegan decoder
- num_mlp=8,
- lr_mlp=0.01,
- input_is_latent=False,
- different_w=False,
- narrow=1,
- sft_half=False,
- ):
- super(GFPGANv1, self).__init__()
- self.input_is_latent = input_is_latent
- self.different_w = different_w
- self.num_style_feat = num_style_feat
-
- unet_narrow = narrow * 0.5 # by default, use a half of input channels
- channels = {
- "4": int(512 * unet_narrow),
- "8": int(512 * unet_narrow),
- "16": int(512 * unet_narrow),
- "32": int(512 * unet_narrow),
- "64": int(256 * channel_multiplier * unet_narrow),
- "128": int(128 * channel_multiplier * unet_narrow),
- "256": int(64 * channel_multiplier * unet_narrow),
- "512": int(32 * channel_multiplier * unet_narrow),
- "1024": int(16 * channel_multiplier * unet_narrow),
- }
-
- self.log_size = int(math.log(out_size, 2))
- first_out_size = 2 ** (int(math.log(out_size, 2)))
-
- self.conv_body_first = ConvLayer(
- 3, channels[f"{first_out_size}"], 1, bias=True, activate=True
- )
-
- # downsample
- in_channels = channels[f"{first_out_size}"]
- self.conv_body_down = nn.ModuleList()
- for i in range(self.log_size, 2, -1):
- out_channels = channels[f"{2**(i - 1)}"]
- self.conv_body_down.append(
- ResBlock(in_channels, out_channels, resample_kernel)
- )
- in_channels = out_channels
-
- self.final_conv = ConvLayer(
- in_channels, channels["4"], 3, bias=True, activate=True
- )
-
- # upsample
- in_channels = channels["4"]
- self.conv_body_up = nn.ModuleList()
- for i in range(3, self.log_size + 1):
- out_channels = channels[f"{2**i}"]
- self.conv_body_up.append(ResUpBlock(in_channels, out_channels))
- in_channels = out_channels
-
- # to RGB
- self.toRGB = nn.ModuleList()
- for i in range(3, self.log_size + 1):
- self.toRGB.append(
- EqualConv2d(
- channels[f"{2**i}"],
- 3,
- 1,
- stride=1,
- padding=0,
- bias=True,
- bias_init_val=0,
- )
- )
-
- if different_w:
- linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat
- else:
- linear_out_channel = num_style_feat
-
- self.final_linear = EqualLinear(
- channels["4"] * 4 * 4,
- linear_out_channel,
- bias=True,
- bias_init_val=0,
- lr_mul=1,
- activation=None,
- )
-
- # the decoder: stylegan2 generator with SFT modulations
- self.stylegan_decoder = StyleGAN2GeneratorSFT(
- out_size=out_size,
- num_style_feat=num_style_feat,
- num_mlp=num_mlp,
- channel_multiplier=channel_multiplier,
- resample_kernel=resample_kernel,
- lr_mlp=lr_mlp,
- narrow=narrow,
- sft_half=sft_half,
- )
-
- # load pre-trained stylegan2 model if necessary
- if decoder_load_path:
- self.stylegan_decoder.load_state_dict(
- torch.load(
- decoder_load_path, map_location=lambda storage, loc: storage
- )["params_ema"]
- )
- # fix decoder without updating params
- if fix_decoder:
- for _, param in self.stylegan_decoder.named_parameters():
- param.requires_grad = False
-
- # for SFT modulations (scale and shift)
- self.condition_scale = nn.ModuleList()
- self.condition_shift = nn.ModuleList()
- for i in range(3, self.log_size + 1):
- out_channels = channels[f"{2**i}"]
- if sft_half:
- sft_out_channels = out_channels
- else:
- sft_out_channels = out_channels * 2
- self.condition_scale.append(
- nn.Sequential(
- EqualConv2d(
- out_channels,
- out_channels,
- 3,
- stride=1,
- padding=1,
- bias=True,
- bias_init_val=0,
- ),
- ScaledLeakyReLU(0.2),
- EqualConv2d(
- out_channels,
- sft_out_channels,
- 3,
- stride=1,
- padding=1,
- bias=True,
- bias_init_val=1,
- ),
- )
- )
- self.condition_shift.append(
- nn.Sequential(
- EqualConv2d(
- out_channels,
- out_channels,
- 3,
- stride=1,
- padding=1,
- bias=True,
- bias_init_val=0,
- ),
- ScaledLeakyReLU(0.2),
- EqualConv2d(
- out_channels,
- sft_out_channels,
- 3,
- stride=1,
- padding=1,
- bias=True,
- bias_init_val=0,
- ),
- )
- )
-
- def forward(
- self, x, return_latents=False, return_rgb=True, randomize_noise=True, **kwargs
- ):
- """Forward function for GFPGANv1.
- Args:
- x (Tensor): Input images.
- return_latents (bool): Whether to return style latents. Default: False.
- return_rgb (bool): Whether return intermediate rgb images. Default: True.
- randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
- """
- conditions = []
- unet_skips = []
- out_rgbs = []
-
- # encoder
- feat = self.conv_body_first(x)
- for i in range(self.log_size - 2):
- feat = self.conv_body_down[i](feat)
- unet_skips.insert(0, feat)
-
- feat = self.final_conv(feat)
-
- # style code
- style_code = self.final_linear(feat.view(feat.size(0), -1))
- if self.different_w:
- style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
-
- # decode
- for i in range(self.log_size - 2):
- # add unet skip
- feat = feat + unet_skips[i]
- # ResUpLayer
- feat = self.conv_body_up[i](feat)
- # generate scale and shift for SFT layers
- scale = self.condition_scale[i](feat)
- conditions.append(scale.clone())
- shift = self.condition_shift[i](feat)
- conditions.append(shift.clone())
- # generate rgb images
- if return_rgb:
- out_rgbs.append(self.toRGB[i](feat))
-
- # decoder
- image, _ = self.stylegan_decoder(
- [style_code],
- conditions,
- return_latents=return_latents,
- input_is_latent=self.input_is_latent,
- randomize_noise=randomize_noise,
- )
-
- return image, out_rgbs
-
-
-class FacialComponentDiscriminator(nn.Module):
- """Facial component (eyes, mouth, noise) discriminator used in GFPGAN."""
-
- def __init__(self):
- super(FacialComponentDiscriminator, self).__init__()
- # It now uses a VGG-style architectrue with fixed model size
- self.conv1 = ConvLayer(
- 3,
- 64,
- 3,
- downsample=False,
- resample_kernel=(1, 3, 3, 1),
- bias=True,
- activate=True,
- )
- self.conv2 = ConvLayer(
- 64,
- 128,
- 3,
- downsample=True,
- resample_kernel=(1, 3, 3, 1),
- bias=True,
- activate=True,
- )
- self.conv3 = ConvLayer(
- 128,
- 128,
- 3,
- downsample=False,
- resample_kernel=(1, 3, 3, 1),
- bias=True,
- activate=True,
- )
- self.conv4 = ConvLayer(
- 128,
- 256,
- 3,
- downsample=True,
- resample_kernel=(1, 3, 3, 1),
- bias=True,
- activate=True,
- )
- self.conv5 = ConvLayer(
- 256,
- 256,
- 3,
- downsample=False,
- resample_kernel=(1, 3, 3, 1),
- bias=True,
- activate=True,
- )
- self.final_conv = ConvLayer(256, 1, 3, bias=True, activate=False)
-
- def forward(self, x, return_feats=False, **kwargs):
- """Forward function for FacialComponentDiscriminator.
- Args:
- x (Tensor): Input images.
- return_feats (bool): Whether to return intermediate features. Default: False.
- """
- feat = self.conv1(x)
- feat = self.conv3(self.conv2(feat))
- rlt_feats = []
- if return_feats:
- rlt_feats.append(feat.clone())
- feat = self.conv5(self.conv4(feat))
- if return_feats:
- rlt_feats.append(feat.clone())
- out = self.final_conv(feat)
-
- if return_feats:
- return out, rlt_feats
- else:
- return out, None
diff --git a/comfy_extras/chainner_models/architecture/face/gfpganv1_clean_arch.py b/comfy_extras/chainner_models/architecture/face/gfpganv1_clean_arch.py
deleted file mode 100644
index 16470d6345f..00000000000
--- a/comfy_extras/chainner_models/architecture/face/gfpganv1_clean_arch.py
+++ /dev/null
@@ -1,370 +0,0 @@
-# pylint: skip-file
-# type: ignore
-import math
-import random
-
-import torch
-from torch import nn
-from torch.nn import functional as F
-
-from .stylegan2_clean_arch import StyleGAN2GeneratorClean
-
-
-class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
- """StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
- It is the clean version without custom compiled CUDA extensions used in StyleGAN2.
- Args:
- out_size (int): The spatial size of outputs.
- num_style_feat (int): Channel number of style features. Default: 512.
- num_mlp (int): Layer number of MLP style layers. Default: 8.
- channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
- narrow (float): The narrow ratio for channels. Default: 1.
- sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
- """
-
- def __init__(
- self,
- out_size,
- num_style_feat=512,
- num_mlp=8,
- channel_multiplier=2,
- narrow=1,
- sft_half=False,
- ):
- super(StyleGAN2GeneratorCSFT, self).__init__(
- out_size,
- num_style_feat=num_style_feat,
- num_mlp=num_mlp,
- channel_multiplier=channel_multiplier,
- narrow=narrow,
- )
- self.sft_half = sft_half
-
- def forward(
- self,
- styles,
- conditions,
- input_is_latent=False,
- noise=None,
- randomize_noise=True,
- truncation=1,
- truncation_latent=None,
- inject_index=None,
- return_latents=False,
- ):
- """Forward function for StyleGAN2GeneratorCSFT.
- Args:
- styles (list[Tensor]): Sample codes of styles.
- conditions (list[Tensor]): SFT conditions to generators.
- input_is_latent (bool): Whether input is latent style. Default: False.
- noise (Tensor | None): Input noise or None. Default: None.
- randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
- truncation (float): The truncation ratio. Default: 1.
- truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
- inject_index (int | None): The injection index for mixing noise. Default: None.
- return_latents (bool): Whether to return style latents. Default: False.
- """
- # style codes -> latents with Style MLP layer
- if not input_is_latent:
- styles = [self.style_mlp(s) for s in styles]
- # noises
- if noise is None:
- if randomize_noise:
- noise = [None] * self.num_layers # for each style conv layer
- else: # use the stored noise
- noise = [
- getattr(self.noises, f"noise{i}") for i in range(self.num_layers)
- ]
- # style truncation
- if truncation < 1:
- style_truncation = []
- for style in styles:
- style_truncation.append(
- truncation_latent + truncation * (style - truncation_latent)
- )
- styles = style_truncation
- # get style latents with injection
- if len(styles) == 1:
- inject_index = self.num_latent
-
- if styles[0].ndim < 3:
- # repeat latent code for all the layers
- latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
- else: # used for encoder with different latent code for each layer
- latent = styles[0]
- elif len(styles) == 2: # mixing noises
- if inject_index is None:
- inject_index = random.randint(1, self.num_latent - 1)
- latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
- latent2 = (
- styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
- )
- latent = torch.cat([latent1, latent2], 1)
-
- # main generation
- out = self.constant_input(latent.shape[0])
- out = self.style_conv1(out, latent[:, 0], noise=noise[0])
- skip = self.to_rgb1(out, latent[:, 1])
-
- i = 1
- for conv1, conv2, noise1, noise2, to_rgb in zip(
- self.style_convs[::2],
- self.style_convs[1::2],
- noise[1::2],
- noise[2::2],
- self.to_rgbs,
- ):
- out = conv1(out, latent[:, i], noise=noise1)
-
- # the conditions may have fewer levels
- if i < len(conditions):
- # SFT part to combine the conditions
- if self.sft_half: # only apply SFT to half of the channels
- out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
- out_sft = out_sft * conditions[i - 1] + conditions[i]
- out = torch.cat([out_same, out_sft], dim=1)
- else: # apply SFT to all the channels
- out = out * conditions[i - 1] + conditions[i]
-
- out = conv2(out, latent[:, i + 1], noise=noise2)
- skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
- i += 2
-
- image = skip
-
- if return_latents:
- return image, latent
- else:
- return image, None
-
-
-class ResBlock(nn.Module):
- """Residual block with bilinear upsampling/downsampling.
- Args:
- in_channels (int): Channel number of the input.
- out_channels (int): Channel number of the output.
- mode (str): Upsampling/downsampling mode. Options: down | up. Default: down.
- """
-
- def __init__(self, in_channels, out_channels, mode="down"):
- super(ResBlock, self).__init__()
-
- self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, 1)
- self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
- self.skip = nn.Conv2d(in_channels, out_channels, 1, bias=False)
- if mode == "down":
- self.scale_factor = 0.5
- elif mode == "up":
- self.scale_factor = 2
-
- def forward(self, x):
- out = F.leaky_relu_(self.conv1(x), negative_slope=0.2)
- # upsample/downsample
- out = F.interpolate(
- out, scale_factor=self.scale_factor, mode="bilinear", align_corners=False
- )
- out = F.leaky_relu_(self.conv2(out), negative_slope=0.2)
- # skip
- x = F.interpolate(
- x, scale_factor=self.scale_factor, mode="bilinear", align_corners=False
- )
- skip = self.skip(x)
- out = out + skip
- return out
-
-
-class GFPGANv1Clean(nn.Module):
- """The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
- It is the clean version without custom compiled CUDA extensions used in StyleGAN2.
- Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
- Args:
- out_size (int): The spatial size of outputs.
- num_style_feat (int): Channel number of style features. Default: 512.
- channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
- decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
- fix_decoder (bool): Whether to fix the decoder. Default: True.
- num_mlp (int): Layer number of MLP style layers. Default: 8.
- input_is_latent (bool): Whether input is latent style. Default: False.
- different_w (bool): Whether to use different latent w for different layers. Default: False.
- narrow (float): The narrow ratio for channels. Default: 1.
- sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
- """
-
- def __init__(
- self,
- state_dict,
- ):
- super(GFPGANv1Clean, self).__init__()
-
- out_size = 512
- num_style_feat = 512
- channel_multiplier = 2
- decoder_load_path = None
- fix_decoder = False
- num_mlp = 8
- input_is_latent = True
- different_w = True
- narrow = 1
- sft_half = True
-
- self.model_arch = "GFPGAN"
- self.sub_type = "Face SR"
- self.scale = 8
- self.in_nc = 3
- self.out_nc = 3
- self.state = state_dict
-
- self.supports_fp16 = False
- self.supports_bf16 = True
- self.min_size_restriction = 512
-
- self.input_is_latent = input_is_latent
- self.different_w = different_w
- self.num_style_feat = num_style_feat
-
- unet_narrow = narrow * 0.5 # by default, use a half of input channels
- channels = {
- "4": int(512 * unet_narrow),
- "8": int(512 * unet_narrow),
- "16": int(512 * unet_narrow),
- "32": int(512 * unet_narrow),
- "64": int(256 * channel_multiplier * unet_narrow),
- "128": int(128 * channel_multiplier * unet_narrow),
- "256": int(64 * channel_multiplier * unet_narrow),
- "512": int(32 * channel_multiplier * unet_narrow),
- "1024": int(16 * channel_multiplier * unet_narrow),
- }
-
- self.log_size = int(math.log(out_size, 2))
- first_out_size = 2 ** (int(math.log(out_size, 2)))
-
- self.conv_body_first = nn.Conv2d(3, channels[f"{first_out_size}"], 1)
-
- # downsample
- in_channels = channels[f"{first_out_size}"]
- self.conv_body_down = nn.ModuleList()
- for i in range(self.log_size, 2, -1):
- out_channels = channels[f"{2**(i - 1)}"]
- self.conv_body_down.append(ResBlock(in_channels, out_channels, mode="down"))
- in_channels = out_channels
-
- self.final_conv = nn.Conv2d(in_channels, channels["4"], 3, 1, 1)
-
- # upsample
- in_channels = channels["4"]
- self.conv_body_up = nn.ModuleList()
- for i in range(3, self.log_size + 1):
- out_channels = channels[f"{2**i}"]
- self.conv_body_up.append(ResBlock(in_channels, out_channels, mode="up"))
- in_channels = out_channels
-
- # to RGB
- self.toRGB = nn.ModuleList()
- for i in range(3, self.log_size + 1):
- self.toRGB.append(nn.Conv2d(channels[f"{2**i}"], 3, 1))
-
- if different_w:
- linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat
- else:
- linear_out_channel = num_style_feat
-
- self.final_linear = nn.Linear(channels["4"] * 4 * 4, linear_out_channel)
-
- # the decoder: stylegan2 generator with SFT modulations
- self.stylegan_decoder = StyleGAN2GeneratorCSFT(
- out_size=out_size,
- num_style_feat=num_style_feat,
- num_mlp=num_mlp,
- channel_multiplier=channel_multiplier,
- narrow=narrow,
- sft_half=sft_half,
- )
-
- # load pre-trained stylegan2 model if necessary
- if decoder_load_path:
- self.stylegan_decoder.load_state_dict(
- torch.load(
- decoder_load_path, map_location=lambda storage, loc: storage
- )["params_ema"]
- )
- # fix decoder without updating params
- if fix_decoder:
- for _, param in self.stylegan_decoder.named_parameters():
- param.requires_grad = False
-
- # for SFT modulations (scale and shift)
- self.condition_scale = nn.ModuleList()
- self.condition_shift = nn.ModuleList()
- for i in range(3, self.log_size + 1):
- out_channels = channels[f"{2**i}"]
- if sft_half:
- sft_out_channels = out_channels
- else:
- sft_out_channels = out_channels * 2
- self.condition_scale.append(
- nn.Sequential(
- nn.Conv2d(out_channels, out_channels, 3, 1, 1),
- nn.LeakyReLU(0.2, True),
- nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1),
- )
- )
- self.condition_shift.append(
- nn.Sequential(
- nn.Conv2d(out_channels, out_channels, 3, 1, 1),
- nn.LeakyReLU(0.2, True),
- nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1),
- )
- )
- self.load_state_dict(state_dict)
-
- def forward(
- self, x, return_latents=False, return_rgb=True, randomize_noise=True, **kwargs
- ):
- """Forward function for GFPGANv1Clean.
- Args:
- x (Tensor): Input images.
- return_latents (bool): Whether to return style latents. Default: False.
- return_rgb (bool): Whether return intermediate rgb images. Default: True.
- randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
- """
- conditions = []
- unet_skips = []
- out_rgbs = []
-
- # encoder
- feat = F.leaky_relu_(self.conv_body_first(x), negative_slope=0.2)
- for i in range(self.log_size - 2):
- feat = self.conv_body_down[i](feat)
- unet_skips.insert(0, feat)
- feat = F.leaky_relu_(self.final_conv(feat), negative_slope=0.2)
-
- # style code
- style_code = self.final_linear(feat.view(feat.size(0), -1))
- if self.different_w:
- style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
-
- # decode
- for i in range(self.log_size - 2):
- # add unet skip
- feat = feat + unet_skips[i]
- # ResUpLayer
- feat = self.conv_body_up[i](feat)
- # generate scale and shift for SFT layers
- scale = self.condition_scale[i](feat)
- conditions.append(scale.clone())
- shift = self.condition_shift[i](feat)
- conditions.append(shift.clone())
- # generate rgb images
- if return_rgb:
- out_rgbs.append(self.toRGB[i](feat))
-
- # decoder
- image, _ = self.stylegan_decoder(
- [style_code],
- conditions,
- return_latents=return_latents,
- input_is_latent=self.input_is_latent,
- randomize_noise=randomize_noise,
- )
-
- return image, out_rgbs
diff --git a/comfy_extras/chainner_models/architecture/face/restoreformer_arch.py b/comfy_extras/chainner_models/architecture/face/restoreformer_arch.py
deleted file mode 100644
index 4492260291d..00000000000
--- a/comfy_extras/chainner_models/architecture/face/restoreformer_arch.py
+++ /dev/null
@@ -1,776 +0,0 @@
-# pylint: skip-file
-# type: ignore
-"""Modified from https://github.com/wzhouxiff/RestoreFormer
-"""
-import numpy as np
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-
-class VectorQuantizer(nn.Module):
- """
- see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
- ____________________________________________
- Discretization bottleneck part of the VQ-VAE.
- Inputs:
- - n_e : number of embeddings
- - e_dim : dimension of embedding
- - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
- _____________________________________________
- """
-
- def __init__(self, n_e, e_dim, beta):
- super(VectorQuantizer, self).__init__()
- self.n_e = n_e
- self.e_dim = e_dim
- self.beta = beta
-
- self.embedding = nn.Embedding(self.n_e, self.e_dim)
- self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
-
- def forward(self, z):
- """
- Inputs the output of the encoder network z and maps it to a discrete
- one-hot vector that is the index of the closest embedding vector e_j
- z (continuous) -> z_q (discrete)
- z.shape = (batch, channel, height, width)
- quantization pipeline:
- 1. get encoder input (B,C,H,W)
- 2. flatten input to (B*H*W,C)
- """
- # reshape z -> (batch, height, width, channel) and flatten
- z = z.permute(0, 2, 3, 1).contiguous()
- z_flattened = z.view(-1, self.e_dim)
- # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
-
- d = (
- torch.sum(z_flattened**2, dim=1, keepdim=True)
- + torch.sum(self.embedding.weight**2, dim=1)
- - 2 * torch.matmul(z_flattened, self.embedding.weight.t())
- )
-
- # could possible replace this here
- # #\start...
- # find closest encodings
-
- min_value, min_encoding_indices = torch.min(d, dim=1)
-
- min_encoding_indices = min_encoding_indices.unsqueeze(1)
-
- min_encodings = torch.zeros(min_encoding_indices.shape[0], self.n_e).to(z)
- min_encodings.scatter_(1, min_encoding_indices, 1)
-
- # dtype min encodings: torch.float32
- # min_encodings shape: torch.Size([2048, 512])
- # min_encoding_indices.shape: torch.Size([2048, 1])
-
- # get quantized latent vectors
- z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
- # .........\end
-
- # with:
- # .........\start
- # min_encoding_indices = torch.argmin(d, dim=1)
- # z_q = self.embedding(min_encoding_indices)
- # ......\end......... (TODO)
-
- # compute loss for embedding
- loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean(
- (z_q - z.detach()) ** 2
- )
-
- # preserve gradients
- z_q = z + (z_q - z).detach()
-
- # perplexity
-
- e_mean = torch.mean(min_encodings, dim=0)
- perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
-
- # reshape back to match original input shape
- z_q = z_q.permute(0, 3, 1, 2).contiguous()
-
- return z_q, loss, (perplexity, min_encodings, min_encoding_indices, d)
-
- def get_codebook_entry(self, indices, shape):
- # shape specifying (batch, height, width, channel)
- # TODO: check for more easy handling with nn.Embedding
- min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices)
- min_encodings.scatter_(1, indices[:, None], 1)
-
- # get quantized latent vectors
- z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
-
- if shape is not None:
- z_q = z_q.view(shape)
-
- # reshape back to match original input shape
- z_q = z_q.permute(0, 3, 1, 2).contiguous()
-
- return z_q
-
-
-# pytorch_diffusion + derived encoder decoder
-def nonlinearity(x):
- # swish
- return x * torch.sigmoid(x)
-
-
-def Normalize(in_channels):
- return torch.nn.GroupNorm(
- num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
- )
-
-
-class Upsample(nn.Module):
- def __init__(self, in_channels, with_conv):
- super().__init__()
- self.with_conv = with_conv
- if self.with_conv:
- self.conv = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=3, stride=1, padding=1
- )
-
- def forward(self, x):
- x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
- if self.with_conv:
- x = self.conv(x)
- return x
-
-
-class Downsample(nn.Module):
- def __init__(self, in_channels, with_conv):
- super().__init__()
- self.with_conv = with_conv
- if self.with_conv:
- # no asymmetric padding in torch conv, must do it ourselves
- self.conv = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=3, stride=2, padding=0
- )
-
- def forward(self, x):
- if self.with_conv:
- pad = (0, 1, 0, 1)
- x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
- x = self.conv(x)
- else:
- x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
- return x
-
-
-class ResnetBlock(nn.Module):
- def __init__(
- self,
- *,
- in_channels,
- out_channels=None,
- conv_shortcut=False,
- dropout,
- temb_channels=512
- ):
- super().__init__()
- self.in_channels = in_channels
- out_channels = in_channels if out_channels is None else out_channels
- self.out_channels = out_channels
- self.use_conv_shortcut = conv_shortcut
-
- self.norm1 = Normalize(in_channels)
- self.conv1 = torch.nn.Conv2d(
- in_channels, out_channels, kernel_size=3, stride=1, padding=1
- )
- if temb_channels > 0:
- self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
- self.norm2 = Normalize(out_channels)
- self.dropout = torch.nn.Dropout(dropout)
- self.conv2 = torch.nn.Conv2d(
- out_channels, out_channels, kernel_size=3, stride=1, padding=1
- )
- if self.in_channels != self.out_channels:
- if self.use_conv_shortcut:
- self.conv_shortcut = torch.nn.Conv2d(
- in_channels, out_channels, kernel_size=3, stride=1, padding=1
- )
- else:
- self.nin_shortcut = torch.nn.Conv2d(
- in_channels, out_channels, kernel_size=1, stride=1, padding=0
- )
-
- def forward(self, x, temb):
- h = x
- h = self.norm1(h)
- h = nonlinearity(h)
- h = self.conv1(h)
-
- if temb is not None:
- h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
-
- h = self.norm2(h)
- h = nonlinearity(h)
- h = self.dropout(h)
- h = self.conv2(h)
-
- if self.in_channels != self.out_channels:
- if self.use_conv_shortcut:
- x = self.conv_shortcut(x)
- else:
- x = self.nin_shortcut(x)
-
- return x + h
-
-
-class MultiHeadAttnBlock(nn.Module):
- def __init__(self, in_channels, head_size=1):
- super().__init__()
- self.in_channels = in_channels
- self.head_size = head_size
- self.att_size = in_channels // head_size
- assert (
- in_channels % head_size == 0
- ), "The size of head should be divided by the number of channels."
-
- self.norm1 = Normalize(in_channels)
- self.norm2 = Normalize(in_channels)
-
- self.q = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
- )
- self.k = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
- )
- self.v = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
- )
- self.proj_out = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
- )
- self.num = 0
-
- def forward(self, x, y=None):
- h_ = x
- h_ = self.norm1(h_)
- if y is None:
- y = h_
- else:
- y = self.norm2(y)
-
- q = self.q(y)
- k = self.k(h_)
- v = self.v(h_)
-
- # compute attention
- b, c, h, w = q.shape
- q = q.reshape(b, self.head_size, self.att_size, h * w)
- q = q.permute(0, 3, 1, 2) # b, hw, head, att
-
- k = k.reshape(b, self.head_size, self.att_size, h * w)
- k = k.permute(0, 3, 1, 2)
-
- v = v.reshape(b, self.head_size, self.att_size, h * w)
- v = v.permute(0, 3, 1, 2)
-
- q = q.transpose(1, 2)
- v = v.transpose(1, 2)
- k = k.transpose(1, 2).transpose(2, 3)
-
- scale = int(self.att_size) ** (-0.5)
- q.mul_(scale)
- w_ = torch.matmul(q, k)
- w_ = F.softmax(w_, dim=3)
-
- w_ = w_.matmul(v)
-
- w_ = w_.transpose(1, 2).contiguous() # [b, h*w, head, att]
- w_ = w_.view(b, h, w, -1)
- w_ = w_.permute(0, 3, 1, 2)
-
- w_ = self.proj_out(w_)
-
- return x + w_
-
-
-class MultiHeadEncoder(nn.Module):
- def __init__(
- self,
- ch,
- out_ch,
- ch_mult=(1, 2, 4, 8),
- num_res_blocks=2,
- attn_resolutions=(16,),
- dropout=0.0,
- resamp_with_conv=True,
- in_channels=3,
- resolution=512,
- z_channels=256,
- double_z=True,
- enable_mid=True,
- head_size=1,
- **ignore_kwargs
- ):
- super().__init__()
- self.ch = ch
- self.temb_ch = 0
- self.num_resolutions = len(ch_mult)
- self.num_res_blocks = num_res_blocks
- self.resolution = resolution
- self.in_channels = in_channels
- self.enable_mid = enable_mid
-
- # downsampling
- self.conv_in = torch.nn.Conv2d(
- in_channels, self.ch, kernel_size=3, stride=1, padding=1
- )
-
- curr_res = resolution
- in_ch_mult = (1,) + tuple(ch_mult)
- self.down = nn.ModuleList()
- for i_level in range(self.num_resolutions):
- block = nn.ModuleList()
- attn = nn.ModuleList()
- block_in = ch * in_ch_mult[i_level]
- block_out = ch * ch_mult[i_level]
- for i_block in range(self.num_res_blocks):
- block.append(
- ResnetBlock(
- in_channels=block_in,
- out_channels=block_out,
- temb_channels=self.temb_ch,
- dropout=dropout,
- )
- )
- block_in = block_out
- if curr_res in attn_resolutions:
- attn.append(MultiHeadAttnBlock(block_in, head_size))
- down = nn.Module()
- down.block = block
- down.attn = attn
- if i_level != self.num_resolutions - 1:
- down.downsample = Downsample(block_in, resamp_with_conv)
- curr_res = curr_res // 2
- self.down.append(down)
-
- # middle
- if self.enable_mid:
- self.mid = nn.Module()
- self.mid.block_1 = ResnetBlock(
- in_channels=block_in,
- out_channels=block_in,
- temb_channels=self.temb_ch,
- dropout=dropout,
- )
- self.mid.attn_1 = MultiHeadAttnBlock(block_in, head_size)
- self.mid.block_2 = ResnetBlock(
- in_channels=block_in,
- out_channels=block_in,
- temb_channels=self.temb_ch,
- dropout=dropout,
- )
-
- # end
- self.norm_out = Normalize(block_in)
- self.conv_out = torch.nn.Conv2d(
- block_in,
- 2 * z_channels if double_z else z_channels,
- kernel_size=3,
- stride=1,
- padding=1,
- )
-
- def forward(self, x):
- hs = {}
- # timestep embedding
- temb = None
-
- # downsampling
- h = self.conv_in(x)
- hs["in"] = h
- for i_level in range(self.num_resolutions):
- for i_block in range(self.num_res_blocks):
- h = self.down[i_level].block[i_block](h, temb)
- if len(self.down[i_level].attn) > 0:
- h = self.down[i_level].attn[i_block](h)
-
- if i_level != self.num_resolutions - 1:
- # hs.append(h)
- hs["block_" + str(i_level)] = h
- h = self.down[i_level].downsample(h)
-
- # middle
- # h = hs[-1]
- if self.enable_mid:
- h = self.mid.block_1(h, temb)
- hs["block_" + str(i_level) + "_atten"] = h
- h = self.mid.attn_1(h)
- h = self.mid.block_2(h, temb)
- hs["mid_atten"] = h
-
- # end
- h = self.norm_out(h)
- h = nonlinearity(h)
- h = self.conv_out(h)
- # hs.append(h)
- hs["out"] = h
-
- return hs
-
-
-class MultiHeadDecoder(nn.Module):
- def __init__(
- self,
- ch,
- out_ch,
- ch_mult=(1, 2, 4, 8),
- num_res_blocks=2,
- attn_resolutions=(16,),
- dropout=0.0,
- resamp_with_conv=True,
- in_channels=3,
- resolution=512,
- z_channels=256,
- give_pre_end=False,
- enable_mid=True,
- head_size=1,
- **ignorekwargs
- ):
- super().__init__()
- self.ch = ch
- self.temb_ch = 0
- self.num_resolutions = len(ch_mult)
- self.num_res_blocks = num_res_blocks
- self.resolution = resolution
- self.in_channels = in_channels
- self.give_pre_end = give_pre_end
- self.enable_mid = enable_mid
-
- # compute in_ch_mult, block_in and curr_res at lowest res
- block_in = ch * ch_mult[self.num_resolutions - 1]
- curr_res = resolution // 2 ** (self.num_resolutions - 1)
- self.z_shape = (1, z_channels, curr_res, curr_res)
- print(
- "Working with z of shape {} = {} dimensions.".format(
- self.z_shape, np.prod(self.z_shape)
- )
- )
-
- # z to block_in
- self.conv_in = torch.nn.Conv2d(
- z_channels, block_in, kernel_size=3, stride=1, padding=1
- )
-
- # middle
- if self.enable_mid:
- self.mid = nn.Module()
- self.mid.block_1 = ResnetBlock(
- in_channels=block_in,
- out_channels=block_in,
- temb_channels=self.temb_ch,
- dropout=dropout,
- )
- self.mid.attn_1 = MultiHeadAttnBlock(block_in, head_size)
- self.mid.block_2 = ResnetBlock(
- in_channels=block_in,
- out_channels=block_in,
- temb_channels=self.temb_ch,
- dropout=dropout,
- )
-
- # upsampling
- self.up = nn.ModuleList()
- for i_level in reversed(range(self.num_resolutions)):
- block = nn.ModuleList()
- attn = nn.ModuleList()
- block_out = ch * ch_mult[i_level]
- for i_block in range(self.num_res_blocks + 1):
- block.append(
- ResnetBlock(
- in_channels=block_in,
- out_channels=block_out,
- temb_channels=self.temb_ch,
- dropout=dropout,
- )
- )
- block_in = block_out
- if curr_res in attn_resolutions:
- attn.append(MultiHeadAttnBlock(block_in, head_size))
- up = nn.Module()
- up.block = block
- up.attn = attn
- if i_level != 0:
- up.upsample = Upsample(block_in, resamp_with_conv)
- curr_res = curr_res * 2
- self.up.insert(0, up) # prepend to get consistent order
-
- # end
- self.norm_out = Normalize(block_in)
- self.conv_out = torch.nn.Conv2d(
- block_in, out_ch, kernel_size=3, stride=1, padding=1
- )
-
- def forward(self, z):
- # assert z.shape[1:] == self.z_shape[1:]
- self.last_z_shape = z.shape
-
- # timestep embedding
- temb = None
-
- # z to block_in
- h = self.conv_in(z)
-
- # middle
- if self.enable_mid:
- h = self.mid.block_1(h, temb)
- h = self.mid.attn_1(h)
- h = self.mid.block_2(h, temb)
-
- # upsampling
- for i_level in reversed(range(self.num_resolutions)):
- for i_block in range(self.num_res_blocks + 1):
- h = self.up[i_level].block[i_block](h, temb)
- if len(self.up[i_level].attn) > 0:
- h = self.up[i_level].attn[i_block](h)
- if i_level != 0:
- h = self.up[i_level].upsample(h)
-
- # end
- if self.give_pre_end:
- return h
-
- h = self.norm_out(h)
- h = nonlinearity(h)
- h = self.conv_out(h)
- return h
-
-
-class MultiHeadDecoderTransformer(nn.Module):
- def __init__(
- self,
- ch,
- out_ch,
- ch_mult=(1, 2, 4, 8),
- num_res_blocks=2,
- attn_resolutions=(16,),
- dropout=0.0,
- resamp_with_conv=True,
- in_channels=3,
- resolution=512,
- z_channels=256,
- give_pre_end=False,
- enable_mid=True,
- head_size=1,
- **ignorekwargs
- ):
- super().__init__()
- self.ch = ch
- self.temb_ch = 0
- self.num_resolutions = len(ch_mult)
- self.num_res_blocks = num_res_blocks
- self.resolution = resolution
- self.in_channels = in_channels
- self.give_pre_end = give_pre_end
- self.enable_mid = enable_mid
-
- # compute in_ch_mult, block_in and curr_res at lowest res
- block_in = ch * ch_mult[self.num_resolutions - 1]
- curr_res = resolution // 2 ** (self.num_resolutions - 1)
- self.z_shape = (1, z_channels, curr_res, curr_res)
- print(
- "Working with z of shape {} = {} dimensions.".format(
- self.z_shape, np.prod(self.z_shape)
- )
- )
-
- # z to block_in
- self.conv_in = torch.nn.Conv2d(
- z_channels, block_in, kernel_size=3, stride=1, padding=1
- )
-
- # middle
- if self.enable_mid:
- self.mid = nn.Module()
- self.mid.block_1 = ResnetBlock(
- in_channels=block_in,
- out_channels=block_in,
- temb_channels=self.temb_ch,
- dropout=dropout,
- )
- self.mid.attn_1 = MultiHeadAttnBlock(block_in, head_size)
- self.mid.block_2 = ResnetBlock(
- in_channels=block_in,
- out_channels=block_in,
- temb_channels=self.temb_ch,
- dropout=dropout,
- )
-
- # upsampling
- self.up = nn.ModuleList()
- for i_level in reversed(range(self.num_resolutions)):
- block = nn.ModuleList()
- attn = nn.ModuleList()
- block_out = ch * ch_mult[i_level]
- for i_block in range(self.num_res_blocks + 1):
- block.append(
- ResnetBlock(
- in_channels=block_in,
- out_channels=block_out,
- temb_channels=self.temb_ch,
- dropout=dropout,
- )
- )
- block_in = block_out
- if curr_res in attn_resolutions:
- attn.append(MultiHeadAttnBlock(block_in, head_size))
- up = nn.Module()
- up.block = block
- up.attn = attn
- if i_level != 0:
- up.upsample = Upsample(block_in, resamp_with_conv)
- curr_res = curr_res * 2
- self.up.insert(0, up) # prepend to get consistent order
-
- # end
- self.norm_out = Normalize(block_in)
- self.conv_out = torch.nn.Conv2d(
- block_in, out_ch, kernel_size=3, stride=1, padding=1
- )
-
- def forward(self, z, hs):
- # assert z.shape[1:] == self.z_shape[1:]
- # self.last_z_shape = z.shape
-
- # timestep embedding
- temb = None
-
- # z to block_in
- h = self.conv_in(z)
-
- # middle
- if self.enable_mid:
- h = self.mid.block_1(h, temb)
- h = self.mid.attn_1(h, hs["mid_atten"])
- h = self.mid.block_2(h, temb)
-
- # upsampling
- for i_level in reversed(range(self.num_resolutions)):
- for i_block in range(self.num_res_blocks + 1):
- h = self.up[i_level].block[i_block](h, temb)
- if len(self.up[i_level].attn) > 0:
- h = self.up[i_level].attn[i_block](
- h, hs["block_" + str(i_level) + "_atten"]
- )
- # hfeature = h.clone()
- if i_level != 0:
- h = self.up[i_level].upsample(h)
-
- # end
- if self.give_pre_end:
- return h
-
- h = self.norm_out(h)
- h = nonlinearity(h)
- h = self.conv_out(h)
- return h
-
-
-class RestoreFormer(nn.Module):
- def __init__(
- self,
- state_dict,
- ):
- super(RestoreFormer, self).__init__()
-
- n_embed = 1024
- embed_dim = 256
- ch = 64
- out_ch = 3
- ch_mult = (1, 2, 2, 4, 4, 8)
- num_res_blocks = 2
- attn_resolutions = (16,)
- dropout = 0.0
- in_channels = 3
- resolution = 512
- z_channels = 256
- double_z = False
- enable_mid = True
- fix_decoder = False
- fix_codebook = True
- fix_encoder = False
- head_size = 8
-
- self.model_arch = "RestoreFormer"
- self.sub_type = "Face SR"
- self.scale = 8
- self.in_nc = 3
- self.out_nc = out_ch
- self.state = state_dict
-
- self.supports_fp16 = False
- self.supports_bf16 = True
- self.min_size_restriction = 16
-
- self.encoder = MultiHeadEncoder(
- ch=ch,
- out_ch=out_ch,
- ch_mult=ch_mult,
- num_res_blocks=num_res_blocks,
- attn_resolutions=attn_resolutions,
- dropout=dropout,
- in_channels=in_channels,
- resolution=resolution,
- z_channels=z_channels,
- double_z=double_z,
- enable_mid=enable_mid,
- head_size=head_size,
- )
- self.decoder = MultiHeadDecoderTransformer(
- ch=ch,
- out_ch=out_ch,
- ch_mult=ch_mult,
- num_res_blocks=num_res_blocks,
- attn_resolutions=attn_resolutions,
- dropout=dropout,
- in_channels=in_channels,
- resolution=resolution,
- z_channels=z_channels,
- enable_mid=enable_mid,
- head_size=head_size,
- )
-
- self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25)
-
- self.quant_conv = torch.nn.Conv2d(z_channels, embed_dim, 1)
- self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1)
-
- if fix_decoder:
- for _, param in self.decoder.named_parameters():
- param.requires_grad = False
- for _, param in self.post_quant_conv.named_parameters():
- param.requires_grad = False
- for _, param in self.quantize.named_parameters():
- param.requires_grad = False
- elif fix_codebook:
- for _, param in self.quantize.named_parameters():
- param.requires_grad = False
-
- if fix_encoder:
- for _, param in self.encoder.named_parameters():
- param.requires_grad = False
-
- self.load_state_dict(state_dict)
-
- def encode(self, x):
- hs = self.encoder(x)
- h = self.quant_conv(hs["out"])
- quant, emb_loss, info = self.quantize(h)
- return quant, emb_loss, info, hs
-
- def decode(self, quant, hs):
- quant = self.post_quant_conv(quant)
- dec = self.decoder(quant, hs)
-
- return dec
-
- def forward(self, input, **kwargs):
- quant, diff, info, hs = self.encode(input)
- dec = self.decode(quant, hs)
-
- return dec, None
diff --git a/comfy_extras/chainner_models/architecture/face/stylegan2_arch.py b/comfy_extras/chainner_models/architecture/face/stylegan2_arch.py
deleted file mode 100644
index 1eb0e9f15f7..00000000000
--- a/comfy_extras/chainner_models/architecture/face/stylegan2_arch.py
+++ /dev/null
@@ -1,865 +0,0 @@
-# pylint: skip-file
-# type: ignore
-import math
-import random
-
-import torch
-from torch import nn
-from torch.nn import functional as F
-
-from .fused_act import FusedLeakyReLU, fused_leaky_relu
-from .upfirdn2d import upfirdn2d
-
-
-class NormStyleCode(nn.Module):
- def forward(self, x):
- """Normalize the style codes.
-
- Args:
- x (Tensor): Style codes with shape (b, c).
-
- Returns:
- Tensor: Normalized tensor.
- """
- return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
-
-
-def make_resample_kernel(k):
- """Make resampling kernel for UpFirDn.
-
- Args:
- k (list[int]): A list indicating the 1D resample kernel magnitude.
-
- Returns:
- Tensor: 2D resampled kernel.
- """
- k = torch.tensor(k, dtype=torch.float32)
- if k.ndim == 1:
- k = k[None, :] * k[:, None] # to 2D kernel, outer product
- # normalize
- k /= k.sum()
- return k
-
-
-class UpFirDnUpsample(nn.Module):
- """Upsample, FIR filter, and downsample (upsampole version).
-
- References:
- 1. https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.upfirdn.html # noqa: E501
- 2. http://www.ece.northwestern.edu/local-apps/matlabhelp/toolbox/signal/upfirdn.html # noqa: E501
-
- Args:
- resample_kernel (list[int]): A list indicating the 1D resample kernel
- magnitude.
- factor (int): Upsampling scale factor. Default: 2.
- """
-
- def __init__(self, resample_kernel, factor=2):
- super(UpFirDnUpsample, self).__init__()
- self.kernel = make_resample_kernel(resample_kernel) * (factor**2)
- self.factor = factor
-
- pad = self.kernel.shape[0] - factor
- self.pad = ((pad + 1) // 2 + factor - 1, pad // 2)
-
- def forward(self, x):
- out = upfirdn2d(x, self.kernel.type_as(x), up=self.factor, down=1, pad=self.pad)
- return out
-
- def __repr__(self):
- return f"{self.__class__.__name__}(factor={self.factor})"
-
-
-class UpFirDnDownsample(nn.Module):
- """Upsample, FIR filter, and downsample (downsampole version).
-
- Args:
- resample_kernel (list[int]): A list indicating the 1D resample kernel
- magnitude.
- factor (int): Downsampling scale factor. Default: 2.
- """
-
- def __init__(self, resample_kernel, factor=2):
- super(UpFirDnDownsample, self).__init__()
- self.kernel = make_resample_kernel(resample_kernel)
- self.factor = factor
-
- pad = self.kernel.shape[0] - factor
- self.pad = ((pad + 1) // 2, pad // 2)
-
- def forward(self, x):
- out = upfirdn2d(x, self.kernel.type_as(x), up=1, down=self.factor, pad=self.pad)
- return out
-
- def __repr__(self):
- return f"{self.__class__.__name__}(factor={self.factor})"
-
-
-class UpFirDnSmooth(nn.Module):
- """Upsample, FIR filter, and downsample (smooth version).
-
- Args:
- resample_kernel (list[int]): A list indicating the 1D resample kernel
- magnitude.
- upsample_factor (int): Upsampling scale factor. Default: 1.
- downsample_factor (int): Downsampling scale factor. Default: 1.
- kernel_size (int): Kernel size: Default: 1.
- """
-
- def __init__(
- self, resample_kernel, upsample_factor=1, downsample_factor=1, kernel_size=1
- ):
- super(UpFirDnSmooth, self).__init__()
- self.upsample_factor = upsample_factor
- self.downsample_factor = downsample_factor
- self.kernel = make_resample_kernel(resample_kernel)
- if upsample_factor > 1:
- self.kernel = self.kernel * (upsample_factor**2)
-
- if upsample_factor > 1:
- pad = (self.kernel.shape[0] - upsample_factor) - (kernel_size - 1)
- self.pad = ((pad + 1) // 2 + upsample_factor - 1, pad // 2 + 1)
- elif downsample_factor > 1:
- pad = (self.kernel.shape[0] - downsample_factor) + (kernel_size - 1)
- self.pad = ((pad + 1) // 2, pad // 2)
- else:
- raise NotImplementedError
-
- def forward(self, x):
- out = upfirdn2d(x, self.kernel.type_as(x), up=1, down=1, pad=self.pad)
- return out
-
- def __repr__(self):
- return (
- f"{self.__class__.__name__}(upsample_factor={self.upsample_factor}"
- f", downsample_factor={self.downsample_factor})"
- )
-
-
-class EqualLinear(nn.Module):
- """Equalized Linear as StyleGAN2.
-
- Args:
- in_channels (int): Size of each sample.
- out_channels (int): Size of each output sample.
- bias (bool): If set to ``False``, the layer will not learn an additive
- bias. Default: ``True``.
- bias_init_val (float): Bias initialized value. Default: 0.
- lr_mul (float): Learning rate multiplier. Default: 1.
- activation (None | str): The activation after ``linear`` operation.
- Supported: 'fused_lrelu', None. Default: None.
- """
-
- def __init__(
- self,
- in_channels,
- out_channels,
- bias=True,
- bias_init_val=0,
- lr_mul=1,
- activation=None,
- ):
- super(EqualLinear, self).__init__()
- self.in_channels = in_channels
- self.out_channels = out_channels
- self.lr_mul = lr_mul
- self.activation = activation
- if self.activation not in ["fused_lrelu", None]:
- raise ValueError(
- f"Wrong activation value in EqualLinear: {activation}"
- "Supported ones are: ['fused_lrelu', None]."
- )
- self.scale = (1 / math.sqrt(in_channels)) * lr_mul
-
- self.weight = nn.Parameter(torch.randn(out_channels, in_channels).div_(lr_mul))
- if bias:
- self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
- else:
- self.register_parameter("bias", None)
-
- def forward(self, x):
- if self.bias is None:
- bias = None
- else:
- bias = self.bias * self.lr_mul
- if self.activation == "fused_lrelu":
- out = F.linear(x, self.weight * self.scale)
- out = fused_leaky_relu(out, bias)
- else:
- out = F.linear(x, self.weight * self.scale, bias=bias)
- return out
-
- def __repr__(self):
- return (
- f"{self.__class__.__name__}(in_channels={self.in_channels}, "
- f"out_channels={self.out_channels}, bias={self.bias is not None})"
- )
-
-
-class ModulatedConv2d(nn.Module):
- """Modulated Conv2d used in StyleGAN2.
-
- There is no bias in ModulatedConv2d.
-
- Args:
- in_channels (int): Channel number of the input.
- out_channels (int): Channel number of the output.
- kernel_size (int): Size of the convolving kernel.
- num_style_feat (int): Channel number of style features.
- demodulate (bool): Whether to demodulate in the conv layer.
- Default: True.
- sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
- Default: None.
- resample_kernel (list[int]): A list indicating the 1D resample kernel
- magnitude. Default: (1, 3, 3, 1).
- eps (float): A value added to the denominator for numerical stability.
- Default: 1e-8.
- """
-
- def __init__(
- self,
- in_channels,
- out_channels,
- kernel_size,
- num_style_feat,
- demodulate=True,
- sample_mode=None,
- resample_kernel=(1, 3, 3, 1),
- eps=1e-8,
- ):
- super(ModulatedConv2d, self).__init__()
- self.in_channels = in_channels
- self.out_channels = out_channels
- self.kernel_size = kernel_size
- self.demodulate = demodulate
- self.sample_mode = sample_mode
- self.eps = eps
-
- if self.sample_mode == "upsample":
- self.smooth = UpFirDnSmooth(
- resample_kernel,
- upsample_factor=2,
- downsample_factor=1,
- kernel_size=kernel_size,
- )
- elif self.sample_mode == "downsample":
- self.smooth = UpFirDnSmooth(
- resample_kernel,
- upsample_factor=1,
- downsample_factor=2,
- kernel_size=kernel_size,
- )
- elif self.sample_mode is None:
- pass
- else:
- raise ValueError(
- f"Wrong sample mode {self.sample_mode}, "
- "supported ones are ['upsample', 'downsample', None]."
- )
-
- self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
- # modulation inside each modulated conv
- self.modulation = EqualLinear(
- num_style_feat,
- in_channels,
- bias=True,
- bias_init_val=1,
- lr_mul=1,
- activation=None,
- )
-
- self.weight = nn.Parameter(
- torch.randn(1, out_channels, in_channels, kernel_size, kernel_size)
- )
- self.padding = kernel_size // 2
-
- def forward(self, x, style):
- """Forward function.
-
- Args:
- x (Tensor): Tensor with shape (b, c, h, w).
- style (Tensor): Tensor with shape (b, num_style_feat).
-
- Returns:
- Tensor: Modulated tensor after convolution.
- """
- b, c, h, w = x.shape # c = c_in
- # weight modulation
- style = self.modulation(style).view(b, 1, c, 1, 1)
- # self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
- weight = self.scale * self.weight * style # (b, c_out, c_in, k, k)
-
- if self.demodulate:
- demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
- weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
-
- weight = weight.view(
- b * self.out_channels, c, self.kernel_size, self.kernel_size
- )
-
- if self.sample_mode == "upsample":
- x = x.view(1, b * c, h, w)
- weight = weight.view(
- b, self.out_channels, c, self.kernel_size, self.kernel_size
- )
- weight = weight.transpose(1, 2).reshape(
- b * c, self.out_channels, self.kernel_size, self.kernel_size
- )
- out = F.conv_transpose2d(x, weight, padding=0, stride=2, groups=b)
- out = out.view(b, self.out_channels, *out.shape[2:4])
- out = self.smooth(out)
- elif self.sample_mode == "downsample":
- x = self.smooth(x)
- x = x.view(1, b * c, *x.shape[2:4])
- out = F.conv2d(x, weight, padding=0, stride=2, groups=b)
- out = out.view(b, self.out_channels, *out.shape[2:4])
- else:
- x = x.view(1, b * c, h, w)
- # weight: (b*c_out, c_in, k, k), groups=b
- out = F.conv2d(x, weight, padding=self.padding, groups=b)
- out = out.view(b, self.out_channels, *out.shape[2:4])
-
- return out
-
- def __repr__(self):
- return (
- f"{self.__class__.__name__}(in_channels={self.in_channels}, "
- f"out_channels={self.out_channels}, "
- f"kernel_size={self.kernel_size}, "
- f"demodulate={self.demodulate}, sample_mode={self.sample_mode})"
- )
-
-
-class StyleConv(nn.Module):
- """Style conv.
-
- Args:
- in_channels (int): Channel number of the input.
- out_channels (int): Channel number of the output.
- kernel_size (int): Size of the convolving kernel.
- num_style_feat (int): Channel number of style features.
- demodulate (bool): Whether demodulate in the conv layer. Default: True.
- sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
- Default: None.
- resample_kernel (list[int]): A list indicating the 1D resample kernel
- magnitude. Default: (1, 3, 3, 1).
- """
-
- def __init__(
- self,
- in_channels,
- out_channels,
- kernel_size,
- num_style_feat,
- demodulate=True,
- sample_mode=None,
- resample_kernel=(1, 3, 3, 1),
- ):
- super(StyleConv, self).__init__()
- self.modulated_conv = ModulatedConv2d(
- in_channels,
- out_channels,
- kernel_size,
- num_style_feat,
- demodulate=demodulate,
- sample_mode=sample_mode,
- resample_kernel=resample_kernel,
- )
- self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
- self.activate = FusedLeakyReLU(out_channels)
-
- def forward(self, x, style, noise=None):
- # modulate
- out = self.modulated_conv(x, style)
- # noise injection
- if noise is None:
- b, _, h, w = out.shape
- noise = out.new_empty(b, 1, h, w).normal_()
- out = out + self.weight * noise
- # activation (with bias)
- out = self.activate(out)
- return out
-
-
-class ToRGB(nn.Module):
- """To RGB from features.
-
- Args:
- in_channels (int): Channel number of input.
- num_style_feat (int): Channel number of style features.
- upsample (bool): Whether to upsample. Default: True.
- resample_kernel (list[int]): A list indicating the 1D resample kernel
- magnitude. Default: (1, 3, 3, 1).
- """
-
- def __init__(
- self, in_channels, num_style_feat, upsample=True, resample_kernel=(1, 3, 3, 1)
- ):
- super(ToRGB, self).__init__()
- if upsample:
- self.upsample = UpFirDnUpsample(resample_kernel, factor=2)
- else:
- self.upsample = None
- self.modulated_conv = ModulatedConv2d(
- in_channels,
- 3,
- kernel_size=1,
- num_style_feat=num_style_feat,
- demodulate=False,
- sample_mode=None,
- )
- self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
-
- def forward(self, x, style, skip=None):
- """Forward function.
-
- Args:
- x (Tensor): Feature tensor with shape (b, c, h, w).
- style (Tensor): Tensor with shape (b, num_style_feat).
- skip (Tensor): Base/skip tensor. Default: None.
-
- Returns:
- Tensor: RGB images.
- """
- out = self.modulated_conv(x, style)
- out = out + self.bias
- if skip is not None:
- if self.upsample:
- skip = self.upsample(skip)
- out = out + skip
- return out
-
-
-class ConstantInput(nn.Module):
- """Constant input.
-
- Args:
- num_channel (int): Channel number of constant input.
- size (int): Spatial size of constant input.
- """
-
- def __init__(self, num_channel, size):
- super(ConstantInput, self).__init__()
- self.weight = nn.Parameter(torch.randn(1, num_channel, size, size))
-
- def forward(self, batch):
- out = self.weight.repeat(batch, 1, 1, 1)
- return out
-
-
-class StyleGAN2Generator(nn.Module):
- """StyleGAN2 Generator.
-
- Args:
- out_size (int): The spatial size of outputs.
- num_style_feat (int): Channel number of style features. Default: 512.
- num_mlp (int): Layer number of MLP style layers. Default: 8.
- channel_multiplier (int): Channel multiplier for large networks of
- StyleGAN2. Default: 2.
- resample_kernel (list[int]): A list indicating the 1D resample kernel
- magnitude. A cross production will be applied to extent 1D resample
- kernel to 2D resample kernel. Default: (1, 3, 3, 1).
- lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
- narrow (float): Narrow ratio for channels. Default: 1.0.
- """
-
- def __init__(
- self,
- out_size,
- num_style_feat=512,
- num_mlp=8,
- channel_multiplier=2,
- resample_kernel=(1, 3, 3, 1),
- lr_mlp=0.01,
- narrow=1,
- ):
- super(StyleGAN2Generator, self).__init__()
- # Style MLP layers
- self.num_style_feat = num_style_feat
- style_mlp_layers = [NormStyleCode()]
- for i in range(num_mlp):
- style_mlp_layers.append(
- EqualLinear(
- num_style_feat,
- num_style_feat,
- bias=True,
- bias_init_val=0,
- lr_mul=lr_mlp,
- activation="fused_lrelu",
- )
- )
- self.style_mlp = nn.Sequential(*style_mlp_layers)
-
- channels = {
- "4": int(512 * narrow),
- "8": int(512 * narrow),
- "16": int(512 * narrow),
- "32": int(512 * narrow),
- "64": int(256 * channel_multiplier * narrow),
- "128": int(128 * channel_multiplier * narrow),
- "256": int(64 * channel_multiplier * narrow),
- "512": int(32 * channel_multiplier * narrow),
- "1024": int(16 * channel_multiplier * narrow),
- }
- self.channels = channels
-
- self.constant_input = ConstantInput(channels["4"], size=4)
- self.style_conv1 = StyleConv(
- channels["4"],
- channels["4"],
- kernel_size=3,
- num_style_feat=num_style_feat,
- demodulate=True,
- sample_mode=None,
- resample_kernel=resample_kernel,
- )
- self.to_rgb1 = ToRGB(
- channels["4"],
- num_style_feat,
- upsample=False,
- resample_kernel=resample_kernel,
- )
-
- self.log_size = int(math.log(out_size, 2))
- self.num_layers = (self.log_size - 2) * 2 + 1
- self.num_latent = self.log_size * 2 - 2
-
- self.style_convs = nn.ModuleList()
- self.to_rgbs = nn.ModuleList()
- self.noises = nn.Module()
-
- in_channels = channels["4"]
- # noise
- for layer_idx in range(self.num_layers):
- resolution = 2 ** ((layer_idx + 5) // 2)
- shape = [1, 1, resolution, resolution]
- self.noises.register_buffer(f"noise{layer_idx}", torch.randn(*shape))
- # style convs and to_rgbs
- for i in range(3, self.log_size + 1):
- out_channels = channels[f"{2**i}"]
- self.style_convs.append(
- StyleConv(
- in_channels,
- out_channels,
- kernel_size=3,
- num_style_feat=num_style_feat,
- demodulate=True,
- sample_mode="upsample",
- resample_kernel=resample_kernel,
- )
- )
- self.style_convs.append(
- StyleConv(
- out_channels,
- out_channels,
- kernel_size=3,
- num_style_feat=num_style_feat,
- demodulate=True,
- sample_mode=None,
- resample_kernel=resample_kernel,
- )
- )
- self.to_rgbs.append(
- ToRGB(
- out_channels,
- num_style_feat,
- upsample=True,
- resample_kernel=resample_kernel,
- )
- )
- in_channels = out_channels
-
- def make_noise(self):
- """Make noise for noise injection."""
- device = self.constant_input.weight.device
- noises = [torch.randn(1, 1, 4, 4, device=device)]
-
- for i in range(3, self.log_size + 1):
- for _ in range(2):
- noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
-
- return noises
-
- def get_latent(self, x):
- return self.style_mlp(x)
-
- def mean_latent(self, num_latent):
- latent_in = torch.randn(
- num_latent, self.num_style_feat, device=self.constant_input.weight.device
- )
- latent = self.style_mlp(latent_in).mean(0, keepdim=True)
- return latent
-
- def forward(
- self,
- styles,
- input_is_latent=False,
- noise=None,
- randomize_noise=True,
- truncation=1,
- truncation_latent=None,
- inject_index=None,
- return_latents=False,
- ):
- """Forward function for StyleGAN2Generator.
-
- Args:
- styles (list[Tensor]): Sample codes of styles.
- input_is_latent (bool): Whether input is latent style.
- Default: False.
- noise (Tensor | None): Input noise or None. Default: None.
- randomize_noise (bool): Randomize noise, used when 'noise' is
- False. Default: True.
- truncation (float): TODO. Default: 1.
- truncation_latent (Tensor | None): TODO. Default: None.
- inject_index (int | None): The injection index for mixing noise.
- Default: None.
- return_latents (bool): Whether to return style latents.
- Default: False.
- """
- # style codes -> latents with Style MLP layer
- if not input_is_latent:
- styles = [self.style_mlp(s) for s in styles]
- # noises
- if noise is None:
- if randomize_noise:
- noise = [None] * self.num_layers # for each style conv layer
- else: # use the stored noise
- noise = [
- getattr(self.noises, f"noise{i}") for i in range(self.num_layers)
- ]
- # style truncation
- if truncation < 1:
- style_truncation = []
- for style in styles:
- style_truncation.append(
- truncation_latent + truncation * (style - truncation_latent)
- )
- styles = style_truncation
- # get style latent with injection
- if len(styles) == 1:
- inject_index = self.num_latent
-
- if styles[0].ndim < 3:
- # repeat latent code for all the layers
- latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
- else: # used for encoder with different latent code for each layer
- latent = styles[0]
- elif len(styles) == 2: # mixing noises
- if inject_index is None:
- inject_index = random.randint(1, self.num_latent - 1)
- latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
- latent2 = (
- styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
- )
- latent = torch.cat([latent1, latent2], 1)
-
- # main generation
- out = self.constant_input(latent.shape[0])
- out = self.style_conv1(out, latent[:, 0], noise=noise[0])
- skip = self.to_rgb1(out, latent[:, 1])
-
- i = 1
- for conv1, conv2, noise1, noise2, to_rgb in zip(
- self.style_convs[::2],
- self.style_convs[1::2],
- noise[1::2],
- noise[2::2],
- self.to_rgbs,
- ):
- out = conv1(out, latent[:, i], noise=noise1)
- out = conv2(out, latent[:, i + 1], noise=noise2)
- skip = to_rgb(out, latent[:, i + 2], skip)
- i += 2
-
- image = skip
-
- if return_latents:
- return image, latent
- else:
- return image, None
-
-
-class ScaledLeakyReLU(nn.Module):
- """Scaled LeakyReLU.
-
- Args:
- negative_slope (float): Negative slope. Default: 0.2.
- """
-
- def __init__(self, negative_slope=0.2):
- super(ScaledLeakyReLU, self).__init__()
- self.negative_slope = negative_slope
-
- def forward(self, x):
- out = F.leaky_relu(x, negative_slope=self.negative_slope)
- return out * math.sqrt(2)
-
-
-class EqualConv2d(nn.Module):
- """Equalized Linear as StyleGAN2.
-
- Args:
- in_channels (int): Channel number of the input.
- out_channels (int): Channel number of the output.
- kernel_size (int): Size of the convolving kernel.
- stride (int): Stride of the convolution. Default: 1
- padding (int): Zero-padding added to both sides of the input.
- Default: 0.
- bias (bool): If ``True``, adds a learnable bias to the output.
- Default: ``True``.
- bias_init_val (float): Bias initialized value. Default: 0.
- """
-
- def __init__(
- self,
- in_channels,
- out_channels,
- kernel_size,
- stride=1,
- padding=0,
- bias=True,
- bias_init_val=0,
- ):
- super(EqualConv2d, self).__init__()
- self.in_channels = in_channels
- self.out_channels = out_channels
- self.kernel_size = kernel_size
- self.stride = stride
- self.padding = padding
- self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
-
- self.weight = nn.Parameter(
- torch.randn(out_channels, in_channels, kernel_size, kernel_size)
- )
- if bias:
- self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
- else:
- self.register_parameter("bias", None)
-
- def forward(self, x):
- out = F.conv2d(
- x,
- self.weight * self.scale,
- bias=self.bias,
- stride=self.stride,
- padding=self.padding,
- )
-
- return out
-
- def __repr__(self):
- return (
- f"{self.__class__.__name__}(in_channels={self.in_channels}, "
- f"out_channels={self.out_channels}, "
- f"kernel_size={self.kernel_size},"
- f" stride={self.stride}, padding={self.padding}, "
- f"bias={self.bias is not None})"
- )
-
-
-class ConvLayer(nn.Sequential):
- """Conv Layer used in StyleGAN2 Discriminator.
-
- Args:
- in_channels (int): Channel number of the input.
- out_channels (int): Channel number of the output.
- kernel_size (int): Kernel size.
- downsample (bool): Whether downsample by a factor of 2.
- Default: False.
- resample_kernel (list[int]): A list indicating the 1D resample
- kernel magnitude. A cross production will be applied to
- extent 1D resample kernel to 2D resample kernel.
- Default: (1, 3, 3, 1).
- bias (bool): Whether with bias. Default: True.
- activate (bool): Whether use activateion. Default: True.
- """
-
- def __init__(
- self,
- in_channels,
- out_channels,
- kernel_size,
- downsample=False,
- resample_kernel=(1, 3, 3, 1),
- bias=True,
- activate=True,
- ):
- layers = []
- # downsample
- if downsample:
- layers.append(
- UpFirDnSmooth(
- resample_kernel,
- upsample_factor=1,
- downsample_factor=2,
- kernel_size=kernel_size,
- )
- )
- stride = 2
- self.padding = 0
- else:
- stride = 1
- self.padding = kernel_size // 2
- # conv
- layers.append(
- EqualConv2d(
- in_channels,
- out_channels,
- kernel_size,
- stride=stride,
- padding=self.padding,
- bias=bias and not activate,
- )
- )
- # activation
- if activate:
- if bias:
- layers.append(FusedLeakyReLU(out_channels))
- else:
- layers.append(ScaledLeakyReLU(0.2))
-
- super(ConvLayer, self).__init__(*layers)
-
-
-class ResBlock(nn.Module):
- """Residual block used in StyleGAN2 Discriminator.
-
- Args:
- in_channels (int): Channel number of the input.
- out_channels (int): Channel number of the output.
- resample_kernel (list[int]): A list indicating the 1D resample
- kernel magnitude. A cross production will be applied to
- extent 1D resample kernel to 2D resample kernel.
- Default: (1, 3, 3, 1).
- """
-
- def __init__(self, in_channels, out_channels, resample_kernel=(1, 3, 3, 1)):
- super(ResBlock, self).__init__()
-
- self.conv1 = ConvLayer(in_channels, in_channels, 3, bias=True, activate=True)
- self.conv2 = ConvLayer(
- in_channels,
- out_channels,
- 3,
- downsample=True,
- resample_kernel=resample_kernel,
- bias=True,
- activate=True,
- )
- self.skip = ConvLayer(
- in_channels,
- out_channels,
- 1,
- downsample=True,
- resample_kernel=resample_kernel,
- bias=False,
- activate=False,
- )
-
- def forward(self, x):
- out = self.conv1(x)
- out = self.conv2(out)
- skip = self.skip(x)
- out = (out + skip) / math.sqrt(2)
- return out
diff --git a/comfy_extras/chainner_models/architecture/face/stylegan2_bilinear_arch.py b/comfy_extras/chainner_models/architecture/face/stylegan2_bilinear_arch.py
deleted file mode 100644
index 601f8cc4b33..00000000000
--- a/comfy_extras/chainner_models/architecture/face/stylegan2_bilinear_arch.py
+++ /dev/null
@@ -1,709 +0,0 @@
-# pylint: skip-file
-# type: ignore
-import math
-import random
-
-import torch
-from torch import nn
-from torch.nn import functional as F
-
-from .fused_act import FusedLeakyReLU, fused_leaky_relu
-
-
-class NormStyleCode(nn.Module):
- def forward(self, x):
- """Normalize the style codes.
- Args:
- x (Tensor): Style codes with shape (b, c).
- Returns:
- Tensor: Normalized tensor.
- """
- return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
-
-
-class EqualLinear(nn.Module):
- """Equalized Linear as StyleGAN2.
- Args:
- in_channels (int): Size of each sample.
- out_channels (int): Size of each output sample.
- bias (bool): If set to ``False``, the layer will not learn an additive
- bias. Default: ``True``.
- bias_init_val (float): Bias initialized value. Default: 0.
- lr_mul (float): Learning rate multiplier. Default: 1.
- activation (None | str): The activation after ``linear`` operation.
- Supported: 'fused_lrelu', None. Default: None.
- """
-
- def __init__(
- self,
- in_channels,
- out_channels,
- bias=True,
- bias_init_val=0,
- lr_mul=1,
- activation=None,
- ):
- super(EqualLinear, self).__init__()
- self.in_channels = in_channels
- self.out_channels = out_channels
- self.lr_mul = lr_mul
- self.activation = activation
- if self.activation not in ["fused_lrelu", None]:
- raise ValueError(
- f"Wrong activation value in EqualLinear: {activation}"
- "Supported ones are: ['fused_lrelu', None]."
- )
- self.scale = (1 / math.sqrt(in_channels)) * lr_mul
-
- self.weight = nn.Parameter(torch.randn(out_channels, in_channels).div_(lr_mul))
- if bias:
- self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
- else:
- self.register_parameter("bias", None)
-
- def forward(self, x):
- if self.bias is None:
- bias = None
- else:
- bias = self.bias * self.lr_mul
- if self.activation == "fused_lrelu":
- out = F.linear(x, self.weight * self.scale)
- out = fused_leaky_relu(out, bias)
- else:
- out = F.linear(x, self.weight * self.scale, bias=bias)
- return out
-
- def __repr__(self):
- return (
- f"{self.__class__.__name__}(in_channels={self.in_channels}, "
- f"out_channels={self.out_channels}, bias={self.bias is not None})"
- )
-
-
-class ModulatedConv2d(nn.Module):
- """Modulated Conv2d used in StyleGAN2.
- There is no bias in ModulatedConv2d.
- Args:
- in_channels (int): Channel number of the input.
- out_channels (int): Channel number of the output.
- kernel_size (int): Size of the convolving kernel.
- num_style_feat (int): Channel number of style features.
- demodulate (bool): Whether to demodulate in the conv layer.
- Default: True.
- sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
- Default: None.
- eps (float): A value added to the denominator for numerical stability.
- Default: 1e-8.
- """
-
- def __init__(
- self,
- in_channels,
- out_channels,
- kernel_size,
- num_style_feat,
- demodulate=True,
- sample_mode=None,
- eps=1e-8,
- interpolation_mode="bilinear",
- ):
- super(ModulatedConv2d, self).__init__()
- self.in_channels = in_channels
- self.out_channels = out_channels
- self.kernel_size = kernel_size
- self.demodulate = demodulate
- self.sample_mode = sample_mode
- self.eps = eps
- self.interpolation_mode = interpolation_mode
- if self.interpolation_mode == "nearest":
- self.align_corners = None
- else:
- self.align_corners = False
-
- self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
- # modulation inside each modulated conv
- self.modulation = EqualLinear(
- num_style_feat,
- in_channels,
- bias=True,
- bias_init_val=1,
- lr_mul=1,
- activation=None,
- )
-
- self.weight = nn.Parameter(
- torch.randn(1, out_channels, in_channels, kernel_size, kernel_size)
- )
- self.padding = kernel_size // 2
-
- def forward(self, x, style):
- """Forward function.
- Args:
- x (Tensor): Tensor with shape (b, c, h, w).
- style (Tensor): Tensor with shape (b, num_style_feat).
- Returns:
- Tensor: Modulated tensor after convolution.
- """
- b, c, h, w = x.shape # c = c_in
- # weight modulation
- style = self.modulation(style).view(b, 1, c, 1, 1)
- # self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
- weight = self.scale * self.weight * style # (b, c_out, c_in, k, k)
-
- if self.demodulate:
- demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
- weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
-
- weight = weight.view(
- b * self.out_channels, c, self.kernel_size, self.kernel_size
- )
-
- if self.sample_mode == "upsample":
- x = F.interpolate(
- x,
- scale_factor=2,
- mode=self.interpolation_mode,
- align_corners=self.align_corners,
- )
- elif self.sample_mode == "downsample":
- x = F.interpolate(
- x,
- scale_factor=0.5,
- mode=self.interpolation_mode,
- align_corners=self.align_corners,
- )
-
- b, c, h, w = x.shape
- x = x.view(1, b * c, h, w)
- # weight: (b*c_out, c_in, k, k), groups=b
- out = F.conv2d(x, weight, padding=self.padding, groups=b)
- out = out.view(b, self.out_channels, *out.shape[2:4])
-
- return out
-
- def __repr__(self):
- return (
- f"{self.__class__.__name__}(in_channels={self.in_channels}, "
- f"out_channels={self.out_channels}, "
- f"kernel_size={self.kernel_size}, "
- f"demodulate={self.demodulate}, sample_mode={self.sample_mode})"
- )
-
-
-class StyleConv(nn.Module):
- """Style conv.
- Args:
- in_channels (int): Channel number of the input.
- out_channels (int): Channel number of the output.
- kernel_size (int): Size of the convolving kernel.
- num_style_feat (int): Channel number of style features.
- demodulate (bool): Whether demodulate in the conv layer. Default: True.
- sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
- Default: None.
- """
-
- def __init__(
- self,
- in_channels,
- out_channels,
- kernel_size,
- num_style_feat,
- demodulate=True,
- sample_mode=None,
- interpolation_mode="bilinear",
- ):
- super(StyleConv, self).__init__()
- self.modulated_conv = ModulatedConv2d(
- in_channels,
- out_channels,
- kernel_size,
- num_style_feat,
- demodulate=demodulate,
- sample_mode=sample_mode,
- interpolation_mode=interpolation_mode,
- )
- self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
- self.activate = FusedLeakyReLU(out_channels)
-
- def forward(self, x, style, noise=None):
- # modulate
- out = self.modulated_conv(x, style)
- # noise injection
- if noise is None:
- b, _, h, w = out.shape
- noise = out.new_empty(b, 1, h, w).normal_()
- out = out + self.weight * noise
- # activation (with bias)
- out = self.activate(out)
- return out
-
-
-class ToRGB(nn.Module):
- """To RGB from features.
- Args:
- in_channels (int): Channel number of input.
- num_style_feat (int): Channel number of style features.
- upsample (bool): Whether to upsample. Default: True.
- """
-
- def __init__(
- self, in_channels, num_style_feat, upsample=True, interpolation_mode="bilinear"
- ):
- super(ToRGB, self).__init__()
- self.upsample = upsample
- self.interpolation_mode = interpolation_mode
- if self.interpolation_mode == "nearest":
- self.align_corners = None
- else:
- self.align_corners = False
- self.modulated_conv = ModulatedConv2d(
- in_channels,
- 3,
- kernel_size=1,
- num_style_feat=num_style_feat,
- demodulate=False,
- sample_mode=None,
- interpolation_mode=interpolation_mode,
- )
- self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
-
- def forward(self, x, style, skip=None):
- """Forward function.
- Args:
- x (Tensor): Feature tensor with shape (b, c, h, w).
- style (Tensor): Tensor with shape (b, num_style_feat).
- skip (Tensor): Base/skip tensor. Default: None.
- Returns:
- Tensor: RGB images.
- """
- out = self.modulated_conv(x, style)
- out = out + self.bias
- if skip is not None:
- if self.upsample:
- skip = F.interpolate(
- skip,
- scale_factor=2,
- mode=self.interpolation_mode,
- align_corners=self.align_corners,
- )
- out = out + skip
- return out
-
-
-class ConstantInput(nn.Module):
- """Constant input.
- Args:
- num_channel (int): Channel number of constant input.
- size (int): Spatial size of constant input.
- """
-
- def __init__(self, num_channel, size):
- super(ConstantInput, self).__init__()
- self.weight = nn.Parameter(torch.randn(1, num_channel, size, size))
-
- def forward(self, batch):
- out = self.weight.repeat(batch, 1, 1, 1)
- return out
-
-
-class StyleGAN2GeneratorBilinear(nn.Module):
- """StyleGAN2 Generator.
- Args:
- out_size (int): The spatial size of outputs.
- num_style_feat (int): Channel number of style features. Default: 512.
- num_mlp (int): Layer number of MLP style layers. Default: 8.
- channel_multiplier (int): Channel multiplier for large networks of
- StyleGAN2. Default: 2.
- lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
- narrow (float): Narrow ratio for channels. Default: 1.0.
- """
-
- def __init__(
- self,
- out_size,
- num_style_feat=512,
- num_mlp=8,
- channel_multiplier=2,
- lr_mlp=0.01,
- narrow=1,
- interpolation_mode="bilinear",
- ):
- super(StyleGAN2GeneratorBilinear, self).__init__()
- # Style MLP layers
- self.num_style_feat = num_style_feat
- style_mlp_layers = [NormStyleCode()]
- for i in range(num_mlp):
- style_mlp_layers.append(
- EqualLinear(
- num_style_feat,
- num_style_feat,
- bias=True,
- bias_init_val=0,
- lr_mul=lr_mlp,
- activation="fused_lrelu",
- )
- )
- self.style_mlp = nn.Sequential(*style_mlp_layers)
-
- channels = {
- "4": int(512 * narrow),
- "8": int(512 * narrow),
- "16": int(512 * narrow),
- "32": int(512 * narrow),
- "64": int(256 * channel_multiplier * narrow),
- "128": int(128 * channel_multiplier * narrow),
- "256": int(64 * channel_multiplier * narrow),
- "512": int(32 * channel_multiplier * narrow),
- "1024": int(16 * channel_multiplier * narrow),
- }
- self.channels = channels
-
- self.constant_input = ConstantInput(channels["4"], size=4)
- self.style_conv1 = StyleConv(
- channels["4"],
- channels["4"],
- kernel_size=3,
- num_style_feat=num_style_feat,
- demodulate=True,
- sample_mode=None,
- interpolation_mode=interpolation_mode,
- )
- self.to_rgb1 = ToRGB(
- channels["4"],
- num_style_feat,
- upsample=False,
- interpolation_mode=interpolation_mode,
- )
-
- self.log_size = int(math.log(out_size, 2))
- self.num_layers = (self.log_size - 2) * 2 + 1
- self.num_latent = self.log_size * 2 - 2
-
- self.style_convs = nn.ModuleList()
- self.to_rgbs = nn.ModuleList()
- self.noises = nn.Module()
-
- in_channels = channels["4"]
- # noise
- for layer_idx in range(self.num_layers):
- resolution = 2 ** ((layer_idx + 5) // 2)
- shape = [1, 1, resolution, resolution]
- self.noises.register_buffer(f"noise{layer_idx}", torch.randn(*shape))
- # style convs and to_rgbs
- for i in range(3, self.log_size + 1):
- out_channels = channels[f"{2**i}"]
- self.style_convs.append(
- StyleConv(
- in_channels,
- out_channels,
- kernel_size=3,
- num_style_feat=num_style_feat,
- demodulate=True,
- sample_mode="upsample",
- interpolation_mode=interpolation_mode,
- )
- )
- self.style_convs.append(
- StyleConv(
- out_channels,
- out_channels,
- kernel_size=3,
- num_style_feat=num_style_feat,
- demodulate=True,
- sample_mode=None,
- interpolation_mode=interpolation_mode,
- )
- )
- self.to_rgbs.append(
- ToRGB(
- out_channels,
- num_style_feat,
- upsample=True,
- interpolation_mode=interpolation_mode,
- )
- )
- in_channels = out_channels
-
- def make_noise(self):
- """Make noise for noise injection."""
- device = self.constant_input.weight.device
- noises = [torch.randn(1, 1, 4, 4, device=device)]
-
- for i in range(3, self.log_size + 1):
- for _ in range(2):
- noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
-
- return noises
-
- def get_latent(self, x):
- return self.style_mlp(x)
-
- def mean_latent(self, num_latent):
- latent_in = torch.randn(
- num_latent, self.num_style_feat, device=self.constant_input.weight.device
- )
- latent = self.style_mlp(latent_in).mean(0, keepdim=True)
- return latent
-
- def forward(
- self,
- styles,
- input_is_latent=False,
- noise=None,
- randomize_noise=True,
- truncation=1,
- truncation_latent=None,
- inject_index=None,
- return_latents=False,
- ):
- """Forward function for StyleGAN2Generator.
- Args:
- styles (list[Tensor]): Sample codes of styles.
- input_is_latent (bool): Whether input is latent style.
- Default: False.
- noise (Tensor | None): Input noise or None. Default: None.
- randomize_noise (bool): Randomize noise, used when 'noise' is
- False. Default: True.
- truncation (float): TODO. Default: 1.
- truncation_latent (Tensor | None): TODO. Default: None.
- inject_index (int | None): The injection index for mixing noise.
- Default: None.
- return_latents (bool): Whether to return style latents.
- Default: False.
- """
- # style codes -> latents with Style MLP layer
- if not input_is_latent:
- styles = [self.style_mlp(s) for s in styles]
- # noises
- if noise is None:
- if randomize_noise:
- noise = [None] * self.num_layers # for each style conv layer
- else: # use the stored noise
- noise = [
- getattr(self.noises, f"noise{i}") for i in range(self.num_layers)
- ]
- # style truncation
- if truncation < 1:
- style_truncation = []
- for style in styles:
- style_truncation.append(
- truncation_latent + truncation * (style - truncation_latent)
- )
- styles = style_truncation
- # get style latent with injection
- if len(styles) == 1:
- inject_index = self.num_latent
-
- if styles[0].ndim < 3:
- # repeat latent code for all the layers
- latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
- else: # used for encoder with different latent code for each layer
- latent = styles[0]
- elif len(styles) == 2: # mixing noises
- if inject_index is None:
- inject_index = random.randint(1, self.num_latent - 1)
- latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
- latent2 = (
- styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
- )
- latent = torch.cat([latent1, latent2], 1)
-
- # main generation
- out = self.constant_input(latent.shape[0])
- out = self.style_conv1(out, latent[:, 0], noise=noise[0])
- skip = self.to_rgb1(out, latent[:, 1])
-
- i = 1
- for conv1, conv2, noise1, noise2, to_rgb in zip(
- self.style_convs[::2],
- self.style_convs[1::2],
- noise[1::2],
- noise[2::2],
- self.to_rgbs,
- ):
- out = conv1(out, latent[:, i], noise=noise1)
- out = conv2(out, latent[:, i + 1], noise=noise2)
- skip = to_rgb(out, latent[:, i + 2], skip)
- i += 2
-
- image = skip
-
- if return_latents:
- return image, latent
- else:
- return image, None
-
-
-class ScaledLeakyReLU(nn.Module):
- """Scaled LeakyReLU.
- Args:
- negative_slope (float): Negative slope. Default: 0.2.
- """
-
- def __init__(self, negative_slope=0.2):
- super(ScaledLeakyReLU, self).__init__()
- self.negative_slope = negative_slope
-
- def forward(self, x):
- out = F.leaky_relu(x, negative_slope=self.negative_slope)
- return out * math.sqrt(2)
-
-
-class EqualConv2d(nn.Module):
- """Equalized Linear as StyleGAN2.
- Args:
- in_channels (int): Channel number of the input.
- out_channels (int): Channel number of the output.
- kernel_size (int): Size of the convolving kernel.
- stride (int): Stride of the convolution. Default: 1
- padding (int): Zero-padding added to both sides of the input.
- Default: 0.
- bias (bool): If ``True``, adds a learnable bias to the output.
- Default: ``True``.
- bias_init_val (float): Bias initialized value. Default: 0.
- """
-
- def __init__(
- self,
- in_channels,
- out_channels,
- kernel_size,
- stride=1,
- padding=0,
- bias=True,
- bias_init_val=0,
- ):
- super(EqualConv2d, self).__init__()
- self.in_channels = in_channels
- self.out_channels = out_channels
- self.kernel_size = kernel_size
- self.stride = stride
- self.padding = padding
- self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
-
- self.weight = nn.Parameter(
- torch.randn(out_channels, in_channels, kernel_size, kernel_size)
- )
- if bias:
- self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
- else:
- self.register_parameter("bias", None)
-
- def forward(self, x):
- out = F.conv2d(
- x,
- self.weight * self.scale,
- bias=self.bias,
- stride=self.stride,
- padding=self.padding,
- )
-
- return out
-
- def __repr__(self):
- return (
- f"{self.__class__.__name__}(in_channels={self.in_channels}, "
- f"out_channels={self.out_channels}, "
- f"kernel_size={self.kernel_size},"
- f" stride={self.stride}, padding={self.padding}, "
- f"bias={self.bias is not None})"
- )
-
-
-class ConvLayer(nn.Sequential):
- """Conv Layer used in StyleGAN2 Discriminator.
- Args:
- in_channels (int): Channel number of the input.
- out_channels (int): Channel number of the output.
- kernel_size (int): Kernel size.
- downsample (bool): Whether downsample by a factor of 2.
- Default: False.
- bias (bool): Whether with bias. Default: True.
- activate (bool): Whether use activateion. Default: True.
- """
-
- def __init__(
- self,
- in_channels,
- out_channels,
- kernel_size,
- downsample=False,
- bias=True,
- activate=True,
- interpolation_mode="bilinear",
- ):
- layers = []
- self.interpolation_mode = interpolation_mode
- # downsample
- if downsample:
- if self.interpolation_mode == "nearest":
- self.align_corners = None
- else:
- self.align_corners = False
-
- layers.append(
- torch.nn.Upsample(
- scale_factor=0.5,
- mode=interpolation_mode,
- align_corners=self.align_corners,
- )
- )
- stride = 1
- self.padding = kernel_size // 2
- # conv
- layers.append(
- EqualConv2d(
- in_channels,
- out_channels,
- kernel_size,
- stride=stride,
- padding=self.padding,
- bias=bias and not activate,
- )
- )
- # activation
- if activate:
- if bias:
- layers.append(FusedLeakyReLU(out_channels))
- else:
- layers.append(ScaledLeakyReLU(0.2))
-
- super(ConvLayer, self).__init__(*layers)
-
-
-class ResBlock(nn.Module):
- """Residual block used in StyleGAN2 Discriminator.
- Args:
- in_channels (int): Channel number of the input.
- out_channels (int): Channel number of the output.
- """
-
- def __init__(self, in_channels, out_channels, interpolation_mode="bilinear"):
- super(ResBlock, self).__init__()
-
- self.conv1 = ConvLayer(in_channels, in_channels, 3, bias=True, activate=True)
- self.conv2 = ConvLayer(
- in_channels,
- out_channels,
- 3,
- downsample=True,
- interpolation_mode=interpolation_mode,
- bias=True,
- activate=True,
- )
- self.skip = ConvLayer(
- in_channels,
- out_channels,
- 1,
- downsample=True,
- interpolation_mode=interpolation_mode,
- bias=False,
- activate=False,
- )
-
- def forward(self, x):
- out = self.conv1(x)
- out = self.conv2(out)
- skip = self.skip(x)
- out = (out + skip) / math.sqrt(2)
- return out
diff --git a/comfy_extras/chainner_models/architecture/face/stylegan2_clean_arch.py b/comfy_extras/chainner_models/architecture/face/stylegan2_clean_arch.py
deleted file mode 100644
index c48de9af690..00000000000
--- a/comfy_extras/chainner_models/architecture/face/stylegan2_clean_arch.py
+++ /dev/null
@@ -1,453 +0,0 @@
-# pylint: skip-file
-# type: ignore
-import math
-
-import torch
-from torch import nn
-from torch.nn import functional as F
-from torch.nn import init
-from torch.nn.modules.batchnorm import _BatchNorm
-
-
-@torch.no_grad()
-def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
- """Initialize network weights.
- Args:
- module_list (list[nn.Module] | nn.Module): Modules to be initialized.
- scale (float): Scale initialized weights, especially for residual
- blocks. Default: 1.
- bias_fill (float): The value to fill bias. Default: 0
- kwargs (dict): Other arguments for initialization function.
- """
- if not isinstance(module_list, list):
- module_list = [module_list]
- for module in module_list:
- for m in module.modules():
- if isinstance(m, nn.Conv2d):
- init.kaiming_normal_(m.weight, **kwargs)
- m.weight.data *= scale
- if m.bias is not None:
- m.bias.data.fill_(bias_fill)
- elif isinstance(m, nn.Linear):
- init.kaiming_normal_(m.weight, **kwargs)
- m.weight.data *= scale
- if m.bias is not None:
- m.bias.data.fill_(bias_fill)
- elif isinstance(m, _BatchNorm):
- init.constant_(m.weight, 1)
- if m.bias is not None:
- m.bias.data.fill_(bias_fill)
-
-
-class NormStyleCode(nn.Module):
- def forward(self, x):
- """Normalize the style codes.
- Args:
- x (Tensor): Style codes with shape (b, c).
- Returns:
- Tensor: Normalized tensor.
- """
- return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
-
-
-class ModulatedConv2d(nn.Module):
- """Modulated Conv2d used in StyleGAN2.
- There is no bias in ModulatedConv2d.
- Args:
- in_channels (int): Channel number of the input.
- out_channels (int): Channel number of the output.
- kernel_size (int): Size of the convolving kernel.
- num_style_feat (int): Channel number of style features.
- demodulate (bool): Whether to demodulate in the conv layer. Default: True.
- sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
- eps (float): A value added to the denominator for numerical stability. Default: 1e-8.
- """
-
- def __init__(
- self,
- in_channels,
- out_channels,
- kernel_size,
- num_style_feat,
- demodulate=True,
- sample_mode=None,
- eps=1e-8,
- ):
- super(ModulatedConv2d, self).__init__()
- self.in_channels = in_channels
- self.out_channels = out_channels
- self.kernel_size = kernel_size
- self.demodulate = demodulate
- self.sample_mode = sample_mode
- self.eps = eps
-
- # modulation inside each modulated conv
- self.modulation = nn.Linear(num_style_feat, in_channels, bias=True)
- # initialization
- default_init_weights(
- self.modulation,
- scale=1,
- bias_fill=1,
- a=0,
- mode="fan_in",
- nonlinearity="linear",
- )
-
- self.weight = nn.Parameter(
- torch.randn(1, out_channels, in_channels, kernel_size, kernel_size)
- / math.sqrt(in_channels * kernel_size**2)
- )
- self.padding = kernel_size // 2
-
- def forward(self, x, style):
- """Forward function.
- Args:
- x (Tensor): Tensor with shape (b, c, h, w).
- style (Tensor): Tensor with shape (b, num_style_feat).
- Returns:
- Tensor: Modulated tensor after convolution.
- """
- b, c, h, w = x.shape # c = c_in
- # weight modulation
- style = self.modulation(style).view(b, 1, c, 1, 1)
- # self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
- weight = self.weight * style # (b, c_out, c_in, k, k)
-
- if self.demodulate:
- demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
- weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
-
- weight = weight.view(
- b * self.out_channels, c, self.kernel_size, self.kernel_size
- )
-
- # upsample or downsample if necessary
- if self.sample_mode == "upsample":
- x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=False)
- elif self.sample_mode == "downsample":
- x = F.interpolate(x, scale_factor=0.5, mode="bilinear", align_corners=False)
-
- b, c, h, w = x.shape
- x = x.view(1, b * c, h, w)
- # weight: (b*c_out, c_in, k, k), groups=b
- out = F.conv2d(x, weight, padding=self.padding, groups=b)
- out = out.view(b, self.out_channels, *out.shape[2:4])
-
- return out
-
- def __repr__(self):
- return (
- f"{self.__class__.__name__}(in_channels={self.in_channels}, out_channels={self.out_channels}, "
- f"kernel_size={self.kernel_size}, demodulate={self.demodulate}, sample_mode={self.sample_mode})"
- )
-
-
-class StyleConv(nn.Module):
- """Style conv used in StyleGAN2.
- Args:
- in_channels (int): Channel number of the input.
- out_channels (int): Channel number of the output.
- kernel_size (int): Size of the convolving kernel.
- num_style_feat (int): Channel number of style features.
- demodulate (bool): Whether demodulate in the conv layer. Default: True.
- sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
- """
-
- def __init__(
- self,
- in_channels,
- out_channels,
- kernel_size,
- num_style_feat,
- demodulate=True,
- sample_mode=None,
- ):
- super(StyleConv, self).__init__()
- self.modulated_conv = ModulatedConv2d(
- in_channels,
- out_channels,
- kernel_size,
- num_style_feat,
- demodulate=demodulate,
- sample_mode=sample_mode,
- )
- self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
- self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
- self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
-
- def forward(self, x, style, noise=None):
- # modulate
- out = self.modulated_conv(x, style) * 2**0.5 # for conversion
- # noise injection
- if noise is None:
- b, _, h, w = out.shape
- noise = out.new_empty(b, 1, h, w).normal_()
- out = out + self.weight * noise
- # add bias
- out = out + self.bias
- # activation
- out = self.activate(out)
- return out
-
-
-class ToRGB(nn.Module):
- """To RGB (image space) from features.
- Args:
- in_channels (int): Channel number of input.
- num_style_feat (int): Channel number of style features.
- upsample (bool): Whether to upsample. Default: True.
- """
-
- def __init__(self, in_channels, num_style_feat, upsample=True):
- super(ToRGB, self).__init__()
- self.upsample = upsample
- self.modulated_conv = ModulatedConv2d(
- in_channels,
- 3,
- kernel_size=1,
- num_style_feat=num_style_feat,
- demodulate=False,
- sample_mode=None,
- )
- self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
-
- def forward(self, x, style, skip=None):
- """Forward function.
- Args:
- x (Tensor): Feature tensor with shape (b, c, h, w).
- style (Tensor): Tensor with shape (b, num_style_feat).
- skip (Tensor): Base/skip tensor. Default: None.
- Returns:
- Tensor: RGB images.
- """
- out = self.modulated_conv(x, style)
- out = out + self.bias
- if skip is not None:
- if self.upsample:
- skip = F.interpolate(
- skip, scale_factor=2, mode="bilinear", align_corners=False
- )
- out = out + skip
- return out
-
-
-class ConstantInput(nn.Module):
- """Constant input.
- Args:
- num_channel (int): Channel number of constant input.
- size (int): Spatial size of constant input.
- """
-
- def __init__(self, num_channel, size):
- super(ConstantInput, self).__init__()
- self.weight = nn.Parameter(torch.randn(1, num_channel, size, size))
-
- def forward(self, batch):
- out = self.weight.repeat(batch, 1, 1, 1)
- return out
-
-
-class StyleGAN2GeneratorClean(nn.Module):
- """Clean version of StyleGAN2 Generator.
- Args:
- out_size (int): The spatial size of outputs.
- num_style_feat (int): Channel number of style features. Default: 512.
- num_mlp (int): Layer number of MLP style layers. Default: 8.
- channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
- narrow (float): Narrow ratio for channels. Default: 1.0.
- """
-
- def __init__(
- self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1
- ):
- super(StyleGAN2GeneratorClean, self).__init__()
- # Style MLP layers
- self.num_style_feat = num_style_feat
- style_mlp_layers = [NormStyleCode()]
- for i in range(num_mlp):
- style_mlp_layers.extend(
- [
- nn.Linear(num_style_feat, num_style_feat, bias=True),
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
- ]
- )
- self.style_mlp = nn.Sequential(*style_mlp_layers)
- # initialization
- default_init_weights(
- self.style_mlp,
- scale=1,
- bias_fill=0,
- a=0.2,
- mode="fan_in",
- nonlinearity="leaky_relu",
- )
-
- # channel list
- channels = {
- "4": int(512 * narrow),
- "8": int(512 * narrow),
- "16": int(512 * narrow),
- "32": int(512 * narrow),
- "64": int(256 * channel_multiplier * narrow),
- "128": int(128 * channel_multiplier * narrow),
- "256": int(64 * channel_multiplier * narrow),
- "512": int(32 * channel_multiplier * narrow),
- "1024": int(16 * channel_multiplier * narrow),
- }
- self.channels = channels
-
- self.constant_input = ConstantInput(channels["4"], size=4)
- self.style_conv1 = StyleConv(
- channels["4"],
- channels["4"],
- kernel_size=3,
- num_style_feat=num_style_feat,
- demodulate=True,
- sample_mode=None,
- )
- self.to_rgb1 = ToRGB(channels["4"], num_style_feat, upsample=False)
-
- self.log_size = int(math.log(out_size, 2))
- self.num_layers = (self.log_size - 2) * 2 + 1
- self.num_latent = self.log_size * 2 - 2
-
- self.style_convs = nn.ModuleList()
- self.to_rgbs = nn.ModuleList()
- self.noises = nn.Module()
-
- in_channels = channels["4"]
- # noise
- for layer_idx in range(self.num_layers):
- resolution = 2 ** ((layer_idx + 5) // 2)
- shape = [1, 1, resolution, resolution]
- self.noises.register_buffer(f"noise{layer_idx}", torch.randn(*shape))
- # style convs and to_rgbs
- for i in range(3, self.log_size + 1):
- out_channels = channels[f"{2**i}"]
- self.style_convs.append(
- StyleConv(
- in_channels,
- out_channels,
- kernel_size=3,
- num_style_feat=num_style_feat,
- demodulate=True,
- sample_mode="upsample",
- )
- )
- self.style_convs.append(
- StyleConv(
- out_channels,
- out_channels,
- kernel_size=3,
- num_style_feat=num_style_feat,
- demodulate=True,
- sample_mode=None,
- )
- )
- self.to_rgbs.append(ToRGB(out_channels, num_style_feat, upsample=True))
- in_channels = out_channels
-
- def make_noise(self):
- """Make noise for noise injection."""
- device = self.constant_input.weight.device
- noises = [torch.randn(1, 1, 4, 4, device=device)]
-
- for i in range(3, self.log_size + 1):
- for _ in range(2):
- noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
-
- return noises
-
- def get_latent(self, x):
- return self.style_mlp(x)
-
- def mean_latent(self, num_latent):
- latent_in = torch.randn(
- num_latent, self.num_style_feat, device=self.constant_input.weight.device
- )
- latent = self.style_mlp(latent_in).mean(0, keepdim=True)
- return latent
-
- def forward(
- self,
- styles,
- input_is_latent=False,
- noise=None,
- randomize_noise=True,
- truncation=1,
- truncation_latent=None,
- inject_index=None,
- return_latents=False,
- ):
- """Forward function for StyleGAN2GeneratorClean.
- Args:
- styles (list[Tensor]): Sample codes of styles.
- input_is_latent (bool): Whether input is latent style. Default: False.
- noise (Tensor | None): Input noise or None. Default: None.
- randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
- truncation (float): The truncation ratio. Default: 1.
- truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
- inject_index (int | None): The injection index for mixing noise. Default: None.
- return_latents (bool): Whether to return style latents. Default: False.
- """
- # style codes -> latents with Style MLP layer
- if not input_is_latent:
- styles = [self.style_mlp(s) for s in styles]
- # noises
- if noise is None:
- if randomize_noise:
- noise = [None] * self.num_layers # for each style conv layer
- else: # use the stored noise
- noise = [
- getattr(self.noises, f"noise{i}") for i in range(self.num_layers)
- ]
- # style truncation
- if truncation < 1:
- style_truncation = []
- for style in styles:
- style_truncation.append(
- truncation_latent + truncation * (style - truncation_latent)
- )
- styles = style_truncation
- # get style latents with injection
- if len(styles) == 1:
- inject_index = self.num_latent
-
- if styles[0].ndim < 3:
- # repeat latent code for all the layers
- latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
- else: # used for encoder with different latent code for each layer
- latent = styles[0]
- elif len(styles) == 2: # mixing noises
- if inject_index is None:
- inject_index = random.randint(1, self.num_latent - 1)
- latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
- latent2 = (
- styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
- )
- latent = torch.cat([latent1, latent2], 1)
-
- # main generation
- out = self.constant_input(latent.shape[0])
- out = self.style_conv1(out, latent[:, 0], noise=noise[0])
- skip = self.to_rgb1(out, latent[:, 1])
-
- i = 1
- for conv1, conv2, noise1, noise2, to_rgb in zip(
- self.style_convs[::2],
- self.style_convs[1::2],
- noise[1::2],
- noise[2::2],
- self.to_rgbs,
- ):
- out = conv1(out, latent[:, i], noise=noise1)
- out = conv2(out, latent[:, i + 1], noise=noise2)
- skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
- i += 2
-
- image = skip
-
- if return_latents:
- return image, latent
- else:
- return image, None
diff --git a/comfy_extras/chainner_models/architecture/face/upfirdn2d.py b/comfy_extras/chainner_models/architecture/face/upfirdn2d.py
deleted file mode 100644
index 4ea4541513f..00000000000
--- a/comfy_extras/chainner_models/architecture/face/upfirdn2d.py
+++ /dev/null
@@ -1,194 +0,0 @@
-# pylint: skip-file
-# type: ignore
-# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py # noqa:E501
-
-import os
-
-import torch
-from torch.autograd import Function
-from torch.nn import functional as F
-
-upfirdn2d_ext = None
-
-
-class UpFirDn2dBackward(Function):
- @staticmethod
- def forward(
- ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
- ):
- up_x, up_y = up
- down_x, down_y = down
- g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
-
- grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
-
- grad_input = upfirdn2d_ext.upfirdn2d(
- grad_output,
- grad_kernel,
- down_x,
- down_y,
- up_x,
- up_y,
- g_pad_x0,
- g_pad_x1,
- g_pad_y0,
- g_pad_y1,
- )
- grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
-
- ctx.save_for_backward(kernel)
-
- pad_x0, pad_x1, pad_y0, pad_y1 = pad
-
- ctx.up_x = up_x
- ctx.up_y = up_y
- ctx.down_x = down_x
- ctx.down_y = down_y
- ctx.pad_x0 = pad_x0
- ctx.pad_x1 = pad_x1
- ctx.pad_y0 = pad_y0
- ctx.pad_y1 = pad_y1
- ctx.in_size = in_size
- ctx.out_size = out_size
-
- return grad_input
-
- @staticmethod
- def backward(ctx, gradgrad_input):
- (kernel,) = ctx.saved_tensors
-
- gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
-
- gradgrad_out = upfirdn2d_ext.upfirdn2d(
- gradgrad_input,
- kernel,
- ctx.up_x,
- ctx.up_y,
- ctx.down_x,
- ctx.down_y,
- ctx.pad_x0,
- ctx.pad_x1,
- ctx.pad_y0,
- ctx.pad_y1,
- )
- # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0],
- # ctx.out_size[1], ctx.in_size[3])
- gradgrad_out = gradgrad_out.view(
- ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
- )
-
- return gradgrad_out, None, None, None, None, None, None, None, None
-
-
-class UpFirDn2d(Function):
- @staticmethod
- def forward(ctx, input, kernel, up, down, pad):
- up_x, up_y = up
- down_x, down_y = down
- pad_x0, pad_x1, pad_y0, pad_y1 = pad
-
- kernel_h, kernel_w = kernel.shape
- _, channel, in_h, in_w = input.shape
- ctx.in_size = input.shape
-
- input = input.reshape(-1, in_h, in_w, 1)
-
- ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
-
- out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
- out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
- ctx.out_size = (out_h, out_w)
-
- ctx.up = (up_x, up_y)
- ctx.down = (down_x, down_y)
- ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
-
- g_pad_x0 = kernel_w - pad_x0 - 1
- g_pad_y0 = kernel_h - pad_y0 - 1
- g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
- g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
-
- ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
-
- out = upfirdn2d_ext.upfirdn2d(
- input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
- )
- # out = out.view(major, out_h, out_w, minor)
- out = out.view(-1, channel, out_h, out_w)
-
- return out
-
- @staticmethod
- def backward(ctx, grad_output):
- kernel, grad_kernel = ctx.saved_tensors
-
- grad_input = UpFirDn2dBackward.apply(
- grad_output,
- kernel,
- grad_kernel,
- ctx.up,
- ctx.down,
- ctx.pad,
- ctx.g_pad,
- ctx.in_size,
- ctx.out_size,
- )
-
- return grad_input, None, None, None, None
-
-
-def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
- if input.device.type == "cpu":
- out = upfirdn2d_native(
- input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]
- )
- else:
- out = UpFirDn2d.apply(
- input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
- )
-
- return out
-
-
-def upfirdn2d_native(
- input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
-):
- _, channel, in_h, in_w = input.shape
- input = input.reshape(-1, in_h, in_w, 1)
-
- _, in_h, in_w, minor = input.shape
- kernel_h, kernel_w = kernel.shape
-
- out = input.view(-1, in_h, 1, in_w, 1, minor)
- out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
- out = out.view(-1, in_h * up_y, in_w * up_x, minor)
-
- out = F.pad(
- out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
- )
- out = out[
- :,
- max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
- max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
- :,
- ]
-
- out = out.permute(0, 3, 1, 2)
- out = out.reshape(
- [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
- )
- w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
- out = F.conv2d(out, w)
- out = out.reshape(
- -1,
- minor,
- in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
- in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
- )
- out = out.permute(0, 2, 3, 1)
- out = out[:, ::down_y, ::down_x, :]
-
- out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
- out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
-
- return out.view(-1, channel, out_h, out_w)
diff --git a/comfy_extras/chainner_models/architecture/timm/LICENSE b/comfy_extras/chainner_models/architecture/timm/LICENSE
deleted file mode 100644
index b4e9438bd1e..00000000000
--- a/comfy_extras/chainner_models/architecture/timm/LICENSE
+++ /dev/null
@@ -1,201 +0,0 @@
- Apache License
- Version 2.0, January 2004
- http://www.apache.org/licenses/
-
- TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
-
- 1. Definitions.
-
- "License" shall mean the terms and conditions for use, reproduction,
- and distribution as defined by Sections 1 through 9 of this document.
-
- "Licensor" shall mean the copyright owner or entity authorized by
- the copyright owner that is granting the License.
-
- "Legal Entity" shall mean the union of the acting entity and all
- other entities that control, are controlled by, or are under common
- control with that entity. For the purposes of this definition,
- "control" means (i) the power, direct or indirect, to cause the
- direction or management of such entity, whether by contract or
- otherwise, or (ii) ownership of fifty percent (50%) or more of the
- outstanding shares, or (iii) beneficial ownership of such entity.
-
- "You" (or "Your") shall mean an individual or Legal Entity
- exercising permissions granted by this License.
-
- "Source" form shall mean the preferred form for making modifications,
- including but not limited to software source code, documentation
- source, and configuration files.
-
- "Object" form shall mean any form resulting from mechanical
- transformation or translation of a Source form, including but
- not limited to compiled object code, generated documentation,
- and conversions to other media types.
-
- "Work" shall mean the work of authorship, whether in Source or
- Object form, made available under the License, as indicated by a
- copyright notice that is included in or attached to the work
- (an example is provided in the Appendix below).
-
- "Derivative Works" shall mean any work, whether in Source or Object
- form, that is based on (or derived from) the Work and for which the
- editorial revisions, annotations, elaborations, or other modifications
- represent, as a whole, an original work of authorship. For the purposes
- of this License, Derivative Works shall not include works that remain
- separable from, or merely link (or bind by name) to the interfaces of,
- the Work and Derivative Works thereof.
-
- "Contribution" shall mean any work of authorship, including
- the original version of the Work and any modifications or additions
- to that Work or Derivative Works thereof, that is intentionally
- submitted to Licensor for inclusion in the Work by the copyright owner
- or by an individual or Legal Entity authorized to submit on behalf of
- the copyright owner. For the purposes of this definition, "submitted"
- means any form of electronic, verbal, or written communication sent
- to the Licensor or its representatives, including but not limited to
- communication on electronic mailing lists, source code control systems,
- and issue tracking systems that are managed by, or on behalf of, the
- Licensor for the purpose of discussing and improving the Work, but
- excluding communication that is conspicuously marked or otherwise
- designated in writing by the copyright owner as "Not a Contribution."
-
- "Contributor" shall mean Licensor and any individual or Legal Entity
- on behalf of whom a Contribution has been received by Licensor and
- subsequently incorporated within the Work.
-
- 2. Grant of Copyright License. Subject to the terms and conditions of
- this License, each Contributor hereby grants to You a perpetual,
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
- copyright license to reproduce, prepare Derivative Works of,
- publicly display, publicly perform, sublicense, and distribute the
- Work and such Derivative Works in Source or Object form.
-
- 3. Grant of Patent License. Subject to the terms and conditions of
- this License, each Contributor hereby grants to You a perpetual,
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
- (except as stated in this section) patent license to make, have made,
- use, offer to sell, sell, import, and otherwise transfer the Work,
- where such license applies only to those patent claims licensable
- by such Contributor that are necessarily infringed by their
- Contribution(s) alone or by combination of their Contribution(s)
- with the Work to which such Contribution(s) was submitted. If You
- institute patent litigation against any entity (including a
- cross-claim or counterclaim in a lawsuit) alleging that the Work
- or a Contribution incorporated within the Work constitutes direct
- or contributory patent infringement, then any patent licenses
- granted to You under this License for that Work shall terminate
- as of the date such litigation is filed.
-
- 4. Redistribution. You may reproduce and distribute copies of the
- Work or Derivative Works thereof in any medium, with or without
- modifications, and in Source or Object form, provided that You
- meet the following conditions:
-
- (a) You must give any other recipients of the Work or
- Derivative Works a copy of this License; and
-
- (b) You must cause any modified files to carry prominent notices
- stating that You changed the files; and
-
- (c) You must retain, in the Source form of any Derivative Works
- that You distribute, all copyright, patent, trademark, and
- attribution notices from the Source form of the Work,
- excluding those notices that do not pertain to any part of
- the Derivative Works; and
-
- (d) If the Work includes a "NOTICE" text file as part of its
- distribution, then any Derivative Works that You distribute must
- include a readable copy of the attribution notices contained
- within such NOTICE file, excluding those notices that do not
- pertain to any part of the Derivative Works, in at least one
- of the following places: within a NOTICE text file distributed
- as part of the Derivative Works; within the Source form or
- documentation, if provided along with the Derivative Works; or,
- within a display generated by the Derivative Works, if and
- wherever such third-party notices normally appear. The contents
- of the NOTICE file are for informational purposes only and
- do not modify the License. You may add Your own attribution
- notices within Derivative Works that You distribute, alongside
- or as an addendum to the NOTICE text from the Work, provided
- that such additional attribution notices cannot be construed
- as modifying the License.
-
- You may add Your own copyright statement to Your modifications and
- may provide additional or different license terms and conditions
- for use, reproduction, or distribution of Your modifications, or
- for any such Derivative Works as a whole, provided Your use,
- reproduction, and distribution of the Work otherwise complies with
- the conditions stated in this License.
-
- 5. Submission of Contributions. Unless You explicitly state otherwise,
- any Contribution intentionally submitted for inclusion in the Work
- by You to the Licensor shall be under the terms and conditions of
- this License, without any additional terms or conditions.
- Notwithstanding the above, nothing herein shall supersede or modify
- the terms of any separate license agreement you may have executed
- with Licensor regarding such Contributions.
-
- 6. Trademarks. This License does not grant permission to use the trade
- names, trademarks, service marks, or product names of the Licensor,
- except as required for reasonable and customary use in describing the
- origin of the Work and reproducing the content of the NOTICE file.
-
- 7. Disclaimer of Warranty. Unless required by applicable law or
- agreed to in writing, Licensor provides the Work (and each
- Contributor provides its Contributions) on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
- implied, including, without limitation, any warranties or conditions
- of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
- PARTICULAR PURPOSE. You are solely responsible for determining the
- appropriateness of using or redistributing the Work and assume any
- risks associated with Your exercise of permissions under this License.
-
- 8. Limitation of Liability. In no event and under no legal theory,
- whether in tort (including negligence), contract, or otherwise,
- unless required by applicable law (such as deliberate and grossly
- negligent acts) or agreed to in writing, shall any Contributor be
- liable to You for damages, including any direct, indirect, special,
- incidental, or consequential damages of any character arising as a
- result of this License or out of the use or inability to use the
- Work (including but not limited to damages for loss of goodwill,
- work stoppage, computer failure or malfunction, or any and all
- other commercial damages or losses), even if such Contributor
- has been advised of the possibility of such damages.
-
- 9. Accepting Warranty or Additional Liability. While redistributing
- the Work or Derivative Works thereof, You may choose to offer,
- and charge a fee for, acceptance of support, warranty, indemnity,
- or other liability obligations and/or rights consistent with this
- License. However, in accepting such obligations, You may act only
- on Your own behalf and on Your sole responsibility, not on behalf
- of any other Contributor, and only if You agree to indemnify,
- defend, and hold each Contributor harmless for any liability
- incurred by, or claims asserted against, such Contributor by reason
- of your accepting any such warranty or additional liability.
-
- END OF TERMS AND CONDITIONS
-
- APPENDIX: How to apply the Apache License to your work.
-
- To apply the Apache License to your work, attach the following
- boilerplate notice, with the fields enclosed by brackets "{}"
- replaced with your own identifying information. (Don't include
- the brackets!) The text should be enclosed in the appropriate
- comment syntax for the file format. We also recommend that a
- file or class name and description of purpose be included on the
- same "printed page" as the copyright notice for easier
- identification within third-party archives.
-
- Copyright 2019 Ross Wightman
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
\ No newline at end of file
diff --git a/comfy_extras/chainner_models/architecture/timm/drop.py b/comfy_extras/chainner_models/architecture/timm/drop.py
deleted file mode 100644
index 14f0da914b2..00000000000
--- a/comfy_extras/chainner_models/architecture/timm/drop.py
+++ /dev/null
@@ -1,223 +0,0 @@
-""" DropBlock, DropPath
-
-PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers.
-
-Papers:
-DropBlock: A regularization method for convolutional networks (https://arxiv.org/abs/1810.12890)
-
-Deep Networks with Stochastic Depth (https://arxiv.org/abs/1603.09382)
-
-Code:
-DropBlock impl inspired by two Tensorflow impl that I liked:
- - https://github.com/tensorflow/tpu/blob/master/models/official/resnet/resnet_model.py#L74
- - https://github.com/clovaai/assembled-cnn/blob/master/nets/blocks.py
-
-Hacked together by / Copyright 2020 Ross Wightman
-"""
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-
-def drop_block_2d(
- x,
- drop_prob: float = 0.1,
- block_size: int = 7,
- gamma_scale: float = 1.0,
- with_noise: bool = False,
- inplace: bool = False,
- batchwise: bool = False,
-):
- """DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
-
- DropBlock with an experimental gaussian noise option. This layer has been tested on a few training
- runs with success, but needs further validation and possibly optimization for lower runtime impact.
- """
- _, C, H, W = x.shape
- total_size = W * H
- clipped_block_size = min(block_size, min(W, H))
- # seed_drop_rate, the gamma parameter
- gamma = (
- gamma_scale
- * drop_prob
- * total_size
- / clipped_block_size**2
- / ((W - block_size + 1) * (H - block_size + 1))
- )
-
- # Forces the block to be inside the feature map.
- w_i, h_i = torch.meshgrid(
- torch.arange(W).to(x.device), torch.arange(H).to(x.device)
- )
- valid_block = (
- (w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)
- ) & ((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2))
- valid_block = torch.reshape(valid_block, (1, 1, H, W)).to(dtype=x.dtype)
-
- if batchwise:
- # one mask for whole batch, quite a bit faster
- uniform_noise = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device)
- else:
- uniform_noise = torch.rand_like(x)
- block_mask = ((2 - gamma - valid_block + uniform_noise) >= 1).to(dtype=x.dtype)
- block_mask = -F.max_pool2d(
- -block_mask,
- kernel_size=clipped_block_size, # block_size,
- stride=1,
- padding=clipped_block_size // 2,
- )
-
- if with_noise:
- normal_noise = (
- torch.randn((1, C, H, W), dtype=x.dtype, device=x.device)
- if batchwise
- else torch.randn_like(x)
- )
- if inplace:
- x.mul_(block_mask).add_(normal_noise * (1 - block_mask))
- else:
- x = x * block_mask + normal_noise * (1 - block_mask)
- else:
- normalize_scale = (
- block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)
- ).to(x.dtype)
- if inplace:
- x.mul_(block_mask * normalize_scale)
- else:
- x = x * block_mask * normalize_scale
- return x
-
-
-def drop_block_fast_2d(
- x: torch.Tensor,
- drop_prob: float = 0.1,
- block_size: int = 7,
- gamma_scale: float = 1.0,
- with_noise: bool = False,
- inplace: bool = False,
-):
- """DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
-
- DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid
- block mask at edges.
- """
- _, _, H, W = x.shape
- total_size = W * H
- clipped_block_size = min(block_size, min(W, H))
- gamma = (
- gamma_scale
- * drop_prob
- * total_size
- / clipped_block_size**2
- / ((W - block_size + 1) * (H - block_size + 1))
- )
-
- block_mask = torch.empty_like(x).bernoulli_(gamma)
- block_mask = F.max_pool2d(
- block_mask.to(x.dtype),
- kernel_size=clipped_block_size,
- stride=1,
- padding=clipped_block_size // 2,
- )
-
- if with_noise:
- normal_noise = torch.empty_like(x).normal_()
- if inplace:
- x.mul_(1.0 - block_mask).add_(normal_noise * block_mask)
- else:
- x = x * (1.0 - block_mask) + normal_noise * block_mask
- else:
- block_mask = 1 - block_mask
- normalize_scale = (
- block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-6)
- ).to(dtype=x.dtype)
- if inplace:
- x.mul_(block_mask * normalize_scale)
- else:
- x = x * block_mask * normalize_scale
- return x
-
-
-class DropBlock2d(nn.Module):
- """DropBlock. See https://arxiv.org/pdf/1810.12890.pdf"""
-
- def __init__(
- self,
- drop_prob: float = 0.1,
- block_size: int = 7,
- gamma_scale: float = 1.0,
- with_noise: bool = False,
- inplace: bool = False,
- batchwise: bool = False,
- fast: bool = True,
- ):
- super(DropBlock2d, self).__init__()
- self.drop_prob = drop_prob
- self.gamma_scale = gamma_scale
- self.block_size = block_size
- self.with_noise = with_noise
- self.inplace = inplace
- self.batchwise = batchwise
- self.fast = fast # FIXME finish comparisons of fast vs not
-
- def forward(self, x):
- if not self.training or not self.drop_prob:
- return x
- if self.fast:
- return drop_block_fast_2d(
- x,
- self.drop_prob,
- self.block_size,
- self.gamma_scale,
- self.with_noise,
- self.inplace,
- )
- else:
- return drop_block_2d(
- x,
- self.drop_prob,
- self.block_size,
- self.gamma_scale,
- self.with_noise,
- self.inplace,
- self.batchwise,
- )
-
-
-def drop_path(
- x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
-):
- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
-
- This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
- the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
- See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
- changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
- 'survival rate' as the argument.
-
- """
- if drop_prob == 0.0 or not training:
- return x
- keep_prob = 1 - drop_prob
- shape = (x.shape[0],) + (1,) * (
- x.ndim - 1
- ) # work with diff dim tensors, not just 2D ConvNets
- random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
- if keep_prob > 0.0 and scale_by_keep:
- random_tensor.div_(keep_prob)
- return x * random_tensor
-
-
-class DropPath(nn.Module):
- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
-
- def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
- super(DropPath, self).__init__()
- self.drop_prob = drop_prob
- self.scale_by_keep = scale_by_keep
-
- def forward(self, x):
- return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
-
- def extra_repr(self):
- return f"drop_prob={round(self.drop_prob,3):0.3f}"
diff --git a/comfy_extras/chainner_models/architecture/timm/helpers.py b/comfy_extras/chainner_models/architecture/timm/helpers.py
deleted file mode 100644
index cdafee07091..00000000000
--- a/comfy_extras/chainner_models/architecture/timm/helpers.py
+++ /dev/null
@@ -1,31 +0,0 @@
-""" Layer/Module Helpers
-Hacked together by / Copyright 2020 Ross Wightman
-"""
-import collections.abc
-from itertools import repeat
-
-
-# From PyTorch internals
-def _ntuple(n):
- def parse(x):
- if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
- return x
- return tuple(repeat(x, n))
-
- return parse
-
-
-to_1tuple = _ntuple(1)
-to_2tuple = _ntuple(2)
-to_3tuple = _ntuple(3)
-to_4tuple = _ntuple(4)
-to_ntuple = _ntuple
-
-
-def make_divisible(v, divisor=8, min_value=None, round_limit=0.9):
- min_value = min_value or divisor
- new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
- # Make sure that round down does not go down by more than 10%.
- if new_v < round_limit * v:
- new_v += divisor
- return new_v
diff --git a/comfy_extras/chainner_models/architecture/timm/weight_init.py b/comfy_extras/chainner_models/architecture/timm/weight_init.py
deleted file mode 100644
index b0169774657..00000000000
--- a/comfy_extras/chainner_models/architecture/timm/weight_init.py
+++ /dev/null
@@ -1,128 +0,0 @@
-import math
-import warnings
-
-import torch
-from torch.nn.init import _calculate_fan_in_and_fan_out
-
-
-def _no_grad_trunc_normal_(tensor, mean, std, a, b):
- # Cut & paste from PyTorch official master until it's in a few official releases - RW
- # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
- def norm_cdf(x):
- # Computes standard normal cumulative distribution function
- return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
-
- if (mean < a - 2 * std) or (mean > b + 2 * std):
- warnings.warn(
- "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
- "The distribution of values may be incorrect.",
- stacklevel=2,
- )
-
- with torch.no_grad():
- # Values are generated by using a truncated uniform distribution and
- # then using the inverse CDF for the normal distribution.
- # Get upper and lower cdf values
- l = norm_cdf((a - mean) / std)
- u = norm_cdf((b - mean) / std)
-
- # Uniformly fill tensor with values from [l, u], then translate to
- # [2l-1, 2u-1].
- tensor.uniform_(2 * l - 1, 2 * u - 1)
-
- # Use inverse cdf transform for normal distribution to get truncated
- # standard normal
- tensor.erfinv_()
-
- # Transform to proper mean, std
- tensor.mul_(std * math.sqrt(2.0))
- tensor.add_(mean)
-
- # Clamp to ensure it's in the proper range
- tensor.clamp_(min=a, max=b)
- return tensor
-
-
-def trunc_normal_(
- tensor: torch.Tensor, mean=0.0, std=1.0, a=-2.0, b=2.0
-) -> torch.Tensor:
- r"""Fills the input Tensor with values drawn from a truncated
- normal distribution. The values are effectively drawn from the
- normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
- with values outside :math:`[a, b]` redrawn until they are within
- the bounds. The method used for generating the random values works
- best when :math:`a \leq \text{mean} \leq b`.
-
- NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are
- applied while sampling the normal with mean/std applied, therefore a, b args
- should be adjusted to match the range of mean, std args.
-
- Args:
- tensor: an n-dimensional `torch.Tensor`
- mean: the mean of the normal distribution
- std: the standard deviation of the normal distribution
- a: the minimum cutoff value
- b: the maximum cutoff value
- Examples:
- >>> w = torch.empty(3, 5)
- >>> nn.init.trunc_normal_(w)
- """
- return _no_grad_trunc_normal_(tensor, mean, std, a, b)
-
-
-def trunc_normal_tf_(
- tensor: torch.Tensor, mean=0.0, std=1.0, a=-2.0, b=2.0
-) -> torch.Tensor:
- r"""Fills the input Tensor with values drawn from a truncated
- normal distribution. The values are effectively drawn from the
- normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
- with values outside :math:`[a, b]` redrawn until they are within
- the bounds. The method used for generating the random values works
- best when :math:`a \leq \text{mean} \leq b`.
-
- NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
- bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
- and the result is subsquently scaled and shifted by the mean and std args.
-
- Args:
- tensor: an n-dimensional `torch.Tensor`
- mean: the mean of the normal distribution
- std: the standard deviation of the normal distribution
- a: the minimum cutoff value
- b: the maximum cutoff value
- Examples:
- >>> w = torch.empty(3, 5)
- >>> nn.init.trunc_normal_(w)
- """
- _no_grad_trunc_normal_(tensor, 0, 1.0, a, b)
- with torch.no_grad():
- tensor.mul_(std).add_(mean)
- return tensor
-
-
-def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
- fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
- if mode == "fan_in":
- denom = fan_in
- elif mode == "fan_out":
- denom = fan_out
- elif mode == "fan_avg":
- denom = (fan_in + fan_out) / 2
-
- variance = scale / denom # type: ignore
-
- if distribution == "truncated_normal":
- # constant is stddev of standard normal truncated to (-2, 2)
- trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
- elif distribution == "normal":
- tensor.normal_(std=math.sqrt(variance))
- elif distribution == "uniform":
- bound = math.sqrt(3 * variance)
- # pylint: disable=invalid-unary-operand-type
- tensor.uniform_(-bound, bound)
- else:
- raise ValueError(f"invalid distribution {distribution}")
-
-
-def lecun_normal_(tensor):
- variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
diff --git a/comfy_extras/chainner_models/model_loading.py b/comfy_extras/chainner_models/model_loading.py
index e000871c1bf..d48bc238ccc 100644
--- a/comfy_extras/chainner_models/model_loading.py
+++ b/comfy_extras/chainner_models/model_loading.py
@@ -1,99 +1,5 @@
-import logging as logger
+from spandrel import ModelLoader
-from .architecture.DAT import DAT
-from .architecture.face.codeformer import CodeFormer
-from .architecture.face.gfpganv1_clean_arch import GFPGANv1Clean
-from .architecture.face.restoreformer_arch import RestoreFormer
-from .architecture.HAT import HAT
-from .architecture.LaMa import LaMa
-from .architecture.OmniSR.OmniSR import OmniSR
-from .architecture.RRDB import RRDBNet as ESRGAN
-from .architecture.SCUNet import SCUNet
-from .architecture.SPSR import SPSRNet as SPSR
-from .architecture.SRVGG import SRVGGNetCompact as RealESRGANv2
-from .architecture.SwiftSRGAN import Generator as SwiftSRGAN
-from .architecture.Swin2SR import Swin2SR
-from .architecture.SwinIR import SwinIR
-from .types import PyTorchModel
-
-
-class UnsupportedModel(Exception):
- pass
-
-
-def load_state_dict(state_dict) -> PyTorchModel:
- logger.debug(f"Loading state dict into pytorch model arch")
-
- state_dict_keys = list(state_dict.keys())
-
- if "params_ema" in state_dict_keys:
- state_dict = state_dict["params_ema"]
- elif "params-ema" in state_dict_keys:
- state_dict = state_dict["params-ema"]
- elif "params" in state_dict_keys:
- state_dict = state_dict["params"]
-
- state_dict_keys = list(state_dict.keys())
- # SRVGGNet Real-ESRGAN (v2)
- if "body.0.weight" in state_dict_keys and "body.1.weight" in state_dict_keys:
- model = RealESRGANv2(state_dict)
- # SPSR (ESRGAN with lots of extra layers)
- elif "f_HR_conv1.0.weight" in state_dict:
- model = SPSR(state_dict)
- # Swift-SRGAN
- elif (
- "model" in state_dict_keys
- and "initial.cnn.depthwise.weight" in state_dict["model"].keys()
- ):
- model = SwiftSRGAN(state_dict)
- # SwinIR, Swin2SR, HAT
- elif "layers.0.residual_group.blocks.0.norm1.weight" in state_dict_keys:
- if (
- "layers.0.residual_group.blocks.0.conv_block.cab.0.weight"
- in state_dict_keys
- ):
- model = HAT(state_dict)
- elif "patch_embed.proj.weight" in state_dict_keys:
- model = Swin2SR(state_dict)
- else:
- model = SwinIR(state_dict)
- # GFPGAN
- elif (
- "toRGB.0.weight" in state_dict_keys
- and "stylegan_decoder.style_mlp.1.weight" in state_dict_keys
- ):
- model = GFPGANv1Clean(state_dict)
- # RestoreFormer
- elif (
- "encoder.conv_in.weight" in state_dict_keys
- and "encoder.down.0.block.0.norm1.weight" in state_dict_keys
- ):
- model = RestoreFormer(state_dict)
- elif (
- "encoder.blocks.0.weight" in state_dict_keys
- and "quantize.embedding.weight" in state_dict_keys
- ):
- model = CodeFormer(state_dict)
- # LaMa
- elif (
- "model.model.1.bn_l.running_mean" in state_dict_keys
- or "generator.model.1.bn_l.running_mean" in state_dict_keys
- ):
- model = LaMa(state_dict)
- # Omni-SR
- elif "residual_layer.0.residual_layer.0.layer.0.fn.0.weight" in state_dict_keys:
- model = OmniSR(state_dict)
- # SCUNet
- elif "m_head.0.weight" in state_dict_keys and "m_tail.0.weight" in state_dict_keys:
- model = SCUNet(state_dict)
- # DAT
- elif "layers.0.blocks.2.attn.attn_mask_0" in state_dict_keys:
- model = DAT(state_dict)
- # Regular ESRGAN, "new-arch" ESRGAN, Real-ESRGAN v1
- else:
- try:
- model = ESRGAN(state_dict)
- except:
- # pylint: disable=raise-missing-from
- raise UnsupportedModel
- return model
+def load_state_dict(state_dict):
+ print("WARNING: comfy_extras.chainner_models is deprecated and has been replaced by the spandrel library.")
+ return ModelLoader().load_from_state_dict(state_dict).eval()
diff --git a/comfy_extras/chainner_models/types.py b/comfy_extras/chainner_models/types.py
deleted file mode 100644
index 193333b9e80..00000000000
--- a/comfy_extras/chainner_models/types.py
+++ /dev/null
@@ -1,69 +0,0 @@
-from typing import Union
-
-from .architecture.DAT import DAT
-from .architecture.face.codeformer import CodeFormer
-from .architecture.face.gfpganv1_clean_arch import GFPGANv1Clean
-from .architecture.face.restoreformer_arch import RestoreFormer
-from .architecture.HAT import HAT
-from .architecture.LaMa import LaMa
-from .architecture.OmniSR.OmniSR import OmniSR
-from .architecture.RRDB import RRDBNet as ESRGAN
-from .architecture.SCUNet import SCUNet
-from .architecture.SPSR import SPSRNet as SPSR
-from .architecture.SRVGG import SRVGGNetCompact as RealESRGANv2
-from .architecture.SwiftSRGAN import Generator as SwiftSRGAN
-from .architecture.Swin2SR import Swin2SR
-from .architecture.SwinIR import SwinIR
-
-PyTorchSRModels = (
- RealESRGANv2,
- SPSR,
- SwiftSRGAN,
- ESRGAN,
- SwinIR,
- Swin2SR,
- HAT,
- OmniSR,
- SCUNet,
- DAT,
-)
-PyTorchSRModel = Union[
- RealESRGANv2,
- SPSR,
- SwiftSRGAN,
- ESRGAN,
- SwinIR,
- Swin2SR,
- HAT,
- OmniSR,
- SCUNet,
- DAT,
-]
-
-
-def is_pytorch_sr_model(model: object):
- return isinstance(model, PyTorchSRModels)
-
-
-PyTorchFaceModels = (GFPGANv1Clean, RestoreFormer, CodeFormer)
-PyTorchFaceModel = Union[GFPGANv1Clean, RestoreFormer, CodeFormer]
-
-
-def is_pytorch_face_model(model: object):
- return isinstance(model, PyTorchFaceModels)
-
-
-PyTorchInpaintModels = (LaMa,)
-PyTorchInpaintModel = Union[LaMa]
-
-
-def is_pytorch_inpaint_model(model: object):
- return isinstance(model, PyTorchInpaintModels)
-
-
-PyTorchModels = (*PyTorchSRModels, *PyTorchFaceModels, *PyTorchInpaintModels)
-PyTorchModel = Union[PyTorchSRModel, PyTorchFaceModel, PyTorchInpaintModel]
-
-
-def is_pytorch_model(model: object):
- return isinstance(model, PyTorchModels)
diff --git a/comfy_extras/nodes_advanced_samplers.py b/comfy_extras/nodes_advanced_samplers.py
new file mode 100644
index 00000000000..d973def816b
--- /dev/null
+++ b/comfy_extras/nodes_advanced_samplers.py
@@ -0,0 +1,61 @@
+import comfy.samplers
+import comfy.utils
+import torch
+import numpy as np
+from tqdm.auto import trange, tqdm
+import math
+
+
+@torch.no_grad()
+def sample_lcm_upscale(model, x, sigmas, extra_args=None, callback=None, disable=None, total_upscale=2.0, upscale_method="bislerp", upscale_steps=None):
+ extra_args = {} if extra_args is None else extra_args
+
+ if upscale_steps is None:
+ upscale_steps = max(len(sigmas) // 2 + 1, 2)
+ else:
+ upscale_steps += 1
+ upscale_steps = min(upscale_steps, len(sigmas) + 1)
+
+ upscales = np.linspace(1.0, total_upscale, upscale_steps)[1:]
+
+ orig_shape = x.size()
+ s_in = x.new_ones([x.shape[0]])
+ for i in trange(len(sigmas) - 1, disable=disable):
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
+ if callback is not None:
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
+
+ x = denoised
+ if i < len(upscales):
+ x = comfy.utils.common_upscale(x, round(orig_shape[-1] * upscales[i]), round(orig_shape[-2] * upscales[i]), upscale_method, "disabled")
+
+ if sigmas[i + 1] > 0:
+ x += sigmas[i + 1] * torch.randn_like(x)
+ return x
+
+
+class SamplerLCMUpscale:
+ upscale_methods = ["bislerp", "nearest-exact", "bilinear", "area", "bicubic"]
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required":
+ {"scale_ratio": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 20.0, "step": 0.01}),
+ "scale_steps": ("INT", {"default": -1, "min": -1, "max": 1000, "step": 1}),
+ "upscale_method": (s.upscale_methods,),
+ }
+ }
+ RETURN_TYPES = ("SAMPLER",)
+ CATEGORY = "sampling/custom_sampling/samplers"
+
+ FUNCTION = "get_sampler"
+
+ def get_sampler(self, scale_ratio, scale_steps, upscale_method):
+ if scale_steps < 0:
+ scale_steps = None
+ sampler = comfy.samplers.KSAMPLER(sample_lcm_upscale, extra_options={"total_upscale": scale_ratio, "upscale_steps": scale_steps, "upscale_method": upscale_method})
+ return (sampler, )
+
+NODE_CLASS_MAPPINGS = {
+ "SamplerLCMUpscale": SamplerLCMUpscale,
+}
diff --git a/comfy_extras/nodes_align_your_steps.py b/comfy_extras/nodes_align_your_steps.py
new file mode 100644
index 00000000000..3ffe5318785
--- /dev/null
+++ b/comfy_extras/nodes_align_your_steps.py
@@ -0,0 +1,53 @@
+#from: https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/howto.html
+import numpy as np
+import torch
+
+def loglinear_interp(t_steps, num_steps):
+ """
+ Performs log-linear interpolation of a given array of decreasing numbers.
+ """
+ xs = np.linspace(0, 1, len(t_steps))
+ ys = np.log(t_steps[::-1])
+
+ new_xs = np.linspace(0, 1, num_steps)
+ new_ys = np.interp(new_xs, xs, ys)
+
+ interped_ys = np.exp(new_ys)[::-1].copy()
+ return interped_ys
+
+NOISE_LEVELS = {"SD1": [14.6146412293, 6.4745760956, 3.8636745985, 2.6946151520, 1.8841921177, 1.3943805092, 0.9642583904, 0.6523686016, 0.3977456272, 0.1515232662, 0.0291671582],
+ "SDXL":[14.6146412293, 6.3184485287, 3.7681790315, 2.1811480769, 1.3405244945, 0.8620721141, 0.5550693289, 0.3798540708, 0.2332364134, 0.1114188177, 0.0291671582],
+ "SVD": [700.00, 54.5, 15.886, 7.977, 4.248, 1.789, 0.981, 0.403, 0.173, 0.034, 0.002]}
+
+class AlignYourStepsScheduler:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required":
+ {"model_type": (["SD1", "SDXL", "SVD"], ),
+ "steps": ("INT", {"default": 10, "min": 10, "max": 10000}),
+ "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
+ }
+ }
+ RETURN_TYPES = ("SIGMAS",)
+ CATEGORY = "sampling/custom_sampling/schedulers"
+
+ FUNCTION = "get_sigmas"
+
+ def get_sigmas(self, model_type, steps, denoise):
+ total_steps = steps
+ if denoise < 1.0:
+ if denoise <= 0.0:
+ return (torch.FloatTensor([]),)
+ total_steps = round(steps * denoise)
+
+ sigmas = NOISE_LEVELS[model_type][:]
+ if (steps + 1) != len(sigmas):
+ sigmas = loglinear_interp(sigmas, steps + 1)
+
+ sigmas = sigmas[-(total_steps + 1):]
+ sigmas[-1] = 0
+ return (torch.FloatTensor(sigmas), )
+
+NODE_CLASS_MAPPINGS = {
+ "AlignYourStepsScheduler": AlignYourStepsScheduler,
+}
diff --git a/comfy_extras/nodes_attention_multiply.py b/comfy_extras/nodes_attention_multiply.py
new file mode 100644
index 00000000000..4747eb39568
--- /dev/null
+++ b/comfy_extras/nodes_attention_multiply.py
@@ -0,0 +1,120 @@
+
+def attention_multiply(attn, model, q, k, v, out):
+ m = model.clone()
+ sd = model.model_state_dict()
+
+ for key in sd:
+ if key.endswith("{}.to_q.bias".format(attn)) or key.endswith("{}.to_q.weight".format(attn)):
+ m.add_patches({key: (None,)}, 0.0, q)
+ if key.endswith("{}.to_k.bias".format(attn)) or key.endswith("{}.to_k.weight".format(attn)):
+ m.add_patches({key: (None,)}, 0.0, k)
+ if key.endswith("{}.to_v.bias".format(attn)) or key.endswith("{}.to_v.weight".format(attn)):
+ m.add_patches({key: (None,)}, 0.0, v)
+ if key.endswith("{}.to_out.0.bias".format(attn)) or key.endswith("{}.to_out.0.weight".format(attn)):
+ m.add_patches({key: (None,)}, 0.0, out)
+
+ return m
+
+
+class UNetSelfAttentionMultiply:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": { "model": ("MODEL",),
+ "q": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
+ "k": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
+ "v": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
+ "out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
+ }}
+ RETURN_TYPES = ("MODEL",)
+ FUNCTION = "patch"
+
+ CATEGORY = "_for_testing/attention_experiments"
+
+ def patch(self, model, q, k, v, out):
+ m = attention_multiply("attn1", model, q, k, v, out)
+ return (m, )
+
+class UNetCrossAttentionMultiply:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": { "model": ("MODEL",),
+ "q": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
+ "k": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
+ "v": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
+ "out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
+ }}
+ RETURN_TYPES = ("MODEL",)
+ FUNCTION = "patch"
+
+ CATEGORY = "_for_testing/attention_experiments"
+
+ def patch(self, model, q, k, v, out):
+ m = attention_multiply("attn2", model, q, k, v, out)
+ return (m, )
+
+class CLIPAttentionMultiply:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": { "clip": ("CLIP",),
+ "q": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
+ "k": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
+ "v": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
+ "out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
+ }}
+ RETURN_TYPES = ("CLIP",)
+ FUNCTION = "patch"
+
+ CATEGORY = "_for_testing/attention_experiments"
+
+ def patch(self, clip, q, k, v, out):
+ m = clip.clone()
+ sd = m.patcher.model_state_dict()
+
+ for key in sd:
+ if key.endswith("self_attn.q_proj.weight") or key.endswith("self_attn.q_proj.bias"):
+ m.add_patches({key: (None,)}, 0.0, q)
+ if key.endswith("self_attn.k_proj.weight") or key.endswith("self_attn.k_proj.bias"):
+ m.add_patches({key: (None,)}, 0.0, k)
+ if key.endswith("self_attn.v_proj.weight") or key.endswith("self_attn.v_proj.bias"):
+ m.add_patches({key: (None,)}, 0.0, v)
+ if key.endswith("self_attn.out_proj.weight") or key.endswith("self_attn.out_proj.bias"):
+ m.add_patches({key: (None,)}, 0.0, out)
+ return (m, )
+
+class UNetTemporalAttentionMultiply:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": { "model": ("MODEL",),
+ "self_structural": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
+ "self_temporal": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
+ "cross_structural": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
+ "cross_temporal": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
+ }}
+ RETURN_TYPES = ("MODEL",)
+ FUNCTION = "patch"
+
+ CATEGORY = "_for_testing/attention_experiments"
+
+ def patch(self, model, self_structural, self_temporal, cross_structural, cross_temporal):
+ m = model.clone()
+ sd = model.model_state_dict()
+
+ for k in sd:
+ if (k.endswith("attn1.to_out.0.bias") or k.endswith("attn1.to_out.0.weight")):
+ if '.time_stack.' in k:
+ m.add_patches({k: (None,)}, 0.0, self_temporal)
+ else:
+ m.add_patches({k: (None,)}, 0.0, self_structural)
+ elif (k.endswith("attn2.to_out.0.bias") or k.endswith("attn2.to_out.0.weight")):
+ if '.time_stack.' in k:
+ m.add_patches({k: (None,)}, 0.0, cross_temporal)
+ else:
+ m.add_patches({k: (None,)}, 0.0, cross_structural)
+ return (m, )
+
+NODE_CLASS_MAPPINGS = {
+ "UNetSelfAttentionMultiply": UNetSelfAttentionMultiply,
+ "UNetCrossAttentionMultiply": UNetCrossAttentionMultiply,
+ "CLIPAttentionMultiply": CLIPAttentionMultiply,
+ "UNetTemporalAttentionMultiply": UNetTemporalAttentionMultiply,
+}
diff --git a/comfy_extras/nodes_canny.py b/comfy_extras/nodes_canny.py
index fab2ab7ac73..d85e6b85691 100644
--- a/comfy_extras/nodes_canny.py
+++ b/comfy_extras/nodes_canny.py
@@ -1,10 +1,5 @@
-import math
-
-import torch
-import torch.nn.functional as F
-import comfy.model_management
-
from kornia.filters import canny
+import comfy.model_management
class Canny:
diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py
index 06238f89222..47f08bf60d9 100644
--- a/comfy_extras/nodes_custom_sampler.py
+++ b/comfy_extras/nodes_custom_sampler.py
@@ -39,8 +39,8 @@ class KarrasScheduler:
def INPUT_TYPES(s):
return {"required":
{"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
- "sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 1000.0, "step":0.01, "round": False}),
- "sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 1000.0, "step":0.01, "round": False}),
+ "sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}),
+ "sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}),
"rho": ("FLOAT", {"default": 7.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
}
}
@@ -58,8 +58,8 @@ class ExponentialScheduler:
def INPUT_TYPES(s):
return {"required":
{"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
- "sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 1000.0, "step":0.01, "round": False}),
- "sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 1000.0, "step":0.01, "round": False}),
+ "sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}),
+ "sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}),
}
}
RETURN_TYPES = ("SIGMAS",)
@@ -76,8 +76,8 @@ class PolyexponentialScheduler:
def INPUT_TYPES(s):
return {"required":
{"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
- "sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 1000.0, "step":0.01, "round": False}),
- "sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 1000.0, "step":0.01, "round": False}),
+ "sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}),
+ "sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}),
"rho": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
}
}
@@ -107,8 +107,7 @@ def INPUT_TYPES(s):
def get_sigmas(self, model, steps, denoise):
start_step = 10 - int(10 * denoise)
timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[start_step:start_step + steps]
- comfy.model_management.load_models_gpu([model])
- sigmas = model.model.model_sampling.sigma(timesteps)
+ sigmas = model.get_model_object("model_sampling").sigma(timesteps)
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
return (sigmas, )
@@ -117,8 +116,8 @@ class VPScheduler:
def INPUT_TYPES(s):
return {"required":
{"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
- "beta_d": ("FLOAT", {"default": 19.9, "min": 0.0, "max": 1000.0, "step":0.01, "round": False}), #TODO: fix default values
- "beta_min": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1000.0, "step":0.01, "round": False}),
+ "beta_d": ("FLOAT", {"default": 19.9, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), #TODO: fix default values
+ "beta_min": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}),
"eps_s": ("FLOAT", {"default": 0.001, "min": 0.0, "max": 1.0, "step":0.0001, "round": False}),
}
}
@@ -140,6 +139,7 @@ def INPUT_TYPES(s):
}
}
RETURN_TYPES = ("SIGMAS","SIGMAS")
+ RETURN_NAMES = ("high_sigmas", "low_sigmas")
CATEGORY = "sampling/custom_sampling/sigmas"
FUNCTION = "get_sigmas"
@@ -149,6 +149,27 @@ def get_sigmas(self, sigmas, step):
sigmas2 = sigmas[step:]
return (sigmas1, sigmas2)
+class SplitSigmasDenoise:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required":
+ {"sigmas": ("SIGMAS", ),
+ "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
+ }
+ }
+ RETURN_TYPES = ("SIGMAS","SIGMAS")
+ RETURN_NAMES = ("high_sigmas", "low_sigmas")
+ CATEGORY = "sampling/custom_sampling/sigmas"
+
+ FUNCTION = "get_sigmas"
+
+ def get_sigmas(self, sigmas, denoise):
+ steps = max(sigmas.shape[-1] - 1, 0)
+ total_steps = round(steps * denoise)
+ sigmas1 = sigmas[:-(total_steps)]
+ sigmas2 = sigmas[-(total_steps + 1):]
+ return (sigmas1, sigmas2)
+
class FlipSigmas:
@classmethod
def INPUT_TYPES(s):
@@ -600,6 +621,7 @@ def add_noise(self, model, noise, sigmas, latent_image):
"SamplerDPMPP_SDE": SamplerDPMPP_SDE,
"SamplerDPMAdaptative": SamplerDPMAdaptative,
"SplitSigmas": SplitSigmas,
+ "SplitSigmasDenoise": SplitSigmasDenoise,
"FlipSigmas": FlipSigmas,
"CFGGuider": CFGGuider,
diff --git a/comfy_extras/nodes_freelunch.py b/comfy_extras/nodes_freelunch.py
index 6f1d87bf354..c5ebcf26fd6 100644
--- a/comfy_extras/nodes_freelunch.py
+++ b/comfy_extras/nodes_freelunch.py
@@ -42,7 +42,7 @@ def patch(self, model, b1, b2, s1, s2):
on_cpu_devices = {}
def output_block_patch(h, hsp, transformer_options):
- scale = scale_dict.get(h.shape[1], None)
+ scale = scale_dict.get(int(h.shape[1]), None)
if scale is not None:
h[:,:h.shape[1] // 2] = h[:,:h.shape[1] // 2] * scale[0]
if hsp.device not in on_cpu_devices:
@@ -81,7 +81,7 @@ def patch(self, model, b1, b2, s1, s2):
on_cpu_devices = {}
def output_block_patch(h, hsp, transformer_options):
- scale = scale_dict.get(h.shape[1], None)
+ scale = scale_dict.get(int(h.shape[1]), None)
if scale is not None:
hidden_mean = h.mean(1).unsqueeze(1)
B = hidden_mean.shape[0]
diff --git a/comfy_extras/nodes_model_downscale.py b/comfy_extras/nodes_model_downscale.py
index 48bcc689273..58b5073ec08 100644
--- a/comfy_extras/nodes_model_downscale.py
+++ b/comfy_extras/nodes_model_downscale.py
@@ -20,8 +20,9 @@ def INPUT_TYPES(s):
CATEGORY = "_for_testing"
def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method):
- sigma_start = model.model.model_sampling.percent_to_sigma(start_percent)
- sigma_end = model.model.model_sampling.percent_to_sigma(end_percent)
+ model_sampling = model.get_model_object("model_sampling")
+ sigma_start = model_sampling.percent_to_sigma(start_percent)
+ sigma_end = model_sampling.percent_to_sigma(end_percent)
def input_block_patch(h, transformer_options):
if transformer_options["block"][1] == block_number:
diff --git a/comfy_extras/nodes_model_merging.py b/comfy_extras/nodes_model_merging.py
index 2a431f65da9..bb15112f4e9 100644
--- a/comfy_extras/nodes_model_merging.py
+++ b/comfy_extras/nodes_model_merging.py
@@ -175,9 +175,14 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi
enable_modelspec = True
if isinstance(model.model, comfy.model_base.SDXL):
- metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-base"
+ if isinstance(model.model, comfy.model_base.SDXL_instructpix2pix):
+ metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-edit"
+ else:
+ metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-base"
elif isinstance(model.model, comfy.model_base.SDXLRefiner):
metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-refiner"
+ elif isinstance(model.model, comfy.model_base.SVD_img2vid):
+ metadata["modelspec.architecture"] = "stable-video-diffusion-img2vid-v1"
else:
enable_modelspec = False
@@ -262,7 +267,7 @@ def save(self, clip, filename_prefix, prompt=None, extra_pnginfo=None):
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
- comfy.model_management.load_models_gpu([clip.load_model()])
+ comfy.model_management.load_models_gpu([clip.load_model()], force_patch_weights=True)
clip_sd = clip.get_sd()
for prefix in ["clip_l.", "clip_g.", ""]:
diff --git a/comfy_extras/nodes_perpneg.py b/comfy_extras/nodes_perpneg.py
index 306cf9cd0f6..546276aa154 100644
--- a/comfy_extras/nodes_perpneg.py
+++ b/comfy_extras/nodes_perpneg.py
@@ -61,12 +61,38 @@ def set_cfg(self, cfg, neg_scale):
self.neg_scale = neg_scale
def predict_noise(self, x, timestep, model_options={}, seed=None):
+ # in CFGGuider.predict_noise, we call sampling_function(), which uses cfg_function() to compute pos & neg
+ # but we'd rather do a single batch of sampling pos, neg, and empty, so we call calc_cond_batch([pos,neg,empty]) directly
+
positive_cond = self.conds.get("positive", None)
negative_cond = self.conds.get("negative", None)
empty_cond = self.conds.get("empty_negative_prompt", None)
- out = comfy.samplers.calc_cond_batch(self.inner_model, [negative_cond, positive_cond, empty_cond], x, timestep, model_options)
- return perp_neg(x, out[1], out[0], out[2], self.neg_scale, self.cfg)
+ (noise_pred_pos, noise_pred_neg, noise_pred_empty) = \
+ comfy.samplers.calc_cond_batch(self.inner_model, [positive_cond, negative_cond, empty_cond], x, timestep, model_options)
+ cfg_result = perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_empty, self.neg_scale, self.cfg)
+
+ # normally this would be done in cfg_function, but we skipped
+ # that for efficiency: we can compute the noise predictions in
+ # a single call to calc_cond_batch() (rather than two)
+ # so we replicate the hook here
+ for fn in model_options.get("sampler_post_cfg_function", []):
+ args = {
+ "denoised": cfg_result,
+ "cond": positive_cond,
+ "uncond": negative_cond,
+ "model": self.inner_model,
+ "uncond_denoised": noise_pred_neg,
+ "cond_denoised": noise_pred_pos,
+ "sigma": timestep,
+ "model_options": model_options,
+ "input": x,
+ # not in the original call in samplers.py:cfg_function, but made available for future hooks
+ "empty_cond": empty_cond,
+ "empty_cond_denoised": noise_pred_empty,}
+ cfg_result = fn(args)
+
+ return cfg_result
class PerpNegGuider:
@classmethod
diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py
index 0110b472f37..68f6ef51e79 100644
--- a/comfy_extras/nodes_post_processing.py
+++ b/comfy_extras/nodes_post_processing.py
@@ -5,6 +5,7 @@
import math
import comfy.utils
+import comfy.model_management
class Blend:
@@ -102,6 +103,7 @@ def blur(self, image: torch.Tensor, blur_radius: int, sigma: float):
if blur_radius == 0:
return (image,)
+ image = image.to(comfy.model_management.get_torch_device())
batch_size, height, width, channels = image.shape
kernel_size = blur_radius * 2 + 1
@@ -112,7 +114,7 @@ def blur(self, image: torch.Tensor, blur_radius: int, sigma: float):
blurred = F.conv2d(padded_image, kernel, padding=kernel_size // 2, groups=channels)[:,:,blur_radius:-blur_radius, blur_radius:-blur_radius]
blurred = blurred.permute(0, 2, 3, 1)
- return (blurred,)
+ return (blurred.to(comfy.model_management.intermediate_device()),)
class Quantize:
def __init__(self):
@@ -225,6 +227,7 @@ def sharpen(self, image: torch.Tensor, sharpen_radius: int, sigma:float, alpha:
return (image,)
batch_size, height, width, channels = image.shape
+ image = image.to(comfy.model_management.get_torch_device())
kernel_size = sharpen_radius * 2 + 1
kernel = gaussian_kernel(kernel_size, sigma, device=image.device) * -(alpha*10)
@@ -239,7 +242,7 @@ def sharpen(self, image: torch.Tensor, sharpen_radius: int, sigma:float, alpha:
result = torch.clamp(sharpened, 0, 1)
- return (result,)
+ return (result.to(comfy.model_management.intermediate_device()),)
class ImageScaleToTotalPixels:
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
diff --git a/comfy_extras/nodes_sag.py b/comfy_extras/nodes_sag.py
index 69084e91db1..010e9974496 100644
--- a/comfy_extras/nodes_sag.py
+++ b/comfy_extras/nodes_sag.py
@@ -4,13 +4,12 @@
import math
from einops import rearrange, repeat
-import os
-from comfy.ldm.modules.attention import optimized_attention, _ATTN_PRECISION
+from comfy.ldm.modules.attention import optimized_attention
import comfy.samplers
# from comfy/ldm/modules/attention.py
# but modified to return attention scores as well as output
-def attention_basic_with_sim(q, k, v, heads, mask=None):
+def attention_basic_with_sim(q, k, v, heads, mask=None, attn_precision=None):
b, _, dim_head = q.shape
dim_head //= heads
scale = dim_head ** -0.5
@@ -26,7 +25,7 @@ def attention_basic_with_sim(q, k, v, heads, mask=None):
)
# force cast to fp32 to avoid overflowing
- if _ATTN_PRECISION =="fp32":
+ if attn_precision == torch.float32:
sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
else:
sim = einsum('b i d, b j d -> b i j', q, k) * scale
@@ -121,13 +120,13 @@ def attn_and_record(q, k, v, extra_options):
if 1 in cond_or_uncond:
uncond_index = cond_or_uncond.index(1)
# do the entire attention operation, but save the attention scores to attn_scores
- (out, sim) = attention_basic_with_sim(q, k, v, heads=heads)
+ (out, sim) = attention_basic_with_sim(q, k, v, heads=heads, attn_precision=extra_options["attn_precision"])
# when using a higher batch size, I BELIEVE the result batch dimension is [uc1, ... ucn, c1, ... cn]
n_slices = heads * b
attn_scores = sim[n_slices * uncond_index:n_slices * (uncond_index+1)]
return out
else:
- return optimized_attention(q, k, v, heads=heads)
+ return optimized_attention(q, k, v, heads=heads, attn_precision=extra_options["attn_precision"])
def post_cfg_function(args):
nonlocal attn_scores
diff --git a/comfy_extras/nodes_sdupscale.py b/comfy_extras/nodes_sdupscale.py
index 28c1cb0f171..bba67e8ddff 100644
--- a/comfy_extras/nodes_sdupscale.py
+++ b/comfy_extras/nodes_sdupscale.py
@@ -1,5 +1,4 @@
import torch
-import nodes
import comfy.utils
class SD_4XUpscale_Conditioning:
diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py
index 2b5e49a55c2..bca79ef2e13 100644
--- a/comfy_extras/nodes_upscale_model.py
+++ b/comfy_extras/nodes_upscale_model.py
@@ -1,10 +1,19 @@
import os
-from comfy_extras.chainner_models import model_loading
+import logging
+from spandrel import ModelLoader, ImageModelDescriptor
from comfy import model_management
import torch
import comfy.utils
import folder_paths
+try:
+ from spandrel_extra_arches import EXTRA_REGISTRY
+ from spandrel import MAIN_REGISTRY
+ MAIN_REGISTRY.add(*EXTRA_REGISTRY)
+ logging.info("Successfully imported spandrel_extra_arches: support for non commercial upscale models.")
+except:
+ pass
+
class UpscaleModelLoader:
@classmethod
def INPUT_TYPES(s):
@@ -20,7 +29,11 @@ def load_model(self, model_name):
sd = comfy.utils.load_torch_file(model_path, safe_load=True)
if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd:
sd = comfy.utils.state_dict_prefix_replace(sd, {"module.":""})
- out = model_loading.load_state_dict(sd).eval()
+ out = ModelLoader().load_from_state_dict(sd).eval()
+
+ if not isinstance(out, ImageModelDescriptor):
+ raise Exception("Upscale model must be a single-image model.")
+
return (out, )
@@ -37,9 +50,14 @@ def INPUT_TYPES(s):
def upscale(self, upscale_model, image):
device = model_management.get_torch_device()
+
+ memory_required = model_management.module_size(upscale_model.model)
+ memory_required += (512 * 512 * 3) * image.element_size() * max(upscale_model.scale, 1.0) * 384.0 #The 384.0 is an estimate of how much some of these models take, TODO: make it more accurate
+ memory_required += image.nelement() * image.element_size()
+ model_management.free_memory(memory_required, device)
+
upscale_model.to(device)
in_img = image.movedim(-1,-3).to(device)
- free_memory = model_management.get_free_memory(device)
tile = 512
overlap = 32
@@ -56,7 +74,7 @@ def upscale(self, upscale_model, image):
if tile < 128:
raise e
- upscale_model.cpu()
+ upscale_model.to("cpu")
s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0)
return (s,)
diff --git a/comfy_extras/nodes_webcam.py b/comfy_extras/nodes_webcam.py
new file mode 100644
index 00000000000..32a0ba2f67b
--- /dev/null
+++ b/comfy_extras/nodes_webcam.py
@@ -0,0 +1,33 @@
+import nodes
+import folder_paths
+
+MAX_RESOLUTION = nodes.MAX_RESOLUTION
+
+
+class WebcamCapture(nodes.LoadImage):
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "image": ("WEBCAM", {}),
+ "width": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
+ "height": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
+ "capture_on_queue": ("BOOLEAN", {"default": True}),
+ }
+ }
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "load_capture"
+
+ CATEGORY = "image"
+
+ def load_capture(s, image, **kwargs):
+ return super().load_image(folder_paths.get_annotated_filepath(image))
+
+
+NODE_CLASS_MAPPINGS = {
+ "WebcamCapture": WebcamCapture,
+}
+
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "WebcamCapture": "Webcam Capture",
+}
\ No newline at end of file
diff --git a/execution.py b/execution.py
index f269936f257..c0379222b69 100644
--- a/execution.py
+++ b/execution.py
@@ -200,7 +200,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
for node_id, node_outputs in outputs.items():
output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs]
- logging.error("!!! Exception during processing !!!")
+ logging.error(f"!!! Exception during processing!!! {ex}")
logging.error(traceback.format_exc())
error_details = {
@@ -646,8 +646,27 @@ def full_type_name(klass):
def validate_prompt(prompt):
outputs = set()
for x in prompt:
- class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']]
- if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE == True:
+ if 'class_type' not in prompt[x]:
+ error = {
+ "type": "invalid_prompt",
+ "message": f"Cannot execute because a node is missing the class_type property.",
+ "details": f"Node ID '#{x}'",
+ "extra_info": {}
+ }
+ return (False, error, [], [])
+
+ class_type = prompt[x]['class_type']
+ class_ = nodes.NODE_CLASS_MAPPINGS.get(class_type, None)
+ if class_ is None:
+ error = {
+ "type": "invalid_prompt",
+ "message": f"Cannot execute because node {class_type} does not exist.",
+ "details": f"Node ID '#{x}'",
+ "extra_info": {}
+ }
+ return (False, error, [], [])
+
+ if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True:
outputs.add(x)
if len(outputs) == 0:
diff --git a/folder_paths.py b/folder_paths.py
index 3f9a32b4d59..234b734095e 100644
--- a/folder_paths.py
+++ b/folder_paths.py
@@ -2,7 +2,7 @@
import time
import logging
-supported_pt_extensions = set(['.ckpt', '.pt', '.bin', '.pth', '.safetensors'])
+supported_pt_extensions = set(['.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl'])
folder_names_and_paths = {}
@@ -258,7 +258,7 @@ def compute_vars(input, image_width, image_height):
raise Exception(err)
try:
- counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1
+ counter = max(filter(lambda a: os.path.normcase(a[1][:-1]) == os.path.normcase(filename) and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1
except ValueError:
counter = 1
except FileNotFoundError:
diff --git a/latent_preview.py b/latent_preview.py
index e5107ea3b9b..75fcedd4824 100644
--- a/latent_preview.py
+++ b/latent_preview.py
@@ -4,6 +4,7 @@
import numpy as np
from comfy.cli_args import args, LatentPreviewMethod
from comfy.taesd.taesd import TAESD
+import comfy.model_management
import folder_paths
import comfy.utils
from comfy import model_management
@@ -11,6 +12,13 @@
MAX_PREVIEW_RESOLUTION = 512
+def preview_to_image(latent_image):
+ latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1
+ .mul(0xFF) # to 0..255
+ ).to(device="cpu", dtype=torch.uint8, non_blocking=comfy.model_management.device_supports_non_blocking(latent_image.device))
+
+ return Image.fromarray(latents_ubyte.numpy())
+
class LatentPreviewer:
def decode_latent_to_preview(self, x0):
pass
@@ -25,13 +33,8 @@ def __init__(self, taesd, device):
self.device = device
def decode_latent_to_preview(self, x0):
- x_sample = self.taesd.decode(x0[:1].to(self.device))[0].detach()
- x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
- x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
- x_sample = x_sample.astype(np.uint8)
-
- preview_image = Image.fromarray(x_sample)
- return preview_image
+ x_sample = self.taesd.decode(x0[:1].to(self.device))[0].movedim(0, 2)
+ return preview_to_image(x_sample)
class Latent2RGBPreviewer(LatentPreviewer):
@@ -39,14 +42,9 @@ def __init__(self, latent_rgb_factors):
self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu")
def decode_latent_to_preview(self, x0):
- latent_image = x0[0].permute(1, 2, 0).cpu() @ self.latent_rgb_factors
-
- latents_ubyte = (((latent_image + 1) / 2)
- .clamp(0, 1) # change scale from -1..1 to 0..1
- .mul(0xFF) # to 0..255
- .byte()).cpu()
-
- return Image.fromarray(latents_ubyte.numpy())
+ self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device)
+ latent_image = x0[0].permute(1, 2, 0) @ self.latent_rgb_factors
+ return preview_to_image(latent_image)
def get_previewer(device, latent_format):
@@ -67,8 +65,6 @@ def get_previewer(device, latent_format):
if method == LatentPreviewMethod.Auto:
method = LatentPreviewMethod.Latent2RGB
- if taesd_decoder_path:
- method = LatentPreviewMethod.TAESD
if method == LatentPreviewMethod.TAESD:
if taesd_decoder_path:
diff --git a/main.py b/main.py
index b3a3ebea8cf..a374f2b124a 100644
--- a/main.py
+++ b/main.py
@@ -243,11 +243,11 @@ def load_extra_path_config(yaml_path):
call_on_start = None
if args.auto_launch:
- def startup_server(address, port):
+ def startup_server(scheme, address, port):
import webbrowser
if os.name == 'nt' and address == '0.0.0.0':
address = '127.0.0.1'
- webbrowser.open(f"http://{address}:{port}")
+ webbrowser.open(f"{scheme}://{address}:{port}")
call_on_start = startup_server
try:
diff --git a/node_helpers.py b/node_helpers.py
index 8828a4ec9d0..43b9e829f59 100644
--- a/node_helpers.py
+++ b/node_helpers.py
@@ -1,3 +1,4 @@
+from PIL import ImageFile, UnidentifiedImageError
def conditioning_set_values(conditioning, values={}):
c = []
@@ -8,3 +9,16 @@ def conditioning_set_values(conditioning, values={}):
c.append(n)
return c
+
+def pillow(fn, arg):
+ prev_value = None
+ try:
+ x = fn(arg)
+ except (OSError, UnidentifiedImageError, ValueError): #PIL issues #4472 and #2445, also fixes ComfyUI issue #3416
+ prev_value = ImageFile.LOAD_TRUNCATED_IMAGES
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
+ x = fn(arg)
+ finally:
+ if prev_value is not None:
+ ImageFile.LOAD_TRUNCATED_IMAGES = prev_value
+ return x
diff --git a/nodes.py b/nodes.py
index ea1e32030e0..34821ca3f5b 100644
--- a/nodes.py
+++ b/nodes.py
@@ -10,14 +10,14 @@
import random
import logging
-from PIL import Image, ImageOps, ImageSequence
+from PIL import Image, ImageOps, ImageSequence, ImageFile
from PIL.PngImagePlugin import PngInfo
+
import numpy as np
import safetensors.torch
sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy"))
-
import comfy.diffusers_load
import comfy.samplers
import comfy.sample
@@ -583,8 +583,8 @@ def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"clip": ("CLIP", ),
"lora_name": (folder_paths.get_filename_list("loras"), ),
- "strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
- "strength_clip": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
+ "strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01}),
+ "strength_clip": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL", "CLIP")
FUNCTION = "load_lora"
@@ -617,7 +617,7 @@ class LoraLoaderModelOnly(LoraLoader):
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"lora_name": (folder_paths.get_filename_list("loras"), ),
- "strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
+ "strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "load_lora_model_only"
@@ -1456,14 +1456,29 @@ def INPUT_TYPES(s):
FUNCTION = "load_image"
def load_image(self, image):
image_path = folder_paths.get_annotated_filepath(image)
- img = Image.open(image_path)
+
+ img = node_helpers.pillow(Image.open, image_path)
+
output_images = []
output_masks = []
+ w, h = None, None
+
+ excluded_formats = ['MPO']
+
for i in ImageSequence.Iterator(img):
- i = ImageOps.exif_transpose(i)
+ i = node_helpers.pillow(ImageOps.exif_transpose, i)
+
if i.mode == 'I':
i = i.point(lambda i: i * (1 / 255))
image = i.convert("RGB")
+
+ if len(output_images) == 0:
+ w = image.size[0]
+ h = image.size[1]
+
+ if image.size[0] != w or image.size[1] != h:
+ continue
+
image = np.array(image).astype(np.float32) / 255.0
image = torch.from_numpy(image)[None,]
if 'A' in i.getbands():
@@ -1474,7 +1489,7 @@ def load_image(self, image):
output_images.append(image)
output_masks.append(mask.unsqueeze(0))
- if len(output_images) > 1:
+ if len(output_images) > 1 and img.format not in excluded_formats:
output_image = torch.cat(output_images, dim=0)
output_mask = torch.cat(output_masks, dim=0)
else:
@@ -1515,8 +1530,8 @@ def INPUT_TYPES(s):
FUNCTION = "load_image"
def load_image(self, image, channel):
image_path = folder_paths.get_annotated_filepath(image)
- i = Image.open(image_path)
- i = ImageOps.exif_transpose(i)
+ i = node_helpers.pillow(Image.open, image_path)
+ i = node_helpers.pillow(ImageOps.exif_transpose, i)
if i.getbands() != ("R", "G", "B", "A"):
if i.mode == 'I':
i = i.point(lambda i: i * (1 / 255))
@@ -1943,6 +1958,10 @@ def init_custom_nodes():
"nodes_ip2p.py",
"nodes_model_merging_model_specific.py",
"nodes_pag.py",
+ "nodes_align_your_steps.py",
+ "nodes_attention_multiply.py",
+ "nodes_advanced_samplers.py",
+ "nodes_webcam.py",
]
import_failed = []
diff --git a/requirements.txt b/requirements.txt
index 33b89f4dc54..1d982eb2194 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -11,4 +11,7 @@ scipy
tqdm
psutil
gputil
+
+#non essential dependencies:
kornia>=0.7.1
+spandrel
diff --git a/server.py b/server.py
index 5642bd5e2c4..bab3b06000d 100644
--- a/server.py
+++ b/server.py
@@ -11,6 +11,7 @@
import json
import glob
import struct
+import ssl
from PIL import Image, ImageOps
from PIL.PngImagePlugin import PngInfo
from io import BytesIO
@@ -623,14 +624,22 @@ async def publish_loop(self):
async def start(self, address, port, verbose=True, call_on_start=None):
runner = web.AppRunner(self.app, access_log=None)
await runner.setup()
- site = web.TCPSite(runner, address, port)
+ ssl_ctx = None
+ scheme = "http"
+ if args.tls_keyfile and args.tls_certfile:
+ ssl_ctx = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS_SERVER, verify_mode=ssl.CERT_NONE)
+ ssl_ctx.load_cert_chain(certfile=args.tls_certfile,
+ keyfile=args.tls_keyfile)
+ scheme = "https"
+
+ site = web.TCPSite(runner, address, port, ssl_context=ssl_ctx)
await site.start()
if verbose:
logging.info("Starting server\n")
- logging.info("To see the GUI go to: http://{}:{}".format(address, port))
+ logging.info("To see the GUI go to: {}://{}:{}".format(scheme, address, port))
if call_on_start is not None:
- call_on_start(address, port)
+ call_on_start(scheme, address, port)
def add_on_prompt_handler(self, handler):
self.on_prompt_handlers.append(handler)
diff --git a/tests-ui/utils/ezgraph.js b/tests-ui/utils/ezgraph.js
index 8bf8c5d8c4a..97be7aa72f0 100644
--- a/tests-ui/utils/ezgraph.js
+++ b/tests-ui/utils/ezgraph.js
@@ -204,7 +204,7 @@ export class EzWidget {
convertToWidget() {
if (!this.isConvertedToInput)
throw new Error(`Widget ${this.widget.name} cannot be converted as it is already a widget.`);
- var menu = this.node.menu["Convert 🔘 to widget.."].item.submenu.options;
+ var menu = this.node.menu["Convert Input to Widget"].item.submenu.options;
var index = menu.findIndex(a => a.content == `Convert ${this.widget.name} to widget`);
menu[index].callback.call();
}
@@ -212,7 +212,7 @@ export class EzWidget {
convertToInput() {
if (this.isConvertedToInput)
throw new Error(`Widget ${this.widget.name} cannot be converted as it is already an input.`);
- var menu = this.node.menu["Convert input to 🔘.."].item.submenu.options;
+ var menu = this.node.menu["Convert Widget to Input"].item.submenu.options;
var index = menu.findIndex(a => a.content == `Convert ${this.widget.name} to input`);
menu[index].callback.call();
}
diff --git a/web/extensions/core/keybinds.js b/web/extensions/core/keybinds.js
index cf698ea5a66..ac367c116f8 100644
--- a/web/extensions/core/keybinds.js
+++ b/web/extensions/core/keybinds.js
@@ -21,7 +21,6 @@ app.registerExtension({
s: "#comfy-save-button",
o: "#comfy-file-input",
Backspace: "#comfy-clear-button",
- Delete: "#comfy-clear-button",
d: "#comfy-load-default-button",
};
diff --git a/web/extensions/core/maskeditor.js b/web/extensions/core/maskeditor.js
index 4f69ac7607c..36f7496e711 100644
--- a/web/extensions/core/maskeditor.js
+++ b/web/extensions/core/maskeditor.js
@@ -164,6 +164,41 @@ class MaskEditorDialog extends ComfyDialog {
return divElement;
}
+ createOpacitySlider(self, name, callback) {
+ const divElement = document.createElement('div');
+ divElement.id = "maskeditor-opacity-slider";
+ divElement.style.cssFloat = "left";
+ divElement.style.fontFamily = "sans-serif";
+ divElement.style.marginRight = "4px";
+ divElement.style.color = "var(--input-text)";
+ divElement.style.backgroundColor = "var(--comfy-input-bg)";
+ divElement.style.borderRadius = "8px";
+ divElement.style.borderColor = "var(--border-color)";
+ divElement.style.borderStyle = "solid";
+ divElement.style.fontSize = "15px";
+ divElement.style.height = "21px";
+ divElement.style.padding = "1px 6px";
+ divElement.style.display = "flex";
+ divElement.style.position = "relative";
+ divElement.style.top = "2px";
+ divElement.style.pointerEvents = "auto";
+ self.opacity_slider_input = document.createElement('input');
+ self.opacity_slider_input.setAttribute('type', 'range');
+ self.opacity_slider_input.setAttribute('min', '0.1');
+ self.opacity_slider_input.setAttribute('max', '1.0');
+ self.opacity_slider_input.setAttribute('step', '0.01')
+ self.opacity_slider_input.setAttribute('value', '0.7');
+ const labelElement = document.createElement("label");
+ labelElement.textContent = name;
+
+ divElement.appendChild(labelElement);
+ divElement.appendChild(self.opacity_slider_input);
+
+ self.opacity_slider_input.addEventListener("input", callback);
+
+ return divElement;
+ }
+
setlayout(imgCanvas, maskCanvas) {
const self = this;
@@ -203,6 +238,13 @@ class MaskEditorDialog extends ComfyDialog {
self.updateBrushPreview(self, null, null);
});
+ this.brush_opacity_slider = this.createOpacitySlider(self, "Opacity", (event) => {
+ self.brush_opacity = event.target.value;
+ if (self.brush_color_mode !== "negative") {
+ self.maskCanvas.style.opacity = self.brush_opacity;
+ }
+ });
+
this.colorButton = this.createLeftButton(this.getColorButtonText(), () => {
if (self.brush_color_mode === "black") {
self.brush_color_mode = "white";
@@ -237,6 +279,7 @@ class MaskEditorDialog extends ComfyDialog {
bottom_panel.appendChild(this.saveButton);
bottom_panel.appendChild(cancelButton);
bottom_panel.appendChild(this.brush_size_slider);
+ bottom_panel.appendChild(this.brush_opacity_slider);
bottom_panel.appendChild(this.colorButton);
imgCanvas.style.position = "absolute";
@@ -472,7 +515,7 @@ class MaskEditorDialog extends ComfyDialog {
else {
return {
mixBlendMode: "initial",
- opacity: "0.7",
+ opacity: this.brush_opacity,
};
}
}
@@ -538,6 +581,7 @@ class MaskEditorDialog extends ComfyDialog {
this.maskCtx.putImageData(maskData, 0, 0);
}
+ brush_opacity = 0.7;
brush_size = 10;
brush_color_mode = "black";
drawing_mode = false;
diff --git a/web/extensions/core/snapToGrid.js b/web/extensions/core/snapToGrid.js
index dc534d6edf9..aac01774840 100644
--- a/web/extensions/core/snapToGrid.js
+++ b/web/extensions/core/snapToGrid.js
@@ -2,6 +2,13 @@ import { app } from "../../scripts/app.js";
// Shift + drag/resize to snap to grid
+/** Rounds a Vector2 in-place to the current CANVAS_GRID_SIZE. */
+function roundVectorToGrid(vec) {
+ vec[0] = LiteGraph.CANVAS_GRID_SIZE * Math.round(vec[0] / LiteGraph.CANVAS_GRID_SIZE);
+ vec[1] = LiteGraph.CANVAS_GRID_SIZE * Math.round(vec[1] / LiteGraph.CANVAS_GRID_SIZE);
+ return vec;
+}
+
app.registerExtension({
name: "Comfy.SnapToGrid",
init() {
@@ -43,10 +50,7 @@ app.registerExtension({
const onResize = node.onResize;
node.onResize = function () {
if (app.shiftDown) {
- const w = LiteGraph.CANVAS_GRID_SIZE * Math.round(node.size[0] / LiteGraph.CANVAS_GRID_SIZE);
- const h = LiteGraph.CANVAS_GRID_SIZE * Math.round(node.size[1] / LiteGraph.CANVAS_GRID_SIZE);
- node.size[0] = w;
- node.size[1] = h;
+ roundVectorToGrid(node.size);
}
return onResize?.apply(this, arguments);
};
@@ -57,9 +61,7 @@ app.registerExtension({
const origDrawNode = LGraphCanvas.prototype.drawNode;
LGraphCanvas.prototype.drawNode = function (node, ctx) {
if (app.shiftDown && this.node_dragged && node.id in this.selected_nodes) {
- const x = LiteGraph.CANVAS_GRID_SIZE * Math.round(node.pos[0] / LiteGraph.CANVAS_GRID_SIZE);
- const y = LiteGraph.CANVAS_GRID_SIZE * Math.round(node.pos[1] / LiteGraph.CANVAS_GRID_SIZE);
-
+ const [x, y] = roundVectorToGrid([...node.pos]);
const shiftX = x - node.pos[0];
let shiftY = y - node.pos[1];
@@ -85,5 +87,85 @@ app.registerExtension({
return origDrawNode.apply(this, arguments);
};
+
+
+
+ /**
+ * The currently moving, selected group only. Set after the `selected_group` has actually started
+ * moving.
+ */
+ let selectedAndMovingGroup = null;
+
+ /**
+ * Handles moving a group; tracking when a group has been moved (to show the ghost in `drawGroups`
+ * below) as well as handle the last move call from LiteGraph's `processMouseUp`.
+ */
+ const groupMove = LGraphGroup.prototype.move;
+ LGraphGroup.prototype.move = function(deltax, deltay, ignore_nodes) {
+ const v = groupMove.apply(this, arguments);
+ // When we've started moving, set `selectedAndMovingGroup` as LiteGraph sets `selected_group`
+ // too eagerly and we don't want to behave like we're moving until we get a delta.
+ if (!selectedAndMovingGroup && app.canvas.selected_group === this && (deltax || deltay)) {
+ selectedAndMovingGroup = this;
+ }
+
+ // LiteGraph will call group.move both on mouse-move as well as mouse-up though we only want
+ // to snap on a mouse-up which we can determine by checking if `app.canvas.last_mouse_dragging`
+ // has been set to `false`. Essentially, this check here is the equivilant to calling an
+ // `LGraphGroup.prototype.onNodeMoved` if it had existed.
+ if (app.canvas.last_mouse_dragging === false && app.shiftDown) {
+ // After moving a group (while app.shiftDown), snap all the child nodes and, finally,
+ // align the group itself.
+ this.recomputeInsideNodes();
+ for (const node of this._nodes) {
+ node.alignToGrid();
+ }
+ LGraphNode.prototype.alignToGrid.apply(this);
+ }
+ return v;
+ };
+
+ /**
+ * Handles drawing a group when, snapping the size when one is actively being resized tracking and/or
+ * drawing a ghost box when one is actively being moved. This mimics the node snapping behavior for
+ * both.
+ */
+ const drawGroups = LGraphCanvas.prototype.drawGroups;
+ LGraphCanvas.prototype.drawGroups = function (canvas, ctx) {
+ if (this.selected_group && app.shiftDown) {
+ if (this.selected_group_resizing) {
+ roundVectorToGrid(this.selected_group.size);
+ } else if (selectedAndMovingGroup) {
+ const [x, y] = roundVectorToGrid([...selectedAndMovingGroup.pos]);
+ const f = ctx.fillStyle;
+ const s = ctx.strokeStyle;
+ ctx.fillStyle = "rgba(100, 100, 100, 0.33)";
+ ctx.strokeStyle = "rgba(100, 100, 100, 0.66)";
+ ctx.rect(x, y, ...selectedAndMovingGroup.size);
+ ctx.fill();
+ ctx.stroke();
+ ctx.fillStyle = f;
+ ctx.strokeStyle = s;
+ }
+ } else if (!this.selected_group) {
+ selectedAndMovingGroup = null;
+ }
+ return drawGroups.apply(this, arguments);
+ };
+
+
+ /** Handles adding a group in a snapping-enabled state. */
+ const onGroupAdd = LGraphCanvas.onGroupAdd;
+ LGraphCanvas.onGroupAdd = function() {
+ const v = onGroupAdd.apply(app.canvas, arguments);
+ if (app.shiftDown) {
+ const lastGroup = app.graph._groups[app.graph._groups.length - 1];
+ if (lastGroup) {
+ roundVectorToGrid(lastGroup.pos);
+ roundVectorToGrid(lastGroup.size);
+ }
+ }
+ return v;
+ };
},
});
diff --git a/web/extensions/core/webcamCapture.js b/web/extensions/core/webcamCapture.js
new file mode 100644
index 00000000000..dd5725bd4fb
--- /dev/null
+++ b/web/extensions/core/webcamCapture.js
@@ -0,0 +1,126 @@
+import { app } from "../../scripts/app.js";
+import { api } from "../../scripts/api.js";
+
+const WEBCAM_READY = Symbol();
+
+app.registerExtension({
+ name: "Comfy.WebcamCapture",
+ getCustomWidgets(app) {
+ return {
+ WEBCAM(node, inputName) {
+ let res;
+ node[WEBCAM_READY] = new Promise((resolve) => (res = resolve));
+
+ const container = document.createElement("div");
+ container.style.background = "rgba(0,0,0,0.25)";
+ container.style.textAlign = "center";
+
+ const video = document.createElement("video");
+ video.style.height = video.style.width = "100%";
+
+ const loadVideo = async () => {
+ try {
+ const stream = await navigator.mediaDevices.getUserMedia({ video: true, audio: false });
+ container.replaceChildren(video);
+
+ setTimeout(() => res(video), 500); // Fallback as loadedmetadata doesnt fire sometimes?
+ video.addEventListener("loadedmetadata", () => res(video), false);
+ video.srcObject = stream;
+ video.play();
+ } catch (error) {
+ const label = document.createElement("div");
+ label.style.color = "red";
+ label.style.overflow = "auto";
+ label.style.maxHeight = "100%";
+ label.style.whiteSpace = "pre-wrap";
+
+ if (window.isSecureContext) {
+ label.textContent = "Unable to load webcam, please ensure access is granted:\n" + error.message;
+ } else {
+ label.textContent = "Unable to load webcam. A secure context is required, if you are not accessing ComfyUI on localhost (127.0.0.1) you will have to enable TLS (https)\n\n" + error.message;
+ }
+
+ container.replaceChildren(label);
+ }
+ };
+
+ loadVideo();
+
+ return { widget: node.addDOMWidget(inputName, "WEBCAM", container) };
+ },
+ };
+ },
+ nodeCreated(node) {
+ if ((node.type, node.constructor.comfyClass !== "WebcamCapture")) return;
+
+ let video;
+ const camera = node.widgets.find((w) => w.name === "image");
+ const w = node.widgets.find((w) => w.name === "width");
+ const h = node.widgets.find((w) => w.name === "height");
+ const captureOnQueue = node.widgets.find((w) => w.name === "capture_on_queue");
+
+ const canvas = document.createElement("canvas");
+
+ const capture = () => {
+ canvas.width = w.value;
+ canvas.height = h.value;
+ const ctx = canvas.getContext("2d");
+ ctx.drawImage(video, 0, 0, w.value, h.value);
+ const data = canvas.toDataURL("image/png");
+
+ const img = new Image();
+ img.onload = () => {
+ node.imgs = [img];
+ app.graph.setDirtyCanvas(true);
+ requestAnimationFrame(() => {
+ node.setSizeForImage?.();
+ });
+ };
+ img.src = data;
+ };
+
+ const btn = node.addWidget("button", "waiting for camera...", "capture", capture);
+ btn.disabled = true;
+ btn.serializeValue = () => undefined;
+
+ camera.serializeValue = async () => {
+ if (captureOnQueue.value) {
+ capture();
+ } else if (!node.imgs?.length) {
+ const err = `No webcam image captured`;
+ alert(err);
+ throw new Error(err);
+ }
+
+ // Upload image to temp storage
+ const blob = await new Promise((r) => canvas.toBlob(r));
+ const name = `${+new Date()}.png`;
+ const file = new File([blob], name);
+ const body = new FormData();
+ body.append("image", file);
+ body.append("subfolder", "webcam");
+ body.append("type", "temp");
+ const resp = await api.fetchApi("/upload/image", {
+ method: "POST",
+ body,
+ });
+ if (resp.status !== 200) {
+ const err = `Error uploading camera image: ${resp.status} - ${resp.statusText}`;
+ alert(err);
+ throw new Error(err);
+ }
+ return `webcam/${name} [temp]`;
+ };
+
+ node[WEBCAM_READY].then((v) => {
+ video = v;
+ // If width isnt specified then use video output resolution
+ if (!w.value) {
+ w.value = video.videoWidth || 640;
+ h.value = video.videoHeight || 480;
+ }
+ btn.disabled = false;
+ btn.label = "capture";
+ });
+ },
+});
diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js
index e6db9f71ab1..f1a1d22cd93 100644
--- a/web/extensions/core/widgetInputs.js
+++ b/web/extensions/core/widgetInputs.js
@@ -256,8 +256,18 @@ export function mergeIfValid(output, config2, forceUpdate, recreateWidget, confi
return { customConfig };
}
+let useConversionSubmenusSetting;
app.registerExtension({
name: "Comfy.WidgetInputs",
+ init() {
+ useConversionSubmenusSetting = app.ui.settings.addSetting({
+ id: "Comfy.NodeInputConversionSubmenus",
+ name: "Node widget/input conversion sub-menus",
+ tooltip: "In the node context menu, place the entries that convert between input/widget in sub-menus.",
+ type: "boolean",
+ defaultValue: true,
+ });
+ },
async beforeRegisterNodeDef(nodeType, nodeData, app) {
// Add menu options to conver to/from widgets
const origGetExtraMenuOptions = nodeType.prototype.getExtraMenuOptions;
@@ -295,20 +305,28 @@ app.registerExtension({
//Convert.. main menu
if (toInput.length) {
- options.push({
- content: `Convert input to 🔘..`,
- submenu: {
- options: toInput,
- },
- });
+ if (useConversionSubmenusSetting.value) {
+ options.push({
+ content: "Convert Widget to Input",
+ submenu: {
+ options: toInput,
+ },
+ });
+ } else {
+ options.push(...toInput, null);
+ }
}
if (toWidget.length) {
- options.push({
- content: `Convert 🔘 to widget..`,
- submenu: {
- options: toWidget,
- },
- });
+ if (useConversionSubmenusSetting.value) {
+ options.push({
+ content: "Convert Input to Widget",
+ submenu: {
+ options: toWidget,
+ },
+ });
+ } else {
+ options.push(...toWidget, null);
+ }
}
}
diff --git a/web/scripts/app.js b/web/scripts/app.js
index b16e5079565..2cc10d38eb3 100644
--- a/web/scripts/app.js
+++ b/web/scripts/app.js
@@ -262,6 +262,36 @@ export class ComfyApp {
})
);
}
+
+ #addRestoreWorkflowView() {
+ const serialize = LGraph.prototype.serialize;
+ const self = this;
+ LGraph.prototype.serialize = function() {
+ const workflow = serialize.apply(this, arguments);
+
+ // Store the drag & scale info in the serialized workflow if the setting is enabled
+ if (self.enableWorkflowViewRestore.value) {
+ if (!workflow.extra) {
+ workflow.extra = {};
+ }
+ workflow.extra.ds = {
+ scale: self.canvas.ds.scale,
+ offset: self.canvas.ds.offset,
+ };
+ } else if (workflow.extra?.ds) {
+ // Clear any old view data
+ delete workflow.extra.ds;
+ }
+
+ return workflow;
+ }
+ this.enableWorkflowViewRestore = this.ui.settings.addSetting({
+ id: "Comfy.EnableWorkflowViewRestore",
+ name: "Save and restore canvas position and zoom level in workflows",
+ type: "boolean",
+ defaultValue: true
+ });
+ }
/**
* Adds special context menu handling for nodes
@@ -953,6 +983,12 @@ export class ComfyApp {
const origProcessMouseDown = LGraphCanvas.prototype.processMouseDown;
LGraphCanvas.prototype.processMouseDown = function(e) {
+ // prepare for ctrl+shift drag: zoom start
+ if(e.ctrlKey && e.shiftKey && e.buttons) {
+ self.zoom_drag_start = [e.x, e.y, this.ds.scale];
+ return;
+ }
+
const res = origProcessMouseDown.apply(this, arguments);
this.selected_group_moving = false;
@@ -973,6 +1009,26 @@ export class ComfyApp {
const origProcessMouseMove = LGraphCanvas.prototype.processMouseMove;
LGraphCanvas.prototype.processMouseMove = function(e) {
+ // handle ctrl+shift drag
+ if(e.ctrlKey && e.shiftKey && self.zoom_drag_start) {
+ // stop canvas zoom action
+ if(!e.buttons) {
+ self.zoom_drag_start = null;
+ return;
+ }
+
+ // calculate delta
+ let deltaY = e.y - self.zoom_drag_start[1];
+ let startScale = self.zoom_drag_start[2];
+
+ let scale = startScale - deltaY/100;
+
+ this.ds.changeScale(scale, [this.ds.element.width/2, this.ds.element.height/2]);
+ this.graph.change();
+
+ return;
+ }
+
const orig_selected_group = this.selected_group;
if (this.selected_group && !this.selected_group_resizing && !this.selected_group_moving) {
@@ -1059,6 +1115,20 @@ export class ComfyApp {
// Trigger onPaste
return true;
}
+
+ if((e.key === '+') && e.altKey) {
+ block_default = true;
+ let scale = this.ds.scale * 1.1;
+ this.ds.changeScale(scale, [this.ds.element.width/2, this.ds.element.height/2]);
+ this.graph.change();
+ }
+
+ if((e.key === '-') && e.altKey) {
+ block_default = true;
+ let scale = this.ds.scale * 1 / 1.1;
+ this.ds.changeScale(scale, [this.ds.element.width/2, this.ds.element.height/2]);
+ this.graph.change();
+ }
}
this.graph.change();
@@ -1465,6 +1535,7 @@ export class ComfyApp {
this.#addProcessKeyHandler();
this.#addConfigureHandler();
this.#addApiUpdateHandlers();
+ this.#addRestoreWorkflowView();
this.graph = new LGraph();
@@ -1770,6 +1841,10 @@ export class ComfyApp {
try {
this.graph.configure(graphData);
+ if (this.enableWorkflowViewRestore.value && graphData.extra?.ds) {
+ this.canvas.ds.offset = graphData.extra.ds.offset;
+ this.canvas.ds.scale = graphData.extra.ds.scale;
+ }
} catch (error) {
let errorHint = [];
// Try extracting filename to see if it was caused by an extension script
@@ -2087,6 +2162,14 @@ export class ComfyApp {
api.dispatchEvent(new CustomEvent("promptQueued", { detail: { number, batchCount } }));
}
+ showErrorOnFileLoad(file) {
+ this.ui.dialog.show(
+ $el("div", [
+ $el("p", {textContent: `Unable to find workflow in ${file.name}`})
+ ]).outerHTML
+ );
+ }
+
/**
* Loads workflow data from the specified file
* @param {File} file
@@ -2094,27 +2177,27 @@ export class ComfyApp {
async handleFile(file) {
if (file.type === "image/png") {
const pngInfo = await getPngMetadata(file);
- if (pngInfo) {
- if (pngInfo.workflow) {
- await this.loadGraphData(JSON.parse(pngInfo.workflow));
- } else if (pngInfo.prompt) {
- this.loadApiJson(JSON.parse(pngInfo.prompt));
- } else if (pngInfo.parameters) {
- importA1111(this.graph, pngInfo.parameters);
- }
+ if (pngInfo?.workflow) {
+ await this.loadGraphData(JSON.parse(pngInfo.workflow));
+ } else if (pngInfo?.prompt) {
+ this.loadApiJson(JSON.parse(pngInfo.prompt));
+ } else if (pngInfo?.parameters) {
+ importA1111(this.graph, pngInfo.parameters);
+ } else {
+ this.showErrorOnFileLoad(file);
}
} else if (file.type === "image/webp") {
const pngInfo = await getWebpMetadata(file);
- if (pngInfo) {
- if (pngInfo.workflow) {
- this.loadGraphData(JSON.parse(pngInfo.workflow));
- } else if (pngInfo.Workflow) {
- this.loadGraphData(JSON.parse(pngInfo.Workflow)); // Support loading workflows from that webp custom node.
- } else if (pngInfo.prompt) {
- this.loadApiJson(JSON.parse(pngInfo.prompt));
- } else if (pngInfo.Prompt) {
- this.loadApiJson(JSON.parse(pngInfo.Prompt)); // Support loading prompts from that webp custom node.
- }
+ // Support loading workflows from that webp custom node.
+ const workflow = pngInfo?.workflow || pngInfo?.Workflow;
+ const prompt = pngInfo?.prompt || pngInfo?.Prompt;
+
+ if (workflow) {
+ this.loadGraphData(JSON.parse(workflow));
+ } else if (prompt) {
+ this.loadApiJson(JSON.parse(prompt));
+ } else {
+ this.showErrorOnFileLoad(file);
}
} else if (file.type === "application/json" || file.name?.endsWith(".json")) {
const reader = new FileReader();
@@ -2135,7 +2218,11 @@ export class ComfyApp {
await this.loadGraphData(JSON.parse(info.workflow));
} else if (info.prompt) {
this.loadApiJson(JSON.parse(info.prompt));
+ } else {
+ this.showErrorOnFileLoad(file);
}
+ } else {
+ this.showErrorOnFileLoad(file);
}
}
@@ -2156,6 +2243,7 @@ export class ComfyApp {
const data = apiData[id];
const node = LiteGraph.createNode(data.class_type);
node.id = isNaN(+id) ? id : +id;
+ node.title = data._meta?.title ?? node.title
graph.add(node);
}
@@ -2243,6 +2331,12 @@ export class ComfyApp {
await this.#invokeExtensionsAsync("refreshComboInNodes", defs);
}
+ resetView() {
+ app.canvas.ds.scale = 1;
+ app.canvas.ds.offset = [0, 0]
+ app.graph.setDirtyCanvas(true, true);
+ }
+
/**
* Clean current state
*/
diff --git a/web/scripts/domWidget.js b/web/scripts/domWidget.js
index d5eeebdbd39..b7f437ad269 100644
--- a/web/scripts/domWidget.js
+++ b/web/scripts/domWidget.js
@@ -11,9 +11,10 @@ function intersect(a, b) {
else return null;
}
-function getClipPath(node, element, elRect) {
+function getClipPath(node, element) {
const selectedNode = Object.values(app.canvas.selected_nodes)[0];
if (selectedNode && selectedNode !== node) {
+ const elRect = element.getBoundingClientRect();
const MARGIN = 7;
const scale = app.canvas.ds.scale;
@@ -269,7 +270,7 @@ LGraphNode.prototype.addDOMWidget = function (name, type, element, options) {
});
if (enableDomClipping) {
- element.style.clipPath = getClipPath(node, element, elRect);
+ element.style.clipPath = getClipPath(node, element);
element.style.willChange = "clip-path";
}
diff --git a/web/scripts/ui.js b/web/scripts/ui.js
index d0fa46efbb5..36fed323837 100644
--- a/web/scripts/ui.js
+++ b/web/scripts/ui.js
@@ -597,16 +597,23 @@ export class ComfyUI {
if (!confirmClear.value || confirm("Clear workflow?")) {
app.clean();
app.graph.clear();
+ app.resetView();
}
}
}),
$el("button", {
id: "comfy-load-default-button", textContent: "Load Default", onclick: async () => {
if (!confirmClear.value || confirm("Load default workflow?")) {
+ app.resetView();
await app.loadGraphData()
}
}
}),
+ $el("button", {
+ id: "comfy-reset-view-button", textContent: "Reset View", onclick: async () => {
+ app.resetView();
+ }
+ }),
]);
const devMode = this.settings.addSetting({
diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js
index 678b1b8ec7a..6a689970545 100644
--- a/web/scripts/widgets.js
+++ b/web/scripts/widgets.js
@@ -229,7 +229,11 @@ function createIntWidget(node, inputName, inputData, app, isSeedInput) {
val,
function (v) {
const s = this.options.step / 10;
- this.value = Math.round(v / s) * s;
+ let sh = this.options.min % s;
+ if (isNaN(sh)) {
+ sh = 0;
+ }
+ this.value = Math.round((v - sh) / s) * s + sh;
},
config
),
@@ -307,7 +311,9 @@ export const ComfyWidgets = {
return { widget: node.addWidget(widgetType, inputName, val,
function (v) {
if (config.round) {
- this.value = Math.round(v/config.round)*config.round;
+ this.value = Math.round((v + Number.EPSILON)/config.round)*config.round;
+ if (this.value > config.max) this.value = config.max;
+ if (this.value < config.min) this.value = config.min;
} else {
this.value = v;
}