Skip to content

Commit

Permalink
Update models for managed install and cloud
Browse files Browse the repository at this point in the history
  • Loading branch information
Acly committed Nov 7, 2024
1 parent 2022bf1 commit 26e0914
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 28 deletions.
2 changes: 1 addition & 1 deletion ai_diffusion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Generative AI plugin for Krita"""

__version__ = "1.27.1"
__version__ = "1.28.0"

import importlib.util

Expand Down
36 changes: 23 additions & 13 deletions ai_diffusion/cloud_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from base64 import b64encode
from datetime import datetime
from dataclasses import dataclass
from itertools import chain

from .api import WorkflowInput, WorkflowKind
from .client import Client, ClientEvent, ClientMessage, ClientModels, DeviceInfo, CheckpointInfo
Expand All @@ -18,7 +19,7 @@
from .settings import PerformanceSettings, settings
from .localization import translate as _
from .util import clamp, ensure, client_logger as log
from . import __version__ as plugin_version
from . import resources, __version__ as plugin_version


@dataclass
Expand Down Expand Up @@ -346,22 +347,31 @@ def _base64_size(size: int):
return math.ceil(size / 3) * 4


def _checkpoint_info(id: str, arch: Arch):
models = chain(resources.default_checkpoints, resources.deprecated_models)
res = next(m for m in models if m.id.identifier == id and m.arch == arch)
return (res.filename, CheckpointInfo(res.filename, res.arch))


_poll_interval = 0.5 # seconds

models = ClientModels()
models.checkpoints = {
"dreamshaper_8.safetensors": CheckpointInfo("dreamshaper_8.safetensors", Arch.sd15),
"realisticVisionV51_v51VAE.safetensors": CheckpointInfo(
"realisticVisionV51_v51VAE.safetensors", Arch.sd15
),
"flat2DAnimerge_v45Sharp.safetensors": CheckpointInfo(
"flat2DAnimerge_v45Sharp.safetensors", Arch.sd15
),
"juggernautXL_version6Rundiffusion.safetensors": CheckpointInfo(
"juggernautXL_version6Rundiffusion.safetensors", Arch.sdxl
),
"zavychromaxl_v80.safetensors": CheckpointInfo("zavychromaxl_v80.safetensors", Arch.sdxl),
"flux1-schnell-fp8.safetensors": CheckpointInfo("flux1-schnell-fp8.safetensors", Arch.flux),
filename: info
for filename, info in (
_checkpoint_info(name, arch)
for name, arch in [
("dreamshaper", Arch.sd15),
("realistic-vision", Arch.sd15),
("serenity", Arch.sd15),
("flat2d-animerge", Arch.sd15),
("juggernaut", Arch.sdxl),
("realvis", Arch.sdxl),
("zavychroma", Arch.sdxl),
("pixelwave", Arch.sdxl),
("flux-schnell", Arch.flux),
]
)
}
models.vae = []
models.loras = [
Expand Down
55 changes: 41 additions & 14 deletions ai_diffusion/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
version = "1.28.0"

comfy_url = "https://github.com/comfyanonymous/ComfyUI"
comfy_version = "52810907e20e11b126642f5b4917406e7043e70a"
comfy_version = "5e29e7a488b3f48afc6c4a3cb8ed110976d0ebb8"


class CustomNode(NamedTuple):
Expand Down Expand Up @@ -46,7 +46,7 @@ class CustomNode(NamedTuple):
"Inpaint Nodes",
"comfyui-inpaint-nodes",
"https://github.com/Acly/comfyui-inpaint-nodes",
"146a2f17b1f91eb155011ab36aa349c696b6e38b",
"422eccd86685e084b551fb7e14bc025d77a64cc2",
["INPAINT_LoadFooocusInpaint", "INPAINT_ApplyFooocusInpaint", "INPAINT_ExpandMask"],
),
]
Expand All @@ -56,7 +56,7 @@ class CustomNode(NamedTuple):
"GGUF",
"ComfyUI-GGUF",
"https://github.com/city96/ComfyUI-GGUF",
"98333480059a2ccafb4718924ebcb9cdcb9b1f43",
"8e898fad4caab59bf4144e0cf11978b893de7e54",
["UnetLoaderGGUF", "DualCLIPLoaderGGUF"],
)
]
Expand Down Expand Up @@ -437,16 +437,16 @@ def __hash__(self):

default_checkpoints = [
ModelResource(
"Realistic Vision (Photography)",
ResourceId(ResourceKind.checkpoint, Arch.sd15, "realistic-vision"),
"Serenity (SD1.5 - Photography)",
ResourceId(ResourceKind.checkpoint, Arch.sd15, "serenity"),
{
Path(
"models/checkpoints/realisticVisionV51_v51VAE.safetensors"
): "https://huggingface.co/lllyasviel/fav_models/resolve/main/fav/realisticVisionV51_v51VAE.safetensors",
"models/checkpoints/serenity_v21Safetensors.safetensors"
): "https://huggingface.co/Acly/SD-Checkpoints/resolve/main/serenity_v21Safetensors.safetensors"
},
),
ModelResource(
"DreamShaper (Artwork)",
"DreamShaper (SD1.5 - Artwork)",
ResourceId(ResourceKind.checkpoint, Arch.sd15, "dreamshaper"),
{
Path(
Expand All @@ -455,7 +455,7 @@ def __hash__(self):
},
),
ModelResource(
"Flat2D AniMerge (Cartoon/Anime)",
"Flat2D AniMerge (SD1.5 - Cartoon/Anime)",
ResourceId(ResourceKind.checkpoint, Arch.sd15, "flat2d-animerge"),
{
Path(
Expand All @@ -464,23 +464,32 @@ def __hash__(self):
},
),
ModelResource(
"Juggernaut XL",
ResourceId(ResourceKind.checkpoint, Arch.sdxl, "juggernaut"),
"RealVis (SDXL - Photography)",
ResourceId(ResourceKind.checkpoint, Arch.sdxl, "realvis"),
{
Path(
"models/checkpoints/juggernautXL_version6Rundiffusion.safetensors"
): "https://huggingface.co/lllyasviel/fav_models/resolve/main/fav/juggernautXL_version6Rundiffusion.safetensors"
"models/checkpoints/RealVisXL_V5.0_fp16.safetensors"
): "https://huggingface.co/SG161222/RealVisXL_V5.0/resolve/main/RealVisXL_V5.0_fp16.safetensors"
},
),
ModelResource(
"ZavyChroma XL",
"ZavyChroma (SDXL - Artwork)",
ResourceId(ResourceKind.checkpoint, Arch.sdxl, "zavychroma"),
{
Path(
"models/checkpoints/zavychromaxl_v80.safetensors"
): "https://huggingface.co/misri/zavychromaxl_v80/resolve/main/zavychromaxl_v80.safetensors"
},
),
ModelResource(
"Pixelwave (SDXL - Artwork)",
ResourceId(ResourceKind.checkpoint, Arch.sdxl, "pixelwave"),
{
Path(
"models/checkpoints/pixelwave_11.safetensors"
): "https://huggingface.co/Acly/SD-Checkpoints/resolve/main/pixelwave_11.safetensors"
},
),
ModelResource(
"Flux [dev]",
ResourceId(ResourceKind.checkpoint, Arch.flux, "flux-dev"),
Expand Down Expand Up @@ -768,6 +777,24 @@ def __hash__(self):
): "https://huggingface.co/latent-consistency/lcm-lora-sdxl/resolve/main/pytorch_lora_weights.safetensors",
},
),
ModelResource(
"Realistic Vision (SD1.5 - Photography)",
ResourceId(ResourceKind.checkpoint, Arch.sd15, "realistic-vision"),
{
Path(
"models/checkpoints/realisticVisionV51_v51VAE.safetensors"
): "https://huggingface.co/lllyasviel/fav_models/resolve/main/fav/realisticVisionV51_v51VAE.safetensors",
},
),
ModelResource(
"Juggernaut XL (Old)",
ResourceId(ResourceKind.checkpoint, Arch.sdxl, "juggernaut"),
{
Path(
"models/checkpoints/juggernautXL_version6Rundiffusion.safetensors"
): "https://huggingface.co/lllyasviel/fav_models/resolve/main/fav/juggernautXL_version6Rundiffusion.safetensors"
},
),
]


Expand Down
1 change: 1 addition & 0 deletions ai_diffusion/styles/cinematic-photo-xl.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"version": 2,
"architecture": "auto",
"checkpoints": [
"RealVisXL_V5.0_fp16.safetensors",
"juggernautXL_juggXIByRundiffusion.safetensors",
"Juggernaut_X_RunDiffusion.safetensors",
"juggernautXL_v9Rundiffusionphoto2.safetensors",
Expand Down
5 changes: 5 additions & 0 deletions scripts/download_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ async def main(
checkpoints=[],
controlnet=False,
prefetch=False,
deprecated=False,
minimal=False,
recommended=False,
all=False,
Expand Down Expand Up @@ -151,6 +152,8 @@ async def main(
models.update([m for m in optional_models if m.kind in kinds and m.arch in versions])
if prefetch or all:
models.update(resources.prefetch_models)
if deprecated:
models.update([m for m in resources.deprecated_models if m.arch in versions])

models = models - set([m for m in models if m.id.string in exclude])

Expand Down Expand Up @@ -199,6 +202,7 @@ async def main(
parser.add_argument("--checkpoint", action="append", choices=checkpoint_names, dest="checkpoint_list", help="download a specific checkpoint (can specify multiple times)")
parser.add_argument("--upscalers", action="store_true", help="download additional upscale models")
parser.add_argument("--prefetch", action="store_true", help="download models which would be automatically downloaded on first use")
parser.add_argument("--deprecated", action="store_true", help="download old models which will be removed in the near future")
parser.add_argument("--retry-attempts", type=int, default=5, metavar="N", help="number of retry attempts for downloading a model")
parser.add_argument("--continue-on-error", action="store_true", help="continue downloading models even if an error occurs")
# fmt: on
Expand All @@ -222,6 +226,7 @@ async def main(
checkpoints=checkpoints,
controlnet=args.controlnet,
prefetch=args.prefetch,
deprecated=args.deprecated,
minimal=args.minimal,
recommended=args.recommended,
all=args.all,
Expand Down

0 comments on commit 26e0914

Please sign in to comment.