From fb88d86be2095de03205ac533149e2a347c9725e Mon Sep 17 00:00:00 2001 From: rmatif Date: Mon, 8 Dec 2025 11:12:33 +0000 Subject: [PATCH 01/12] add ucache --- examples/cli/README.md | 1 + examples/cli/main.cpp | 1906 ++++++++++++++++++++++++++++++++++-- examples/common/common.hpp | 271 +++-- stable-diffusion.cpp | 154 ++- stable-diffusion.h | 10 + ucache.hpp | 286 ++++++ 6 files changed, 2437 insertions(+), 191 deletions(-) create mode 100644 ucache.hpp diff --git a/examples/cli/README.md b/examples/cli/README.md index 8531b2aed..ba9a0b278 100644 --- a/examples/cli/README.md +++ b/examples/cli/README.md @@ -127,4 +127,5 @@ Generation Options: --high-noise-skip-layers (high noise) layers to skip for SLG steps (default: [7,8,9]) -r, --ref-image reference image for Flux Kontext models (can be used multiple times) --easycache enable EasyCache for DiT models with optional "threshold,start_percent,end_percent" (default: 0.2,0.15,0.95) + --ucache enable UCache for UNET models with optional "threshold,start_percent,end_percent" (default: 1,0.15,0.95) ``` diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 42b909e4f..7810ee8fc 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -15,10 +15,38 @@ // #include "preprocessing.hpp" #include "stable-diffusion.h" -#include "common/common.hpp" +#define STB_IMAGE_IMPLEMENTATION +#define STB_IMAGE_STATIC +#include "stb_image.h" + +#define STB_IMAGE_WRITE_IMPLEMENTATION +#define STB_IMAGE_WRITE_STATIC +#include "stb_image_write.h" + +#define STB_IMAGE_RESIZE_IMPLEMENTATION +#define STB_IMAGE_RESIZE_STATIC +#include "stb_image_resize.h" #include "avi_writer.h" +#if defined(_WIN32) +#define NOMINMAX +#include +#endif // _WIN32 + +#define SAFE_STR(s) ((s) ? (s) : "") +#define BOOL_STR(b) ((b) ? "true" : "false") + +namespace fs = std::filesystem; + +const char* modes_str[] = { + "img_gen", + "vid_gen", + "convert", + "upscale", +}; +#define SD_ALL_MODES_STR "img_gen, vid_gen, convert, upscale" + const char* previews_str[] = { "none", "proj", @@ -26,6 +54,271 @@ const char* previews_str[] = { "vae", }; +enum SDMode { + IMG_GEN, + VID_GEN, + CONVERT, + UPSCALE, + MODE_COUNT +}; + +#if defined(_WIN32) +static std::string utf16_to_utf8(const std::wstring& wstr) { + if (wstr.empty()) + return {}; + int size_needed = WideCharToMultiByte(CP_UTF8, 0, wstr.data(), (int)wstr.size(), + nullptr, 0, nullptr, nullptr); + if (size_needed <= 0) + throw std::runtime_error("UTF-16 to UTF-8 conversion failed"); + + std::string utf8(size_needed, 0); + WideCharToMultiByte(CP_UTF8, 0, wstr.data(), (int)wstr.size(), + (char*)utf8.data(), size_needed, nullptr, nullptr); + return utf8; +} + +static std::string argv_to_utf8(int index, const char** argv) { + int argc; + wchar_t** argv_w = CommandLineToArgvW(GetCommandLineW(), &argc); + if (!argv_w) + throw std::runtime_error("Failed to parse command line"); + + std::string result; + if (index < argc) { + result = utf16_to_utf8(argv_w[index]); + } + LocalFree(argv_w); + return result; +} + +#else // Linux / macOS +static std::string argv_to_utf8(int index, const char** argv) { + return std::string(argv[index]); +} + +#endif + +struct StringOption { + std::string short_name; + std::string long_name; + std::string desc; + std::string* target; +}; + +struct IntOption { + std::string short_name; + std::string long_name; + std::string desc; + int* target; +}; + +struct FloatOption { + std::string short_name; + std::string long_name; + std::string desc; + float* target; +}; + +struct BoolOption { + std::string short_name; + std::string long_name; + std::string desc; + bool keep_true; + bool* target; +}; + +struct ManualOption { + std::string short_name; + std::string long_name; + std::string desc; + std::function cb; +}; + +struct ArgOptions { + std::vector string_options; + std::vector int_options; + std::vector float_options; + std::vector bool_options; + std::vector manual_options; + + static std::string wrap_text(const std::string& text, size_t width, size_t indent) { + std::ostringstream oss; + size_t line_len = 0; + size_t pos = 0; + + while (pos < text.size()) { + // Preserve manual newlines + if (text[pos] == '\n') { + oss << '\n' + << std::string(indent, ' '); + line_len = indent; + ++pos; + continue; + } + + // Add the character + oss << text[pos]; + ++line_len; + ++pos; + + // If the current line exceeds width, try to break at the last space + if (line_len >= width) { + std::string current = oss.str(); + size_t back = current.size(); + + // Find the last space (for a clean break) + while (back > 0 && current[back - 1] != ' ' && current[back - 1] != '\n') + --back; + + // If found a space to break on + if (back > 0 && current[back - 1] != '\n') { + std::string before = current.substr(0, back - 1); + std::string after = current.substr(back); + oss.str(""); + oss.clear(); + oss << before << "\n" + << std::string(indent, ' ') << after; + } else { + // If no space found, just break at width + oss << "\n" + << std::string(indent, ' '); + } + line_len = indent; + } + } + + return oss.str(); + } + + void print() const { + constexpr size_t max_line_width = 120; + + struct Entry { + std::string names; + std::string desc; + }; + std::vector entries; + + auto add_entry = [&](const std::string& s, const std::string& l, + const std::string& desc, const std::string& hint = "") { + std::ostringstream ss; + if (!s.empty()) + ss << s; + if (!s.empty() && !l.empty()) + ss << ", "; + if (!l.empty()) + ss << l; + if (!hint.empty()) + ss << " " << hint; + entries.push_back({ss.str(), desc}); + }; + + for (auto& o : string_options) + add_entry(o.short_name, o.long_name, o.desc, ""); + for (auto& o : int_options) + add_entry(o.short_name, o.long_name, o.desc, ""); + for (auto& o : float_options) + add_entry(o.short_name, o.long_name, o.desc, ""); + for (auto& o : bool_options) + add_entry(o.short_name, o.long_name, o.desc, ""); + for (auto& o : manual_options) + add_entry(o.short_name, o.long_name, o.desc); + + size_t max_name_width = 0; + for (auto& e : entries) + max_name_width = std::max(max_name_width, e.names.size()); + + for (auto& e : entries) { + size_t indent = 2 + max_name_width + 4; + size_t desc_width = (max_line_width > indent ? max_line_width - indent : 40); + std::string wrapped_desc = wrap_text(e.desc, max_line_width, indent); + std::cout << " " << std::left << std::setw(static_cast(max_name_width) + 4) + << e.names << wrapped_desc << "\n"; + } + } +}; + +bool parse_options(int argc, const char** argv, const std::vector& options_list) { + bool invalid_arg = false; + std::string arg; + + auto match_and_apply = [&](auto& opts, auto&& apply_fn) -> bool { + for (auto& option : opts) { + if ((option.short_name.size() > 0 && arg == option.short_name) || + (option.long_name.size() > 0 && arg == option.long_name)) { + apply_fn(option); + return true; + } + } + return false; + }; + + for (int i = 1; i < argc; i++) { + arg = argv[i]; + bool found_arg = false; + + for (auto& options : options_list) { + if (match_and_apply(options.string_options, [&](auto& option) { + if (++i >= argc) { + invalid_arg = true; + return; + } + *option.target = argv_to_utf8(i, argv); + found_arg = true; + })) + break; + + if (match_and_apply(options.int_options, [&](auto& option) { + if (++i >= argc) { + invalid_arg = true; + return; + } + *option.target = std::stoi(argv[i]); + found_arg = true; + })) + break; + + if (match_and_apply(options.float_options, [&](auto& option) { + if (++i >= argc) { + invalid_arg = true; + return; + } + *option.target = std::stof(argv[i]); + found_arg = true; + })) + break; + + if (match_and_apply(options.bool_options, [&](auto& option) { + *option.target = option.keep_true ? true : false; + found_arg = true; + })) + break; + + if (match_and_apply(options.manual_options, [&](auto& option) { + int ret = option.cb(argc, argv, i); + if (ret < 0) { + invalid_arg = true; + return; + } + i += ret; + found_arg = true; + })) + break; + } + + if (invalid_arg) { + fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str()); + return false; + } + if (!found_arg) { + fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); + return false; + } + } + + return true; +} + struct SDCliParams { SDMode mode = IMG_GEN; std::string output_path = "output.png"; @@ -59,132 +352,1470 @@ struct SDCliParams { options.int_options = { {"", - "--preview-interval", - "interval in denoising steps between consecutive updates of the image preview file (default is 1, meaning updating at every step)", - &preview_interval}, + "--preview-interval", + "interval in denoising steps between consecutive updates of the image preview file (default is 1, meaning updating at every step)", + &preview_interval}, + }; + + options.bool_options = { + {"", + "--canny", + "apply canny preprocessor (edge detection)", + true, &canny_preprocess}, + {"-v", + "--verbose", + "print extra info", + true, &verbose}, + {"", + "--color", + "colors the logging tags according to level", + true, &color}, + {"", + "--taesd-preview-only", + std::string("prevents usage of taesd for decoding the final image. (for use with --preview ") + previews_str[PREVIEW_TAE] + ")", + true, &taesd_preview}, + {"", + "--preview-noisy", + "enables previewing noisy inputs of the models rather than the denoised outputs", + true, &preview_noisy}, + + }; + + auto on_mode_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; + } + const char* mode_c_str = argv[index]; + if (mode_c_str != nullptr) { + int mode_found = -1; + for (int i = 0; i < MODE_COUNT; i++) { + if (!strcmp(mode_c_str, modes_str[i])) { + mode_found = i; + } + } + if (mode_found == -1) { + LOG_ERROR("error: invalid mode %s, must be one of [%s]\n", + mode_c_str, SD_ALL_MODES_STR); + exit(1); + } + mode = (SDMode)mode_found; + } + return 1; + }; + + auto on_preview_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; + } + const char* preview = argv[index]; + int preview_found = -1; + for (int m = 0; m < PREVIEW_COUNT; m++) { + if (!strcmp(preview, previews_str[m])) { + preview_found = m; + } + } + if (preview_found == -1) { + LOG_ERROR("error: preview method %s", preview); + return -1; + } + preview_method = (preview_t)preview_found; + return 1; + }; + + auto on_help_arg = [&](int argc, const char** argv, int index) { + normal_exit = true; + return -1; + }; + + options.manual_options = { + {"-M", + "--mode", + "run mode, one of [img_gen, vid_gen, upscale, convert], default: img_gen", + on_mode_arg}, + {"", + "--preview", + std::string("preview method. must be one of the following [") + previews_str[0] + ", " + previews_str[1] + ", " + previews_str[2] + ", " + previews_str[3] + "] (default is " + previews_str[PREVIEW_NONE] + ")", + on_preview_arg}, + {"-h", + "--help", + "show this help message and exit", + on_help_arg}, + }; + + return options; + }; + + bool process_and_check() { + if (output_path.length() == 0) { + LOG_ERROR("error: the following arguments are required: output_path"); + return false; + } + + if (mode == CONVERT) { + if (output_path == "output.png") { + output_path = "output.gguf"; + } + } + return true; + } + + std::string to_string() const { + std::ostringstream oss; + oss << "SDCliParams {\n" + << " mode: " << modes_str[mode] << ",\n" + << " output_path: \"" << output_path << "\",\n" + << " verbose: " << (verbose ? "true" : "false") << ",\n" + << " color: " << (color ? "true" : "false") << ",\n" + << " canny_preprocess: " << (canny_preprocess ? "true" : "false") << ",\n" + << " preview_method: " << previews_str[preview_method] << ",\n" + << " preview_interval: " << preview_interval << ",\n" + << " preview_path: \"" << preview_path << "\",\n" + << " preview_fps: " << preview_fps << ",\n" + << " taesd_preview: " << (taesd_preview ? "true" : "false") << ",\n" + << " preview_noisy: " << (preview_noisy ? "true" : "false") << "\n" + << "}"; + return oss.str(); + } +}; + +struct SDContextParams { + int n_threads = -1; + std::string model_path; + std::string clip_l_path; + std::string clip_g_path; + std::string clip_vision_path; + std::string t5xxl_path; + std::string llm_path; + std::string llm_vision_path; + std::string diffusion_model_path; + std::string high_noise_diffusion_model_path; + std::string vae_path; + std::string taesd_path; + std::string esrgan_path; + std::string control_net_path; + std::string embedding_dir; + std::string photo_maker_path; + sd_type_t wtype = SD_TYPE_COUNT; + std::string tensor_type_rules; + std::string lora_model_dir; + + std::map embedding_map; + std::vector embedding_vec; + + rng_type_t rng_type = CUDA_RNG; + rng_type_t sampler_rng_type = RNG_TYPE_COUNT; + bool offload_params_to_cpu = false; + bool control_net_cpu = false; + bool clip_on_cpu = false; + bool vae_on_cpu = false; + bool diffusion_flash_attn = false; + bool diffusion_conv_direct = false; + bool vae_conv_direct = false; + + bool chroma_use_dit_mask = true; + bool chroma_use_t5_mask = false; + int chroma_t5_mask_pad = 1; + + prediction_t prediction = PREDICTION_COUNT; + lora_apply_mode_t lora_apply_mode = LORA_APPLY_AUTO; + + sd_tiling_params_t vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f}; + bool force_sdxl_vae_conv_scale = false; + + float flow_shift = INFINITY; + + ArgOptions get_options() { + ArgOptions options; + options.string_options = { + {"-m", + "--model", + "path to full model", + &model_path}, + {"", + "--clip_l", + "path to the clip-l text encoder", &clip_l_path}, + {"", "--clip_g", + "path to the clip-g text encoder", + &clip_g_path}, + {"", + "--clip_vision", + "path to the clip-vision encoder", + &clip_vision_path}, + {"", + "--t5xxl", + "path to the t5xxl text encoder", + &t5xxl_path}, + {"", + "--llm", + "path to the llm text encoder. For example: (qwenvl2.5 for qwen-image, mistral-small3.2 for flux2, ...)", + &llm_path}, + {"", + "--llm_vision", + "path to the llm vit", + &llm_vision_path}, + {"", + "--qwen2vl", + "alias of --llm. Deprecated.", + &llm_path}, + {"", + "--qwen2vl_vision", + "alias of --llm_vision. Deprecated.", + &llm_vision_path}, + {"", + "--diffusion-model", + "path to the standalone diffusion model", + &diffusion_model_path}, + {"", + "--high-noise-diffusion-model", + "path to the standalone high noise diffusion model", + &high_noise_diffusion_model_path}, + {"", + "--vae", + "path to standalone vae model", + &vae_path}, + {"", + "--taesd", + "path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)", + &taesd_path}, + {"", + "--control-net", + "path to control net model", + &control_net_path}, + {"", + "--embd-dir", + "embeddings directory", + &embedding_dir}, + {"", + "--lora-model-dir", + "lora model directory", + &lora_model_dir}, + + {"", + "--tensor-type-rules", + "weight type per tensor pattern (example: \"^vae\\.=f16,model\\.=q8_0\")", + &tensor_type_rules}, + {"", + "--photo-maker", + "path to PHOTOMAKER model", + &photo_maker_path}, + {"", + "--upscale-model", + "path to esrgan model.", + &esrgan_path}, + }; + + options.int_options = { + {"-t", + "--threads", + "number of threads to use during computation (default: -1). " + "If threads <= 0, then threads will be set to the number of CPU physical cores", + &n_threads}, + {"", + "--chroma-t5-mask-pad", + "t5 mask pad size of chroma", + &chroma_t5_mask_pad}, + }; + + options.float_options = { + {"", + "--vae-tile-overlap", + "tile overlap for vae tiling, in fraction of tile size (default: 0.5)", + &vae_tiling_params.target_overlap}, + {"", + "--flow-shift", + "shift value for Flow models like SD3.x or WAN (default: auto)", + &flow_shift}, + }; + + options.bool_options = { + {"", + "--vae-tiling", + "process vae in tiles to reduce memory usage", + true, &vae_tiling_params.enabled}, + {"", + "--force-sdxl-vae-conv-scale", + "force use of conv scale on sdxl vae", + true, &force_sdxl_vae_conv_scale}, + {"", + "--offload-to-cpu", + "place the weights in RAM to save VRAM, and automatically load them into VRAM when needed", + true, &offload_params_to_cpu}, + {"", + "--control-net-cpu", + "keep controlnet in cpu (for low vram)", + true, &control_net_cpu}, + {"", + "--clip-on-cpu", + "keep clip in cpu (for low vram)", + true, &clip_on_cpu}, + {"", + "--vae-on-cpu", + "keep vae in cpu (for low vram)", + true, &vae_on_cpu}, + {"", + "--diffusion-fa", + "use flash attention in the diffusion model", + true, &diffusion_flash_attn}, + {"", + "--diffusion-conv-direct", + "use ggml_conv2d_direct in the diffusion model", + true, &diffusion_conv_direct}, + {"", + "--vae-conv-direct", + "use ggml_conv2d_direct in the vae model", + true, &vae_conv_direct}, + {"", + "--chroma-disable-dit-mask", + "disable dit mask for chroma", + false, &chroma_use_dit_mask}, + {"", + "--chroma-enable-t5-mask", + "enable t5 mask for chroma", + true, &chroma_use_t5_mask}, + }; + + auto on_type_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; + } + const char* arg = argv[index]; + wtype = str_to_sd_type(arg); + if (wtype == SD_TYPE_COUNT) { + fprintf(stderr, "error: invalid weight format %s\n", + arg); + return -1; + } + return 1; + }; + + auto on_rng_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; + } + const char* arg = argv[index]; + rng_type = str_to_rng_type(arg); + if (rng_type == RNG_TYPE_COUNT) { + fprintf(stderr, "error: invalid rng type %s\n", + arg); + return -1; + } + return 1; + }; + + auto on_sampler_rng_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; + } + const char* arg = argv[index]; + sampler_rng_type = str_to_rng_type(arg); + if (sampler_rng_type == RNG_TYPE_COUNT) { + fprintf(stderr, "error: invalid sampler rng type %s\n", + arg); + return -1; + } + return 1; + }; + + auto on_prediction_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; + } + const char* arg = argv[index]; + prediction = str_to_prediction(arg); + if (prediction == PREDICTION_COUNT) { + fprintf(stderr, "error: invalid prediction type %s\n", + arg); + return -1; + } + return 1; + }; + + auto on_lora_apply_mode_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; + } + const char* arg = argv[index]; + lora_apply_mode = str_to_lora_apply_mode(arg); + if (lora_apply_mode == LORA_APPLY_MODE_COUNT) { + fprintf(stderr, "error: invalid lora apply model %s\n", + arg); + return -1; + } + return 1; + }; + + auto on_tile_size_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; + } + std::string tile_size_str = argv[index]; + size_t x_pos = tile_size_str.find('x'); + try { + if (x_pos != std::string::npos) { + std::string tile_x_str = tile_size_str.substr(0, x_pos); + std::string tile_y_str = tile_size_str.substr(x_pos + 1); + vae_tiling_params.tile_size_x = std::stoi(tile_x_str); + vae_tiling_params.tile_size_y = std::stoi(tile_y_str); + } else { + vae_tiling_params.tile_size_x = vae_tiling_params.tile_size_y = std::stoi(tile_size_str); + } + } catch (const std::invalid_argument&) { + return -1; + } catch (const std::out_of_range&) { + return -1; + } + return 1; + }; + + auto on_relative_tile_size_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; + } + std::string rel_size_str = argv[index]; + size_t x_pos = rel_size_str.find('x'); + try { + if (x_pos != std::string::npos) { + std::string rel_x_str = rel_size_str.substr(0, x_pos); + std::string rel_y_str = rel_size_str.substr(x_pos + 1); + vae_tiling_params.rel_size_x = std::stof(rel_x_str); + vae_tiling_params.rel_size_y = std::stof(rel_y_str); + } else { + vae_tiling_params.rel_size_x = vae_tiling_params.rel_size_y = std::stof(rel_size_str); + } + } catch (const std::invalid_argument&) { + return -1; + } catch (const std::out_of_range&) { + return -1; + } + return 1; + }; + + options.manual_options = { + {"", + "--type", + "weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K). " + "If not specified, the default is the type of the weight file", + on_type_arg}, + {"", + "--rng", + "RNG, one of [std_default, cuda, cpu], default: cuda(sd-webui), cpu(comfyui)", + on_rng_arg}, + {"", + "--sampler-rng", + "sampler RNG, one of [std_default, cuda, cpu]. If not specified, use --rng", + on_sampler_rng_arg}, + {"", + "--prediction", + "prediction type override, one of [eps, v, edm_v, sd3_flow, flux_flow, flux2_flow]", + on_prediction_arg}, + {"", + "--lora-apply-mode", + "the way to apply LoRA, one of [auto, immediately, at_runtime], default is auto. " + "In auto mode, if the model weights contain any quantized parameters, the at_runtime mode will be used; otherwise, immediately will be used." + "The immediately mode may have precision and compatibility issues with quantized parameters, " + "but it usually offers faster inference speed and, in some cases, lower memory usage. " + "The at_runtime mode, on the other hand, is exactly the opposite.", + on_lora_apply_mode_arg}, + {"", + "--vae-tile-size", + "tile size for vae tiling, format [X]x[Y] (default: 32x32)", + on_tile_size_arg}, + {"", + "--vae-relative-tile-size", + "relative tile size for vae tiling, format [X]x[Y], in fraction of image size if < 1, in number of tiles per dim if >=1 (overrides --vae-tile-size)", + on_relative_tile_size_arg}, + }; + + return options; + } + + void build_embedding_map() { + static const std::vector valid_ext = {".pt", ".safetensors", ".gguf"}; + + if (!fs::exists(embedding_dir) || !fs::is_directory(embedding_dir)) { + return; + } + + for (auto& p : fs::directory_iterator(embedding_dir)) { + if (!p.is_regular_file()) + continue; + + auto path = p.path(); + std::string ext = path.extension().string(); + + bool valid = false; + for (auto& e : valid_ext) { + if (ext == e) { + valid = true; + break; + } + } + if (!valid) + continue; + + std::string key = path.stem().string(); + std::string value = path.string(); + + embedding_map[key] = value; + } + } + + bool process_and_check(SDMode mode) { + if (mode != UPSCALE && model_path.length() == 0 && diffusion_model_path.length() == 0) { + fprintf(stderr, "error: the following arguments are required: model_path/diffusion_model\n"); + return false; + } + + if (mode == UPSCALE) { + if (esrgan_path.length() == 0) { + fprintf(stderr, "error: upscale mode needs an upscaler model (--upscale-model)\n"); + return false; + } + } + + if (n_threads <= 0) { + n_threads = sd_get_num_physical_cores(); + } + + build_embedding_map(); + + return true; + } + + std::string to_string() const { + std::ostringstream emb_ss; + emb_ss << "{\n"; + for (auto it = embedding_map.begin(); it != embedding_map.end(); ++it) { + emb_ss << " \"" << it->first << "\": \"" << it->second << "\""; + if (std::next(it) != embedding_map.end()) { + emb_ss << ","; + } + emb_ss << "\n"; + } + emb_ss << " }"; + + std::string embeddings_str = emb_ss.str(); + std::ostringstream oss; + oss << "SDContextParams {\n" + << " n_threads: " << n_threads << ",\n" + << " model_path: \"" << model_path << "\",\n" + << " clip_l_path: \"" << clip_l_path << "\",\n" + << " clip_g_path: \"" << clip_g_path << "\",\n" + << " clip_vision_path: \"" << clip_vision_path << "\",\n" + << " t5xxl_path: \"" << t5xxl_path << "\",\n" + << " llm_path: \"" << llm_path << "\",\n" + << " llm_vision_path: \"" << llm_vision_path << "\",\n" + << " diffusion_model_path: \"" << diffusion_model_path << "\",\n" + << " high_noise_diffusion_model_path: \"" << high_noise_diffusion_model_path << "\",\n" + << " vae_path: \"" << vae_path << "\",\n" + << " taesd_path: \"" << taesd_path << "\",\n" + << " esrgan_path: \"" << esrgan_path << "\",\n" + << " control_net_path: \"" << control_net_path << "\",\n" + << " embedding_dir: \"" << embedding_dir << "\",\n" + << " embeddings: " << embeddings_str << "\n" + << " wtype: " << sd_type_name(wtype) << ",\n" + << " tensor_type_rules: \"" << tensor_type_rules << "\",\n" + << " lora_model_dir: \"" << lora_model_dir << "\",\n" + << " photo_maker_path: \"" << photo_maker_path << "\",\n" + << " rng_type: " << sd_rng_type_name(rng_type) << ",\n" + << " sampler_rng_type: " << sd_rng_type_name(sampler_rng_type) << ",\n" + << " flow_shift: " << (std::isinf(flow_shift) ? "INF" : std::to_string(flow_shift)) << "\n" + << " offload_params_to_cpu: " << (offload_params_to_cpu ? "true" : "false") << ",\n" + << " control_net_cpu: " << (control_net_cpu ? "true" : "false") << ",\n" + << " clip_on_cpu: " << (clip_on_cpu ? "true" : "false") << ",\n" + << " vae_on_cpu: " << (vae_on_cpu ? "true" : "false") << ",\n" + << " diffusion_flash_attn: " << (diffusion_flash_attn ? "true" : "false") << ",\n" + << " diffusion_conv_direct: " << (diffusion_conv_direct ? "true" : "false") << ",\n" + << " vae_conv_direct: " << (vae_conv_direct ? "true" : "false") << ",\n" + << " chroma_use_dit_mask: " << (chroma_use_dit_mask ? "true" : "false") << ",\n" + << " chroma_use_t5_mask: " << (chroma_use_t5_mask ? "true" : "false") << ",\n" + << " chroma_t5_mask_pad: " << chroma_t5_mask_pad << ",\n" + << " prediction: " << sd_prediction_name(prediction) << ",\n" + << " lora_apply_mode: " << sd_lora_apply_mode_name(lora_apply_mode) << ",\n" + << " vae_tiling_params: { " + << vae_tiling_params.enabled << ", " + << vae_tiling_params.tile_size_x << ", " + << vae_tiling_params.tile_size_y << ", " + << vae_tiling_params.target_overlap << ", " + << vae_tiling_params.rel_size_x << ", " + << vae_tiling_params.rel_size_y << " },\n" + << " force_sdxl_vae_conv_scale: " << (force_sdxl_vae_conv_scale ? "true" : "false") << "\n" + << "}"; + return oss.str(); + } + + sd_ctx_params_t to_sd_ctx_params_t(bool vae_decode_only, bool free_params_immediately, bool taesd_preview) { + embedding_vec.clear(); + embedding_vec.reserve(embedding_map.size()); + for (const auto& kv : embedding_map) { + sd_embedding_t item; + item.name = kv.first.c_str(); + item.path = kv.second.c_str(); + embedding_vec.emplace_back(item); + } + + sd_ctx_params_t sd_ctx_params = { + model_path.c_str(), + clip_l_path.c_str(), + clip_g_path.c_str(), + clip_vision_path.c_str(), + t5xxl_path.c_str(), + llm_path.c_str(), + llm_vision_path.c_str(), + diffusion_model_path.c_str(), + high_noise_diffusion_model_path.c_str(), + vae_path.c_str(), + taesd_path.c_str(), + control_net_path.c_str(), + lora_model_dir.c_str(), + embedding_vec.data(), + static_cast(embedding_vec.size()), + photo_maker_path.c_str(), + tensor_type_rules.c_str(), + vae_decode_only, + free_params_immediately, + n_threads, + wtype, + rng_type, + sampler_rng_type, + prediction, + lora_apply_mode, + offload_params_to_cpu, + clip_on_cpu, + control_net_cpu, + vae_on_cpu, + diffusion_flash_attn, + taesd_preview, + diffusion_conv_direct, + vae_conv_direct, + force_sdxl_vae_conv_scale, + chroma_use_dit_mask, + chroma_use_t5_mask, + chroma_t5_mask_pad, + flow_shift, + }; + return sd_ctx_params; + } +}; + +template +static std::string vec_to_string(const std::vector& v) { + std::ostringstream oss; + oss << "["; + for (size_t i = 0; i < v.size(); i++) { + oss << v[i]; + if (i + 1 < v.size()) + oss << ", "; + } + oss << "]"; + return oss.str(); +} + +static std::string vec_str_to_string(const std::vector& v) { + std::ostringstream oss; + oss << "["; + for (size_t i = 0; i < v.size(); i++) { + oss << "\"" << v[i] << "\""; + if (i + 1 < v.size()) + oss << ", "; + } + oss << "]"; + return oss.str(); +} + +static bool is_absolute_path(const std::string& p) { +#ifdef _WIN32 + // Windows: C:/path or C:\path + return p.size() > 1 && std::isalpha(static_cast(p[0])) && p[1] == ':'; +#else + return !p.empty() && p[0] == '/'; +#endif +} + +struct SDGenerationParams { + std::string prompt; + std::string negative_prompt; + int clip_skip = -1; // <= 0 represents unspecified + int width = 512; + int height = 512; + int batch_count = 1; + std::string init_image_path; + std::string end_image_path; + std::string mask_image_path; + std::string control_image_path; + std::vector ref_image_paths; + std::string control_video_path; + bool auto_resize_ref_image = true; + bool increase_ref_index = false; + + std::vector skip_layers = {7, 8, 9}; + sd_sample_params_t sample_params; + + std::vector high_noise_skip_layers = {7, 8, 9}; + sd_sample_params_t high_noise_sample_params; + + std::string easycache_option; + sd_easycache_params_t easycache_params; + + std::string ucache_option; + sd_ucache_params_t ucache_params; + + float moe_boundary = 0.875f; + int video_frames = 1; + int fps = 16; + float vace_strength = 1.f; + + float strength = 0.75f; + float control_strength = 0.9f; + + int64_t seed = 42; + + // Photo Maker + std::string pm_id_images_dir; + std::string pm_id_embed_path; + float pm_style_strength = 20.f; + + int upscale_repeats = 1; + + std::map lora_map; + std::map high_noise_lora_map; + std::vector lora_vec; + + SDGenerationParams() { + sd_sample_params_init(&sample_params); + sd_sample_params_init(&high_noise_sample_params); + } + + ArgOptions get_options() { + ArgOptions options; + options.string_options = { + {"-p", + "--prompt", + "the prompt to render", + &prompt}, + {"-n", + "--negative-prompt", + "the negative prompt (default: \"\")", + &negative_prompt}, + {"-i", + "--init-img", + "path to the init image", + &init_image_path}, + {"", + "--end-img", + "path to the end image, required by flf2v", + &end_image_path}, + {"", + "--mask", + "path to the mask image", + &mask_image_path}, + {"", + "--control-image", + "path to control image, control net", + &control_image_path}, + {"", + "--control-video", + "path to control video frames, It must be a directory path. The video frames inside should be stored as images in " + "lexicographical (character) order. For example, if the control video path is `frames`, the directory contain images " + "such as 00.png, 01.png, ... etc.", + &control_video_path}, + {"", + "--pm-id-images-dir", + "path to PHOTOMAKER input id images dir", + &pm_id_images_dir}, + {"", + "--pm-id-embed-path", + "path to PHOTOMAKER v2 id embed", + &pm_id_embed_path}, + }; + + options.int_options = { + {"-H", + "--height", + "image height, in pixel space (default: 512)", + &height}, + {"-W", + "--width", + "image width, in pixel space (default: 512)", + &width}, + {"", + "--steps", + "number of sample steps (default: 20)", + &sample_params.sample_steps}, + {"", + "--high-noise-steps", + "(high noise) number of sample steps (default: -1 = auto)", + &high_noise_sample_params.sample_steps}, + {"", + "--clip-skip", + "ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1). " + "<= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x", + &clip_skip}, + {"-b", + "--batch-count", + "batch count", + &batch_count}, + {"", + "--video-frames", + "video frames (default: 1)", + &video_frames}, + {"", + "--fps", + "fps (default: 24)", + &fps}, + {"", + "--timestep-shift", + "shift timestep for NitroFusion models (default: 0). " + "recommended N for NitroSD-Realism around 250 and 500 for NitroSD-Vibrant", + &sample_params.shifted_timestep}, + {"", + "--upscale-repeats", + "Run the ESRGAN upscaler this many times (default: 1)", + &upscale_repeats}, + }; + + options.float_options = { + {"", + "--cfg-scale", + "unconditional guidance scale: (default: 7.0)", + &sample_params.guidance.txt_cfg}, + {"", + "--img-cfg-scale", + "image guidance scale for inpaint or instruct-pix2pix models: (default: same as --cfg-scale)", + &sample_params.guidance.img_cfg}, + {"", + "--guidance", + "distilled guidance scale for models with guidance input (default: 3.5)", + &sample_params.guidance.distilled_guidance}, + {"", + "--slg-scale", + "skip layer guidance (SLG) scale, only for DiT models: (default: 0). 0 means disabled, a value of 2.5 is nice for sd3.5 medium", + &sample_params.guidance.slg.scale}, + {"", + "--skip-layer-start", + "SLG enabling point (default: 0.01)", + &sample_params.guidance.slg.layer_start}, + {"", + "--skip-layer-end", + "SLG disabling point (default: 0.2)", + &sample_params.guidance.slg.layer_end}, + {"", + "--eta", + "eta in DDIM, only for DDIM and TCD (default: 0)", + &sample_params.eta}, + {"", + "--high-noise-cfg-scale", + "(high noise) unconditional guidance scale: (default: 7.0)", + &high_noise_sample_params.guidance.txt_cfg}, + {"", + "--high-noise-img-cfg-scale", + "(high noise) image guidance scale for inpaint or instruct-pix2pix models (default: same as --cfg-scale)", + &high_noise_sample_params.guidance.img_cfg}, + {"", + "--high-noise-guidance", + "(high noise) distilled guidance scale for models with guidance input (default: 3.5)", + &high_noise_sample_params.guidance.distilled_guidance}, + {"", + "--high-noise-slg-scale", + "(high noise) skip layer guidance (SLG) scale, only for DiT models: (default: 0)", + &high_noise_sample_params.guidance.slg.scale}, + {"", + "--high-noise-skip-layer-start", + "(high noise) SLG enabling point (default: 0.01)", + &high_noise_sample_params.guidance.slg.layer_start}, + {"", + "--high-noise-skip-layer-end", + "(high noise) SLG disabling point (default: 0.2)", + &high_noise_sample_params.guidance.slg.layer_end}, + {"", + "--high-noise-eta", + "(high noise) eta in DDIM, only for DDIM and TCD (default: 0)", + &high_noise_sample_params.eta}, + {"", + "--strength", + "strength for noising/unnoising (default: 0.75)", + &strength}, + {"", + "--pm-style-strength", + "", + &pm_style_strength}, + {"", + "--control-strength", + "strength to apply Control Net (default: 0.9). 1.0 corresponds to full destruction of information in init image", + &control_strength}, + {"", + "--moe-boundary", + "timestep boundary for Wan2.2 MoE model. (default: 0.875). Only enabled if `--high-noise-steps` is set to -1", + &moe_boundary}, + {"", + "--vace-strength", + "wan vace strength", + &vace_strength}, }; options.bool_options = { {"", - "--canny", - "apply canny preprocessor (edge detection)", - true, &canny_preprocess}, - {"-v", - "--verbose", - "print extra info", - true, &verbose}, - {"", - "--color", - "colors the logging tags according to level", - true, &color}, - {"", - "--taesd-preview-only", - std::string("prevents usage of taesd for decoding the final image. (for use with --preview ") + previews_str[PREVIEW_TAE] + ")", - true, &taesd_preview}, + "--increase-ref-index", + "automatically increase the indices of references images based on the order they are listed (starting with 1).", + true, + &increase_ref_index}, {"", - "--preview-noisy", - "enables previewing noisy inputs of the models rather than the denoised outputs", - true, &preview_noisy}, + "--disable-auto-resize-ref-image", + "disable auto resize of ref images", + false, + &auto_resize_ref_image}, + }; + auto on_seed_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; + } + seed = std::stoll(argv[index]); + return 1; }; - auto on_mode_arg = [&](int argc, const char** argv, int index) { + auto on_sample_method_arg = [&](int argc, const char** argv, int index) { if (++index >= argc) { return -1; } - const char* mode_c_str = argv[index]; - if (mode_c_str != nullptr) { - int mode_found = -1; - for (int i = 0; i < MODE_COUNT; i++) { - if (!strcmp(mode_c_str, modes_str[i])) { - mode_found = i; - } - } - if (mode_found == -1) { - LOG_ERROR("error: invalid mode %s, must be one of [%s]\n", - mode_c_str, SD_ALL_MODES_STR); - exit(1); + const char* arg = argv[index]; + sample_params.sample_method = str_to_sample_method(arg); + if (sample_params.sample_method == SAMPLE_METHOD_COUNT) { + fprintf(stderr, "error: invalid sample method %s\n", + arg); + return -1; + } + return 1; + }; + + auto on_high_noise_sample_method_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; + } + const char* arg = argv[index]; + high_noise_sample_params.sample_method = str_to_sample_method(arg); + if (high_noise_sample_params.sample_method == SAMPLE_METHOD_COUNT) { + fprintf(stderr, "error: invalid high noise sample method %s\n", + arg); + return -1; + } + return 1; + }; + + auto on_scheduler_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; + } + const char* arg = argv[index]; + sample_params.scheduler = str_to_scheduler(arg); + if (sample_params.scheduler == SCHEDULER_COUNT) { + fprintf(stderr, "error: invalid scheduler %s\n", + arg); + return -1; + } + return 1; + }; + + auto on_skip_layers_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; + } + std::string layers_str = argv[index]; + if (layers_str[0] != '[' || layers_str[layers_str.size() - 1] != ']') { + return -1; + } + + layers_str = layers_str.substr(1, layers_str.size() - 2); + + std::regex regex("[, ]+"); + std::sregex_token_iterator iter(layers_str.begin(), layers_str.end(), regex, -1); + std::sregex_token_iterator end; + std::vector tokens(iter, end); + std::vector layers; + for (const auto& token : tokens) { + try { + layers.push_back(std::stoi(token)); + } catch (const std::invalid_argument&) { + return -1; } - mode = (SDMode)mode_found; } + skip_layers = layers; return 1; }; - auto on_preview_arg = [&](int argc, const char** argv, int index) { + auto on_high_noise_skip_layers_arg = [&](int argc, const char** argv, int index) { if (++index >= argc) { return -1; } - const char* preview = argv[index]; - int preview_found = -1; - for (int m = 0; m < PREVIEW_COUNT; m++) { - if (!strcmp(preview, previews_str[m])) { - preview_found = m; + std::string layers_str = argv[index]; + if (layers_str[0] != '[' || layers_str[layers_str.size() - 1] != ']') { + return -1; + } + + layers_str = layers_str.substr(1, layers_str.size() - 2); + + std::regex regex("[, ]+"); + std::sregex_token_iterator iter(layers_str.begin(), layers_str.end(), regex, -1); + std::sregex_token_iterator end; + std::vector tokens(iter, end); + std::vector layers; + for (const auto& token : tokens) { + try { + layers.push_back(std::stoi(token)); + } catch (const std::invalid_argument&) { + return -1; } } - if (preview_found == -1) { - LOG_ERROR("error: preview method %s", preview); + high_noise_skip_layers = layers; + return 1; + }; + + auto on_ref_image_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { return -1; } - preview_method = (preview_t)preview_found; + ref_image_paths.push_back(argv[index]); return 1; }; - auto on_help_arg = [&](int argc, const char** argv, int index) { - normal_exit = true; - return -1; + auto on_easycache_arg = [&](int argc, const char** argv, int index) { + const std::string default_values = "0.2,0.15,0.95"; + auto looks_like_value = [](const std::string& token) { + if (token.empty()) { + return false; + } + if (token[0] != '-') { + return true; + } + if (token.size() == 1) { + return false; + } + unsigned char next = static_cast(token[1]); + return std::isdigit(next) || token[1] == '.'; + }; + + std::string option_value; + int consumed = 0; + if (index + 1 < argc) { + std::string next_arg = argv[index + 1]; + if (looks_like_value(next_arg)) { + option_value = argv_to_utf8(index + 1, argv); + consumed = 1; + } + } + if (option_value.empty()) { + option_value = default_values; + } + easycache_option = option_value; + return consumed; + }; + + auto on_ucache_arg = [&](int argc, const char** argv, int index) { + const std::string default_values = "1.0,0.15,0.95"; + auto looks_like_value = [](const std::string& token) { + if (token.empty()) { + return false; + } + if (token[0] != '-') { + return true; + } + if (token.size() == 1) { + return false; + } + unsigned char next = static_cast(token[1]); + return std::isdigit(next) || token[1] == '.'; + }; + + std::string option_value; + int consumed = 0; + if (index + 1 < argc) { + std::string next_arg = argv[index + 1]; + if (looks_like_value(next_arg)) { + option_value = argv_to_utf8(index + 1, argv); + consumed = 1; + } + } + if (option_value.empty()) { + option_value = default_values; + } + ucache_option = option_value; + return consumed; }; options.manual_options = { - {"-M", - "--mode", - "run mode, one of [img_gen, vid_gen, upscale, convert], default: img_gen", - on_mode_arg}, + {"-s", + "--seed", + "RNG seed (default: 42, use random seed for < 0)", + on_seed_arg}, {"", - "--preview", - std::string("preview method. must be one of the following [") + previews_str[0] + ", " + previews_str[1] + ", " + previews_str[2] + ", " + previews_str[3] + "] (default is " + previews_str[PREVIEW_NONE] + ")", - on_preview_arg}, - {"-h", - "--help", - "show this help message and exit", - on_help_arg}, + "--sampling-method", + "sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd] " + "(default: euler for Flux/SD3/Wan, euler_a otherwise)", + on_sample_method_arg}, + {"", + "--high-noise-sampling-method", + "(high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd]" + " default: euler for Flux/SD3/Wan, euler_a otherwise", + on_high_noise_sample_method_arg}, + {"", + "--scheduler", + "denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, lcm], default: discrete", + on_scheduler_arg}, + {"", + "--skip-layers", + "layers to skip for SLG steps (default: [7,8,9])", + on_skip_layers_arg}, + {"", + "--high-noise-skip-layers", + "(high noise) layers to skip for SLG steps (default: [7,8,9])", + on_high_noise_skip_layers_arg}, + {"-r", + "--ref-image", + "reference image for Flux Kontext models (can be used multiple times)", + on_ref_image_arg}, + {"", + "--easycache", + "enable EasyCache for DiT models with optional \"threshold,start_percent,end_percent\" (default: 0.2,0.15,0.95)", + on_easycache_arg}, + {"", + "--ucache", + "enable UCache for UNET models (SD1.x/SD2.x/SDXL) with optional \"threshold,start_percent,end_percent\" (default: 1.0,0.15,0.95)", + on_ucache_arg}, + }; return options; - }; + } - bool process_and_check() { - if (output_path.length() == 0) { - LOG_ERROR("error: the following arguments are required: output_path"); + void extract_and_remove_lora(const std::string& lora_model_dir) { + static const std::regex re(R"(]+):([^>]+)>)"); + static const std::vector valid_ext = {".pt", ".safetensors", ".gguf"}; + std::smatch m; + + std::string tmp = prompt; + + while (std::regex_search(tmp, m, re)) { + std::string raw_path = m[1].str(); + const std::string raw_mul = m[2].str(); + + float mul = 0.f; + try { + mul = std::stof(raw_mul); + } catch (...) { + tmp = m.suffix().str(); + prompt = std::regex_replace(prompt, re, "", std::regex_constants::format_first_only); + continue; + } + + bool is_high_noise = false; + static const std::string prefix = "|high_noise|"; + if (raw_path.rfind(prefix, 0) == 0) { + raw_path.erase(0, prefix.size()); + is_high_noise = true; + } + + fs::path final_path; + if (is_absolute_path(raw_path)) { + final_path = raw_path; + } else { + final_path = fs::path(lora_model_dir) / raw_path; + } + if (!fs::exists(final_path)) { + bool found = false; + for (const auto& ext : valid_ext) { + fs::path try_path = final_path; + try_path += ext; + if (fs::exists(try_path)) { + final_path = try_path; + found = true; + break; + } + } + if (!found) { + printf("can not found lora %s\n", final_path.lexically_normal().string().c_str()); + tmp = m.suffix().str(); + prompt = std::regex_replace(prompt, re, "", std::regex_constants::format_first_only); + continue; + } + } + + const std::string key = final_path.lexically_normal().string(); + + if (is_high_noise) + high_noise_lora_map[key] += mul; + else + lora_map[key] += mul; + + prompt = std::regex_replace(prompt, re, "", std::regex_constants::format_first_only); + + tmp = m.suffix().str(); + } + + for (const auto& kv : lora_map) { + sd_lora_t item; + item.is_high_noise = false; + item.path = kv.first.c_str(); + item.multiplier = kv.second; + lora_vec.emplace_back(item); + } + + for (const auto& kv : high_noise_lora_map) { + sd_lora_t item; + item.is_high_noise = true; + item.path = kv.first.c_str(); + item.multiplier = kv.second; + lora_vec.emplace_back(item); + } + } + + bool process_and_check(SDMode mode, const std::string& lora_model_dir) { + if (width <= 0) { + fprintf(stderr, "error: the width must be greater than 0\n"); return false; } - if (mode == CONVERT) { - if (output_path == "output.png") { - output_path = "output.gguf"; + if (height <= 0) { + fprintf(stderr, "error: the height must be greater than 0\n"); + return false; + } + + if (sample_params.sample_steps <= 0) { + fprintf(stderr, "error: the sample_steps must be greater than 0\n"); + return false; + } + + if (high_noise_sample_params.sample_steps <= 0) { + high_noise_sample_params.sample_steps = -1; + } + + if (strength < 0.f || strength > 1.f) { + fprintf(stderr, "error: can only work with strength in [0.0, 1.0]\n"); + return false; + } + + if (!easycache_option.empty()) { + float values[3] = {0.0f, 0.0f, 0.0f}; + std::stringstream ss(easycache_option); + std::string token; + int idx = 0; + while (std::getline(ss, token, ',')) { + auto trim = [](std::string& s) { + const char* whitespace = " \t\r\n"; + auto start = s.find_first_not_of(whitespace); + if (start == std::string::npos) { + s.clear(); + return; + } + auto end = s.find_last_not_of(whitespace); + s = s.substr(start, end - start + 1); + }; + trim(token); + if (token.empty()) { + fprintf(stderr, "error: invalid easycache option '%s'\n", easycache_option.c_str()); + return false; + } + if (idx >= 3) { + fprintf(stderr, "error: easycache expects exactly 3 comma-separated values (threshold,start,end)\n"); + return false; + } + try { + values[idx] = std::stof(token); + } catch (const std::exception&) { + fprintf(stderr, "error: invalid easycache value '%s'\n", token.c_str()); + return false; + } + idx++; + } + if (idx != 3) { + fprintf(stderr, "error: easycache expects exactly 3 comma-separated values (threshold,start,end)\n"); + return false; + } + if (values[0] < 0.0f) { + fprintf(stderr, "error: easycache threshold must be non-negative\n"); + return false; + } + if (values[1] < 0.0f || values[1] >= 1.0f || values[2] <= 0.0f || values[2] > 1.0f || values[1] >= values[2]) { + fprintf(stderr, "error: easycache start/end percents must satisfy 0.0 <= start < end <= 1.0\n"); + return false; + } + easycache_params.enabled = true; + easycache_params.reuse_threshold = values[0]; + easycache_params.start_percent = values[1]; + easycache_params.end_percent = values[2]; + } else { + easycache_params.enabled = false; + } + + if (!ucache_option.empty()) { + float values[3] = {0.0f, 0.0f, 0.0f}; + std::stringstream ss(ucache_option); + std::string token; + int idx = 0; + while (std::getline(ss, token, ',')) { + auto trim = [](std::string& s) { + const char* whitespace = " \t\r\n"; + auto start = s.find_first_not_of(whitespace); + if (start == std::string::npos) { + s.clear(); + return; + } + auto end = s.find_last_not_of(whitespace); + s = s.substr(start, end - start + 1); + }; + trim(token); + if (token.empty()) { + fprintf(stderr, "error: invalid ucache option '%s'\n", ucache_option.c_str()); + return false; + } + if (idx >= 3) { + fprintf(stderr, "error: ucache expects exactly 3 comma-separated values (threshold,start,end)\n"); + return false; + } + try { + values[idx] = std::stof(token); + } catch (const std::exception&) { + fprintf(stderr, "error: invalid ucache value '%s'\n", token.c_str()); + return false; + } + idx++; + } + if (idx != 3) { + fprintf(stderr, "error: ucache expects exactly 3 comma-separated values (threshold,start,end)\n"); + return false; + } + if (values[0] < 0.0f) { + fprintf(stderr, "error: ucache threshold must be non-negative\n"); + return false; + } + if (values[1] < 0.0f || values[1] >= 1.0f || values[2] <= 0.0f || values[2] > 1.0f || values[1] >= values[2]) { + fprintf(stderr, "error: ucache start/end percents must satisfy 0.0 <= start < end <= 1.0\n"); + return false; + } + ucache_params.enabled = true; + ucache_params.reuse_threshold = values[0]; + ucache_params.start_percent = values[1]; + ucache_params.end_percent = values[2]; + } else { + ucache_params.enabled = false; + } + + sample_params.guidance.slg.layers = skip_layers.data(); + sample_params.guidance.slg.layer_count = skip_layers.size(); + high_noise_sample_params.guidance.slg.layers = high_noise_skip_layers.data(); + high_noise_sample_params.guidance.slg.layer_count = high_noise_skip_layers.size(); + + if (mode == VID_GEN && video_frames <= 0) { + return false; + } + + if (mode == VID_GEN && fps <= 0) { + return false; + } + + if (sample_params.shifted_timestep < 0 || sample_params.shifted_timestep > 1000) { + return false; + } + + if (upscale_repeats < 1) { + return false; + } + + if (mode == UPSCALE) { + if (init_image_path.length() == 0) { + fprintf(stderr, "error: upscale mode needs an init image (--init-img)\n"); + return false; } } + + if (seed < 0) { + srand((int)time(nullptr)); + seed = rand(); + } + + extract_and_remove_lora(lora_model_dir); + return true; } std::string to_string() const { + char* sample_params_str = sd_sample_params_to_str(&sample_params); + char* high_noise_sample_params_str = sd_sample_params_to_str(&high_noise_sample_params); + + std::ostringstream lora_ss; + lora_ss << "{\n"; + for (auto it = lora_map.begin(); it != lora_map.end(); ++it) { + lora_ss << " \"" << it->first << "\": \"" << it->second << "\""; + if (std::next(it) != lora_map.end()) { + lora_ss << ","; + } + lora_ss << "\n"; + } + lora_ss << " }"; + std::string loras_str = lora_ss.str(); + + lora_ss = std::ostringstream(); + ; + lora_ss << "{\n"; + for (auto it = high_noise_lora_map.begin(); it != high_noise_lora_map.end(); ++it) { + lora_ss << " \"" << it->first << "\": \"" << it->second << "\""; + if (std::next(it) != high_noise_lora_map.end()) { + lora_ss << ","; + } + lora_ss << "\n"; + } + lora_ss << " }"; + std::string high_noise_loras_str = lora_ss.str(); + std::ostringstream oss; - oss << "SDCliParams {\n" - << " mode: " << modes_str[mode] << ",\n" - << " output_path: \"" << output_path << "\",\n" - << " verbose: " << (verbose ? "true" : "false") << ",\n" - << " color: " << (color ? "true" : "false") << ",\n" - << " canny_preprocess: " << (canny_preprocess ? "true" : "false") << ",\n" - << " preview_method: " << previews_str[preview_method] << ",\n" - << " preview_interval: " << preview_interval << ",\n" - << " preview_path: \"" << preview_path << "\",\n" - << " preview_fps: " << preview_fps << ",\n" - << " taesd_preview: " << (taesd_preview ? "true" : "false") << ",\n" - << " preview_noisy: " << (preview_noisy ? "true" : "false") << "\n" + oss << "SDGenerationParams {\n" + << " loras: \"" << loras_str << "\",\n" + << " high_noise_loras: \"" << high_noise_loras_str << "\",\n" + << " prompt: \"" << prompt << "\",\n" + << " negative_prompt: \"" << negative_prompt << "\",\n" + << " clip_skip: " << clip_skip << ",\n" + << " width: " << width << ",\n" + << " height: " << height << ",\n" + << " batch_count: " << batch_count << ",\n" + << " init_image_path: \"" << init_image_path << "\",\n" + << " end_image_path: \"" << end_image_path << "\",\n" + << " mask_image_path: \"" << mask_image_path << "\",\n" + << " control_image_path: \"" << control_image_path << "\",\n" + << " ref_image_paths: " << vec_str_to_string(ref_image_paths) << ",\n" + << " control_video_path: \"" << control_video_path << "\",\n" + << " auto_resize_ref_image: " << (auto_resize_ref_image ? "true" : "false") << ",\n" + << " increase_ref_index: " << (increase_ref_index ? "true" : "false") << ",\n" + << " pm_id_images_dir: \"" << pm_id_images_dir << "\",\n" + << " pm_id_embed_path: \"" << pm_id_embed_path << "\",\n" + << " pm_style_strength: " << pm_style_strength << ",\n" + << " skip_layers: " << vec_to_string(skip_layers) << ",\n" + << " sample_params: " << sample_params_str << ",\n" + << " high_noise_skip_layers: " << vec_to_string(high_noise_skip_layers) << ",\n" + << " high_noise_sample_params: " << high_noise_sample_params_str << ",\n" + << " easycache_option: \"" << easycache_option << "\",\n" + << " easycache: " + << (easycache_params.enabled ? "enabled" : "disabled") + << " (threshold=" << easycache_params.reuse_threshold + << ", start=" << easycache_params.start_percent + << ", end=" << easycache_params.end_percent << "),\n" + << " moe_boundary: " << moe_boundary << ",\n" + << " video_frames: " << video_frames << ",\n" + << " fps: " << fps << ",\n" + << " vace_strength: " << vace_strength << ",\n" + << " strength: " << strength << ",\n" + << " control_strength: " << control_strength << ",\n" + << " seed: " << seed << ",\n" + << " upscale_repeats: " << upscale_repeats << ",\n" << "}"; + free(sample_params_str); + free(high_noise_sample_params_str); return oss.str(); } }; +static std::string version_string() { + return std::string("stable-diffusion.cpp version ") + sd_version() + ", commit " + sd_commit(); +} + void print_usage(int argc, const char* argv[], const std::vector& options_list) { std::cout << version_string() << "\n"; std::cout << "Usage: " << argv[0] << " [options]\n\n"; @@ -213,7 +1844,7 @@ void parse_args(int argc, const char** argv, SDCliParams& cli_params, SDContextP } std::string get_image_params(const SDCliParams& cli_params, const SDContextParams& ctx_params, const SDGenerationParams& gen_params, int64_t seed) { - std::string parameter_string = gen_params.prompt_with_lora + "\n"; + std::string parameter_string = gen_params.prompt + "\n"; if (gen_params.negative_prompt.size() != 0) { parameter_string += "Negative prompt: " + gen_params.negative_prompt + "\n"; } @@ -239,15 +1870,7 @@ std::string get_image_params(const SDCliParams& cli_params, const SDContextParam parameter_string += "Sampler RNG: " + std::string(sd_rng_type_name(ctx_params.sampler_rng_type)) + ", "; } parameter_string += "Sampler: " + std::string(sd_sample_method_name(gen_params.sample_params.sample_method)); - if (!gen_params.custom_sigmas.empty()) { - parameter_string += ", Custom Sigmas: ["; - for (size_t i = 0; i < gen_params.custom_sigmas.size(); ++i) { - std::ostringstream oss; - oss << std::fixed << std::setprecision(4) << gen_params.custom_sigmas[i]; - parameter_string += oss.str() + (i == gen_params.custom_sigmas.size() - 1 ? "" : ", "); - } - parameter_string += "]"; - } else if (gen_params.sample_params.scheduler != SCHEDULER_COUNT) { // Only show schedule if not using custom sigmas + if (gen_params.sample_params.scheduler != SCHEDULER_COUNT) { parameter_string += " " + std::string(sd_scheduler_name(gen_params.sample_params.scheduler)); } parameter_string += ", "; @@ -274,6 +1897,94 @@ void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) { log_print(level, log, cli_params->verbose, cli_params->color); } +uint8_t* load_image(const char* image_path, int& width, int& height, int expected_width = 0, int expected_height = 0, int expected_channel = 3) { + int c = 0; + uint8_t* image_buffer = (uint8_t*)stbi_load(image_path, &width, &height, &c, expected_channel); + if (image_buffer == nullptr) { + fprintf(stderr, "load image from '%s' failed\n", image_path); + return nullptr; + } + if (c < expected_channel) { + fprintf(stderr, + "the number of channels for the input image must be >= %d," + "but got %d channels, image_path = %s\n", + expected_channel, + c, + image_path); + free(image_buffer); + return nullptr; + } + if (width <= 0) { + fprintf(stderr, "error: the width of image must be greater than 0, image_path = %s\n", image_path); + free(image_buffer); + return nullptr; + } + if (height <= 0) { + fprintf(stderr, "error: the height of image must be greater than 0, image_path = %s\n", image_path); + free(image_buffer); + return nullptr; + } + + // Resize input image ... + if ((expected_width > 0 && expected_height > 0) && (height != expected_height || width != expected_width)) { + float dst_aspect = (float)expected_width / (float)expected_height; + float src_aspect = (float)width / (float)height; + + int crop_x = 0, crop_y = 0; + int crop_w = width, crop_h = height; + + if (src_aspect > dst_aspect) { + crop_w = (int)(height * dst_aspect); + crop_x = (width - crop_w) / 2; + } else if (src_aspect < dst_aspect) { + crop_h = (int)(width / dst_aspect); + crop_y = (height - crop_h) / 2; + } + + if (crop_x != 0 || crop_y != 0) { + printf("crop input image from %dx%d to %dx%d, image_path = %s\n", width, height, crop_w, crop_h, image_path); + uint8_t* cropped_image_buffer = (uint8_t*)malloc(crop_w * crop_h * expected_channel); + if (cropped_image_buffer == nullptr) { + fprintf(stderr, "error: allocate memory for crop\n"); + free(image_buffer); + return nullptr; + } + for (int row = 0; row < crop_h; row++) { + uint8_t* src = image_buffer + ((crop_y + row) * width + crop_x) * expected_channel; + uint8_t* dst = cropped_image_buffer + (row * crop_w) * expected_channel; + memcpy(dst, src, crop_w * expected_channel); + } + + width = crop_w; + height = crop_h; + free(image_buffer); + image_buffer = cropped_image_buffer; + } + + printf("resize input image from %dx%d to %dx%d\n", width, height, expected_width, expected_height); + int resized_height = expected_height; + int resized_width = expected_width; + + uint8_t* resized_image_buffer = (uint8_t*)malloc(resized_height * resized_width * expected_channel); + if (resized_image_buffer == nullptr) { + fprintf(stderr, "error: allocate memory for resize input image\n"); + free(image_buffer); + return nullptr; + } + stbir_resize(image_buffer, width, height, 0, + resized_image_buffer, resized_width, resized_height, 0, STBIR_TYPE_UINT8, + expected_channel, STBIR_ALPHA_CHANNEL_NONE, 0, + STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP, + STBIR_FILTER_BOX, STBIR_FILTER_BOX, + STBIR_COLORSPACE_SRGB, nullptr); + width = resized_width; + height = resized_height; + free(image_buffer); + image_buffer = resized_image_buffer; + } + return image_buffer; +} + bool load_images_from_dir(const std::string dir, std::vector& images, int expected_width = 0, @@ -306,7 +2017,7 @@ bool load_images_from_dir(const std::string dir, LOG_DEBUG("load image %zu from '%s'", images.size(), path.c_str()); int width = 0; int height = 0; - uint8_t* image_buffer = load_image_from_file(path.c_str(), width, height, expected_width, expected_height); + uint8_t* image_buffer = load_image(path.c_str(), width, height, expected_width, expected_height); if (image_buffer == nullptr) { LOG_ERROR("load image from '%s' failed", path.c_str()); return false; @@ -439,7 +2150,7 @@ int main(int argc, const char* argv[]) { int width = 0; int height = 0; - init_image.data = load_image_from_file(gen_params.init_image_path.c_str(), width, height, gen_params.width, gen_params.height); + init_image.data = load_image(gen_params.init_image_path.c_str(), width, height, gen_params.width, gen_params.height); if (init_image.data == nullptr) { LOG_ERROR("load image from '%s' failed", gen_params.init_image_path.c_str()); release_all_resources(); @@ -452,7 +2163,7 @@ int main(int argc, const char* argv[]) { int width = 0; int height = 0; - end_image.data = load_image_from_file(gen_params.end_image_path.c_str(), width, height, gen_params.width, gen_params.height); + end_image.data = load_image(gen_params.end_image_path.c_str(), width, height, gen_params.width, gen_params.height); if (end_image.data == nullptr) { LOG_ERROR("load image from '%s' failed", gen_params.end_image_path.c_str()); release_all_resources(); @@ -464,7 +2175,7 @@ int main(int argc, const char* argv[]) { int c = 0; int width = 0; int height = 0; - mask_image.data = load_image_from_file(gen_params.mask_image_path.c_str(), width, height, gen_params.width, gen_params.height, 1); + mask_image.data = load_image(gen_params.mask_image_path.c_str(), width, height, gen_params.width, gen_params.height, 1); if (mask_image.data == nullptr) { LOG_ERROR("load image from '%s' failed", gen_params.mask_image_path.c_str()); release_all_resources(); @@ -483,7 +2194,7 @@ int main(int argc, const char* argv[]) { if (gen_params.control_image_path.size() > 0) { int width = 0; int height = 0; - control_image.data = load_image_from_file(gen_params.control_image_path.c_str(), width, height, gen_params.width, gen_params.height); + control_image.data = load_image(gen_params.control_image_path.c_str(), width, height, gen_params.width, gen_params.height); if (control_image.data == nullptr) { LOG_ERROR("load image from '%s' failed", gen_params.control_image_path.c_str()); release_all_resources(); @@ -504,7 +2215,7 @@ int main(int argc, const char* argv[]) { for (auto& path : gen_params.ref_image_paths) { int width = 0; int height = 0; - uint8_t* image_buffer = load_image_from_file(path.c_str(), width, height); + uint8_t* image_buffer = load_image(path.c_str(), width, height); if (image_buffer == nullptr) { LOG_ERROR("load image from '%s' failed", path.c_str()); release_all_resources(); @@ -611,6 +2322,7 @@ int main(int argc, const char* argv[]) { }, // pm_params ctx_params.vae_tiling_params, gen_params.easycache_params, + gen_params.ucache_params, }; results = generate_image(sd_ctx, &img_gen_params); @@ -636,6 +2348,7 @@ int main(int argc, const char* argv[]) { gen_params.video_frames, gen_params.vace_strength, gen_params.easycache_params, + gen_params.ucache_params, }; results = generate_video(sd_ctx, &vid_gen_params, &num_results); @@ -655,8 +2368,7 @@ int main(int argc, const char* argv[]) { upscaler_ctx_t* upscaler_ctx = new_upscaler_ctx(ctx_params.esrgan_path.c_str(), ctx_params.offload_params_to_cpu, ctx_params.diffusion_conv_direct, - ctx_params.n_threads, - gen_params.upscale_tile_size); + ctx_params.n_threads); if (upscaler_ctx == nullptr) { LOG_ERROR("new_upscaler_ctx failed"); @@ -752,4 +2464,4 @@ int main(int argc, const char* argv[]) { release_all_resources(); return 0; -} \ No newline at end of file +} diff --git a/examples/common/common.hpp b/examples/common/common.hpp index f3a561367..9ea14153b 100644 --- a/examples/common/common.hpp +++ b/examples/common/common.hpp @@ -997,8 +997,12 @@ struct SDGenerationParams { std::vector custom_sigmas; - std::string easycache_option; - sd_easycache_params_t easycache_params; + std::string cache_mode; + std::string cache_option; + std::string cache_preset; + std::string scm_mask; + bool scm_policy_dynamic = true; + sd_cache_params_t cache_params{}; float moe_boundary = 0.875f; int video_frames = 1; @@ -1360,36 +1364,64 @@ struct SDGenerationParams { return 1; }; - auto on_easycache_arg = [&](int argc, const char** argv, int index) { - const std::string default_values = "0.2,0.15,0.95"; - auto looks_like_value = [](const std::string& token) { - if (token.empty()) { - return false; - } - if (token[0] != '-') { - return true; - } - if (token.size() == 1) { - return false; - } - unsigned char next = static_cast(token[1]); - return std::isdigit(next) || token[1] == '.'; - }; - - std::string option_value; - int consumed = 0; - if (index + 1 < argc) { - std::string next_arg = argv[index + 1]; - if (looks_like_value(next_arg)) { - option_value = argv_to_utf8(index + 1, argv); - consumed = 1; - } + auto on_cache_mode_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; } - if (option_value.empty()) { - option_value = default_values; + cache_mode = argv_to_utf8(index, argv); + if (cache_mode != "easycache" && cache_mode != "ucache" && + cache_mode != "dbcache" && cache_mode != "taylorseer" && cache_mode != "cache-dit") { + fprintf(stderr, "error: invalid cache mode '%s', must be 'easycache', 'ucache', 'dbcache', 'taylorseer', or 'cache-dit'\n", cache_mode.c_str()); + return -1; } - easycache_option = option_value; - return consumed; + return 1; + }; + + auto on_cache_option_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; + } + cache_option = argv_to_utf8(index, argv); + return 1; + }; + + auto on_scm_mask_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; + } + scm_mask = argv_to_utf8(index, argv); + return 1; + }; + + auto on_scm_policy_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; + } + std::string policy = argv_to_utf8(index, argv); + if (policy == "dynamic") { + scm_policy_dynamic = true; + } else if (policy == "static") { + scm_policy_dynamic = false; + } else { + fprintf(stderr, "error: invalid scm policy '%s', must be 'dynamic' or 'static'\n", policy.c_str()); + return -1; + } + return 1; + }; + + auto on_cache_preset_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; + } + cache_preset = argv_to_utf8(index, argv); + if (cache_preset != "slow" && cache_preset != "s" && cache_preset != "S" && + cache_preset != "medium" && cache_preset != "m" && cache_preset != "M" && + cache_preset != "fast" && cache_preset != "f" && cache_preset != "F" && + cache_preset != "ultra" && cache_preset != "u" && cache_preset != "U") { + fprintf(stderr, "error: invalid cache preset '%s', must be 'slow'/'s', 'medium'/'m', 'fast'/'f', or 'ultra'/'u'\n", cache_preset.c_str()); + return -1; + } + return 1; }; options.manual_options = { @@ -1428,9 +1460,25 @@ struct SDGenerationParams { "reference image for Flux Kontext models (can be used multiple times)", on_ref_image_arg}, {"", - "--easycache", - "enable EasyCache for DiT models with optional \"threshold,start_percent,end_percent\" (default: 0.2,0.15,0.95)", - on_easycache_arg}, + "--cache-mode", + "caching method: 'easycache'/'ucache' (legacy), 'dbcache'/'taylorseer'/'cache-dit' (DiT block-level)", + on_cache_mode_arg}, + {"", + "--cache-option", + "named cache params (key=value format, comma-separated):\n - easycache/ucache: threshold=,start=,end=,decay=,relative=,reset=\n - dbcache/taylorseer/cache-dit: Fn=,Bn=,threshold=,warmup=\n Examples: \"threshold=0.25\" or \"threshold=1.5,reset=0\"", + on_cache_option_arg}, + {"", + "--cache-preset", + "cache-dit preset: 'slow'/'s', 'medium'/'m', 'fast'/'f', 'ultra'/'u'", + on_cache_preset_arg}, + {"", + "--scm-mask", + "SCM steps mask for cache-dit: comma-separated 0/1 (e.g., \"1,1,1,0,0,1,0,0,1,0\") - 1=compute, 0=can cache", + on_scm_mask_arg}, + {"", + "--scm-policy", + "SCM policy: 'dynamic' (default) or 'static'", + on_scm_policy_arg}, }; @@ -1473,7 +1521,10 @@ struct SDGenerationParams { load_if_exists("prompt", prompt); load_if_exists("negative_prompt", negative_prompt); - load_if_exists("easycache_option", easycache_option); + load_if_exists("cache_mode", cache_mode); + load_if_exists("cache_option", cache_option); + load_if_exists("cache_preset", cache_preset); + load_if_exists("scm_mask", scm_mask); load_if_exists("clip_skip", clip_skip); load_if_exists("width", width); @@ -1613,57 +1664,118 @@ struct SDGenerationParams { return false; } - if (!easycache_option.empty()) { - float values[3] = {0.0f, 0.0f, 0.0f}; - std::stringstream ss(easycache_option); + sd_cache_params_init(&cache_params); + + auto parse_named_params = [&](const std::string& opt_str) -> bool { + std::stringstream ss(opt_str); std::string token; - int idx = 0; while (std::getline(ss, token, ',')) { - auto trim = [](std::string& s) { - const char* whitespace = " \t\r\n"; - auto start = s.find_first_not_of(whitespace); - if (start == std::string::npos) { - s.clear(); - return; - } - auto end = s.find_last_not_of(whitespace); - s = s.substr(start, end - start + 1); - }; - trim(token); - if (token.empty()) { - LOG_ERROR("error: invalid easycache option '%s'", easycache_option.c_str()); - return false; - } - if (idx >= 3) { - LOG_ERROR("error: easycache expects exactly 3 comma-separated values (threshold,start,end)\n"); + size_t eq_pos = token.find('='); + if (eq_pos == std::string::npos) { + LOG_ERROR("error: cache option '%s' missing '=' separator", token.c_str()); return false; } + std::string key = token.substr(0, eq_pos); + std::string val = token.substr(eq_pos + 1); try { - values[idx] = std::stof(token); + if (key == "threshold") { + if (cache_mode == "easycache" || cache_mode == "ucache") { + cache_params.reuse_threshold = std::stof(val); + } else { + cache_params.residual_diff_threshold = std::stof(val); + } + } else if (key == "start") { + cache_params.start_percent = std::stof(val); + } else if (key == "end") { + cache_params.end_percent = std::stof(val); + } else if (key == "decay") { + cache_params.error_decay_rate = std::stof(val); + } else if (key == "relative") { + cache_params.use_relative_threshold = (std::stof(val) != 0.0f); + } else if (key == "reset") { + cache_params.reset_error_on_compute = (std::stof(val) != 0.0f); + } else if (key == "Fn" || key == "fn") { + cache_params.Fn_compute_blocks = std::stoi(val); + } else if (key == "Bn" || key == "bn") { + cache_params.Bn_compute_blocks = std::stoi(val); + } else if (key == "warmup") { + cache_params.max_warmup_steps = std::stoi(val); + } else { + LOG_ERROR("error: unknown cache parameter '%s'", key.c_str()); + return false; + } } catch (const std::exception&) { - LOG_ERROR("error: invalid easycache value '%s'", token.c_str()); + LOG_ERROR("error: invalid value '%s' for parameter '%s'", val.c_str(), key.c_str()); return false; } - idx++; } - if (idx != 3) { - LOG_ERROR("error: easycache expects exactly 3 comma-separated values (threshold,start,end)\n"); - return false; + return true; + }; + + if (!cache_mode.empty()) { + if (cache_mode == "easycache") { + cache_params.mode = SD_CACHE_EASYCACHE; + cache_params.reuse_threshold = 0.2f; + cache_params.start_percent = 0.15f; + cache_params.end_percent = 0.95f; + cache_params.error_decay_rate = 1.0f; + cache_params.use_relative_threshold = true; + cache_params.reset_error_on_compute = true; + } else if (cache_mode == "ucache") { + cache_params.mode = SD_CACHE_UCACHE; + cache_params.reuse_threshold = 1.0f; + cache_params.start_percent = 0.15f; + cache_params.end_percent = 0.95f; + cache_params.error_decay_rate = 1.0f; + cache_params.use_relative_threshold = true; + cache_params.reset_error_on_compute = true; + } else if (cache_mode == "dbcache") { + cache_params.mode = SD_CACHE_DBCACHE; + cache_params.Fn_compute_blocks = 8; + cache_params.Bn_compute_blocks = 0; + cache_params.residual_diff_threshold = 0.08f; + cache_params.max_warmup_steps = 8; + } else if (cache_mode == "taylorseer") { + cache_params.mode = SD_CACHE_TAYLORSEER; + cache_params.Fn_compute_blocks = 8; + cache_params.Bn_compute_blocks = 0; + cache_params.residual_diff_threshold = 0.08f; + cache_params.max_warmup_steps = 8; + } else if (cache_mode == "cache-dit") { + cache_params.mode = SD_CACHE_CACHE_DIT; + cache_params.Fn_compute_blocks = 8; + cache_params.Bn_compute_blocks = 0; + cache_params.residual_diff_threshold = 0.08f; + cache_params.max_warmup_steps = 8; } - if (values[0] < 0.0f) { - LOG_ERROR("error: easycache threshold must be non-negative\n"); - return false; + + if (!cache_option.empty()) { + if (!parse_named_params(cache_option)) { + return false; + } } - if (values[1] < 0.0f || values[1] >= 1.0f || values[2] <= 0.0f || values[2] > 1.0f || values[1] >= values[2]) { - LOG_ERROR("error: easycache start/end percents must satisfy 0.0 <= start < end <= 1.0\n"); - return false; + + if (cache_mode == "easycache" || cache_mode == "ucache") { + if (cache_params.reuse_threshold < 0.0f) { + LOG_ERROR("error: cache threshold must be non-negative"); + return false; + } + if (cache_params.start_percent < 0.0f || cache_params.start_percent >= 1.0f || + cache_params.end_percent <= 0.0f || cache_params.end_percent > 1.0f || + cache_params.start_percent >= cache_params.end_percent) { + LOG_ERROR("error: cache start/end percents must satisfy 0.0 <= start < end <= 1.0"); + return false; + } + } + } + + if (cache_params.mode == SD_CACHE_DBCACHE || + cache_params.mode == SD_CACHE_TAYLORSEER || + cache_params.mode == SD_CACHE_CACHE_DIT) { + if (!scm_mask.empty()) { + cache_params.scm_mask = scm_mask.c_str(); } - easycache_params.enabled = true; - easycache_params.reuse_threshold = values[0]; - easycache_params.start_percent = values[1]; - easycache_params.end_percent = values[2]; - } else { - easycache_params.enabled = false; + cache_params.scm_policy_dynamic = scm_policy_dynamic; } sample_params.guidance.slg.layers = skip_layers.data(); @@ -1765,12 +1877,13 @@ struct SDGenerationParams { << " high_noise_skip_layers: " << vec_to_string(high_noise_skip_layers) << ",\n" << " high_noise_sample_params: " << high_noise_sample_params_str << ",\n" << " custom_sigmas: " << vec_to_string(custom_sigmas) << ",\n" - << " easycache_option: \"" << easycache_option << "\",\n" - << " easycache: " - << (easycache_params.enabled ? "enabled" : "disabled") - << " (threshold=" << easycache_params.reuse_threshold - << ", start=" << easycache_params.start_percent - << ", end=" << easycache_params.end_percent << "),\n" + << " cache_mode: \"" << cache_mode << "\",\n" + << " cache_option: \"" << cache_option << "\",\n" + << " cache: " + << (cache_params.mode != SD_CACHE_DISABLED ? "enabled" : "disabled") + << " (threshold=" << cache_params.reuse_threshold + << ", start=" << cache_params.start_percent + << ", end=" << cache_params.end_percent << "),\n" << " moe_boundary: " << moe_boundary << ",\n" << " video_frames: " << video_frames << ",\n" << " fps: " << fps << ",\n" diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 44bd3ccac..326ad59ad 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -13,6 +13,7 @@ #include "diffusion_model.hpp" #include "easycache.hpp" #include "esrgan.hpp" +#include "ucache.hpp" #include "lora.hpp" #include "pmid.hpp" #include "tae.hpp" @@ -1486,7 +1487,8 @@ class StableDiffusionGGML { ggml_tensor* denoise_mask = nullptr, ggml_tensor* vace_context = nullptr, float vace_strength = 1.f, - const sd_easycache_params_t* easycache_params = nullptr) { + const sd_easycache_params_t* easycache_params = nullptr, + const sd_ucache_params_t* ucache_params = nullptr) { if (shifted_timestep > 0 && !sd_version_is_sdxl(version)) { LOG_WARN("timestep shifting is only supported for SDXL models!"); shifted_timestep = 0; @@ -1538,6 +1540,42 @@ class StableDiffusionGGML { } } + UCacheState ucache_state; + bool ucache_enabled = false; + if (ucache_params != nullptr && ucache_params->enabled) { + bool ucache_supported = sd_version_is_unet(version); + if (!ucache_supported) { + LOG_WARN("UCache requested but not supported for this model type (only UNET models)"); + } else { + UCacheConfig ucache_config; + ucache_config.enabled = true; + ucache_config.reuse_threshold = std::max(0.0f, ucache_params->reuse_threshold); + ucache_config.start_percent = ucache_params->start_percent; + ucache_config.end_percent = ucache_params->end_percent; + bool percent_valid = ucache_config.start_percent >= 0.0f && + ucache_config.start_percent < 1.0f && + ucache_config.end_percent > 0.0f && + ucache_config.end_percent <= 1.0f && + ucache_config.start_percent < ucache_config.end_percent; + if (!percent_valid) { + LOG_WARN("UCache disabled due to invalid percent range (start=%.3f, end=%.3f)", + ucache_config.start_percent, + ucache_config.end_percent); + } else { + ucache_state.init(ucache_config, denoiser.get()); + if (ucache_state.enabled()) { + ucache_enabled = true; + LOG_INFO("UCache enabled - threshold: %.3f, start_percent: %.2f, end_percent: %.2f", + ucache_config.reuse_threshold, + ucache_config.start_percent, + ucache_config.end_percent); + } else { + LOG_WARN("UCache requested but could not be initialized for this run"); + } + } + } + } + size_t steps = sigmas.size() - 1; struct ggml_tensor* x = ggml_dup_tensor(work_ctx, init_latent); copy_ggml_tensor(x, init_latent); @@ -1641,6 +1679,57 @@ class StableDiffusionGGML { return easycache_step_active && easycache_state.is_step_skipped(); }; + const bool ucache_step_active = ucache_enabled && step > 0; + int ucache_step_index = ucache_step_active ? (step - 1) : -1; + if (ucache_step_active) { + ucache_state.begin_step(ucache_step_index, sigma); + } + + auto ucache_before_condition = [&](const SDCondition* condition, struct ggml_tensor* output_tensor) -> bool { + if (!ucache_step_active || condition == nullptr || output_tensor == nullptr) { + return false; + } + return ucache_state.before_condition(condition, + diffusion_params.x, + output_tensor, + sigma, + ucache_step_index); + }; + + auto ucache_after_condition = [&](const SDCondition* condition, struct ggml_tensor* output_tensor) { + if (!ucache_step_active || condition == nullptr || output_tensor == nullptr) { + return; + } + ucache_state.after_condition(condition, + diffusion_params.x, + output_tensor); + }; + + auto ucache_step_is_skipped = [&]() { + return ucache_step_active && ucache_state.is_step_skipped(); + }; + + auto cache_before_condition = [&](const SDCondition* condition, struct ggml_tensor* output_tensor) -> bool { + if (easycache_step_active) { + return easycache_before_condition(condition, output_tensor); + } else if (ucache_step_active) { + return ucache_before_condition(condition, output_tensor); + } + return false; + }; + + auto cache_after_condition = [&](const SDCondition* condition, struct ggml_tensor* output_tensor) { + if (easycache_step_active) { + easycache_after_condition(condition, output_tensor); + } else if (ucache_step_active) { + ucache_after_condition(condition, output_tensor); + } + }; + + auto cache_step_is_skipped = [&]() { + return easycache_step_is_skipped() || ucache_step_is_skipped(); + }; + std::vector scaling = denoiser->get_scalings(sigma); GGML_ASSERT(scaling.size() == 3); float c_skip = scaling[0]; @@ -1716,7 +1805,7 @@ class StableDiffusionGGML { active_condition = &id_cond; } - bool skip_model = easycache_before_condition(active_condition, *active_output); + bool skip_model = cache_before_condition(active_condition, *active_output); if (!skip_model) { if (!work_diffusion_model->compute(n_threads, diffusion_params, @@ -1724,10 +1813,10 @@ class StableDiffusionGGML { LOG_ERROR("diffusion model compute failed"); return nullptr; } - easycache_after_condition(active_condition, *active_output); + cache_after_condition(active_condition, *active_output); } - bool current_step_skipped = easycache_step_is_skipped(); + bool current_step_skipped = cache_step_is_skipped(); float* negative_data = nullptr; if (has_unconditioned) { @@ -1739,12 +1828,12 @@ class StableDiffusionGGML { LOG_ERROR("controlnet compute failed"); } } - current_step_skipped = easycache_step_is_skipped(); + current_step_skipped = cache_step_is_skipped(); diffusion_params.controls = controls; diffusion_params.context = uncond.c_crossattn; diffusion_params.c_concat = uncond.c_concat; diffusion_params.y = uncond.c_vector; - bool skip_uncond = easycache_before_condition(&uncond, out_uncond); + bool skip_uncond = cache_before_condition(&uncond, out_uncond); if (!skip_uncond) { if (!work_diffusion_model->compute(n_threads, diffusion_params, @@ -1752,7 +1841,7 @@ class StableDiffusionGGML { LOG_ERROR("diffusion model compute failed"); return nullptr; } - easycache_after_condition(&uncond, out_uncond); + cache_after_condition(&uncond, out_uncond); } negative_data = (float*)out_uncond->data; } @@ -1762,7 +1851,7 @@ class StableDiffusionGGML { diffusion_params.context = img_cond.c_crossattn; diffusion_params.c_concat = img_cond.c_concat; diffusion_params.y = img_cond.c_vector; - bool skip_img_cond = easycache_before_condition(&img_cond, out_img_cond); + bool skip_img_cond = cache_before_condition(&img_cond, out_img_cond); if (!skip_img_cond) { if (!work_diffusion_model->compute(n_threads, diffusion_params, @@ -1770,7 +1859,7 @@ class StableDiffusionGGML { LOG_ERROR("diffusion model compute failed"); return nullptr; } - easycache_after_condition(&img_cond, out_img_cond); + cache_after_condition(&img_cond, out_img_cond); } img_cond_data = (float*)out_img_cond->data; } @@ -1780,7 +1869,7 @@ class StableDiffusionGGML { float* skip_layer_data = has_skiplayer ? (float*)out_skip->data : nullptr; if (is_skiplayer_step) { LOG_DEBUG("Skipping layers at step %d\n", step); - if (!easycache_step_is_skipped()) { + if (!cache_step_is_skipped()) { // skip layer (same as conditioned) diffusion_params.context = cond.c_crossattn; diffusion_params.c_concat = cond.c_concat; @@ -1884,6 +1973,26 @@ class StableDiffusionGGML { } } + if (ucache_enabled) { + size_t total_steps = sigmas.size() > 0 ? sigmas.size() - 1 : 0; + if (ucache_state.total_steps_skipped > 0 && total_steps > 0) { + if (ucache_state.total_steps_skipped < static_cast(total_steps)) { + double speedup = static_cast(total_steps) / + static_cast(total_steps - ucache_state.total_steps_skipped); + LOG_INFO("UCache skipped %d/%zu steps (%.2fx estimated speedup)", + ucache_state.total_steps_skipped, + total_steps, + speedup); + } else { + LOG_INFO("UCache skipped %d/%zu steps", + ucache_state.total_steps_skipped, + total_steps); + } + } else if (total_steps > 0) { + LOG_INFO("UCache completed without skipping steps"); + } + } + if (inverse_noise_scaling) { x = denoiser->inverse_noise_scaling(sigmas[sigmas.size() - 1], x); } @@ -2506,6 +2615,14 @@ void sd_easycache_params_init(sd_easycache_params_t* easycache_params) { easycache_params->end_percent = 0.95f; } +void sd_ucache_params_init(sd_ucache_params_t* ucache_params) { + *ucache_params = {}; + ucache_params->enabled = false; + ucache_params->reuse_threshold = 1.0f; + ucache_params->start_percent = 0.15f; + ucache_params->end_percent = 0.95f; +} + void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) { *sd_ctx_params = {}; sd_ctx_params->vae_decode_only = true; @@ -2663,6 +2780,7 @@ void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params) { sd_img_gen_params->pm_params = {nullptr, 0, nullptr, 20.f}; sd_img_gen_params->vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f}; sd_easycache_params_init(&sd_img_gen_params->easycache); + sd_ucache_params_init(&sd_img_gen_params->ucache); } char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) { @@ -2729,6 +2847,7 @@ void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params) { sd_vid_gen_params->moe_boundary = 0.875f; sd_vid_gen_params->vace_strength = 1.f; sd_easycache_params_init(&sd_vid_gen_params->easycache); + sd_ucache_params_init(&sd_vid_gen_params->ucache); } struct sd_ctx_t { @@ -2806,7 +2925,8 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, bool increase_ref_index, ggml_tensor* concat_latent = nullptr, ggml_tensor* denoise_mask = nullptr, - const sd_easycache_params_t* easycache_params = nullptr) { + const sd_easycache_params_t* easycache_params = nullptr, + const sd_ucache_params_t* ucache_params = nullptr) { if (seed < 0) { // Generally, when using the provided command line, the seed is always >0. // However, to prevent potential issues if 'stable-diffusion.cpp' is invoked as a library @@ -3095,7 +3215,8 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, denoise_mask, nullptr, 1.0f, - easycache_params); + easycache_params, + ucache_params); int64_t sampling_end = ggml_time_ms(); if (x_0 != nullptr) { // print_ggml_tensor(x_0); @@ -3429,7 +3550,8 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g sd_img_gen_params->increase_ref_index, concat_latent, denoise_mask, - &sd_img_gen_params->easycache); + &sd_img_gen_params->easycache, + &sd_img_gen_params->ucache); size_t t2 = ggml_time_ms(); @@ -3796,7 +3918,8 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s denoise_mask, vace_context, sd_vid_gen_params->vace_strength, - &sd_vid_gen_params->easycache); + &sd_vid_gen_params->easycache, + &sd_vid_gen_params->ucache); int64_t sampling_end = ggml_time_ms(); LOG_INFO("sampling(high noise) completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000); @@ -3833,7 +3956,8 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s denoise_mask, vace_context, sd_vid_gen_params->vace_strength, - &sd_vid_gen_params->easycache); + &sd_vid_gen_params->easycache, + &sd_vid_gen_params->ucache); int64_t sampling_end = ggml_time_ms(); LOG_INFO("sampling completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000); diff --git a/stable-diffusion.h b/stable-diffusion.h index 9266ba437..a65fd7e84 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -242,6 +242,13 @@ typedef struct { float end_percent; } sd_easycache_params_t; +typedef struct { + bool enabled; + float reuse_threshold; + float start_percent; + float end_percent; +} sd_ucache_params_t; + typedef struct { bool is_high_noise; float multiplier; @@ -271,6 +278,7 @@ typedef struct { sd_pm_params_t pm_params; sd_tiling_params_t vae_tiling_params; sd_easycache_params_t easycache; + sd_ucache_params_t ucache; } sd_img_gen_params_t; typedef struct { @@ -293,6 +301,7 @@ typedef struct { int video_frames; float vace_strength; sd_easycache_params_t easycache; + sd_ucache_params_t ucache; } sd_vid_gen_params_t; typedef struct sd_ctx_t sd_ctx_t; @@ -323,6 +332,7 @@ SD_API const char* sd_lora_apply_mode_name(enum lora_apply_mode_t mode); SD_API enum lora_apply_mode_t str_to_lora_apply_mode(const char* str); SD_API void sd_easycache_params_init(sd_easycache_params_t* easycache_params); +SD_API void sd_ucache_params_init(sd_ucache_params_t* ucache_params); SD_API void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params); SD_API char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params); diff --git a/ucache.hpp b/ucache.hpp new file mode 100644 index 000000000..ca3875a59 --- /dev/null +++ b/ucache.hpp @@ -0,0 +1,286 @@ +#ifndef __UCACHE_HPP__ +#define __UCACHE_HPP__ + +#include +#include +#include +#include + +#include "denoiser.hpp" +#include "ggml_extend.hpp" + +struct UCacheConfig { + bool enabled = false; + float reuse_threshold = 1.0f; + float start_percent = 0.15f; + float end_percent = 0.95f; +}; + +struct UCacheCacheEntry { + std::vector diff; +}; + +struct UCacheState { + UCacheConfig config; + Denoiser* denoiser = nullptr; + float start_sigma = std::numeric_limits::max(); + float end_sigma = 0.0f; + bool initialized = false; + bool initial_step = true; + bool skip_current_step = false; + bool step_active = false; + const SDCondition* anchor_condition = nullptr; + std::unordered_map cache_diffs; + std::vector prev_input; + std::vector prev_output; + float output_prev_norm = 0.0f; + bool has_prev_input = false; + bool has_prev_output = false; + bool has_output_prev_norm = false; + bool has_relative_transformation_rate = false; + float relative_transformation_rate = 0.0f; + float cumulative_change_rate = 0.0f; + float last_input_change = 0.0f; + bool has_last_input_change = false; + int total_steps_skipped = 0; + int current_step_index = -1; + + void reset_runtime() { + initial_step = true; + skip_current_step = false; + step_active = false; + anchor_condition = nullptr; + cache_diffs.clear(); + prev_input.clear(); + prev_output.clear(); + output_prev_norm = 0.0f; + has_prev_input = false; + has_prev_output = false; + has_output_prev_norm = false; + has_relative_transformation_rate = false; + relative_transformation_rate = 0.0f; + cumulative_change_rate = 0.0f; + last_input_change = 0.0f; + has_last_input_change = false; + total_steps_skipped = 0; + current_step_index = -1; + } + + void init(const UCacheConfig& cfg, Denoiser* d) { + config = cfg; + denoiser = d; + initialized = cfg.enabled && d != nullptr; + reset_runtime(); + if (initialized) { + start_sigma = percent_to_sigma(config.start_percent); + end_sigma = percent_to_sigma(config.end_percent); + } + } + + bool enabled() const { + return initialized && config.enabled; + } + + float percent_to_sigma(float percent) const { + if (!denoiser) { + return 0.0f; + } + if (percent <= 0.0f) { + return std::numeric_limits::max(); + } + if (percent >= 1.0f) { + return 0.0f; + } + float t = (1.0f - percent) * (TIMESTEPS - 1); + return denoiser->t_to_sigma(t); + } + + void begin_step(int step_index, float sigma) { + if (!enabled()) { + return; + } + if (step_index == current_step_index) { + return; + } + current_step_index = step_index; + skip_current_step = false; + has_last_input_change = false; + step_active = false; + + if (sigma > start_sigma) { + return; + } + if (!(sigma > end_sigma)) { + return; + } + step_active = true; + } + + bool step_is_active() const { + return enabled() && step_active; + } + + bool is_step_skipped() const { + return enabled() && step_active && skip_current_step; + } + + bool has_cache(const SDCondition* cond) const { + auto it = cache_diffs.find(cond); + return it != cache_diffs.end() && !it->second.diff.empty(); + } + + void update_cache(const SDCondition* cond, ggml_tensor* input, ggml_tensor* output) { + UCacheCacheEntry& entry = cache_diffs[cond]; + size_t ne = static_cast(ggml_nelements(output)); + entry.diff.resize(ne); + float* out_data = (float*)output->data; + float* in_data = (float*)input->data; + + for (size_t i = 0; i < ne; ++i) { + entry.diff[i] = out_data[i] - in_data[i]; + } + } + + void apply_cache(const SDCondition* cond, ggml_tensor* input, ggml_tensor* output) { + auto it = cache_diffs.find(cond); + if (it == cache_diffs.end() || it->second.diff.empty()) { + return; + } + + copy_ggml_tensor(output, input); + float* out_data = (float*)output->data; + const std::vector& diff = it->second.diff; + for (size_t i = 0; i < diff.size(); ++i) { + out_data[i] += diff[i]; + } + } + + bool before_condition(const SDCondition* cond, + ggml_tensor* input, + ggml_tensor* output, + float sigma, + int step_index) { + if (!enabled() || step_index < 0) { + return false; + } + if (step_index != current_step_index) { + begin_step(step_index, sigma); + } + if (!step_active) { + return false; + } + + if (initial_step) { + anchor_condition = cond; + initial_step = false; + } + + bool is_anchor = (cond == anchor_condition); + + if (skip_current_step) { + if (has_cache(cond)) { + apply_cache(cond, input, output); + return true; + } + return false; + } + + if (!is_anchor) { + return false; + } + + if (!has_prev_input || !has_prev_output || !has_cache(cond)) { + return false; + } + + size_t ne = static_cast(ggml_nelements(input)); + if (prev_input.size() != ne) { + return false; + } + + float* input_data = (float*)input->data; + last_input_change = 0.0f; + for (size_t i = 0; i < ne; ++i) { + last_input_change += std::fabs(input_data[i] - prev_input[i]); + } + if (ne > 0) { + last_input_change /= static_cast(ne); + } + has_last_input_change = true; + + if (has_output_prev_norm && has_relative_transformation_rate && + last_input_change > 0.0f && output_prev_norm > 0.0f) { + + float approx_output_change_rate = (relative_transformation_rate * last_input_change) / output_prev_norm; + cumulative_change_rate += approx_output_change_rate; + + if (cumulative_change_rate < config.reuse_threshold) { + skip_current_step = true; + total_steps_skipped++; + apply_cache(cond, input, output); + return true; + } else { + cumulative_change_rate = 0.0f; + } + } + + return false; + } + + void after_condition(const SDCondition* cond, ggml_tensor* input, ggml_tensor* output) { + if (!step_is_active()) { + return; + } + + update_cache(cond, input, output); + + if (cond != anchor_condition) { + return; + } + + size_t ne = static_cast(ggml_nelements(input)); + float* in_data = (float*)input->data; + prev_input.resize(ne); + for (size_t i = 0; i < ne; ++i) { + prev_input[i] = in_data[i]; + } + has_prev_input = true; + + float* out_data = (float*)output->data; + float output_change = 0.0f; + if (has_prev_output && prev_output.size() == ne) { + for (size_t i = 0; i < ne; ++i) { + output_change += std::fabs(out_data[i] - prev_output[i]); + } + if (ne > 0) { + output_change /= static_cast(ne); + } + } + + prev_output.resize(ne); + for (size_t i = 0; i < ne; ++i) { + prev_output[i] = out_data[i]; + } + has_prev_output = true; + + float mean_abs = 0.0f; + for (size_t i = 0; i < ne; ++i) { + mean_abs += std::fabs(out_data[i]); + } + output_prev_norm = (ne > 0) ? (mean_abs / static_cast(ne)) : 0.0f; + has_output_prev_norm = output_prev_norm > 0.0f; + + if (has_last_input_change && last_input_change > 0.0f && output_change > 0.0f) { + float rate = output_change / last_input_change; + if (std::isfinite(rate)) { + relative_transformation_rate = rate; + has_relative_transformation_rate = true; + } + } + + cumulative_change_rate = 0.0f; + has_last_input_change = false; + } +}; + +#endif // __UCACHE_HPP__ From f3470101c6b39cc9c8e68898d8ff0c546e31a62d Mon Sep 17 00:00:00 2001 From: rmatif Date: Tue, 9 Dec 2025 22:59:27 +0000 Subject: [PATCH 02/12] add cache-mode and cache-option --- examples/cli/README.md | 4 +- examples/cli/main.cpp | 188 +++++++++++++---------------------------- 2 files changed, 59 insertions(+), 133 deletions(-) diff --git a/examples/cli/README.md b/examples/cli/README.md index ba9a0b278..c0b41c033 100644 --- a/examples/cli/README.md +++ b/examples/cli/README.md @@ -126,6 +126,6 @@ Generation Options: --skip-layers layers to skip for SLG steps (default: [7,8,9]) --high-noise-skip-layers (high noise) layers to skip for SLG steps (default: [7,8,9]) -r, --ref-image reference image for Flux Kontext models (can be used multiple times) - --easycache enable EasyCache for DiT models with optional "threshold,start_percent,end_percent" (default: 0.2,0.15,0.95) - --ucache enable UCache for UNET models with optional "threshold,start_percent,end_percent" (default: 1,0.15,0.95) + --cache-mode caching method: 'easycache' for DiT models, 'ucache' for UNET models (SD1.x/SD2.x/SDXL) + --cache-option cache parameters "threshold,start_percent,end_percent" (default: 0.2,0.15,0.95 for easycache, 1.0,0.15,0.95 for ucache) ``` diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 7810ee8fc..a95853882 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -1054,10 +1054,9 @@ struct SDGenerationParams { std::vector high_noise_skip_layers = {7, 8, 9}; sd_sample_params_t high_noise_sample_params; - std::string easycache_option; + std::string cache_mode; + std::string cache_option; sd_easycache_params_t easycache_params; - - std::string ucache_option; sd_ucache_params_t ucache_params; float moe_boundary = 0.875f; @@ -1378,68 +1377,24 @@ struct SDGenerationParams { return 1; }; - auto on_easycache_arg = [&](int argc, const char** argv, int index) { - const std::string default_values = "0.2,0.15,0.95"; - auto looks_like_value = [](const std::string& token) { - if (token.empty()) { - return false; - } - if (token[0] != '-') { - return true; - } - if (token.size() == 1) { - return false; - } - unsigned char next = static_cast(token[1]); - return std::isdigit(next) || token[1] == '.'; - }; - - std::string option_value; - int consumed = 0; - if (index + 1 < argc) { - std::string next_arg = argv[index + 1]; - if (looks_like_value(next_arg)) { - option_value = argv_to_utf8(index + 1, argv); - consumed = 1; - } + auto on_cache_mode_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; } - if (option_value.empty()) { - option_value = default_values; + cache_mode = argv_to_utf8(index, argv); + if (cache_mode != "easycache" && cache_mode != "ucache") { + fprintf(stderr, "error: invalid cache mode '%s', must be 'easycache' or 'ucache'\n", cache_mode.c_str()); + return -1; } - easycache_option = option_value; - return consumed; + return 1; }; - auto on_ucache_arg = [&](int argc, const char** argv, int index) { - const std::string default_values = "1.0,0.15,0.95"; - auto looks_like_value = [](const std::string& token) { - if (token.empty()) { - return false; - } - if (token[0] != '-') { - return true; - } - if (token.size() == 1) { - return false; - } - unsigned char next = static_cast(token[1]); - return std::isdigit(next) || token[1] == '.'; - }; - - std::string option_value; - int consumed = 0; - if (index + 1 < argc) { - std::string next_arg = argv[index + 1]; - if (looks_like_value(next_arg)) { - option_value = argv_to_utf8(index + 1, argv); - consumed = 1; - } - } - if (option_value.empty()) { - option_value = default_values; + auto on_cache_option_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; } - ucache_option = option_value; - return consumed; + cache_option = argv_to_utf8(index, argv); + return 1; }; options.manual_options = { @@ -1474,13 +1429,13 @@ struct SDGenerationParams { "reference image for Flux Kontext models (can be used multiple times)", on_ref_image_arg}, {"", - "--easycache", - "enable EasyCache for DiT models with optional \"threshold,start_percent,end_percent\" (default: 0.2,0.15,0.95)", - on_easycache_arg}, + "--cache-mode", + "caching method: 'easycache' for DiT models, 'ucache' for UNET models (SD1.x/SD2.x/SDXL)", + on_cache_mode_arg}, {"", - "--ucache", - "enable UCache for UNET models (SD1.x/SD2.x/SDXL) with optional \"threshold,start_percent,end_percent\" (default: 1.0,0.15,0.95)", - on_ucache_arg}, + "--cache-option", + "cache parameters \"threshold,start_percent,end_percent\" (default: 0.2,0.15,0.95 for easycache, 1.0,0.15,0.95 for ucache)", + on_cache_option_arg}, }; @@ -1593,62 +1548,21 @@ struct SDGenerationParams { return false; } - if (!easycache_option.empty()) { - float values[3] = {0.0f, 0.0f, 0.0f}; - std::stringstream ss(easycache_option); - std::string token; - int idx = 0; - while (std::getline(ss, token, ',')) { - auto trim = [](std::string& s) { - const char* whitespace = " \t\r\n"; - auto start = s.find_first_not_of(whitespace); - if (start == std::string::npos) { - s.clear(); - return; - } - auto end = s.find_last_not_of(whitespace); - s = s.substr(start, end - start + 1); - }; - trim(token); - if (token.empty()) { - fprintf(stderr, "error: invalid easycache option '%s'\n", easycache_option.c_str()); - return false; - } - if (idx >= 3) { - fprintf(stderr, "error: easycache expects exactly 3 comma-separated values (threshold,start,end)\n"); - return false; - } - try { - values[idx] = std::stof(token); - } catch (const std::exception&) { - fprintf(stderr, "error: invalid easycache value '%s'\n", token.c_str()); - return false; + easycache_params.enabled = false; + ucache_params.enabled = false; + + if (!cache_mode.empty()) { + std::string option_str = cache_option; + if (option_str.empty()) { + if (cache_mode == "easycache") { + option_str = "0.2,0.15,0.95"; + } else { + option_str = "1.0,0.15,0.95"; } - idx++; - } - if (idx != 3) { - fprintf(stderr, "error: easycache expects exactly 3 comma-separated values (threshold,start,end)\n"); - return false; } - if (values[0] < 0.0f) { - fprintf(stderr, "error: easycache threshold must be non-negative\n"); - return false; - } - if (values[1] < 0.0f || values[1] >= 1.0f || values[2] <= 0.0f || values[2] > 1.0f || values[1] >= values[2]) { - fprintf(stderr, "error: easycache start/end percents must satisfy 0.0 <= start < end <= 1.0\n"); - return false; - } - easycache_params.enabled = true; - easycache_params.reuse_threshold = values[0]; - easycache_params.start_percent = values[1]; - easycache_params.end_percent = values[2]; - } else { - easycache_params.enabled = false; - } - if (!ucache_option.empty()) { float values[3] = {0.0f, 0.0f, 0.0f}; - std::stringstream ss(ucache_option); + std::stringstream ss(option_str); std::string token; int idx = 0; while (std::getline(ss, token, ',')) { @@ -1664,39 +1578,45 @@ struct SDGenerationParams { }; trim(token); if (token.empty()) { - fprintf(stderr, "error: invalid ucache option '%s'\n", ucache_option.c_str()); + fprintf(stderr, "error: invalid cache option '%s'\n", option_str.c_str()); return false; } if (idx >= 3) { - fprintf(stderr, "error: ucache expects exactly 3 comma-separated values (threshold,start,end)\n"); + fprintf(stderr, "error: cache option expects exactly 3 comma-separated values (threshold,start,end)\n"); return false; } try { values[idx] = std::stof(token); } catch (const std::exception&) { - fprintf(stderr, "error: invalid ucache value '%s'\n", token.c_str()); + fprintf(stderr, "error: invalid cache option value '%s'\n", token.c_str()); return false; } idx++; } if (idx != 3) { - fprintf(stderr, "error: ucache expects exactly 3 comma-separated values (threshold,start,end)\n"); + fprintf(stderr, "error: cache option expects exactly 3 comma-separated values (threshold,start,end)\n"); return false; } if (values[0] < 0.0f) { - fprintf(stderr, "error: ucache threshold must be non-negative\n"); + fprintf(stderr, "error: cache threshold must be non-negative\n"); return false; } if (values[1] < 0.0f || values[1] >= 1.0f || values[2] <= 0.0f || values[2] > 1.0f || values[1] >= values[2]) { - fprintf(stderr, "error: ucache start/end percents must satisfy 0.0 <= start < end <= 1.0\n"); + fprintf(stderr, "error: cache start/end percents must satisfy 0.0 <= start < end <= 1.0\n"); return false; } - ucache_params.enabled = true; - ucache_params.reuse_threshold = values[0]; - ucache_params.start_percent = values[1]; - ucache_params.end_percent = values[2]; - } else { - ucache_params.enabled = false; + + if (cache_mode == "easycache") { + easycache_params.enabled = true; + easycache_params.reuse_threshold = values[0]; + easycache_params.start_percent = values[1]; + easycache_params.end_percent = values[2]; + } else { + ucache_params.enabled = true; + ucache_params.reuse_threshold = values[0]; + ucache_params.start_percent = values[1]; + ucache_params.end_percent = values[2]; + } } sample_params.guidance.slg.layers = skip_layers.data(); @@ -1791,12 +1711,18 @@ struct SDGenerationParams { << " sample_params: " << sample_params_str << ",\n" << " high_noise_skip_layers: " << vec_to_string(high_noise_skip_layers) << ",\n" << " high_noise_sample_params: " << high_noise_sample_params_str << ",\n" - << " easycache_option: \"" << easycache_option << "\",\n" + << " cache_mode: \"" << cache_mode << "\",\n" + << " cache_option: \"" << cache_option << "\",\n" << " easycache: " << (easycache_params.enabled ? "enabled" : "disabled") << " (threshold=" << easycache_params.reuse_threshold << ", start=" << easycache_params.start_percent << ", end=" << easycache_params.end_percent << "),\n" + << " ucache: " + << (ucache_params.enabled ? "enabled" : "disabled") + << " (threshold=" << ucache_params.reuse_threshold + << ", start=" << ucache_params.start_percent + << ", end=" << ucache_params.end_percent << "),\n" << " moe_boundary: " << moe_boundary << ",\n" << " video_frames: " << video_frames << ",\n" << " fps: " << fps << ",\n" From 148bfdf73a1d1d47c7fe7519408a215ae270a608 Mon Sep 17 00:00:00 2001 From: rmatif Date: Wed, 10 Dec 2025 11:06:44 +0000 Subject: [PATCH 03/12] add decay rate and relative threshold --- examples/cli/README.md | 4 +- examples/cli/main.cpp | 46 +++++++++-------- stable-diffusion.cpp | 28 ++++++---- stable-diffusion.h | 2 + ucache.hpp | 113 +++++++++++++++++++++++++++++++++++++---- 5 files changed, 151 insertions(+), 42 deletions(-) diff --git a/examples/cli/README.md b/examples/cli/README.md index c0b41c033..b9f4f70d0 100644 --- a/examples/cli/README.md +++ b/examples/cli/README.md @@ -127,5 +127,7 @@ Generation Options: --high-noise-skip-layers (high noise) layers to skip for SLG steps (default: [7,8,9]) -r, --ref-image reference image for Flux Kontext models (can be used multiple times) --cache-mode caching method: 'easycache' for DiT models, 'ucache' for UNET models (SD1.x/SD2.x/SDXL) - --cache-option cache parameters "threshold,start_percent,end_percent" (default: 0.2,0.15,0.95 for easycache, 1.0,0.15,0.95 for ucache) + --cache-option cache parameters: easycache uses "threshold,start,end" (default: 0.2,0.15,0.95). + ucache uses "threshold,start,end[,decay,relative]" (default: 1.0,0.15,0.95,1.0,1). + decay: error decay rate (0.0-1.0), relative: use relative threshold (0 or 1) ``` diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index a95853882..4ecec4a5a 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -1434,7 +1434,7 @@ struct SDGenerationParams { on_cache_mode_arg}, {"", "--cache-option", - "cache parameters \"threshold,start_percent,end_percent\" (default: 0.2,0.15,0.95 for easycache, 1.0,0.15,0.95 for ucache)", + "cache parameters \"threshold,start,end[,warmup,decay,relative]\" (ucache extended: warmup=0, decay=1.0, relative=1)", on_cache_option_arg}, }; @@ -1561,28 +1561,32 @@ struct SDGenerationParams { } } - float values[3] = {0.0f, 0.0f, 0.0f}; + // Format: threshold,start,end[,decay,relative] + // - values[0-2]: threshold, start_percent, end_percent (required) + // - values[3]: error_decay_rate (optional, default: 1.0) + // - values[4]: use_relative_threshold (optional, 0 or 1, default: 1) + float values[5] = {0.0f, 0.0f, 0.0f, 1.0f, 1.0f}; std::stringstream ss(option_str); std::string token; int idx = 0; + auto trim = [](std::string& s) { + const char* whitespace = " \t\r\n"; + auto start = s.find_first_not_of(whitespace); + if (start == std::string::npos) { + s.clear(); + return; + } + auto end = s.find_last_not_of(whitespace); + s = s.substr(start, end - start + 1); + }; while (std::getline(ss, token, ',')) { - auto trim = [](std::string& s) { - const char* whitespace = " \t\r\n"; - auto start = s.find_first_not_of(whitespace); - if (start == std::string::npos) { - s.clear(); - return; - } - auto end = s.find_last_not_of(whitespace); - s = s.substr(start, end - start + 1); - }; trim(token); if (token.empty()) { fprintf(stderr, "error: invalid cache option '%s'\n", option_str.c_str()); return false; } - if (idx >= 3) { - fprintf(stderr, "error: cache option expects exactly 3 comma-separated values (threshold,start,end)\n"); + if (idx >= 5) { + fprintf(stderr, "error: cache option expects 3-5 comma-separated values (threshold,start,end[,decay,relative])\n"); return false; } try { @@ -1593,8 +1597,8 @@ struct SDGenerationParams { } idx++; } - if (idx != 3) { - fprintf(stderr, "error: cache option expects exactly 3 comma-separated values (threshold,start,end)\n"); + if (idx < 3) { + fprintf(stderr, "error: cache option expects at least 3 comma-separated values (threshold,start,end)\n"); return false; } if (values[0] < 0.0f) { @@ -1612,10 +1616,12 @@ struct SDGenerationParams { easycache_params.start_percent = values[1]; easycache_params.end_percent = values[2]; } else { - ucache_params.enabled = true; - ucache_params.reuse_threshold = values[0]; - ucache_params.start_percent = values[1]; - ucache_params.end_percent = values[2]; + ucache_params.enabled = true; + ucache_params.reuse_threshold = values[0]; + ucache_params.start_percent = values[1]; + ucache_params.end_percent = values[2]; + ucache_params.error_decay_rate = values[3]; + ucache_params.use_relative_threshold = (values[4] != 0.0f); } } diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 326ad59ad..c898ac1b1 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -1548,10 +1548,12 @@ class StableDiffusionGGML { LOG_WARN("UCache requested but not supported for this model type (only UNET models)"); } else { UCacheConfig ucache_config; - ucache_config.enabled = true; - ucache_config.reuse_threshold = std::max(0.0f, ucache_params->reuse_threshold); - ucache_config.start_percent = ucache_params->start_percent; - ucache_config.end_percent = ucache_params->end_percent; + ucache_config.enabled = true; + ucache_config.reuse_threshold = std::max(0.0f, ucache_params->reuse_threshold); + ucache_config.start_percent = ucache_params->start_percent; + ucache_config.end_percent = ucache_params->end_percent; + ucache_config.error_decay_rate = std::max(0.0f, std::min(1.0f, ucache_params->error_decay_rate)); + ucache_config.use_relative_threshold = ucache_params->use_relative_threshold; bool percent_valid = ucache_config.start_percent >= 0.0f && ucache_config.start_percent < 1.0f && ucache_config.end_percent > 0.0f && @@ -1565,10 +1567,12 @@ class StableDiffusionGGML { ucache_state.init(ucache_config, denoiser.get()); if (ucache_state.enabled()) { ucache_enabled = true; - LOG_INFO("UCache enabled - threshold: %.3f, start_percent: %.2f, end_percent: %.2f", + LOG_INFO("UCache enabled - threshold: %.3f, start: %.2f, end: %.2f, decay: %.2f, relative: %s", ucache_config.reuse_threshold, ucache_config.start_percent, - ucache_config.end_percent); + ucache_config.end_percent, + ucache_config.error_decay_rate, + ucache_config.use_relative_threshold ? "true" : "false"); } else { LOG_WARN("UCache requested but could not be initialized for this run"); } @@ -2616,11 +2620,13 @@ void sd_easycache_params_init(sd_easycache_params_t* easycache_params) { } void sd_ucache_params_init(sd_ucache_params_t* ucache_params) { - *ucache_params = {}; - ucache_params->enabled = false; - ucache_params->reuse_threshold = 1.0f; - ucache_params->start_percent = 0.15f; - ucache_params->end_percent = 0.95f; + *ucache_params = {}; + ucache_params->enabled = false; + ucache_params->reuse_threshold = 1.0f; + ucache_params->start_percent = 0.15f; + ucache_params->end_percent = 0.95f; + ucache_params->error_decay_rate = 1.0f; + ucache_params->use_relative_threshold = true; } void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) { diff --git a/stable-diffusion.h b/stable-diffusion.h index a65fd7e84..d95c73e80 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -247,6 +247,8 @@ typedef struct { float reuse_threshold; float start_percent; float end_percent; + float error_decay_rate; + bool use_relative_threshold; } sd_ucache_params_t; typedef struct { diff --git a/ucache.hpp b/ucache.hpp index ca3875a59..34bc921f0 100644 --- a/ucache.hpp +++ b/ucache.hpp @@ -10,10 +10,15 @@ #include "ggml_extend.hpp" struct UCacheConfig { - bool enabled = false; - float reuse_threshold = 1.0f; - float start_percent = 0.15f; - float end_percent = 0.95f; + bool enabled = false; + float reuse_threshold = 1.0f; + float start_percent = 0.15f; + float end_percent = 0.95f; + float error_decay_rate = 1.0f; + bool use_relative_threshold = true; + bool adaptive_threshold = true; + float early_step_multiplier = 0.5f; + float late_step_multiplier = 1.5f; }; struct UCacheCacheEntry { @@ -44,6 +49,45 @@ struct UCacheState { bool has_last_input_change = false; int total_steps_skipped = 0; int current_step_index = -1; + int steps_computed_since_active = 0; + float accumulated_error = 0.0f; + float reference_output_norm = 0.0f; + + struct BlockMetrics { + float sum_transformation_rate = 0.0f; + float sum_output_norm = 0.0f; + int sample_count = 0; + float min_change_rate = std::numeric_limits::max(); + float max_change_rate = 0.0f; + + void reset() { + sum_transformation_rate = 0.0f; + sum_output_norm = 0.0f; + sample_count = 0; + min_change_rate = std::numeric_limits::max(); + max_change_rate = 0.0f; + } + + void record(float change_rate, float output_norm) { + if (std::isfinite(change_rate) && change_rate > 0.0f) { + sum_transformation_rate += change_rate; + sum_output_norm += output_norm; + sample_count++; + if (change_rate < min_change_rate) min_change_rate = change_rate; + if (change_rate > max_change_rate) max_change_rate = change_rate; + } + } + + float avg_transformation_rate() const { + return (sample_count > 0) ? (sum_transformation_rate / sample_count) : 0.0f; + } + + float avg_output_norm() const { + return (sample_count > 0) ? (sum_output_norm / sample_count) : 0.0f; + } + }; + BlockMetrics block_metrics; + int total_active_steps = 0; void reset_runtime() { initial_step = true; @@ -64,6 +108,11 @@ struct UCacheState { has_last_input_change = false; total_steps_skipped = 0; current_step_index = -1; + steps_computed_since_active = 0; + accumulated_error = 0.0f; + reference_output_norm = 0.0f; + block_metrics.reset(); + total_active_steps = 0; } void init(const UCacheConfig& cfg, Denoiser* d) { @@ -114,6 +163,7 @@ struct UCacheState { return; } step_active = true; + total_active_steps++; } bool step_is_active() const { @@ -124,6 +174,31 @@ struct UCacheState { return enabled() && step_active && skip_current_step; } + float get_adaptive_threshold(int estimated_total_steps = 0) const { + float base_threshold = config.reuse_threshold; + + if (!config.adaptive_threshold) { + return base_threshold; + } + + int effective_total = estimated_total_steps; + if (effective_total <= 0) { + effective_total = std::max(20, steps_computed_since_active * 2); + } + + float progress = (effective_total > 0) ? + (static_cast(steps_computed_since_active) / effective_total) : 0.0f; + + float multiplier = 1.0f; + if (progress < 0.2f) { + multiplier = config.early_step_multiplier; + } else if (progress > 0.8f) { + multiplier = config.late_step_multiplier; + } + + return base_threshold * multiplier; + } + bool has_cache(const SDCondition* cond) const { auto it = cache_diffs.find(cond); return it != cache_diffs.end() && !it->second.diff.empty(); @@ -212,15 +287,18 @@ struct UCacheState { last_input_change > 0.0f && output_prev_norm > 0.0f) { float approx_output_change_rate = (relative_transformation_rate * last_input_change) / output_prev_norm; - cumulative_change_rate += approx_output_change_rate; + accumulated_error = accumulated_error * config.error_decay_rate + approx_output_change_rate; + + float effective_threshold = get_adaptive_threshold(); + if (config.use_relative_threshold && reference_output_norm > 0.0f) { + effective_threshold = effective_threshold * reference_output_norm; + } - if (cumulative_change_rate < config.reuse_threshold) { + if (accumulated_error < effective_threshold) { skip_current_step = true; total_steps_skipped++; apply_cache(cond, input, output); return true; - } else { - cumulative_change_rate = 0.0f; } } @@ -270,16 +348,31 @@ struct UCacheState { output_prev_norm = (ne > 0) ? (mean_abs / static_cast(ne)) : 0.0f; has_output_prev_norm = output_prev_norm > 0.0f; + if (reference_output_norm == 0.0f) { + reference_output_norm = output_prev_norm; + } + if (has_last_input_change && last_input_change > 0.0f && output_change > 0.0f) { float rate = output_change / last_input_change; if (std::isfinite(rate)) { relative_transformation_rate = rate; has_relative_transformation_rate = true; + block_metrics.record(rate, output_prev_norm); } } - cumulative_change_rate = 0.0f; - has_last_input_change = false; + has_last_input_change = false; + } + + void log_block_metrics() const { + if (block_metrics.sample_count > 0) { + LOG_INFO("UCacheBlockMetrics: samples=%d, avg_rate=%.4f, min=%.4f, max=%.4f, avg_norm=%.4f", + block_metrics.sample_count, + block_metrics.avg_transformation_rate(), + block_metrics.min_change_rate, + block_metrics.max_change_rate, + block_metrics.avg_output_norm()); + } } }; From c59f4148c7d986cfbcf2e1e4998c375c66c26666 Mon Sep 17 00:00:00 2001 From: rmatif Date: Wed, 10 Dec 2025 15:07:11 +0000 Subject: [PATCH 04/12] use single unified struct --- examples/cli/main.cpp | 45 +++++-------- stable-diffusion.cpp | 145 ++++++++++++++++++------------------------ stable-diffusion.h | 24 +++---- 3 files changed, 89 insertions(+), 125 deletions(-) diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 4ecec4a5a..8f350f382 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -1056,8 +1056,7 @@ struct SDGenerationParams { std::string cache_mode; std::string cache_option; - sd_easycache_params_t easycache_params; - sd_ucache_params_t ucache_params; + sd_cache_params_t cache_params; float moe_boundary = 0.875f; int video_frames = 1; @@ -1548,8 +1547,7 @@ struct SDGenerationParams { return false; } - easycache_params.enabled = false; - ucache_params.enabled = false; + cache_params.mode = SD_CACHE_DISABLED; if (!cache_mode.empty()) { std::string option_str = cache_option; @@ -1610,18 +1608,15 @@ struct SDGenerationParams { return false; } + cache_params.reuse_threshold = values[0]; + cache_params.start_percent = values[1]; + cache_params.end_percent = values[2]; + cache_params.error_decay_rate = values[3]; + cache_params.use_relative_threshold = (values[4] != 0.0f); if (cache_mode == "easycache") { - easycache_params.enabled = true; - easycache_params.reuse_threshold = values[0]; - easycache_params.start_percent = values[1]; - easycache_params.end_percent = values[2]; + cache_params.mode = SD_CACHE_EASYCACHE; } else { - ucache_params.enabled = true; - ucache_params.reuse_threshold = values[0]; - ucache_params.start_percent = values[1]; - ucache_params.end_percent = values[2]; - ucache_params.error_decay_rate = values[3]; - ucache_params.use_relative_threshold = (values[4] != 0.0f); + cache_params.mode = SD_CACHE_UCACHE; } } @@ -1719,16 +1714,12 @@ struct SDGenerationParams { << " high_noise_sample_params: " << high_noise_sample_params_str << ",\n" << " cache_mode: \"" << cache_mode << "\",\n" << " cache_option: \"" << cache_option << "\",\n" - << " easycache: " - << (easycache_params.enabled ? "enabled" : "disabled") - << " (threshold=" << easycache_params.reuse_threshold - << ", start=" << easycache_params.start_percent - << ", end=" << easycache_params.end_percent << "),\n" - << " ucache: " - << (ucache_params.enabled ? "enabled" : "disabled") - << " (threshold=" << ucache_params.reuse_threshold - << ", start=" << ucache_params.start_percent - << ", end=" << ucache_params.end_percent << "),\n" + << " cache: " + << (cache_params.mode == SD_CACHE_DISABLED ? "disabled" : + (cache_params.mode == SD_CACHE_EASYCACHE ? "easycache" : "ucache")) + << " (threshold=" << cache_params.reuse_threshold + << ", start=" << cache_params.start_percent + << ", end=" << cache_params.end_percent << "),\n" << " moe_boundary: " << moe_boundary << ",\n" << " video_frames: " << video_frames << ",\n" << " fps: " << fps << ",\n" @@ -2253,8 +2244,7 @@ int main(int argc, const char* argv[]) { gen_params.pm_style_strength, }, // pm_params ctx_params.vae_tiling_params, - gen_params.easycache_params, - gen_params.ucache_params, + gen_params.cache_params, }; results = generate_image(sd_ctx, &img_gen_params); @@ -2279,8 +2269,7 @@ int main(int argc, const char* argv[]) { gen_params.seed, gen_params.video_frames, gen_params.vace_strength, - gen_params.easycache_params, - gen_params.ucache_params, + gen_params.cache_params, }; results = generate_video(sd_ctx, &vid_gen_params, &num_results); diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index c898ac1b1..7b76bc2d3 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -1487,8 +1487,7 @@ class StableDiffusionGGML { ggml_tensor* denoise_mask = nullptr, ggml_tensor* vace_context = nullptr, float vace_strength = 1.f, - const sd_easycache_params_t* easycache_params = nullptr, - const sd_ucache_params_t* ucache_params = nullptr) { + const sd_cache_params_t* cache_params = nullptr) { if (shifted_timestep > 0 && !sd_version_is_sdxl(version)) { LOG_WARN("timestep shifting is only supported for SDXL models!"); shifted_timestep = 0; @@ -1505,31 +1504,35 @@ class StableDiffusionGGML { } EasyCacheState easycache_state; + UCacheState ucache_state; bool easycache_enabled = false; - if (easycache_params != nullptr && easycache_params->enabled) { - bool easycache_supported = sd_version_is_dit(version); - if (!easycache_supported) { - LOG_WARN("EasyCache requested but not supported for this model type"); - } else { - EasyCacheConfig easycache_config; - easycache_config.enabled = true; - easycache_config.reuse_threshold = std::max(0.0f, easycache_params->reuse_threshold); - easycache_config.start_percent = easycache_params->start_percent; - easycache_config.end_percent = easycache_params->end_percent; - bool percent_valid = easycache_config.start_percent >= 0.0f && - easycache_config.start_percent < 1.0f && - easycache_config.end_percent > 0.0f && - easycache_config.end_percent <= 1.0f && - easycache_config.start_percent < easycache_config.end_percent; - if (!percent_valid) { - LOG_WARN("EasyCache disabled due to invalid percent range (start=%.3f, end=%.3f)", - easycache_config.start_percent, - easycache_config.end_percent); + bool ucache_enabled = false; + + if (cache_params != nullptr && cache_params->mode != SD_CACHE_DISABLED) { + bool percent_valid = cache_params->start_percent >= 0.0f && + cache_params->start_percent < 1.0f && + cache_params->end_percent > 0.0f && + cache_params->end_percent <= 1.0f && + cache_params->start_percent < cache_params->end_percent; + + if (!percent_valid) { + LOG_WARN("Cache disabled due to invalid percent range (start=%.3f, end=%.3f)", + cache_params->start_percent, + cache_params->end_percent); + } else if (cache_params->mode == SD_CACHE_EASYCACHE) { + bool easycache_supported = sd_version_is_dit(version); + if (!easycache_supported) { + LOG_WARN("EasyCache requested but not supported for this model type"); } else { + EasyCacheConfig easycache_config; + easycache_config.enabled = true; + easycache_config.reuse_threshold = std::max(0.0f, cache_params->reuse_threshold); + easycache_config.start_percent = cache_params->start_percent; + easycache_config.end_percent = cache_params->end_percent; easycache_state.init(easycache_config, denoiser.get()); if (easycache_state.enabled()) { easycache_enabled = true; - LOG_INFO("EasyCache enabled - threshold: %.3f, start_percent: %.2f, end_percent: %.2f", + LOG_INFO("EasyCache enabled - threshold: %.3f, start: %.2f, end: %.2f", easycache_config.reuse_threshold, easycache_config.start_percent, easycache_config.end_percent); @@ -1537,33 +1540,18 @@ class StableDiffusionGGML { LOG_WARN("EasyCache requested but could not be initialized for this run"); } } - } - } - - UCacheState ucache_state; - bool ucache_enabled = false; - if (ucache_params != nullptr && ucache_params->enabled) { - bool ucache_supported = sd_version_is_unet(version); - if (!ucache_supported) { - LOG_WARN("UCache requested but not supported for this model type (only UNET models)"); - } else { - UCacheConfig ucache_config; - ucache_config.enabled = true; - ucache_config.reuse_threshold = std::max(0.0f, ucache_params->reuse_threshold); - ucache_config.start_percent = ucache_params->start_percent; - ucache_config.end_percent = ucache_params->end_percent; - ucache_config.error_decay_rate = std::max(0.0f, std::min(1.0f, ucache_params->error_decay_rate)); - ucache_config.use_relative_threshold = ucache_params->use_relative_threshold; - bool percent_valid = ucache_config.start_percent >= 0.0f && - ucache_config.start_percent < 1.0f && - ucache_config.end_percent > 0.0f && - ucache_config.end_percent <= 1.0f && - ucache_config.start_percent < ucache_config.end_percent; - if (!percent_valid) { - LOG_WARN("UCache disabled due to invalid percent range (start=%.3f, end=%.3f)", - ucache_config.start_percent, - ucache_config.end_percent); + } else if (cache_params->mode == SD_CACHE_UCACHE) { + bool ucache_supported = sd_version_is_unet(version); + if (!ucache_supported) { + LOG_WARN("UCache requested but not supported for this model type (only UNET models)"); } else { + UCacheConfig ucache_config; + ucache_config.enabled = true; + ucache_config.reuse_threshold = std::max(0.0f, cache_params->reuse_threshold); + ucache_config.start_percent = cache_params->start_percent; + ucache_config.end_percent = cache_params->end_percent; + ucache_config.error_decay_rate = std::max(0.0f, std::min(1.0f, cache_params->error_decay_rate)); + ucache_config.use_relative_threshold = cache_params->use_relative_threshold; ucache_state.init(ucache_config, denoiser.get()); if (ucache_state.enabled()) { ucache_enabled = true; @@ -2611,22 +2599,14 @@ enum lora_apply_mode_t str_to_lora_apply_mode(const char* str) { return LORA_APPLY_MODE_COUNT; } -void sd_easycache_params_init(sd_easycache_params_t* easycache_params) { - *easycache_params = {}; - easycache_params->enabled = false; - easycache_params->reuse_threshold = 0.2f; - easycache_params->start_percent = 0.15f; - easycache_params->end_percent = 0.95f; -} - -void sd_ucache_params_init(sd_ucache_params_t* ucache_params) { - *ucache_params = {}; - ucache_params->enabled = false; - ucache_params->reuse_threshold = 1.0f; - ucache_params->start_percent = 0.15f; - ucache_params->end_percent = 0.95f; - ucache_params->error_decay_rate = 1.0f; - ucache_params->use_relative_threshold = true; +void sd_cache_params_init(sd_cache_params_t* cache_params) { + *cache_params = {}; + cache_params->mode = SD_CACHE_DISABLED; + cache_params->reuse_threshold = 1.0f; + cache_params->start_percent = 0.15f; + cache_params->end_percent = 0.95f; + cache_params->error_decay_rate = 1.0f; + cache_params->use_relative_threshold = true; } void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) { @@ -2785,8 +2765,7 @@ void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params) { sd_img_gen_params->control_strength = 0.9f; sd_img_gen_params->pm_params = {nullptr, 0, nullptr, 20.f}; sd_img_gen_params->vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f}; - sd_easycache_params_init(&sd_img_gen_params->easycache); - sd_ucache_params_init(&sd_img_gen_params->ucache); + sd_cache_params_init(&sd_img_gen_params->cache); } char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) { @@ -2830,12 +2809,18 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) { sd_img_gen_params->pm_params.id_images_count, SAFE_STR(sd_img_gen_params->pm_params.id_embed_path), BOOL_STR(sd_img_gen_params->vae_tiling_params.enabled)); + const char* cache_mode_str = "disabled"; + if (sd_img_gen_params->cache.mode == SD_CACHE_EASYCACHE) { + cache_mode_str = "easycache"; + } else if (sd_img_gen_params->cache.mode == SD_CACHE_UCACHE) { + cache_mode_str = "ucache"; + } snprintf(buf + strlen(buf), 4096 - strlen(buf), - "easycache: %s (threshold=%.3f, start=%.2f, end=%.2f)\n", - sd_img_gen_params->easycache.enabled ? "enabled" : "disabled", - sd_img_gen_params->easycache.reuse_threshold, - sd_img_gen_params->easycache.start_percent, - sd_img_gen_params->easycache.end_percent); + "cache: %s (threshold=%.3f, start=%.2f, end=%.2f)\n", + cache_mode_str, + sd_img_gen_params->cache.reuse_threshold, + sd_img_gen_params->cache.start_percent, + sd_img_gen_params->cache.end_percent); free(sample_params_str); return buf; } @@ -2852,8 +2837,7 @@ void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params) { sd_vid_gen_params->video_frames = 6; sd_vid_gen_params->moe_boundary = 0.875f; sd_vid_gen_params->vace_strength = 1.f; - sd_easycache_params_init(&sd_vid_gen_params->easycache); - sd_ucache_params_init(&sd_vid_gen_params->ucache); + sd_cache_params_init(&sd_vid_gen_params->cache); } struct sd_ctx_t { @@ -2931,8 +2915,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, bool increase_ref_index, ggml_tensor* concat_latent = nullptr, ggml_tensor* denoise_mask = nullptr, - const sd_easycache_params_t* easycache_params = nullptr, - const sd_ucache_params_t* ucache_params = nullptr) { + const sd_cache_params_t* cache_params = nullptr) { if (seed < 0) { // Generally, when using the provided command line, the seed is always >0. // However, to prevent potential issues if 'stable-diffusion.cpp' is invoked as a library @@ -3221,8 +3204,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, denoise_mask, nullptr, 1.0f, - easycache_params, - ucache_params); + cache_params); int64_t sampling_end = ggml_time_ms(); if (x_0 != nullptr) { // print_ggml_tensor(x_0); @@ -3556,8 +3538,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g sd_img_gen_params->increase_ref_index, concat_latent, denoise_mask, - &sd_img_gen_params->easycache, - &sd_img_gen_params->ucache); + &sd_img_gen_params->cache); size_t t2 = ggml_time_ms(); @@ -3924,8 +3905,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s denoise_mask, vace_context, sd_vid_gen_params->vace_strength, - &sd_vid_gen_params->easycache, - &sd_vid_gen_params->ucache); + &sd_vid_gen_params->cache); int64_t sampling_end = ggml_time_ms(); LOG_INFO("sampling(high noise) completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000); @@ -3962,8 +3942,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s denoise_mask, vace_context, sd_vid_gen_params->vace_strength, - &sd_vid_gen_params->easycache, - &sd_vid_gen_params->ucache); + &sd_vid_gen_params->cache); int64_t sampling_end = ggml_time_ms(); LOG_INFO("sampling completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000); diff --git a/stable-diffusion.h b/stable-diffusion.h index d95c73e80..7d560d788 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -235,21 +235,20 @@ typedef struct { float style_strength; } sd_pm_params_t; // photo maker -typedef struct { - bool enabled; - float reuse_threshold; - float start_percent; - float end_percent; -} sd_easycache_params_t; +enum sd_cache_mode_t { + SD_CACHE_DISABLED = 0, + SD_CACHE_EASYCACHE, + SD_CACHE_UCACHE, +}; typedef struct { - bool enabled; + enum sd_cache_mode_t mode; float reuse_threshold; float start_percent; float end_percent; float error_decay_rate; bool use_relative_threshold; -} sd_ucache_params_t; +} sd_cache_params_t; typedef struct { bool is_high_noise; @@ -279,8 +278,7 @@ typedef struct { float control_strength; sd_pm_params_t pm_params; sd_tiling_params_t vae_tiling_params; - sd_easycache_params_t easycache; - sd_ucache_params_t ucache; + sd_cache_params_t cache; } sd_img_gen_params_t; typedef struct { @@ -302,8 +300,7 @@ typedef struct { int64_t seed; int video_frames; float vace_strength; - sd_easycache_params_t easycache; - sd_ucache_params_t ucache; + sd_cache_params_t cache; } sd_vid_gen_params_t; typedef struct sd_ctx_t sd_ctx_t; @@ -333,8 +330,7 @@ SD_API enum preview_t str_to_preview(const char* str); SD_API const char* sd_lora_apply_mode_name(enum lora_apply_mode_t mode); SD_API enum lora_apply_mode_t str_to_lora_apply_mode(const char* str); -SD_API void sd_easycache_params_init(sd_easycache_params_t* easycache_params); -SD_API void sd_ucache_params_init(sd_ucache_params_t* ucache_params); +SD_API void sd_cache_params_init(sd_cache_params_t* cache_params); SD_API void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params); SD_API char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params); From 78230da0cfb210e995e2f21563df28c5e3dd97e0 Mon Sep 17 00:00:00 2001 From: rmatif Date: Wed, 10 Dec 2025 20:07:04 +0000 Subject: [PATCH 05/12] use actual scheduler sigmas for ucache bounds --- stable-diffusion.cpp | 4 ++++ ucache.hpp | 20 ++++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 7b76bc2d3..0f91555b3 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -1568,6 +1568,10 @@ class StableDiffusionGGML { } } + if (ucache_enabled) { + ucache_state.set_sigmas(sigmas); + } + size_t steps = sigmas.size() - 1; struct ggml_tensor* x = ggml_dup_tensor(work_ctx, init_latent); copy_ggml_tensor(x, init_latent); diff --git a/ucache.hpp b/ucache.hpp index 34bc921f0..7eea592bf 100644 --- a/ucache.hpp +++ b/ucache.hpp @@ -126,6 +126,26 @@ struct UCacheState { } } + void set_sigmas(const std::vector& sigmas) { + if (!initialized || sigmas.size() < 2) { + return; + } + size_t n_steps = sigmas.size() - 1; + + size_t start_step = static_cast(config.start_percent * n_steps); + size_t end_step = static_cast(config.end_percent * n_steps); + + if (start_step >= n_steps) start_step = n_steps - 1; + if (end_step >= n_steps) end_step = n_steps - 1; + + start_sigma = sigmas[start_step]; + end_sigma = sigmas[end_step]; + + if (start_sigma < end_sigma) { + std::swap(start_sigma, end_sigma); + } + } + bool enabled() const { return initialized && config.enabled; } From 186038ec3add50731232a6c49360dcee3b58e87a Mon Sep 17 00:00:00 2001 From: rmatif Date: Fri, 12 Dec 2025 16:15:29 +0000 Subject: [PATCH 06/12] add cache-dit --- cache_dit.hpp | 898 ++++++++++++++++++++++++++++++++++++++++++ examples/cli/main.cpp | 152 +++++-- stable-diffusion.cpp | 129 +++++- stable-diffusion.h | 13 + 4 files changed, 1158 insertions(+), 34 deletions(-) create mode 100644 cache_dit.hpp diff --git a/cache_dit.hpp b/cache_dit.hpp new file mode 100644 index 000000000..f222fb3ab --- /dev/null +++ b/cache_dit.hpp @@ -0,0 +1,898 @@ +#ifndef __CACHE_DIT_HPP__ +#define __CACHE_DIT_HPP__ + +#include +#include +#include +#include +#include +#include + +#include "ggml_extend.hpp" + +struct DBCacheConfig { + bool enabled = false; + int Fn_compute_blocks = 8; + int Bn_compute_blocks = 0; + float residual_diff_threshold = 0.08f; + int max_warmup_steps = 8; + int max_cached_steps = -1; + int max_continuous_cached_steps = -1; + float max_accumulated_residual_diff = -1.0f; + std::vector steps_computation_mask; + bool scm_policy_dynamic = true; +}; + +struct TaylorSeerConfig { + bool enabled = false; + int n_derivatives = 1; + int max_warmup_steps = 2; + int skip_interval_steps = 1; +}; + +struct CacheDitConfig { + DBCacheConfig dbcache; + TaylorSeerConfig taylorseer; + int double_Fn_blocks = -1; + int double_Bn_blocks = -1; + int single_Fn_blocks = -1; + int single_Bn_blocks = -1; +}; + +struct TaylorSeerState { + int n_derivatives = 1; + int current_step = -1; + int last_computed_step = -1; + std::vector> dY_prev; + std::vector> dY_current; + + void init(int n_deriv, size_t hidden_size) { + n_derivatives = n_deriv; + int order = n_derivatives + 1; + dY_prev.resize(order); + dY_current.resize(order); + for (int i = 0; i < order; i++) { + dY_prev[i].clear(); + dY_current[i].clear(); + } + current_step = -1; + last_computed_step = -1; + } + + void reset() { + for (auto& v : dY_prev) v.clear(); + for (auto& v : dY_current) v.clear(); + current_step = -1; + last_computed_step = -1; + } + + bool can_approximate() const { + return last_computed_step >= n_derivatives && !dY_prev.empty() && !dY_prev[0].empty(); + } + + void update_derivatives(const float* Y, size_t size, int step) { + int order = n_derivatives + 1; + dY_prev = dY_current; + dY_current[0].resize(size); + for (size_t i = 0; i < size; i++) { + dY_current[0][i] = Y[i]; + } + + int window = step - last_computed_step; + if (window <= 0) window = 1; + + for (int d = 0; d < n_derivatives; d++) { + if (!dY_prev[d].empty() && dY_prev[d].size() == size) { + dY_current[d + 1].resize(size); + for (size_t i = 0; i < size; i++) { + dY_current[d + 1][i] = (dY_current[d][i] - dY_prev[d][i]) / static_cast(window); + } + } else { + dY_current[d + 1].clear(); + } + } + + current_step = step; + last_computed_step = step; + } + + void approximate(float* output, size_t size, int target_step) const { + if (!can_approximate() || dY_prev[0].size() != size) { + return; + } + + int elapsed = target_step - last_computed_step; + if (elapsed <= 0) elapsed = 1; + + std::fill(output, output + size, 0.0f); + float factorial = 1.0f; + int order = static_cast(dY_prev.size()); + + for (int o = 0; o < order; o++) { + if (dY_prev[o].empty() || dY_prev[o].size() != size) continue; + if (o > 0) factorial *= static_cast(o); + float coeff = std::pow(static_cast(elapsed), o) / factorial; + for (size_t i = 0; i < size; i++) { + output[i] += coeff * dY_prev[o][i]; + } + } + } +}; + +struct BlockCacheEntry { + std::vector residual_img; + std::vector residual_txt; + std::vector residual; + std::vector prev_img; + std::vector prev_txt; + std::vector prev_output; + bool has_prev = false; +}; + +struct CacheDitState { + CacheDitConfig config; + bool initialized = false; + + int total_double_blocks = 0; + int total_single_blocks = 0; + size_t hidden_size = 0; + + int current_step = -1; + int total_steps = 0; + int warmup_remaining = 0; + std::vector cached_steps; + int continuous_cached_steps = 0; + float accumulated_residual_diff = 0.0f; + + std::vector double_block_cache; + std::vector single_block_cache; + + std::vector Fn_residual_img; + std::vector Fn_residual_txt; + std::vector prev_Fn_residual_img; + std::vector prev_Fn_residual_txt; + bool has_prev_Fn_residual = false; + + std::vector Bn_buffer_img; + std::vector Bn_buffer_txt; + std::vector Bn_buffer; + bool has_Bn_buffer = false; + + TaylorSeerState taylor_state; + + bool can_cache_this_step = false; + bool is_caching_this_step = false; + + int total_blocks_computed = 0; + int total_blocks_cached = 0; + + void init(const CacheDitConfig& cfg, int num_double_blocks, int num_single_blocks, size_t h_size) { + config = cfg; + total_double_blocks = num_double_blocks; + total_single_blocks = num_single_blocks; + hidden_size = h_size; + + initialized = cfg.dbcache.enabled || cfg.taylorseer.enabled; + + if (!initialized) return; + + warmup_remaining = cfg.dbcache.max_warmup_steps; + double_block_cache.resize(total_double_blocks); + single_block_cache.resize(total_single_blocks); + + if (cfg.taylorseer.enabled) { + taylor_state.init(cfg.taylorseer.n_derivatives, h_size); + } + + reset_runtime(); + } + + void reset_runtime() { + current_step = -1; + total_steps = 0; + warmup_remaining = config.dbcache.max_warmup_steps; + cached_steps.clear(); + continuous_cached_steps = 0; + accumulated_residual_diff = 0.0f; + + for (auto& entry : double_block_cache) { + entry.residual_img.clear(); + entry.residual_txt.clear(); + entry.prev_img.clear(); + entry.prev_txt.clear(); + entry.has_prev = false; + } + + for (auto& entry : single_block_cache) { + entry.residual.clear(); + entry.prev_output.clear(); + entry.has_prev = false; + } + + Fn_residual_img.clear(); + Fn_residual_txt.clear(); + prev_Fn_residual_img.clear(); + prev_Fn_residual_txt.clear(); + has_prev_Fn_residual = false; + + Bn_buffer_img.clear(); + Bn_buffer_txt.clear(); + Bn_buffer.clear(); + has_Bn_buffer = false; + + taylor_state.reset(); + + can_cache_this_step = false; + is_caching_this_step = false; + + total_blocks_computed = 0; + total_blocks_cached = 0; + } + + bool enabled() const { + return initialized && (config.dbcache.enabled || config.taylorseer.enabled); + } + + void begin_step(int step_index, float sigma = 0.0f) { + if (!enabled()) return; + if (step_index == current_step) return; + + current_step = step_index; + total_steps++; + + bool in_warmup = warmup_remaining > 0; + if (in_warmup) { + warmup_remaining--; + } + + bool scm_allows_cache = true; + if (!config.dbcache.steps_computation_mask.empty()) { + if (step_index < static_cast(config.dbcache.steps_computation_mask.size())) { + scm_allows_cache = (config.dbcache.steps_computation_mask[step_index] == 0); + if (!config.dbcache.scm_policy_dynamic && scm_allows_cache) { + can_cache_this_step = true; + is_caching_this_step = false; + return; + } + } + } + + bool max_cached_ok = (config.dbcache.max_cached_steps < 0) || + (static_cast(cached_steps.size()) < config.dbcache.max_cached_steps); + + bool max_cont_ok = (config.dbcache.max_continuous_cached_steps < 0) || + (continuous_cached_steps < config.dbcache.max_continuous_cached_steps); + + bool accum_ok = (config.dbcache.max_accumulated_residual_diff < 0.0f) || + (accumulated_residual_diff < config.dbcache.max_accumulated_residual_diff); + + can_cache_this_step = !in_warmup && scm_allows_cache && max_cached_ok && max_cont_ok && accum_ok && has_prev_Fn_residual; + is_caching_this_step = false; + } + + void end_step(bool was_cached) { + if (was_cached) { + cached_steps.push_back(current_step); + continuous_cached_steps++; + } else { + continuous_cached_steps = 0; + } + } + + static float calculate_residual_diff(const float* prev, const float* curr, size_t size) { + if (size == 0) return 0.0f; + + float sum_diff = 0.0f; + float sum_abs = 0.0f; + + for (size_t i = 0; i < size; i++) { + sum_diff += std::fabs(prev[i] - curr[i]); + sum_abs += std::fabs(prev[i]); + } + + return sum_diff / (sum_abs + 1e-6f); + } + + static float calculate_residual_diff(const std::vector& prev, const std::vector& curr) { + if (prev.size() != curr.size() || prev.empty()) return 1.0f; + return calculate_residual_diff(prev.data(), curr.data(), prev.size()); + } + + int get_double_Fn_blocks() const { + return (config.double_Fn_blocks >= 0) ? config.double_Fn_blocks : config.dbcache.Fn_compute_blocks; + } + + int get_double_Bn_blocks() const { + return (config.double_Bn_blocks >= 0) ? config.double_Bn_blocks : config.dbcache.Bn_compute_blocks; + } + + int get_single_Fn_blocks() const { + return (config.single_Fn_blocks >= 0) ? config.single_Fn_blocks : config.dbcache.Fn_compute_blocks; + } + + int get_single_Bn_blocks() const { + return (config.single_Bn_blocks >= 0) ? config.single_Bn_blocks : config.dbcache.Bn_compute_blocks; + } + + bool is_Fn_double_block(int block_idx) const { + return block_idx < get_double_Fn_blocks(); + } + + bool is_Bn_double_block(int block_idx) const { + int Bn = get_double_Bn_blocks(); + return Bn > 0 && block_idx >= (total_double_blocks - Bn); + } + + bool is_Mn_double_block(int block_idx) const { + return !is_Fn_double_block(block_idx) && !is_Bn_double_block(block_idx); + } + + bool is_Fn_single_block(int block_idx) const { + return block_idx < get_single_Fn_blocks(); + } + + bool is_Bn_single_block(int block_idx) const { + int Bn = get_single_Bn_blocks(); + return Bn > 0 && block_idx >= (total_single_blocks - Bn); + } + + bool is_Mn_single_block(int block_idx) const { + return !is_Fn_single_block(block_idx) && !is_Bn_single_block(block_idx); + } + + void store_Fn_residual(const float* img, const float* txt, size_t img_size, size_t txt_size, + const float* input_img, const float* input_txt) { + Fn_residual_img.resize(img_size); + Fn_residual_txt.resize(txt_size); + + for (size_t i = 0; i < img_size; i++) { + Fn_residual_img[i] = img[i] - input_img[i]; + } + for (size_t i = 0; i < txt_size; i++) { + Fn_residual_txt[i] = txt[i] - input_txt[i]; + } + } + + bool check_cache_decision() { + if (!can_cache_this_step) { + is_caching_this_step = false; + return false; + } + + if (!has_prev_Fn_residual || prev_Fn_residual_img.empty()) { + is_caching_this_step = false; + return false; + } + + float diff_img = calculate_residual_diff(prev_Fn_residual_img, Fn_residual_img); + float diff_txt = calculate_residual_diff(prev_Fn_residual_txt, Fn_residual_txt); + float diff = (diff_img + diff_txt) / 2.0f; + + if (diff < config.dbcache.residual_diff_threshold) { + is_caching_this_step = true; + accumulated_residual_diff += diff; + return true; + } + + is_caching_this_step = false; + return false; + } + + void update_prev_Fn_residual() { + prev_Fn_residual_img = Fn_residual_img; + prev_Fn_residual_txt = Fn_residual_txt; + has_prev_Fn_residual = !prev_Fn_residual_img.empty(); + } + + void store_double_block_residual(int block_idx, const float* img, const float* txt, + size_t img_size, size_t txt_size, + const float* prev_img, const float* prev_txt) { + if (block_idx < 0 || block_idx >= static_cast(double_block_cache.size())) return; + + BlockCacheEntry& entry = double_block_cache[block_idx]; + + entry.residual_img.resize(img_size); + entry.residual_txt.resize(txt_size); + for (size_t i = 0; i < img_size; i++) { + entry.residual_img[i] = img[i] - prev_img[i]; + } + for (size_t i = 0; i < txt_size; i++) { + entry.residual_txt[i] = txt[i] - prev_txt[i]; + } + + entry.prev_img.resize(img_size); + entry.prev_txt.resize(txt_size); + for (size_t i = 0; i < img_size; i++) { + entry.prev_img[i] = img[i]; + } + for (size_t i = 0; i < txt_size; i++) { + entry.prev_txt[i] = txt[i]; + } + entry.has_prev = true; + } + + void apply_double_block_cache(int block_idx, float* img, float* txt, + size_t img_size, size_t txt_size) { + if (block_idx < 0 || block_idx >= static_cast(double_block_cache.size())) return; + + const BlockCacheEntry& entry = double_block_cache[block_idx]; + if (entry.residual_img.size() != img_size || entry.residual_txt.size() != txt_size) return; + + for (size_t i = 0; i < img_size; i++) { + img[i] += entry.residual_img[i]; + } + for (size_t i = 0; i < txt_size; i++) { + txt[i] += entry.residual_txt[i]; + } + + total_blocks_cached++; + } + + void store_single_block_residual(int block_idx, const float* output, size_t size, const float* input) { + if (block_idx < 0 || block_idx >= static_cast(single_block_cache.size())) return; + + BlockCacheEntry& entry = single_block_cache[block_idx]; + + entry.residual.resize(size); + for (size_t i = 0; i < size; i++) { + entry.residual[i] = output[i] - input[i]; + } + + entry.prev_output.resize(size); + for (size_t i = 0; i < size; i++) { + entry.prev_output[i] = output[i]; + } + entry.has_prev = true; + } + + void apply_single_block_cache(int block_idx, float* output, size_t size) { + if (block_idx < 0 || block_idx >= static_cast(single_block_cache.size())) return; + + const BlockCacheEntry& entry = single_block_cache[block_idx]; + if (entry.residual.size() != size) return; + + for (size_t i = 0; i < size; i++) { + output[i] += entry.residual[i]; + } + + total_blocks_cached++; + } + + void store_Bn_buffer(const float* img, const float* txt, size_t img_size, size_t txt_size, + const float* Bn_start_img, const float* Bn_start_txt) { + Bn_buffer_img.resize(img_size); + Bn_buffer_txt.resize(txt_size); + + for (size_t i = 0; i < img_size; i++) { + Bn_buffer_img[i] = img[i] - Bn_start_img[i]; + } + for (size_t i = 0; i < txt_size; i++) { + Bn_buffer_txt[i] = txt[i] - Bn_start_txt[i]; + } + has_Bn_buffer = true; + } + + void apply_Bn_buffer(float* img, float* txt, size_t img_size, size_t txt_size) { + if (!has_Bn_buffer) return; + if (Bn_buffer_img.size() != img_size || Bn_buffer_txt.size() != txt_size) return; + + for (size_t i = 0; i < img_size; i++) { + img[i] += Bn_buffer_img[i]; + } + for (size_t i = 0; i < txt_size; i++) { + txt[i] += Bn_buffer_txt[i]; + } + } + + void taylor_update(const float* hidden_state, size_t size) { + if (!config.taylorseer.enabled) return; + taylor_state.update_derivatives(hidden_state, size, current_step); + } + + bool taylor_can_approximate() const { + return config.taylorseer.enabled && taylor_state.can_approximate(); + } + + void taylor_approximate(float* output, size_t size) { + if (!config.taylorseer.enabled) return; + taylor_state.approximate(output, size, current_step); + } + + bool should_use_taylor_this_step() const { + if (!config.taylorseer.enabled) return false; + if (current_step < config.taylorseer.max_warmup_steps) return false; + + int interval = config.taylorseer.skip_interval_steps; + if (interval <= 0) interval = 1; + + return (current_step % (interval + 1)) != 0; + } + + void log_metrics() const { + if (!enabled()) return; + + int total_blocks = total_blocks_computed + total_blocks_cached; + float cache_ratio = (total_blocks > 0) ? + (static_cast(total_blocks_cached) / total_blocks * 100.0f) : 0.0f; + + float step_cache_ratio = (total_steps > 0) ? + (static_cast(cached_steps.size()) / total_steps * 100.0f) : 0.0f; + + LOG_INFO("CacheDIT: steps_cached=%zu/%d (%.1f%%), blocks_cached=%d/%d (%.1f%%), accum_diff=%.4f", + cached_steps.size(), total_steps, step_cache_ratio, + total_blocks_cached, total_blocks, cache_ratio, + accumulated_residual_diff); + } + + std::string get_summary() const { + char buf[256]; + snprintf(buf, sizeof(buf), + "CacheDIT[Fn=%d,Bn=%d,thresh=%.2f]: cached %zu/%d steps, %d/%d blocks", + get_double_Fn_blocks(), get_double_Bn_blocks(), + config.dbcache.residual_diff_threshold, + cached_steps.size(), total_steps, + total_blocks_cached, total_blocks_computed + total_blocks_cached); + return std::string(buf); + } +}; + +inline std::vector parse_scm_mask(const std::string& mask_str) { + std::vector mask; + if (mask_str.empty()) return mask; + + size_t pos = 0; + size_t start = 0; + while ((pos = mask_str.find(',', start)) != std::string::npos) { + std::string token = mask_str.substr(start, pos - start); + mask.push_back(std::stoi(token)); + start = pos + 1; + } + if (start < mask_str.length()) { + mask.push_back(std::stoi(mask_str.substr(start))); + } + + return mask; +} + +inline std::vector generate_scm_mask( + const std::vector& compute_bins, + const std::vector& cache_bins, + int total_steps +) { + std::vector mask; + size_t c_idx = 0, cache_idx = 0; + + while (static_cast(mask.size()) < total_steps) { + if (c_idx < compute_bins.size()) { + for (int i = 0; i < compute_bins[c_idx] && static_cast(mask.size()) < total_steps; i++) { + mask.push_back(1); + } + c_idx++; + } + if (cache_idx < cache_bins.size()) { + for (int i = 0; i < cache_bins[cache_idx] && static_cast(mask.size()) < total_steps; i++) { + mask.push_back(0); + } + cache_idx++; + } + if (c_idx >= compute_bins.size() && cache_idx >= cache_bins.size()) break; + } + + if (!mask.empty()) { + mask.back() = 1; + } + + return mask; +} + +inline std::vector get_scm_preset(const std::string& preset, int total_steps) { + struct Preset { + std::vector compute_bins; + std::vector cache_bins; + }; + + Preset slow = {{8, 3, 3, 2, 1, 1}, {1, 2, 2, 2, 3}}; + Preset medium = {{6, 2, 2, 2, 2, 1}, {1, 3, 3, 3, 3}}; + Preset fast = {{6, 1, 1, 1, 1, 1}, {1, 3, 4, 5, 4}}; + Preset ultra = {{4, 1, 1, 1, 1}, {2, 5, 6, 7}}; + + Preset* p = nullptr; + if (preset == "slow" || preset == "s" || preset == "S") p = &slow; + else if (preset == "medium" || preset == "m" || preset == "M") p = &medium; + else if (preset == "fast" || preset == "f" || preset == "F") p = &fast; + else if (preset == "ultra" || preset == "u" || preset == "U") p = &ultra; + else return {}; + + if (total_steps != 28 && total_steps > 0) { + float scale = static_cast(total_steps) / 28.0f; + std::vector scaled_compute, scaled_cache; + + for (int v : p->compute_bins) { + scaled_compute.push_back(std::max(1, static_cast(v * scale + 0.5f))); + } + for (int v : p->cache_bins) { + scaled_cache.push_back(std::max(1, static_cast(v * scale + 0.5f))); + } + + return generate_scm_mask(scaled_compute, scaled_cache, total_steps); + } + + return generate_scm_mask(p->compute_bins, p->cache_bins, total_steps); +} + +inline float get_preset_threshold(const std::string& preset) { + if (preset == "slow" || preset == "s" || preset == "S") return 0.20f; + if (preset == "medium" || preset == "m" || preset == "M") return 0.25f; + if (preset == "fast" || preset == "f" || preset == "F") return 0.30f; + if (preset == "ultra" || preset == "u" || preset == "U") return 0.34f; + return 0.08f; +} + +inline int get_preset_warmup(const std::string& preset) { + if (preset == "slow" || preset == "s" || preset == "S") return 8; + if (preset == "medium" || preset == "m" || preset == "M") return 6; + if (preset == "fast" || preset == "f" || preset == "F") return 6; + if (preset == "ultra" || preset == "u" || preset == "U") return 4; + return 8; +} + +inline void parse_dbcache_options(const std::string& opts, DBCacheConfig& cfg) { + if (opts.empty()) return; + + int Fn = 8, Bn = 0, warmup = 8, max_cached = -1, max_cont = -1; + float thresh = 0.08f; + + sscanf(opts.c_str(), "%d,%d,%f,%d,%d,%d", + &Fn, &Bn, &thresh, &warmup, &max_cached, &max_cont); + + cfg.Fn_compute_blocks = Fn; + cfg.Bn_compute_blocks = Bn; + cfg.residual_diff_threshold = thresh; + cfg.max_warmup_steps = warmup; + cfg.max_cached_steps = max_cached; + cfg.max_continuous_cached_steps = max_cont; +} + +inline void parse_taylorseer_options(const std::string& opts, TaylorSeerConfig& cfg) { + if (opts.empty()) return; + + int n_deriv = 1, warmup = 2, interval = 1; + sscanf(opts.c_str(), "%d,%d,%d", &n_deriv, &warmup, &interval); + + cfg.n_derivatives = n_deriv; + cfg.max_warmup_steps = warmup; + cfg.skip_interval_steps = interval; +} + +struct CacheDitConditionState { + DBCacheConfig config; + TaylorSeerConfig taylor_config; + bool initialized = false; + + int current_step_index = -1; + bool step_active = false; + bool skip_current_step = false; + bool initial_step = true; + int warmup_remaining = 0; + std::vector cached_steps; + int continuous_cached_steps = 0; + float accumulated_residual_diff = 0.0f; + int total_steps_skipped = 0; + + const void* anchor_condition = nullptr; + + struct CacheEntry { + std::vector diff; + std::vector prev_input; + std::vector prev_output; + bool has_prev = false; + }; + std::unordered_map cache_diffs; + + TaylorSeerState taylor_state; + + float start_sigma = std::numeric_limits::max(); + float end_sigma = 0.0f; + + void reset_runtime() { + current_step_index = -1; + step_active = false; + skip_current_step = false; + initial_step = true; + warmup_remaining = config.max_warmup_steps; + cached_steps.clear(); + continuous_cached_steps = 0; + accumulated_residual_diff = 0.0f; + total_steps_skipped = 0; + anchor_condition = nullptr; + cache_diffs.clear(); + taylor_state.reset(); + } + + void init(const DBCacheConfig& dbcfg, const TaylorSeerConfig& tcfg) { + config = dbcfg; + taylor_config = tcfg; + initialized = dbcfg.enabled || tcfg.enabled; + reset_runtime(); + + if (taylor_config.enabled) { + taylor_state.init(taylor_config.n_derivatives, 0); + } + } + + void set_sigmas(const std::vector& sigmas) { + if (!initialized || sigmas.size() < 2) return; + + float start_percent = 0.15f; + float end_percent = 0.95f; + + size_t n_steps = sigmas.size() - 1; + size_t start_step = static_cast(start_percent * n_steps); + size_t end_step = static_cast(end_percent * n_steps); + + if (start_step >= n_steps) start_step = n_steps - 1; + if (end_step >= n_steps) end_step = n_steps - 1; + + start_sigma = sigmas[start_step]; + end_sigma = sigmas[end_step]; + + if (start_sigma < end_sigma) { + std::swap(start_sigma, end_sigma); + } + } + + bool enabled() const { + return initialized && (config.enabled || taylor_config.enabled); + } + + void begin_step(int step_index, float sigma) { + if (!enabled()) return; + if (step_index == current_step_index) return; + + current_step_index = step_index; + skip_current_step = false; + step_active = false; + + if (sigma > start_sigma) return; + if (!(sigma > end_sigma)) return; + + step_active = true; + + if (warmup_remaining > 0) { + warmup_remaining--; + return; + } + + if (!config.steps_computation_mask.empty()) { + if (step_index < static_cast(config.steps_computation_mask.size())) { + if (config.steps_computation_mask[step_index] == 1) { + return; + } + } + } + + if (config.max_cached_steps >= 0 && + static_cast(cached_steps.size()) >= config.max_cached_steps) { + return; + } + + if (config.max_continuous_cached_steps >= 0 && + continuous_cached_steps >= config.max_continuous_cached_steps) { + return; + } + } + + bool step_is_active() const { + return enabled() && step_active; + } + + bool is_step_skipped() const { + return enabled() && step_active && skip_current_step; + } + + bool has_cache(const void* cond) const { + auto it = cache_diffs.find(cond); + return it != cache_diffs.end() && !it->second.diff.empty(); + } + + void update_cache(const void* cond, const float* input, const float* output, size_t size) { + CacheEntry& entry = cache_diffs[cond]; + entry.diff.resize(size); + for (size_t i = 0; i < size; i++) { + entry.diff[i] = output[i] - input[i]; + } + + entry.prev_input.resize(size); + entry.prev_output.resize(size); + for (size_t i = 0; i < size; i++) { + entry.prev_input[i] = input[i]; + entry.prev_output[i] = output[i]; + } + entry.has_prev = true; + } + + void apply_cache(const void* cond, const float* input, float* output, size_t size) { + auto it = cache_diffs.find(cond); + if (it == cache_diffs.end() || it->second.diff.empty()) return; + if (it->second.diff.size() != size) return; + + for (size_t i = 0; i < size; i++) { + output[i] = input[i] + it->second.diff[i]; + } + } + + bool before_condition(const void* cond, struct ggml_tensor* input, struct ggml_tensor* output, float sigma, int step_index) { + if (!enabled() || step_index < 0) return false; + + if (step_index != current_step_index) { + begin_step(step_index, sigma); + } + + if (!step_active) return false; + + if (initial_step) { + anchor_condition = cond; + initial_step = false; + } + + bool is_anchor = (cond == anchor_condition); + + if (skip_current_step) { + if (has_cache(cond)) { + apply_cache(cond, (float*)input->data, (float*)output->data, + static_cast(ggml_nelements(output))); + return true; + } + return false; + } + + if (!is_anchor) return false; + + auto it = cache_diffs.find(cond); + if (it == cache_diffs.end() || !it->second.has_prev) return false; + + size_t ne = static_cast(ggml_nelements(input)); + if (it->second.prev_input.size() != ne) return false; + + float* input_data = (float*)input->data; + float diff = CacheDitState::calculate_residual_diff( + it->second.prev_input.data(), input_data, ne); + + if (diff < config.residual_diff_threshold) { + skip_current_step = true; + total_steps_skipped++; + cached_steps.push_back(current_step_index); + continuous_cached_steps++; + accumulated_residual_diff += diff; + apply_cache(cond, input_data, (float*)output->data, ne); + return true; + } + + continuous_cached_steps = 0; + return false; + } + + void after_condition(const void* cond, struct ggml_tensor* input, struct ggml_tensor* output) { + if (!step_is_active()) return; + + size_t ne = static_cast(ggml_nelements(output)); + update_cache(cond, (float*)input->data, (float*)output->data, ne); + + if (cond == anchor_condition && taylor_config.enabled) { + taylor_state.update_derivatives((float*)output->data, ne, current_step_index); + } + } + + void log_metrics() const { + if (!enabled()) return; + + LOG_INFO("CacheDIT: steps_skipped=%d/%d (%.1f%%), accum_residual_diff=%.4f", + total_steps_skipped, + current_step_index + 1, + (current_step_index > 0) ? + (100.0f * total_steps_skipped / (current_step_index + 1)) : 0.0f, + accumulated_residual_diff); + } +}; + +#endif diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 8f350f382..862880182 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -14,6 +14,7 @@ // #include "preprocessing.hpp" #include "stable-diffusion.h" +#include "cache_dit.hpp" #define STB_IMAGE_IMPLEMENTATION #define STB_IMAGE_STATIC @@ -1056,7 +1057,10 @@ struct SDGenerationParams { std::string cache_mode; std::string cache_option; - sd_cache_params_t cache_params; + std::string cache_preset; + std::string scm_mask; + bool scm_policy_dynamic = true; + sd_cache_params_t cache_params{}; float moe_boundary = 0.875f; int video_frames = 1; @@ -1381,8 +1385,9 @@ struct SDGenerationParams { return -1; } cache_mode = argv_to_utf8(index, argv); - if (cache_mode != "easycache" && cache_mode != "ucache") { - fprintf(stderr, "error: invalid cache mode '%s', must be 'easycache' or 'ucache'\n", cache_mode.c_str()); + if (cache_mode != "easycache" && cache_mode != "ucache" && + cache_mode != "dbcache" && cache_mode != "taylorseer" && cache_mode != "cache-dit") { + fprintf(stderr, "error: invalid cache mode '%s', must be 'easycache', 'ucache', 'dbcache', 'taylorseer', or 'cache-dit'\n", cache_mode.c_str()); return -1; } return 1; @@ -1396,6 +1401,45 @@ struct SDGenerationParams { return 1; }; + auto on_scm_mask_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; + } + scm_mask = argv_to_utf8(index, argv); + return 1; + }; + + auto on_scm_policy_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; + } + std::string policy = argv_to_utf8(index, argv); + if (policy == "dynamic") { + scm_policy_dynamic = true; + } else if (policy == "static") { + scm_policy_dynamic = false; + } else { + fprintf(stderr, "error: invalid SCM policy '%s', must be 'dynamic' or 'static'\n", policy.c_str()); + return -1; + } + return 1; + }; + + auto on_cache_preset_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; + } + cache_preset = argv_to_utf8(index, argv); + if (cache_preset != "slow" && cache_preset != "s" && cache_preset != "S" && + cache_preset != "medium" && cache_preset != "m" && cache_preset != "M" && + cache_preset != "fast" && cache_preset != "f" && cache_preset != "F" && + cache_preset != "ultra" && cache_preset != "u" && cache_preset != "U") { + fprintf(stderr, "error: invalid cache preset '%s', must be 'slow'/'s', 'medium'/'m', 'fast'/'f', or 'ultra'/'u'\n", cache_preset.c_str()); + return -1; + } + return 1; + }; + options.manual_options = { {"-s", "--seed", @@ -1429,12 +1473,24 @@ struct SDGenerationParams { on_ref_image_arg}, {"", "--cache-mode", - "caching method: 'easycache' for DiT models, 'ucache' for UNET models (SD1.x/SD2.x/SDXL)", + "caching method: 'easycache'/'ucache' (legacy), 'dbcache'/'taylorseer'/'cache-dit' (DiT block-level)", on_cache_mode_arg}, {"", "--cache-option", - "cache parameters \"threshold,start,end[,warmup,decay,relative]\" (ucache extended: warmup=0, decay=1.0, relative=1)", + "cache params - legacy: \"threshold,start,end[,decay,relative]\", cache-dit: \"Fn,Bn,threshold,warmup\" (default: 8,0,0.08,8)", on_cache_option_arg}, + {"", + "--scm-mask", + "SCM steps mask for cache-dit: comma-separated 0/1 (e.g., \"1,1,1,0,0,1,0,0,1,0\") - 1=compute, 0=can cache", + on_scm_mask_arg}, + {"", + "--scm-policy", + "SCM policy: 'dynamic' (check threshold, default) or 'static' (use cache without checking)", + on_scm_policy_arg}, + {"", + "--cache-preset", + "Cache-DIT preset: 'slow'/'s' (~2.7x), 'medium'/'m' (~3.2x), 'fast'/'f' (~5.7x), 'ultra'/'u' (~7.4x). Sets SCM mask + threshold + warmup automatically", + on_cache_preset_arg}, }; @@ -1554,15 +1610,13 @@ struct SDGenerationParams { if (option_str.empty()) { if (cache_mode == "easycache") { option_str = "0.2,0.15,0.95"; - } else { + } else if (cache_mode == "ucache") { option_str = "1.0,0.15,0.95"; + } else if (cache_mode == "dbcache" || cache_mode == "taylorseer" || cache_mode == "cache-dit") { + option_str = "8,0,0.08,8"; } } - // Format: threshold,start,end[,decay,relative] - // - values[0-2]: threshold, start_percent, end_percent (required) - // - values[3]: error_decay_rate (optional, default: 1.0) - // - values[4]: use_relative_threshold (optional, 0 or 1, default: 1) float values[5] = {0.0f, 0.0f, 0.0f, 1.0f, 1.0f}; std::stringstream ss(option_str); std::string token; @@ -1595,29 +1649,71 @@ struct SDGenerationParams { } idx++; } - if (idx < 3) { - fprintf(stderr, "error: cache option expects at least 3 comma-separated values (threshold,start,end)\n"); - return false; + if (cache_mode == "easycache" || cache_mode == "ucache") { + if (idx < 3) { + fprintf(stderr, "error: cache option expects at least 3 comma-separated values (threshold,start,end)\n"); + return false; + } + if (values[0] < 0.0f) { + fprintf(stderr, "error: cache threshold must be non-negative\n"); + return false; + } + if (values[1] < 0.0f || values[1] >= 1.0f || values[2] <= 0.0f || values[2] > 1.0f || values[1] >= values[2]) { + fprintf(stderr, "error: cache start/end percents must satisfy 0.0 <= start < end <= 1.0\n"); + return false; + } } - if (values[0] < 0.0f) { - fprintf(stderr, "error: cache threshold must be non-negative\n"); - return false; + + if (cache_mode == "easycache" || cache_mode == "ucache") { + cache_params.reuse_threshold = values[0]; + cache_params.start_percent = values[1]; + cache_params.end_percent = values[2]; + cache_params.error_decay_rate = values[3]; + cache_params.use_relative_threshold = (values[4] != 0.0f); + if (cache_mode == "easycache") { + cache_params.mode = SD_CACHE_EASYCACHE; + } else { + cache_params.mode = SD_CACHE_UCACHE; + } + } else { + cache_params.Fn_compute_blocks = (idx >= 1) ? static_cast(values[0]) : 8; + cache_params.Bn_compute_blocks = (idx >= 2) ? static_cast(values[1]) : 0; + cache_params.residual_diff_threshold = (idx >= 3) ? values[2] : 0.08f; + cache_params.max_warmup_steps = (idx >= 4) ? static_cast(values[3]) : 8; + if (cache_mode == "dbcache") { + cache_params.mode = SD_CACHE_DBCACHE; + } else if (cache_mode == "taylorseer") { + cache_params.mode = SD_CACHE_TAYLORSEER; + } else { + cache_params.mode = SD_CACHE_CACHE_DIT; + } } - if (values[1] < 0.0f || values[1] >= 1.0f || values[2] <= 0.0f || values[2] > 1.0f || values[1] >= values[2]) { - fprintf(stderr, "error: cache start/end percents must satisfy 0.0 <= start < end <= 1.0\n"); - return false; + } + + if (cache_params.mode == SD_CACHE_DBCACHE || + cache_params.mode == SD_CACHE_TAYLORSEER || + cache_params.mode == SD_CACHE_CACHE_DIT) { + + if (!cache_preset.empty()) { + cache_params.residual_diff_threshold = get_preset_threshold(cache_preset); + cache_params.max_warmup_steps = get_preset_warmup(cache_preset); + + if (scm_mask.empty()) { + int total_steps = sample_params.sample_steps; + std::vector mask = get_scm_preset(cache_preset, total_steps); + std::ostringstream oss; + for (size_t i = 0; i < mask.size(); i++) { + if (i > 0) oss << ","; + oss << mask[i]; + } + scm_mask = oss.str(); + } } - cache_params.reuse_threshold = values[0]; - cache_params.start_percent = values[1]; - cache_params.end_percent = values[2]; - cache_params.error_decay_rate = values[3]; - cache_params.use_relative_threshold = (values[4] != 0.0f); - if (cache_mode == "easycache") { - cache_params.mode = SD_CACHE_EASYCACHE; - } else { - cache_params.mode = SD_CACHE_UCACHE; + if (!scm_mask.empty()) { + cache_params.scm_mask = scm_mask.c_str(); } + cache_params.scm_policy_dynamic = scm_policy_dynamic; } sample_params.guidance.slg.layers = skip_layers.data(); diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 0f91555b3..b9526e2b0 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -14,6 +14,7 @@ #include "easycache.hpp" #include "esrgan.hpp" #include "ucache.hpp" +#include "cache_dit.hpp" #include "lora.hpp" #include "pmid.hpp" #include "tae.hpp" @@ -1505,15 +1506,20 @@ class StableDiffusionGGML { EasyCacheState easycache_state; UCacheState ucache_state; + CacheDitConditionState cachedit_state; bool easycache_enabled = false; bool ucache_enabled = false; + bool cachedit_enabled = false; if (cache_params != nullptr && cache_params->mode != SD_CACHE_DISABLED) { - bool percent_valid = cache_params->start_percent >= 0.0f && - cache_params->start_percent < 1.0f && - cache_params->end_percent > 0.0f && - cache_params->end_percent <= 1.0f && - cache_params->start_percent < cache_params->end_percent; + bool percent_valid = true; + if (cache_params->mode == SD_CACHE_EASYCACHE || cache_params->mode == SD_CACHE_UCACHE) { + percent_valid = cache_params->start_percent >= 0.0f && + cache_params->start_percent < 1.0f && + cache_params->end_percent > 0.0f && + cache_params->end_percent <= 1.0f && + cache_params->start_percent < cache_params->end_percent; + } if (!percent_valid) { LOG_WARN("Cache disabled due to invalid percent range (start=%.3f, end=%.3f)", @@ -1565,6 +1571,47 @@ class StableDiffusionGGML { LOG_WARN("UCache requested but could not be initialized for this run"); } } + } else if (cache_params->mode == SD_CACHE_DBCACHE || + cache_params->mode == SD_CACHE_TAYLORSEER || + cache_params->mode == SD_CACHE_CACHE_DIT) { + bool cachedit_supported = sd_version_is_dit(version); + if (!cachedit_supported) { + LOG_WARN("CacheDIT requested but not supported for this model type (only DiT models)"); + } else { + DBCacheConfig dbcfg; + dbcfg.enabled = (cache_params->mode == SD_CACHE_DBCACHE || + cache_params->mode == SD_CACHE_CACHE_DIT); + dbcfg.Fn_compute_blocks = cache_params->Fn_compute_blocks; + dbcfg.Bn_compute_blocks = cache_params->Bn_compute_blocks; + dbcfg.residual_diff_threshold = cache_params->residual_diff_threshold; + dbcfg.max_warmup_steps = cache_params->max_warmup_steps; + dbcfg.max_cached_steps = cache_params->max_cached_steps; + dbcfg.max_continuous_cached_steps = cache_params->max_continuous_cached_steps; + if (cache_params->scm_mask != nullptr && strlen(cache_params->scm_mask) > 0) { + dbcfg.steps_computation_mask = parse_scm_mask(cache_params->scm_mask); + } + dbcfg.scm_policy_dynamic = cache_params->scm_policy_dynamic; + + TaylorSeerConfig tcfg; + tcfg.enabled = (cache_params->mode == SD_CACHE_TAYLORSEER || + cache_params->mode == SD_CACHE_CACHE_DIT); + tcfg.n_derivatives = cache_params->taylorseer_n_derivatives; + tcfg.skip_interval_steps = cache_params->taylorseer_skip_interval; + + cachedit_state.init(dbcfg, tcfg); + if (cachedit_state.enabled()) { + cachedit_enabled = true; + LOG_INFO("CacheDIT enabled - mode: %s, Fn: %d, Bn: %d, threshold: %.3f, warmup: %d", + cache_params->mode == SD_CACHE_CACHE_DIT ? "DBCache+TaylorSeer" : + (cache_params->mode == SD_CACHE_DBCACHE ? "DBCache" : "TaylorSeer"), + dbcfg.Fn_compute_blocks, + dbcfg.Bn_compute_blocks, + dbcfg.residual_diff_threshold, + dbcfg.max_warmup_steps); + } else { + LOG_WARN("CacheDIT requested but could not be initialized for this run"); + } + } } } @@ -1572,6 +1619,10 @@ class StableDiffusionGGML { ucache_state.set_sigmas(sigmas); } + if (cachedit_enabled) { + cachedit_state.set_sigmas(sigmas); + } + size_t steps = sigmas.size() - 1; struct ggml_tensor* x = ggml_dup_tensor(work_ctx, init_latent); copy_ggml_tensor(x, init_latent); @@ -1705,11 +1756,43 @@ class StableDiffusionGGML { return ucache_step_active && ucache_state.is_step_skipped(); }; + const bool cachedit_step_active = cachedit_enabled && step > 0; + int cachedit_step_index = cachedit_step_active ? (step - 1) : -1; + if (cachedit_step_active) { + cachedit_state.begin_step(cachedit_step_index, sigma); + } + + auto cachedit_before_condition = [&](const SDCondition* condition, struct ggml_tensor* output_tensor) -> bool { + if (!cachedit_step_active || condition == nullptr || output_tensor == nullptr) { + return false; + } + return cachedit_state.before_condition(condition, + diffusion_params.x, + output_tensor, + sigma, + cachedit_step_index); + }; + + auto cachedit_after_condition = [&](const SDCondition* condition, struct ggml_tensor* output_tensor) { + if (!cachedit_step_active || condition == nullptr || output_tensor == nullptr) { + return; + } + cachedit_state.after_condition(condition, + diffusion_params.x, + output_tensor); + }; + + auto cachedit_step_is_skipped = [&]() { + return cachedit_step_active && cachedit_state.is_step_skipped(); + }; + auto cache_before_condition = [&](const SDCondition* condition, struct ggml_tensor* output_tensor) -> bool { if (easycache_step_active) { return easycache_before_condition(condition, output_tensor); } else if (ucache_step_active) { return ucache_before_condition(condition, output_tensor); + } else if (cachedit_step_active) { + return cachedit_before_condition(condition, output_tensor); } return false; }; @@ -1719,11 +1802,13 @@ class StableDiffusionGGML { easycache_after_condition(condition, output_tensor); } else if (ucache_step_active) { ucache_after_condition(condition, output_tensor); + } else if (cachedit_step_active) { + cachedit_after_condition(condition, output_tensor); } }; auto cache_step_is_skipped = [&]() { - return easycache_step_is_skipped() || ucache_step_is_skipped(); + return easycache_step_is_skipped() || ucache_step_is_skipped() || cachedit_step_is_skipped(); }; std::vector scaling = denoiser->get_scalings(sigma); @@ -1989,6 +2074,28 @@ class StableDiffusionGGML { } } + if (cachedit_enabled) { + size_t total_steps = sigmas.size() > 0 ? sigmas.size() - 1 : 0; + if (cachedit_state.total_steps_skipped > 0 && total_steps > 0) { + if (cachedit_state.total_steps_skipped < static_cast(total_steps)) { + double speedup = static_cast(total_steps) / + static_cast(total_steps - cachedit_state.total_steps_skipped); + LOG_INFO("CacheDIT skipped %d/%zu steps (%.2fx estimated speedup), accum_diff: %.4f", + cachedit_state.total_steps_skipped, + total_steps, + speedup, + cachedit_state.accumulated_residual_diff); + } else { + LOG_INFO("CacheDIT skipped %d/%zu steps, accum_diff: %.4f", + cachedit_state.total_steps_skipped, + total_steps, + cachedit_state.accumulated_residual_diff); + } + } else if (total_steps > 0) { + LOG_INFO("CacheDIT completed without skipping steps"); + } + } + if (inverse_noise_scaling) { x = denoiser->inverse_noise_scaling(sigmas[sigmas.size() - 1], x); } @@ -2611,6 +2718,16 @@ void sd_cache_params_init(sd_cache_params_t* cache_params) { cache_params->end_percent = 0.95f; cache_params->error_decay_rate = 1.0f; cache_params->use_relative_threshold = true; + cache_params->Fn_compute_blocks = 8; + cache_params->Bn_compute_blocks = 0; + cache_params->residual_diff_threshold = 0.08f; + cache_params->max_warmup_steps = 8; + cache_params->max_cached_steps = -1; + cache_params->max_continuous_cached_steps = -1; + cache_params->taylorseer_n_derivatives = 1; + cache_params->taylorseer_skip_interval = 1; + cache_params->scm_mask = nullptr; + cache_params->scm_policy_dynamic = true; } void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) { diff --git a/stable-diffusion.h b/stable-diffusion.h index 7d560d788..366670901 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -239,6 +239,9 @@ enum sd_cache_mode_t { SD_CACHE_DISABLED = 0, SD_CACHE_EASYCACHE, SD_CACHE_UCACHE, + SD_CACHE_DBCACHE, + SD_CACHE_TAYLORSEER, + SD_CACHE_CACHE_DIT, }; typedef struct { @@ -248,6 +251,16 @@ typedef struct { float end_percent; float error_decay_rate; bool use_relative_threshold; + int Fn_compute_blocks; + int Bn_compute_blocks; + float residual_diff_threshold; + int max_warmup_steps; + int max_cached_steps; + int max_continuous_cached_steps; + int taylorseer_n_derivatives; + int taylorseer_skip_interval; + const char* scm_mask; + bool scm_policy_dynamic; } sd_cache_params_t; typedef struct { From b176dfdfdeb5431b3ec3412d1bc6698a64faf1a6 Mon Sep 17 00:00:00 2001 From: rmatif Date: Fri, 12 Dec 2025 16:56:14 +0000 Subject: [PATCH 07/12] fix Fn/Bn handling --- cache_dit.hpp | 30 +++++++++++++++++++++++++++--- examples/cli/main.cpp | 2 ++ 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/cache_dit.hpp b/cache_dit.hpp index f222fb3ab..84002da1c 100644 --- a/cache_dit.hpp +++ b/cache_dit.hpp @@ -527,8 +527,7 @@ struct CacheDitState { std::string get_summary() const { char buf[256]; snprintf(buf, sizeof(buf), - "CacheDIT[Fn=%d,Bn=%d,thresh=%.2f]: cached %zu/%d steps, %d/%d blocks", - get_double_Fn_blocks(), get_double_Bn_blocks(), + "CacheDIT[thresh=%.2f]: cached %zu/%d steps, %d/%d blocks", config.dbcache.residual_diff_threshold, cached_steps.size(), total_steps, total_blocks_cached, total_blocks_computed + total_blocks_cached); @@ -636,6 +635,19 @@ inline int get_preset_warmup(const std::string& preset) { return 8; } +inline int get_preset_Fn(const std::string& preset) { + if (preset == "slow" || preset == "s" || preset == "S") return 8; + if (preset == "medium" || preset == "m" || preset == "M") return 8; + if (preset == "fast" || preset == "f" || preset == "F") return 6; + if (preset == "ultra" || preset == "u" || preset == "U") return 4; + return 8; +} + +inline int get_preset_Bn(const std::string& preset) { + (void)preset; + return 0; +} + inline void parse_dbcache_options(const std::string& opts, DBCacheConfig& cfg) { if (opts.empty()) return; @@ -858,7 +870,19 @@ struct CacheDitConditionState { float diff = CacheDitState::calculate_residual_diff( it->second.prev_input.data(), input_data, ne); - if (diff < config.residual_diff_threshold) { + float effective_threshold = config.residual_diff_threshold; + if (config.Fn_compute_blocks > 0) { + float fn_confidence = 1.0f + 0.02f * (config.Fn_compute_blocks - 8); + fn_confidence = std::max(0.5f, std::min(2.0f, fn_confidence)); + effective_threshold *= fn_confidence; + } + if (config.Bn_compute_blocks > 0) { + float bn_quality = 1.0f - 0.03f * config.Bn_compute_blocks; + bn_quality = std::max(0.5f, std::min(1.0f, bn_quality)); + effective_threshold *= bn_quality; + } + + if (diff < effective_threshold) { skip_current_step = true; total_steps_skipped++; cached_steps.push_back(current_step_index); diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 862880182..5d2d30967 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -1695,6 +1695,8 @@ struct SDGenerationParams { cache_params.mode == SD_CACHE_CACHE_DIT) { if (!cache_preset.empty()) { + cache_params.Fn_compute_blocks = get_preset_Fn(cache_preset); + cache_params.Bn_compute_blocks = get_preset_Bn(cache_preset); cache_params.residual_diff_threshold = get_preset_threshold(cache_preset); cache_params.max_warmup_steps = get_preset_warmup(cache_preset); From f04166f5175a7f88032afec2a995be123a914017 Mon Sep 17 00:00:00 2001 From: rmatif Date: Fri, 12 Dec 2025 17:27:55 +0000 Subject: [PATCH 08/12] named parameter --- examples/cli/README.md | 12 ++- examples/cli/main.cpp | 171 +++++++++++++++++++++++++---------------- 2 files changed, 112 insertions(+), 71 deletions(-) diff --git a/examples/cli/README.md b/examples/cli/README.md index b9f4f70d0..e2561615e 100644 --- a/examples/cli/README.md +++ b/examples/cli/README.md @@ -126,8 +126,12 @@ Generation Options: --skip-layers layers to skip for SLG steps (default: [7,8,9]) --high-noise-skip-layers (high noise) layers to skip for SLG steps (default: [7,8,9]) -r, --ref-image reference image for Flux Kontext models (can be used multiple times) - --cache-mode caching method: 'easycache' for DiT models, 'ucache' for UNET models (SD1.x/SD2.x/SDXL) - --cache-option cache parameters: easycache uses "threshold,start,end" (default: 0.2,0.15,0.95). - ucache uses "threshold,start,end[,decay,relative]" (default: 1.0,0.15,0.95,1.0,1). - decay: error decay rate (0.0-1.0), relative: use relative threshold (0 or 1) + --cache-mode caching method: 'easycache'/'ucache' (legacy), 'dbcache'/'taylorseer'/'cache-dit' (DiT block-level) + --cache-option named cache params (key=value format, comma-separated): + - easycache/ucache: threshold=,start=,end=,decay=,relative= + - dbcache/taylorseer/cache-dit: Fn=,Bn=,threshold=,warmup= + Examples: "threshold=0.25" or "Fn=12,threshold=0.30,warmup=4" + --cache-preset cache-dit preset: 'slow'/'s', 'medium'/'m', 'fast'/'f', 'ultra'/'u' + --scm-mask SCM steps mask: comma-separated 0/1 (1=compute, 0=can cache) + --scm-policy SCM policy: 'dynamic' (default) or 'static' ``` diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 5d2d30967..2f21dae1e 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -1477,7 +1477,7 @@ struct SDGenerationParams { on_cache_mode_arg}, {"", "--cache-option", - "cache params - legacy: \"threshold,start,end[,decay,relative]\", cache-dit: \"Fn,Bn,threshold,warmup\" (default: 8,0,0.08,8)", + "named cache params: easycache/ucache: threshold=,start=,end=,decay=,relative= | cache-dit: Fn=,Bn=,threshold=,warmup=", on_cache_option_arg}, {"", "--scm-mask", @@ -1606,88 +1606,125 @@ struct SDGenerationParams { cache_params.mode = SD_CACHE_DISABLED; if (!cache_mode.empty()) { - std::string option_str = cache_option; - if (option_str.empty()) { - if (cache_mode == "easycache") { - option_str = "0.2,0.15,0.95"; - } else if (cache_mode == "ucache") { - option_str = "1.0,0.15,0.95"; - } else if (cache_mode == "dbcache" || cache_mode == "taylorseer" || cache_mode == "cache-dit") { - option_str = "8,0,0.08,8"; - } - } - - float values[5] = {0.0f, 0.0f, 0.0f, 1.0f, 1.0f}; - std::stringstream ss(option_str); - std::string token; - int idx = 0; auto trim = [](std::string& s) { const char* whitespace = " \t\r\n"; - auto start = s.find_first_not_of(whitespace); - if (start == std::string::npos) { - s.clear(); - return; - } + auto start = s.find_first_not_of(whitespace); + if (start == std::string::npos) { s.clear(); return; } auto end = s.find_last_not_of(whitespace); - s = s.substr(start, end - start + 1); + s = s.substr(start, end - start + 1); }; - while (std::getline(ss, token, ',')) { - trim(token); - if (token.empty()) { - fprintf(stderr, "error: invalid cache option '%s'\n", option_str.c_str()); - return false; - } - if (idx >= 5) { - fprintf(stderr, "error: cache option expects 3-5 comma-separated values (threshold,start,end[,decay,relative])\n"); - return false; + + auto parse_named_params = [&](const std::string& opt_str) -> bool { + std::stringstream ss(opt_str); + std::string token; + while (std::getline(ss, token, ',')) { + trim(token); + if (token.empty()) continue; + + size_t eq_pos = token.find('='); + if (eq_pos == std::string::npos) { + fprintf(stderr, "error: invalid named parameter '%s', expected key=value\n", token.c_str()); + return false; + } + + std::string key = token.substr(0, eq_pos); + std::string val = token.substr(eq_pos + 1); + trim(key); + trim(val); + + if (key.empty() || val.empty()) { + fprintf(stderr, "error: invalid named parameter '%s'\n", token.c_str()); + return false; + } + + try { + if (key == "threshold") { + if (cache_mode == "easycache" || cache_mode == "ucache") { + cache_params.reuse_threshold = std::stof(val); + } else { + cache_params.residual_diff_threshold = std::stof(val); + } + } else if (key == "start") { + cache_params.start_percent = std::stof(val); + } else if (key == "end") { + cache_params.end_percent = std::stof(val); + } else if (key == "decay") { + cache_params.error_decay_rate = std::stof(val); + } else if (key == "relative") { + cache_params.use_relative_threshold = (std::stof(val) != 0.0f); + } else if (key == "Fn" || key == "fn") { + cache_params.Fn_compute_blocks = std::stoi(val); + } else if (key == "Bn" || key == "bn") { + cache_params.Bn_compute_blocks = std::stoi(val); + } else if (key == "warmup") { + cache_params.max_warmup_steps = std::stoi(val); + } else { + fprintf(stderr, "error: unknown cache parameter '%s'\n", key.c_str()); + return false; + } + } catch (const std::exception&) { + fprintf(stderr, "error: invalid value '%s' for parameter '%s'\n", val.c_str(), key.c_str()); + return false; + } } - try { - values[idx] = std::stof(token); - } catch (const std::exception&) { - fprintf(stderr, "error: invalid cache option value '%s'\n", token.c_str()); + return true; + }; + + if (cache_mode == "easycache") { + cache_params.mode = SD_CACHE_EASYCACHE; + cache_params.reuse_threshold = 0.2f; + cache_params.start_percent = 0.15f; + cache_params.end_percent = 0.95f; + cache_params.error_decay_rate = 1.0f; + cache_params.use_relative_threshold = true; + } else if (cache_mode == "ucache") { + cache_params.mode = SD_CACHE_UCACHE; + cache_params.reuse_threshold = 1.0f; + cache_params.start_percent = 0.15f; + cache_params.end_percent = 0.95f; + cache_params.error_decay_rate = 1.0f; + cache_params.use_relative_threshold = true; + } else if (cache_mode == "dbcache") { + cache_params.mode = SD_CACHE_DBCACHE; + cache_params.Fn_compute_blocks = 8; + cache_params.Bn_compute_blocks = 0; + cache_params.residual_diff_threshold = 0.08f; + cache_params.max_warmup_steps = 8; + } else if (cache_mode == "taylorseer") { + cache_params.mode = SD_CACHE_TAYLORSEER; + cache_params.Fn_compute_blocks = 8; + cache_params.Bn_compute_blocks = 0; + cache_params.residual_diff_threshold = 0.08f; + cache_params.max_warmup_steps = 8; + } else if (cache_mode == "cache-dit") { + cache_params.mode = SD_CACHE_CACHE_DIT; + cache_params.Fn_compute_blocks = 8; + cache_params.Bn_compute_blocks = 0; + cache_params.residual_diff_threshold = 0.08f; + cache_params.max_warmup_steps = 8; + } else { + fprintf(stderr, "error: unknown cache mode '%s'\n", cache_mode.c_str()); + return false; + } + + if (!cache_option.empty()) { + if (!parse_named_params(cache_option)) { return false; } - idx++; } + if (cache_mode == "easycache" || cache_mode == "ucache") { - if (idx < 3) { - fprintf(stderr, "error: cache option expects at least 3 comma-separated values (threshold,start,end)\n"); - return false; - } - if (values[0] < 0.0f) { + if (cache_params.reuse_threshold < 0.0f) { fprintf(stderr, "error: cache threshold must be non-negative\n"); return false; } - if (values[1] < 0.0f || values[1] >= 1.0f || values[2] <= 0.0f || values[2] > 1.0f || values[1] >= values[2]) { + if (cache_params.start_percent < 0.0f || cache_params.start_percent >= 1.0f || + cache_params.end_percent <= 0.0f || cache_params.end_percent > 1.0f || + cache_params.start_percent >= cache_params.end_percent) { fprintf(stderr, "error: cache start/end percents must satisfy 0.0 <= start < end <= 1.0\n"); return false; } } - - if (cache_mode == "easycache" || cache_mode == "ucache") { - cache_params.reuse_threshold = values[0]; - cache_params.start_percent = values[1]; - cache_params.end_percent = values[2]; - cache_params.error_decay_rate = values[3]; - cache_params.use_relative_threshold = (values[4] != 0.0f); - if (cache_mode == "easycache") { - cache_params.mode = SD_CACHE_EASYCACHE; - } else { - cache_params.mode = SD_CACHE_UCACHE; - } - } else { - cache_params.Fn_compute_blocks = (idx >= 1) ? static_cast(values[0]) : 8; - cache_params.Bn_compute_blocks = (idx >= 2) ? static_cast(values[1]) : 0; - cache_params.residual_diff_threshold = (idx >= 3) ? values[2] : 0.08f; - cache_params.max_warmup_steps = (idx >= 4) ? static_cast(values[3]) : 8; - if (cache_mode == "dbcache") { - cache_params.mode = SD_CACHE_DBCACHE; - } else if (cache_mode == "taylorseer") { - cache_params.mode = SD_CACHE_TAYLORSEER; - } else { - cache_params.mode = SD_CACHE_CACHE_DIT; - } - } } if (cache_params.mode == SD_CACHE_DBCACHE || From c6f0e22215199d8a5e4095360d7cc0970f930565 Mon Sep 17 00:00:00 2001 From: rmatif Date: Mon, 15 Dec 2025 16:34:47 +0000 Subject: [PATCH 09/12] add reset param to ucache --- examples/cli/main.cpp | 6 +++++- stable-diffusion.cpp | 7 +++++-- stable-diffusion.h | 1 + ucache.hpp | 3 +++ 4 files changed, 14 insertions(+), 3 deletions(-) diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 2f21dae1e..b129063f6 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -1477,7 +1477,7 @@ struct SDGenerationParams { on_cache_mode_arg}, {"", "--cache-option", - "named cache params: easycache/ucache: threshold=,start=,end=,decay=,relative= | cache-dit: Fn=,Bn=,threshold=,warmup=", + "named cache params (key=value format, comma-separated):\n - easycache/ucache: threshold=,start=,end=,decay=,relative=,reset=\n - dbcache/taylorseer/cache-dit: Fn=,Bn=,threshold=,warmup=\n Examples: \"threshold=0.25\" or \"threshold=1.5,reset=0\"", on_cache_option_arg}, {"", "--scm-mask", @@ -1652,6 +1652,8 @@ struct SDGenerationParams { cache_params.error_decay_rate = std::stof(val); } else if (key == "relative") { cache_params.use_relative_threshold = (std::stof(val) != 0.0f); + } else if (key == "reset") { + cache_params.reset_error_on_compute = (std::stof(val) != 0.0f); } else if (key == "Fn" || key == "fn") { cache_params.Fn_compute_blocks = std::stoi(val); } else if (key == "Bn" || key == "bn") { @@ -1677,6 +1679,7 @@ struct SDGenerationParams { cache_params.end_percent = 0.95f; cache_params.error_decay_rate = 1.0f; cache_params.use_relative_threshold = true; + cache_params.reset_error_on_compute = true; } else if (cache_mode == "ucache") { cache_params.mode = SD_CACHE_UCACHE; cache_params.reuse_threshold = 1.0f; @@ -1684,6 +1687,7 @@ struct SDGenerationParams { cache_params.end_percent = 0.95f; cache_params.error_decay_rate = 1.0f; cache_params.use_relative_threshold = true; + cache_params.reset_error_on_compute = true; } else if (cache_mode == "dbcache") { cache_params.mode = SD_CACHE_DBCACHE; cache_params.Fn_compute_blocks = 8; diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index b9526e2b0..d7b4fe72b 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -1558,15 +1558,17 @@ class StableDiffusionGGML { ucache_config.end_percent = cache_params->end_percent; ucache_config.error_decay_rate = std::max(0.0f, std::min(1.0f, cache_params->error_decay_rate)); ucache_config.use_relative_threshold = cache_params->use_relative_threshold; + ucache_config.reset_error_on_compute = cache_params->reset_error_on_compute; ucache_state.init(ucache_config, denoiser.get()); if (ucache_state.enabled()) { ucache_enabled = true; - LOG_INFO("UCache enabled - threshold: %.3f, start: %.2f, end: %.2f, decay: %.2f, relative: %s", + LOG_INFO("UCache enabled - threshold: %.3f, start: %.2f, end: %.2f, decay: %.2f, relative: %s, reset: %s", ucache_config.reuse_threshold, ucache_config.start_percent, ucache_config.end_percent, ucache_config.error_decay_rate, - ucache_config.use_relative_threshold ? "true" : "false"); + ucache_config.use_relative_threshold ? "true" : "false", + ucache_config.reset_error_on_compute ? "true" : "false"); } else { LOG_WARN("UCache requested but could not be initialized for this run"); } @@ -2718,6 +2720,7 @@ void sd_cache_params_init(sd_cache_params_t* cache_params) { cache_params->end_percent = 0.95f; cache_params->error_decay_rate = 1.0f; cache_params->use_relative_threshold = true; + cache_params->reset_error_on_compute = true; cache_params->Fn_compute_blocks = 8; cache_params->Bn_compute_blocks = 0; cache_params->residual_diff_threshold = 0.08f; diff --git a/stable-diffusion.h b/stable-diffusion.h index 366670901..de6485f5e 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -251,6 +251,7 @@ typedef struct { float end_percent; float error_decay_rate; bool use_relative_threshold; + bool reset_error_on_compute; int Fn_compute_blocks; int Bn_compute_blocks; float residual_diff_threshold; diff --git a/ucache.hpp b/ucache.hpp index 7eea592bf..9a39557d1 100644 --- a/ucache.hpp +++ b/ucache.hpp @@ -19,6 +19,7 @@ struct UCacheConfig { bool adaptive_threshold = true; float early_step_multiplier = 0.5f; float late_step_multiplier = 1.5f; + bool reset_error_on_compute = true; }; struct UCacheCacheEntry { @@ -319,6 +320,8 @@ struct UCacheState { total_steps_skipped++; apply_cache(cond, input, output); return true; + } else if (config.reset_error_on_compute) { + accumulated_error = 0.0f; } } From ffbe00aefe800d7c769a31065424bdb253c8b33d Mon Sep 17 00:00:00 2001 From: rmatif Date: Mon, 15 Dec 2025 17:00:22 +0000 Subject: [PATCH 10/12] adapt to upstream refactor (common.hpp) --- examples/cli/main.cpp | 2014 +++----------------------------------- examples/server/main.cpp | 4 +- 2 files changed, 123 insertions(+), 1895 deletions(-) diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index b129063f6..889cabc5d 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -14,40 +14,11 @@ // #include "preprocessing.hpp" #include "stable-diffusion.h" -#include "cache_dit.hpp" -#define STB_IMAGE_IMPLEMENTATION -#define STB_IMAGE_STATIC -#include "stb_image.h" - -#define STB_IMAGE_WRITE_IMPLEMENTATION -#define STB_IMAGE_WRITE_STATIC -#include "stb_image_write.h" - -#define STB_IMAGE_RESIZE_IMPLEMENTATION -#define STB_IMAGE_RESIZE_STATIC -#include "stb_image_resize.h" +#include "common/common.hpp" #include "avi_writer.h" -#if defined(_WIN32) -#define NOMINMAX -#include -#endif // _WIN32 - -#define SAFE_STR(s) ((s) ? (s) : "") -#define BOOL_STR(b) ((b) ? "true" : "false") - -namespace fs = std::filesystem; - -const char* modes_str[] = { - "img_gen", - "vid_gen", - "convert", - "upscale", -}; -#define SD_ALL_MODES_STR "img_gen, vid_gen, convert, upscale" - const char* previews_str[] = { "none", "proj", @@ -55,271 +26,6 @@ const char* previews_str[] = { "vae", }; -enum SDMode { - IMG_GEN, - VID_GEN, - CONVERT, - UPSCALE, - MODE_COUNT -}; - -#if defined(_WIN32) -static std::string utf16_to_utf8(const std::wstring& wstr) { - if (wstr.empty()) - return {}; - int size_needed = WideCharToMultiByte(CP_UTF8, 0, wstr.data(), (int)wstr.size(), - nullptr, 0, nullptr, nullptr); - if (size_needed <= 0) - throw std::runtime_error("UTF-16 to UTF-8 conversion failed"); - - std::string utf8(size_needed, 0); - WideCharToMultiByte(CP_UTF8, 0, wstr.data(), (int)wstr.size(), - (char*)utf8.data(), size_needed, nullptr, nullptr); - return utf8; -} - -static std::string argv_to_utf8(int index, const char** argv) { - int argc; - wchar_t** argv_w = CommandLineToArgvW(GetCommandLineW(), &argc); - if (!argv_w) - throw std::runtime_error("Failed to parse command line"); - - std::string result; - if (index < argc) { - result = utf16_to_utf8(argv_w[index]); - } - LocalFree(argv_w); - return result; -} - -#else // Linux / macOS -static std::string argv_to_utf8(int index, const char** argv) { - return std::string(argv[index]); -} - -#endif - -struct StringOption { - std::string short_name; - std::string long_name; - std::string desc; - std::string* target; -}; - -struct IntOption { - std::string short_name; - std::string long_name; - std::string desc; - int* target; -}; - -struct FloatOption { - std::string short_name; - std::string long_name; - std::string desc; - float* target; -}; - -struct BoolOption { - std::string short_name; - std::string long_name; - std::string desc; - bool keep_true; - bool* target; -}; - -struct ManualOption { - std::string short_name; - std::string long_name; - std::string desc; - std::function cb; -}; - -struct ArgOptions { - std::vector string_options; - std::vector int_options; - std::vector float_options; - std::vector bool_options; - std::vector manual_options; - - static std::string wrap_text(const std::string& text, size_t width, size_t indent) { - std::ostringstream oss; - size_t line_len = 0; - size_t pos = 0; - - while (pos < text.size()) { - // Preserve manual newlines - if (text[pos] == '\n') { - oss << '\n' - << std::string(indent, ' '); - line_len = indent; - ++pos; - continue; - } - - // Add the character - oss << text[pos]; - ++line_len; - ++pos; - - // If the current line exceeds width, try to break at the last space - if (line_len >= width) { - std::string current = oss.str(); - size_t back = current.size(); - - // Find the last space (for a clean break) - while (back > 0 && current[back - 1] != ' ' && current[back - 1] != '\n') - --back; - - // If found a space to break on - if (back > 0 && current[back - 1] != '\n') { - std::string before = current.substr(0, back - 1); - std::string after = current.substr(back); - oss.str(""); - oss.clear(); - oss << before << "\n" - << std::string(indent, ' ') << after; - } else { - // If no space found, just break at width - oss << "\n" - << std::string(indent, ' '); - } - line_len = indent; - } - } - - return oss.str(); - } - - void print() const { - constexpr size_t max_line_width = 120; - - struct Entry { - std::string names; - std::string desc; - }; - std::vector entries; - - auto add_entry = [&](const std::string& s, const std::string& l, - const std::string& desc, const std::string& hint = "") { - std::ostringstream ss; - if (!s.empty()) - ss << s; - if (!s.empty() && !l.empty()) - ss << ", "; - if (!l.empty()) - ss << l; - if (!hint.empty()) - ss << " " << hint; - entries.push_back({ss.str(), desc}); - }; - - for (auto& o : string_options) - add_entry(o.short_name, o.long_name, o.desc, ""); - for (auto& o : int_options) - add_entry(o.short_name, o.long_name, o.desc, ""); - for (auto& o : float_options) - add_entry(o.short_name, o.long_name, o.desc, ""); - for (auto& o : bool_options) - add_entry(o.short_name, o.long_name, o.desc, ""); - for (auto& o : manual_options) - add_entry(o.short_name, o.long_name, o.desc); - - size_t max_name_width = 0; - for (auto& e : entries) - max_name_width = std::max(max_name_width, e.names.size()); - - for (auto& e : entries) { - size_t indent = 2 + max_name_width + 4; - size_t desc_width = (max_line_width > indent ? max_line_width - indent : 40); - std::string wrapped_desc = wrap_text(e.desc, max_line_width, indent); - std::cout << " " << std::left << std::setw(static_cast(max_name_width) + 4) - << e.names << wrapped_desc << "\n"; - } - } -}; - -bool parse_options(int argc, const char** argv, const std::vector& options_list) { - bool invalid_arg = false; - std::string arg; - - auto match_and_apply = [&](auto& opts, auto&& apply_fn) -> bool { - for (auto& option : opts) { - if ((option.short_name.size() > 0 && arg == option.short_name) || - (option.long_name.size() > 0 && arg == option.long_name)) { - apply_fn(option); - return true; - } - } - return false; - }; - - for (int i = 1; i < argc; i++) { - arg = argv[i]; - bool found_arg = false; - - for (auto& options : options_list) { - if (match_and_apply(options.string_options, [&](auto& option) { - if (++i >= argc) { - invalid_arg = true; - return; - } - *option.target = argv_to_utf8(i, argv); - found_arg = true; - })) - break; - - if (match_and_apply(options.int_options, [&](auto& option) { - if (++i >= argc) { - invalid_arg = true; - return; - } - *option.target = std::stoi(argv[i]); - found_arg = true; - })) - break; - - if (match_and_apply(options.float_options, [&](auto& option) { - if (++i >= argc) { - invalid_arg = true; - return; - } - *option.target = std::stof(argv[i]); - found_arg = true; - })) - break; - - if (match_and_apply(options.bool_options, [&](auto& option) { - *option.target = option.keep_true ? true : false; - found_arg = true; - })) - break; - - if (match_and_apply(options.manual_options, [&](auto& option) { - int ret = option.cb(argc, argv, i); - if (ret < 0) { - invalid_arg = true; - return; - } - i += ret; - found_arg = true; - })) - break; - } - - if (invalid_arg) { - fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str()); - return false; - } - if (!found_arg) { - fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); - return false; - } - } - - return true; -} - struct SDCliParams { SDMode mode = IMG_GEN; std::string output_path = "output.png"; @@ -332,1552 +38,153 @@ struct SDCliParams { std::string preview_path = "preview.png"; int preview_fps = 16; bool taesd_preview = false; - bool preview_noisy = false; - bool color = false; - - bool normal_exit = false; - - ArgOptions get_options() { - ArgOptions options; - - options.string_options = { - {"-o", - "--output", - "path to write result image to (default: ./output.png)", - &output_path}, - {"", - "--preview-path", - "path to write preview image to (default: ./preview.png)", - &preview_path}, - }; - - options.int_options = { - {"", - "--preview-interval", - "interval in denoising steps between consecutive updates of the image preview file (default is 1, meaning updating at every step)", - &preview_interval}, - }; - - options.bool_options = { - {"", - "--canny", - "apply canny preprocessor (edge detection)", - true, &canny_preprocess}, - {"-v", - "--verbose", - "print extra info", - true, &verbose}, - {"", - "--color", - "colors the logging tags according to level", - true, &color}, - {"", - "--taesd-preview-only", - std::string("prevents usage of taesd for decoding the final image. (for use with --preview ") + previews_str[PREVIEW_TAE] + ")", - true, &taesd_preview}, - {"", - "--preview-noisy", - "enables previewing noisy inputs of the models rather than the denoised outputs", - true, &preview_noisy}, - - }; - - auto on_mode_arg = [&](int argc, const char** argv, int index) { - if (++index >= argc) { - return -1; - } - const char* mode_c_str = argv[index]; - if (mode_c_str != nullptr) { - int mode_found = -1; - for (int i = 0; i < MODE_COUNT; i++) { - if (!strcmp(mode_c_str, modes_str[i])) { - mode_found = i; - } - } - if (mode_found == -1) { - LOG_ERROR("error: invalid mode %s, must be one of [%s]\n", - mode_c_str, SD_ALL_MODES_STR); - exit(1); - } - mode = (SDMode)mode_found; - } - return 1; - }; - - auto on_preview_arg = [&](int argc, const char** argv, int index) { - if (++index >= argc) { - return -1; - } - const char* preview = argv[index]; - int preview_found = -1; - for (int m = 0; m < PREVIEW_COUNT; m++) { - if (!strcmp(preview, previews_str[m])) { - preview_found = m; - } - } - if (preview_found == -1) { - LOG_ERROR("error: preview method %s", preview); - return -1; - } - preview_method = (preview_t)preview_found; - return 1; - }; - - auto on_help_arg = [&](int argc, const char** argv, int index) { - normal_exit = true; - return -1; - }; - - options.manual_options = { - {"-M", - "--mode", - "run mode, one of [img_gen, vid_gen, upscale, convert], default: img_gen", - on_mode_arg}, - {"", - "--preview", - std::string("preview method. must be one of the following [") + previews_str[0] + ", " + previews_str[1] + ", " + previews_str[2] + ", " + previews_str[3] + "] (default is " + previews_str[PREVIEW_NONE] + ")", - on_preview_arg}, - {"-h", - "--help", - "show this help message and exit", - on_help_arg}, - }; - - return options; - }; - - bool process_and_check() { - if (output_path.length() == 0) { - LOG_ERROR("error: the following arguments are required: output_path"); - return false; - } - - if (mode == CONVERT) { - if (output_path == "output.png") { - output_path = "output.gguf"; - } - } - return true; - } - - std::string to_string() const { - std::ostringstream oss; - oss << "SDCliParams {\n" - << " mode: " << modes_str[mode] << ",\n" - << " output_path: \"" << output_path << "\",\n" - << " verbose: " << (verbose ? "true" : "false") << ",\n" - << " color: " << (color ? "true" : "false") << ",\n" - << " canny_preprocess: " << (canny_preprocess ? "true" : "false") << ",\n" - << " preview_method: " << previews_str[preview_method] << ",\n" - << " preview_interval: " << preview_interval << ",\n" - << " preview_path: \"" << preview_path << "\",\n" - << " preview_fps: " << preview_fps << ",\n" - << " taesd_preview: " << (taesd_preview ? "true" : "false") << ",\n" - << " preview_noisy: " << (preview_noisy ? "true" : "false") << "\n" - << "}"; - return oss.str(); - } -}; - -struct SDContextParams { - int n_threads = -1; - std::string model_path; - std::string clip_l_path; - std::string clip_g_path; - std::string clip_vision_path; - std::string t5xxl_path; - std::string llm_path; - std::string llm_vision_path; - std::string diffusion_model_path; - std::string high_noise_diffusion_model_path; - std::string vae_path; - std::string taesd_path; - std::string esrgan_path; - std::string control_net_path; - std::string embedding_dir; - std::string photo_maker_path; - sd_type_t wtype = SD_TYPE_COUNT; - std::string tensor_type_rules; - std::string lora_model_dir; - - std::map embedding_map; - std::vector embedding_vec; - - rng_type_t rng_type = CUDA_RNG; - rng_type_t sampler_rng_type = RNG_TYPE_COUNT; - bool offload_params_to_cpu = false; - bool control_net_cpu = false; - bool clip_on_cpu = false; - bool vae_on_cpu = false; - bool diffusion_flash_attn = false; - bool diffusion_conv_direct = false; - bool vae_conv_direct = false; - - bool chroma_use_dit_mask = true; - bool chroma_use_t5_mask = false; - int chroma_t5_mask_pad = 1; - - prediction_t prediction = PREDICTION_COUNT; - lora_apply_mode_t lora_apply_mode = LORA_APPLY_AUTO; - - sd_tiling_params_t vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f}; - bool force_sdxl_vae_conv_scale = false; - - float flow_shift = INFINITY; - - ArgOptions get_options() { - ArgOptions options; - options.string_options = { - {"-m", - "--model", - "path to full model", - &model_path}, - {"", - "--clip_l", - "path to the clip-l text encoder", &clip_l_path}, - {"", "--clip_g", - "path to the clip-g text encoder", - &clip_g_path}, - {"", - "--clip_vision", - "path to the clip-vision encoder", - &clip_vision_path}, - {"", - "--t5xxl", - "path to the t5xxl text encoder", - &t5xxl_path}, - {"", - "--llm", - "path to the llm text encoder. For example: (qwenvl2.5 for qwen-image, mistral-small3.2 for flux2, ...)", - &llm_path}, - {"", - "--llm_vision", - "path to the llm vit", - &llm_vision_path}, - {"", - "--qwen2vl", - "alias of --llm. Deprecated.", - &llm_path}, - {"", - "--qwen2vl_vision", - "alias of --llm_vision. Deprecated.", - &llm_vision_path}, - {"", - "--diffusion-model", - "path to the standalone diffusion model", - &diffusion_model_path}, - {"", - "--high-noise-diffusion-model", - "path to the standalone high noise diffusion model", - &high_noise_diffusion_model_path}, - {"", - "--vae", - "path to standalone vae model", - &vae_path}, - {"", - "--taesd", - "path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)", - &taesd_path}, - {"", - "--control-net", - "path to control net model", - &control_net_path}, - {"", - "--embd-dir", - "embeddings directory", - &embedding_dir}, - {"", - "--lora-model-dir", - "lora model directory", - &lora_model_dir}, - - {"", - "--tensor-type-rules", - "weight type per tensor pattern (example: \"^vae\\.=f16,model\\.=q8_0\")", - &tensor_type_rules}, - {"", - "--photo-maker", - "path to PHOTOMAKER model", - &photo_maker_path}, - {"", - "--upscale-model", - "path to esrgan model.", - &esrgan_path}, - }; - - options.int_options = { - {"-t", - "--threads", - "number of threads to use during computation (default: -1). " - "If threads <= 0, then threads will be set to the number of CPU physical cores", - &n_threads}, - {"", - "--chroma-t5-mask-pad", - "t5 mask pad size of chroma", - &chroma_t5_mask_pad}, - }; - - options.float_options = { - {"", - "--vae-tile-overlap", - "tile overlap for vae tiling, in fraction of tile size (default: 0.5)", - &vae_tiling_params.target_overlap}, - {"", - "--flow-shift", - "shift value for Flow models like SD3.x or WAN (default: auto)", - &flow_shift}, - }; - - options.bool_options = { - {"", - "--vae-tiling", - "process vae in tiles to reduce memory usage", - true, &vae_tiling_params.enabled}, - {"", - "--force-sdxl-vae-conv-scale", - "force use of conv scale on sdxl vae", - true, &force_sdxl_vae_conv_scale}, - {"", - "--offload-to-cpu", - "place the weights in RAM to save VRAM, and automatically load them into VRAM when needed", - true, &offload_params_to_cpu}, - {"", - "--control-net-cpu", - "keep controlnet in cpu (for low vram)", - true, &control_net_cpu}, - {"", - "--clip-on-cpu", - "keep clip in cpu (for low vram)", - true, &clip_on_cpu}, - {"", - "--vae-on-cpu", - "keep vae in cpu (for low vram)", - true, &vae_on_cpu}, - {"", - "--diffusion-fa", - "use flash attention in the diffusion model", - true, &diffusion_flash_attn}, - {"", - "--diffusion-conv-direct", - "use ggml_conv2d_direct in the diffusion model", - true, &diffusion_conv_direct}, - {"", - "--vae-conv-direct", - "use ggml_conv2d_direct in the vae model", - true, &vae_conv_direct}, - {"", - "--chroma-disable-dit-mask", - "disable dit mask for chroma", - false, &chroma_use_dit_mask}, - {"", - "--chroma-enable-t5-mask", - "enable t5 mask for chroma", - true, &chroma_use_t5_mask}, - }; - - auto on_type_arg = [&](int argc, const char** argv, int index) { - if (++index >= argc) { - return -1; - } - const char* arg = argv[index]; - wtype = str_to_sd_type(arg); - if (wtype == SD_TYPE_COUNT) { - fprintf(stderr, "error: invalid weight format %s\n", - arg); - return -1; - } - return 1; - }; - - auto on_rng_arg = [&](int argc, const char** argv, int index) { - if (++index >= argc) { - return -1; - } - const char* arg = argv[index]; - rng_type = str_to_rng_type(arg); - if (rng_type == RNG_TYPE_COUNT) { - fprintf(stderr, "error: invalid rng type %s\n", - arg); - return -1; - } - return 1; - }; - - auto on_sampler_rng_arg = [&](int argc, const char** argv, int index) { - if (++index >= argc) { - return -1; - } - const char* arg = argv[index]; - sampler_rng_type = str_to_rng_type(arg); - if (sampler_rng_type == RNG_TYPE_COUNT) { - fprintf(stderr, "error: invalid sampler rng type %s\n", - arg); - return -1; - } - return 1; - }; - - auto on_prediction_arg = [&](int argc, const char** argv, int index) { - if (++index >= argc) { - return -1; - } - const char* arg = argv[index]; - prediction = str_to_prediction(arg); - if (prediction == PREDICTION_COUNT) { - fprintf(stderr, "error: invalid prediction type %s\n", - arg); - return -1; - } - return 1; - }; - - auto on_lora_apply_mode_arg = [&](int argc, const char** argv, int index) { - if (++index >= argc) { - return -1; - } - const char* arg = argv[index]; - lora_apply_mode = str_to_lora_apply_mode(arg); - if (lora_apply_mode == LORA_APPLY_MODE_COUNT) { - fprintf(stderr, "error: invalid lora apply model %s\n", - arg); - return -1; - } - return 1; - }; - - auto on_tile_size_arg = [&](int argc, const char** argv, int index) { - if (++index >= argc) { - return -1; - } - std::string tile_size_str = argv[index]; - size_t x_pos = tile_size_str.find('x'); - try { - if (x_pos != std::string::npos) { - std::string tile_x_str = tile_size_str.substr(0, x_pos); - std::string tile_y_str = tile_size_str.substr(x_pos + 1); - vae_tiling_params.tile_size_x = std::stoi(tile_x_str); - vae_tiling_params.tile_size_y = std::stoi(tile_y_str); - } else { - vae_tiling_params.tile_size_x = vae_tiling_params.tile_size_y = std::stoi(tile_size_str); - } - } catch (const std::invalid_argument&) { - return -1; - } catch (const std::out_of_range&) { - return -1; - } - return 1; - }; - - auto on_relative_tile_size_arg = [&](int argc, const char** argv, int index) { - if (++index >= argc) { - return -1; - } - std::string rel_size_str = argv[index]; - size_t x_pos = rel_size_str.find('x'); - try { - if (x_pos != std::string::npos) { - std::string rel_x_str = rel_size_str.substr(0, x_pos); - std::string rel_y_str = rel_size_str.substr(x_pos + 1); - vae_tiling_params.rel_size_x = std::stof(rel_x_str); - vae_tiling_params.rel_size_y = std::stof(rel_y_str); - } else { - vae_tiling_params.rel_size_x = vae_tiling_params.rel_size_y = std::stof(rel_size_str); - } - } catch (const std::invalid_argument&) { - return -1; - } catch (const std::out_of_range&) { - return -1; - } - return 1; - }; - - options.manual_options = { - {"", - "--type", - "weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K). " - "If not specified, the default is the type of the weight file", - on_type_arg}, - {"", - "--rng", - "RNG, one of [std_default, cuda, cpu], default: cuda(sd-webui), cpu(comfyui)", - on_rng_arg}, - {"", - "--sampler-rng", - "sampler RNG, one of [std_default, cuda, cpu]. If not specified, use --rng", - on_sampler_rng_arg}, - {"", - "--prediction", - "prediction type override, one of [eps, v, edm_v, sd3_flow, flux_flow, flux2_flow]", - on_prediction_arg}, - {"", - "--lora-apply-mode", - "the way to apply LoRA, one of [auto, immediately, at_runtime], default is auto. " - "In auto mode, if the model weights contain any quantized parameters, the at_runtime mode will be used; otherwise, immediately will be used." - "The immediately mode may have precision and compatibility issues with quantized parameters, " - "but it usually offers faster inference speed and, in some cases, lower memory usage. " - "The at_runtime mode, on the other hand, is exactly the opposite.", - on_lora_apply_mode_arg}, - {"", - "--vae-tile-size", - "tile size for vae tiling, format [X]x[Y] (default: 32x32)", - on_tile_size_arg}, - {"", - "--vae-relative-tile-size", - "relative tile size for vae tiling, format [X]x[Y], in fraction of image size if < 1, in number of tiles per dim if >=1 (overrides --vae-tile-size)", - on_relative_tile_size_arg}, - }; - - return options; - } - - void build_embedding_map() { - static const std::vector valid_ext = {".pt", ".safetensors", ".gguf"}; - - if (!fs::exists(embedding_dir) || !fs::is_directory(embedding_dir)) { - return; - } - - for (auto& p : fs::directory_iterator(embedding_dir)) { - if (!p.is_regular_file()) - continue; - - auto path = p.path(); - std::string ext = path.extension().string(); - - bool valid = false; - for (auto& e : valid_ext) { - if (ext == e) { - valid = true; - break; - } - } - if (!valid) - continue; - - std::string key = path.stem().string(); - std::string value = path.string(); - - embedding_map[key] = value; - } - } - - bool process_and_check(SDMode mode) { - if (mode != UPSCALE && model_path.length() == 0 && diffusion_model_path.length() == 0) { - fprintf(stderr, "error: the following arguments are required: model_path/diffusion_model\n"); - return false; - } - - if (mode == UPSCALE) { - if (esrgan_path.length() == 0) { - fprintf(stderr, "error: upscale mode needs an upscaler model (--upscale-model)\n"); - return false; - } - } - - if (n_threads <= 0) { - n_threads = sd_get_num_physical_cores(); - } - - build_embedding_map(); - - return true; - } - - std::string to_string() const { - std::ostringstream emb_ss; - emb_ss << "{\n"; - for (auto it = embedding_map.begin(); it != embedding_map.end(); ++it) { - emb_ss << " \"" << it->first << "\": \"" << it->second << "\""; - if (std::next(it) != embedding_map.end()) { - emb_ss << ","; - } - emb_ss << "\n"; - } - emb_ss << " }"; - - std::string embeddings_str = emb_ss.str(); - std::ostringstream oss; - oss << "SDContextParams {\n" - << " n_threads: " << n_threads << ",\n" - << " model_path: \"" << model_path << "\",\n" - << " clip_l_path: \"" << clip_l_path << "\",\n" - << " clip_g_path: \"" << clip_g_path << "\",\n" - << " clip_vision_path: \"" << clip_vision_path << "\",\n" - << " t5xxl_path: \"" << t5xxl_path << "\",\n" - << " llm_path: \"" << llm_path << "\",\n" - << " llm_vision_path: \"" << llm_vision_path << "\",\n" - << " diffusion_model_path: \"" << diffusion_model_path << "\",\n" - << " high_noise_diffusion_model_path: \"" << high_noise_diffusion_model_path << "\",\n" - << " vae_path: \"" << vae_path << "\",\n" - << " taesd_path: \"" << taesd_path << "\",\n" - << " esrgan_path: \"" << esrgan_path << "\",\n" - << " control_net_path: \"" << control_net_path << "\",\n" - << " embedding_dir: \"" << embedding_dir << "\",\n" - << " embeddings: " << embeddings_str << "\n" - << " wtype: " << sd_type_name(wtype) << ",\n" - << " tensor_type_rules: \"" << tensor_type_rules << "\",\n" - << " lora_model_dir: \"" << lora_model_dir << "\",\n" - << " photo_maker_path: \"" << photo_maker_path << "\",\n" - << " rng_type: " << sd_rng_type_name(rng_type) << ",\n" - << " sampler_rng_type: " << sd_rng_type_name(sampler_rng_type) << ",\n" - << " flow_shift: " << (std::isinf(flow_shift) ? "INF" : std::to_string(flow_shift)) << "\n" - << " offload_params_to_cpu: " << (offload_params_to_cpu ? "true" : "false") << ",\n" - << " control_net_cpu: " << (control_net_cpu ? "true" : "false") << ",\n" - << " clip_on_cpu: " << (clip_on_cpu ? "true" : "false") << ",\n" - << " vae_on_cpu: " << (vae_on_cpu ? "true" : "false") << ",\n" - << " diffusion_flash_attn: " << (diffusion_flash_attn ? "true" : "false") << ",\n" - << " diffusion_conv_direct: " << (diffusion_conv_direct ? "true" : "false") << ",\n" - << " vae_conv_direct: " << (vae_conv_direct ? "true" : "false") << ",\n" - << " chroma_use_dit_mask: " << (chroma_use_dit_mask ? "true" : "false") << ",\n" - << " chroma_use_t5_mask: " << (chroma_use_t5_mask ? "true" : "false") << ",\n" - << " chroma_t5_mask_pad: " << chroma_t5_mask_pad << ",\n" - << " prediction: " << sd_prediction_name(prediction) << ",\n" - << " lora_apply_mode: " << sd_lora_apply_mode_name(lora_apply_mode) << ",\n" - << " vae_tiling_params: { " - << vae_tiling_params.enabled << ", " - << vae_tiling_params.tile_size_x << ", " - << vae_tiling_params.tile_size_y << ", " - << vae_tiling_params.target_overlap << ", " - << vae_tiling_params.rel_size_x << ", " - << vae_tiling_params.rel_size_y << " },\n" - << " force_sdxl_vae_conv_scale: " << (force_sdxl_vae_conv_scale ? "true" : "false") << "\n" - << "}"; - return oss.str(); - } - - sd_ctx_params_t to_sd_ctx_params_t(bool vae_decode_only, bool free_params_immediately, bool taesd_preview) { - embedding_vec.clear(); - embedding_vec.reserve(embedding_map.size()); - for (const auto& kv : embedding_map) { - sd_embedding_t item; - item.name = kv.first.c_str(); - item.path = kv.second.c_str(); - embedding_vec.emplace_back(item); - } - - sd_ctx_params_t sd_ctx_params = { - model_path.c_str(), - clip_l_path.c_str(), - clip_g_path.c_str(), - clip_vision_path.c_str(), - t5xxl_path.c_str(), - llm_path.c_str(), - llm_vision_path.c_str(), - diffusion_model_path.c_str(), - high_noise_diffusion_model_path.c_str(), - vae_path.c_str(), - taesd_path.c_str(), - control_net_path.c_str(), - lora_model_dir.c_str(), - embedding_vec.data(), - static_cast(embedding_vec.size()), - photo_maker_path.c_str(), - tensor_type_rules.c_str(), - vae_decode_only, - free_params_immediately, - n_threads, - wtype, - rng_type, - sampler_rng_type, - prediction, - lora_apply_mode, - offload_params_to_cpu, - clip_on_cpu, - control_net_cpu, - vae_on_cpu, - diffusion_flash_attn, - taesd_preview, - diffusion_conv_direct, - vae_conv_direct, - force_sdxl_vae_conv_scale, - chroma_use_dit_mask, - chroma_use_t5_mask, - chroma_t5_mask_pad, - flow_shift, - }; - return sd_ctx_params; - } -}; - -template -static std::string vec_to_string(const std::vector& v) { - std::ostringstream oss; - oss << "["; - for (size_t i = 0; i < v.size(); i++) { - oss << v[i]; - if (i + 1 < v.size()) - oss << ", "; - } - oss << "]"; - return oss.str(); -} - -static std::string vec_str_to_string(const std::vector& v) { - std::ostringstream oss; - oss << "["; - for (size_t i = 0; i < v.size(); i++) { - oss << "\"" << v[i] << "\""; - if (i + 1 < v.size()) - oss << ", "; - } - oss << "]"; - return oss.str(); -} - -static bool is_absolute_path(const std::string& p) { -#ifdef _WIN32 - // Windows: C:/path or C:\path - return p.size() > 1 && std::isalpha(static_cast(p[0])) && p[1] == ':'; -#else - return !p.empty() && p[0] == '/'; -#endif -} - -struct SDGenerationParams { - std::string prompt; - std::string negative_prompt; - int clip_skip = -1; // <= 0 represents unspecified - int width = 512; - int height = 512; - int batch_count = 1; - std::string init_image_path; - std::string end_image_path; - std::string mask_image_path; - std::string control_image_path; - std::vector ref_image_paths; - std::string control_video_path; - bool auto_resize_ref_image = true; - bool increase_ref_index = false; - - std::vector skip_layers = {7, 8, 9}; - sd_sample_params_t sample_params; - - std::vector high_noise_skip_layers = {7, 8, 9}; - sd_sample_params_t high_noise_sample_params; - - std::string cache_mode; - std::string cache_option; - std::string cache_preset; - std::string scm_mask; - bool scm_policy_dynamic = true; - sd_cache_params_t cache_params{}; - - float moe_boundary = 0.875f; - int video_frames = 1; - int fps = 16; - float vace_strength = 1.f; - - float strength = 0.75f; - float control_strength = 0.9f; - - int64_t seed = 42; - - // Photo Maker - std::string pm_id_images_dir; - std::string pm_id_embed_path; - float pm_style_strength = 20.f; - - int upscale_repeats = 1; - - std::map lora_map; - std::map high_noise_lora_map; - std::vector lora_vec; - - SDGenerationParams() { - sd_sample_params_init(&sample_params); - sd_sample_params_init(&high_noise_sample_params); - } - - ArgOptions get_options() { - ArgOptions options; - options.string_options = { - {"-p", - "--prompt", - "the prompt to render", - &prompt}, - {"-n", - "--negative-prompt", - "the negative prompt (default: \"\")", - &negative_prompt}, - {"-i", - "--init-img", - "path to the init image", - &init_image_path}, - {"", - "--end-img", - "path to the end image, required by flf2v", - &end_image_path}, - {"", - "--mask", - "path to the mask image", - &mask_image_path}, - {"", - "--control-image", - "path to control image, control net", - &control_image_path}, - {"", - "--control-video", - "path to control video frames, It must be a directory path. The video frames inside should be stored as images in " - "lexicographical (character) order. For example, if the control video path is `frames`, the directory contain images " - "such as 00.png, 01.png, ... etc.", - &control_video_path}, - {"", - "--pm-id-images-dir", - "path to PHOTOMAKER input id images dir", - &pm_id_images_dir}, - {"", - "--pm-id-embed-path", - "path to PHOTOMAKER v2 id embed", - &pm_id_embed_path}, - }; - - options.int_options = { - {"-H", - "--height", - "image height, in pixel space (default: 512)", - &height}, - {"-W", - "--width", - "image width, in pixel space (default: 512)", - &width}, - {"", - "--steps", - "number of sample steps (default: 20)", - &sample_params.sample_steps}, - {"", - "--high-noise-steps", - "(high noise) number of sample steps (default: -1 = auto)", - &high_noise_sample_params.sample_steps}, - {"", - "--clip-skip", - "ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1). " - "<= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x", - &clip_skip}, - {"-b", - "--batch-count", - "batch count", - &batch_count}, - {"", - "--video-frames", - "video frames (default: 1)", - &video_frames}, - {"", - "--fps", - "fps (default: 24)", - &fps}, - {"", - "--timestep-shift", - "shift timestep for NitroFusion models (default: 0). " - "recommended N for NitroSD-Realism around 250 and 500 for NitroSD-Vibrant", - &sample_params.shifted_timestep}, - {"", - "--upscale-repeats", - "Run the ESRGAN upscaler this many times (default: 1)", - &upscale_repeats}, - }; - - options.float_options = { - {"", - "--cfg-scale", - "unconditional guidance scale: (default: 7.0)", - &sample_params.guidance.txt_cfg}, - {"", - "--img-cfg-scale", - "image guidance scale for inpaint or instruct-pix2pix models: (default: same as --cfg-scale)", - &sample_params.guidance.img_cfg}, - {"", - "--guidance", - "distilled guidance scale for models with guidance input (default: 3.5)", - &sample_params.guidance.distilled_guidance}, - {"", - "--slg-scale", - "skip layer guidance (SLG) scale, only for DiT models: (default: 0). 0 means disabled, a value of 2.5 is nice for sd3.5 medium", - &sample_params.guidance.slg.scale}, - {"", - "--skip-layer-start", - "SLG enabling point (default: 0.01)", - &sample_params.guidance.slg.layer_start}, - {"", - "--skip-layer-end", - "SLG disabling point (default: 0.2)", - &sample_params.guidance.slg.layer_end}, - {"", - "--eta", - "eta in DDIM, only for DDIM and TCD (default: 0)", - &sample_params.eta}, - {"", - "--high-noise-cfg-scale", - "(high noise) unconditional guidance scale: (default: 7.0)", - &high_noise_sample_params.guidance.txt_cfg}, - {"", - "--high-noise-img-cfg-scale", - "(high noise) image guidance scale for inpaint or instruct-pix2pix models (default: same as --cfg-scale)", - &high_noise_sample_params.guidance.img_cfg}, - {"", - "--high-noise-guidance", - "(high noise) distilled guidance scale for models with guidance input (default: 3.5)", - &high_noise_sample_params.guidance.distilled_guidance}, - {"", - "--high-noise-slg-scale", - "(high noise) skip layer guidance (SLG) scale, only for DiT models: (default: 0)", - &high_noise_sample_params.guidance.slg.scale}, - {"", - "--high-noise-skip-layer-start", - "(high noise) SLG enabling point (default: 0.01)", - &high_noise_sample_params.guidance.slg.layer_start}, - {"", - "--high-noise-skip-layer-end", - "(high noise) SLG disabling point (default: 0.2)", - &high_noise_sample_params.guidance.slg.layer_end}, - {"", - "--high-noise-eta", - "(high noise) eta in DDIM, only for DDIM and TCD (default: 0)", - &high_noise_sample_params.eta}, - {"", - "--strength", - "strength for noising/unnoising (default: 0.75)", - &strength}, - {"", - "--pm-style-strength", - "", - &pm_style_strength}, - {"", - "--control-strength", - "strength to apply Control Net (default: 0.9). 1.0 corresponds to full destruction of information in init image", - &control_strength}, - {"", - "--moe-boundary", - "timestep boundary for Wan2.2 MoE model. (default: 0.875). Only enabled if `--high-noise-steps` is set to -1", - &moe_boundary}, - {"", - "--vace-strength", - "wan vace strength", - &vace_strength}, - }; - - options.bool_options = { - {"", - "--increase-ref-index", - "automatically increase the indices of references images based on the order they are listed (starting with 1).", - true, - &increase_ref_index}, - {"", - "--disable-auto-resize-ref-image", - "disable auto resize of ref images", - false, - &auto_resize_ref_image}, - }; - - auto on_seed_arg = [&](int argc, const char** argv, int index) { - if (++index >= argc) { - return -1; - } - seed = std::stoll(argv[index]); - return 1; - }; - - auto on_sample_method_arg = [&](int argc, const char** argv, int index) { - if (++index >= argc) { - return -1; - } - const char* arg = argv[index]; - sample_params.sample_method = str_to_sample_method(arg); - if (sample_params.sample_method == SAMPLE_METHOD_COUNT) { - fprintf(stderr, "error: invalid sample method %s\n", - arg); - return -1; - } - return 1; - }; - - auto on_high_noise_sample_method_arg = [&](int argc, const char** argv, int index) { - if (++index >= argc) { - return -1; - } - const char* arg = argv[index]; - high_noise_sample_params.sample_method = str_to_sample_method(arg); - if (high_noise_sample_params.sample_method == SAMPLE_METHOD_COUNT) { - fprintf(stderr, "error: invalid high noise sample method %s\n", - arg); - return -1; - } - return 1; - }; - - auto on_scheduler_arg = [&](int argc, const char** argv, int index) { - if (++index >= argc) { - return -1; - } - const char* arg = argv[index]; - sample_params.scheduler = str_to_scheduler(arg); - if (sample_params.scheduler == SCHEDULER_COUNT) { - fprintf(stderr, "error: invalid scheduler %s\n", - arg); - return -1; - } - return 1; - }; - - auto on_skip_layers_arg = [&](int argc, const char** argv, int index) { - if (++index >= argc) { - return -1; - } - std::string layers_str = argv[index]; - if (layers_str[0] != '[' || layers_str[layers_str.size() - 1] != ']') { - return -1; - } - - layers_str = layers_str.substr(1, layers_str.size() - 2); - - std::regex regex("[, ]+"); - std::sregex_token_iterator iter(layers_str.begin(), layers_str.end(), regex, -1); - std::sregex_token_iterator end; - std::vector tokens(iter, end); - std::vector layers; - for (const auto& token : tokens) { - try { - layers.push_back(std::stoi(token)); - } catch (const std::invalid_argument&) { - return -1; - } - } - skip_layers = layers; - return 1; - }; + bool preview_noisy = false; + bool color = false; - auto on_high_noise_skip_layers_arg = [&](int argc, const char** argv, int index) { - if (++index >= argc) { - return -1; - } - std::string layers_str = argv[index]; - if (layers_str[0] != '[' || layers_str[layers_str.size() - 1] != ']') { - return -1; - } + bool normal_exit = false; - layers_str = layers_str.substr(1, layers_str.size() - 2); - - std::regex regex("[, ]+"); - std::sregex_token_iterator iter(layers_str.begin(), layers_str.end(), regex, -1); - std::sregex_token_iterator end; - std::vector tokens(iter, end); - std::vector layers; - for (const auto& token : tokens) { - try { - layers.push_back(std::stoi(token)); - } catch (const std::invalid_argument&) { - return -1; - } - } - high_noise_skip_layers = layers; - return 1; - }; + ArgOptions get_options() { + ArgOptions options; - auto on_ref_image_arg = [&](int argc, const char** argv, int index) { - if (++index >= argc) { - return -1; - } - ref_image_paths.push_back(argv[index]); - return 1; + options.string_options = { + {"-o", + "--output", + "path to write result image to (default: ./output.png)", + &output_path}, + {"", + "--preview-path", + "path to write preview image to (default: ./preview.png)", + &preview_path}, }; - auto on_cache_mode_arg = [&](int argc, const char** argv, int index) { - if (++index >= argc) { - return -1; - } - cache_mode = argv_to_utf8(index, argv); - if (cache_mode != "easycache" && cache_mode != "ucache" && - cache_mode != "dbcache" && cache_mode != "taylorseer" && cache_mode != "cache-dit") { - fprintf(stderr, "error: invalid cache mode '%s', must be 'easycache', 'ucache', 'dbcache', 'taylorseer', or 'cache-dit'\n", cache_mode.c_str()); - return -1; - } - return 1; + options.int_options = { + {"", + "--preview-interval", + "interval in denoising steps between consecutive updates of the image preview file (default is 1, meaning updating at every step)", + &preview_interval}, }; - auto on_cache_option_arg = [&](int argc, const char** argv, int index) { - if (++index >= argc) { - return -1; - } - cache_option = argv_to_utf8(index, argv); - return 1; - }; + options.bool_options = { + {"", + "--canny", + "apply canny preprocessor (edge detection)", + true, &canny_preprocess}, + {"-v", + "--verbose", + "print extra info", + true, &verbose}, + {"", + "--color", + "colors the logging tags according to level", + true, &color}, + {"", + "--taesd-preview-only", + std::string("prevents usage of taesd for decoding the final image. (for use with --preview ") + previews_str[PREVIEW_TAE] + ")", + true, &taesd_preview}, + {"", + "--preview-noisy", + "enables previewing noisy inputs of the models rather than the denoised outputs", + true, &preview_noisy}, - auto on_scm_mask_arg = [&](int argc, const char** argv, int index) { - if (++index >= argc) { - return -1; - } - scm_mask = argv_to_utf8(index, argv); - return 1; }; - auto on_scm_policy_arg = [&](int argc, const char** argv, int index) { + auto on_mode_arg = [&](int argc, const char** argv, int index) { if (++index >= argc) { return -1; } - std::string policy = argv_to_utf8(index, argv); - if (policy == "dynamic") { - scm_policy_dynamic = true; - } else if (policy == "static") { - scm_policy_dynamic = false; - } else { - fprintf(stderr, "error: invalid SCM policy '%s', must be 'dynamic' or 'static'\n", policy.c_str()); - return -1; + const char* mode_c_str = argv[index]; + if (mode_c_str != nullptr) { + int mode_found = -1; + for (int i = 0; i < MODE_COUNT; i++) { + if (!strcmp(mode_c_str, modes_str[i])) { + mode_found = i; + } + } + if (mode_found == -1) { + LOG_ERROR("error: invalid mode %s, must be one of [%s]\n", + mode_c_str, SD_ALL_MODES_STR); + exit(1); + } + mode = (SDMode)mode_found; } return 1; }; - auto on_cache_preset_arg = [&](int argc, const char** argv, int index) { + auto on_preview_arg = [&](int argc, const char** argv, int index) { if (++index >= argc) { return -1; } - cache_preset = argv_to_utf8(index, argv); - if (cache_preset != "slow" && cache_preset != "s" && cache_preset != "S" && - cache_preset != "medium" && cache_preset != "m" && cache_preset != "M" && - cache_preset != "fast" && cache_preset != "f" && cache_preset != "F" && - cache_preset != "ultra" && cache_preset != "u" && cache_preset != "U") { - fprintf(stderr, "error: invalid cache preset '%s', must be 'slow'/'s', 'medium'/'m', 'fast'/'f', or 'ultra'/'u'\n", cache_preset.c_str()); + const char* preview = argv[index]; + int preview_found = -1; + for (int m = 0; m < PREVIEW_COUNT; m++) { + if (!strcmp(preview, previews_str[m])) { + preview_found = m; + } + } + if (preview_found == -1) { + LOG_ERROR("error: preview method %s", preview); return -1; } + preview_method = (preview_t)preview_found; return 1; }; + auto on_help_arg = [&](int argc, const char** argv, int index) { + normal_exit = true; + return -1; + }; + options.manual_options = { - {"-s", - "--seed", - "RNG seed (default: 42, use random seed for < 0)", - on_seed_arg}, - {"", - "--sampling-method", - "sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd] " - "(default: euler for Flux/SD3/Wan, euler_a otherwise)", - on_sample_method_arg}, - {"", - "--high-noise-sampling-method", - "(high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd]" - " default: euler for Flux/SD3/Wan, euler_a otherwise", - on_high_noise_sample_method_arg}, - {"", - "--scheduler", - "denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, lcm], default: discrete", - on_scheduler_arg}, - {"", - "--skip-layers", - "layers to skip for SLG steps (default: [7,8,9])", - on_skip_layers_arg}, - {"", - "--high-noise-skip-layers", - "(high noise) layers to skip for SLG steps (default: [7,8,9])", - on_high_noise_skip_layers_arg}, - {"-r", - "--ref-image", - "reference image for Flux Kontext models (can be used multiple times)", - on_ref_image_arg}, - {"", - "--cache-mode", - "caching method: 'easycache'/'ucache' (legacy), 'dbcache'/'taylorseer'/'cache-dit' (DiT block-level)", - on_cache_mode_arg}, - {"", - "--cache-option", - "named cache params (key=value format, comma-separated):\n - easycache/ucache: threshold=,start=,end=,decay=,relative=,reset=\n - dbcache/taylorseer/cache-dit: Fn=,Bn=,threshold=,warmup=\n Examples: \"threshold=0.25\" or \"threshold=1.5,reset=0\"", - on_cache_option_arg}, - {"", - "--scm-mask", - "SCM steps mask for cache-dit: comma-separated 0/1 (e.g., \"1,1,1,0,0,1,0,0,1,0\") - 1=compute, 0=can cache", - on_scm_mask_arg}, - {"", - "--scm-policy", - "SCM policy: 'dynamic' (check threshold, default) or 'static' (use cache without checking)", - on_scm_policy_arg}, + {"-M", + "--mode", + "run mode, one of [img_gen, vid_gen, upscale, convert], default: img_gen", + on_mode_arg}, {"", - "--cache-preset", - "Cache-DIT preset: 'slow'/'s' (~2.7x), 'medium'/'m' (~3.2x), 'fast'/'f' (~5.7x), 'ultra'/'u' (~7.4x). Sets SCM mask + threshold + warmup automatically", - on_cache_preset_arg}, - + "--preview", + std::string("preview method. must be one of the following [") + previews_str[0] + ", " + previews_str[1] + ", " + previews_str[2] + ", " + previews_str[3] + "] (default is " + previews_str[PREVIEW_NONE] + ")", + on_preview_arg}, + {"-h", + "--help", + "show this help message and exit", + on_help_arg}, }; return options; - } - - void extract_and_remove_lora(const std::string& lora_model_dir) { - static const std::regex re(R"(]+):([^>]+)>)"); - static const std::vector valid_ext = {".pt", ".safetensors", ".gguf"}; - std::smatch m; - - std::string tmp = prompt; - - while (std::regex_search(tmp, m, re)) { - std::string raw_path = m[1].str(); - const std::string raw_mul = m[2].str(); - - float mul = 0.f; - try { - mul = std::stof(raw_mul); - } catch (...) { - tmp = m.suffix().str(); - prompt = std::regex_replace(prompt, re, "", std::regex_constants::format_first_only); - continue; - } - - bool is_high_noise = false; - static const std::string prefix = "|high_noise|"; - if (raw_path.rfind(prefix, 0) == 0) { - raw_path.erase(0, prefix.size()); - is_high_noise = true; - } - - fs::path final_path; - if (is_absolute_path(raw_path)) { - final_path = raw_path; - } else { - final_path = fs::path(lora_model_dir) / raw_path; - } - if (!fs::exists(final_path)) { - bool found = false; - for (const auto& ext : valid_ext) { - fs::path try_path = final_path; - try_path += ext; - if (fs::exists(try_path)) { - final_path = try_path; - found = true; - break; - } - } - if (!found) { - printf("can not found lora %s\n", final_path.lexically_normal().string().c_str()); - tmp = m.suffix().str(); - prompt = std::regex_replace(prompt, re, "", std::regex_constants::format_first_only); - continue; - } - } - - const std::string key = final_path.lexically_normal().string(); - - if (is_high_noise) - high_noise_lora_map[key] += mul; - else - lora_map[key] += mul; - - prompt = std::regex_replace(prompt, re, "", std::regex_constants::format_first_only); - - tmp = m.suffix().str(); - } - - for (const auto& kv : lora_map) { - sd_lora_t item; - item.is_high_noise = false; - item.path = kv.first.c_str(); - item.multiplier = kv.second; - lora_vec.emplace_back(item); - } - - for (const auto& kv : high_noise_lora_map) { - sd_lora_t item; - item.is_high_noise = true; - item.path = kv.first.c_str(); - item.multiplier = kv.second; - lora_vec.emplace_back(item); - } - } - - bool process_and_check(SDMode mode, const std::string& lora_model_dir) { - if (width <= 0) { - fprintf(stderr, "error: the width must be greater than 0\n"); - return false; - } - - if (height <= 0) { - fprintf(stderr, "error: the height must be greater than 0\n"); - return false; - } - - if (sample_params.sample_steps <= 0) { - fprintf(stderr, "error: the sample_steps must be greater than 0\n"); - return false; - } - - if (high_noise_sample_params.sample_steps <= 0) { - high_noise_sample_params.sample_steps = -1; - } - - if (strength < 0.f || strength > 1.f) { - fprintf(stderr, "error: can only work with strength in [0.0, 1.0]\n"); - return false; - } - - cache_params.mode = SD_CACHE_DISABLED; - - if (!cache_mode.empty()) { - auto trim = [](std::string& s) { - const char* whitespace = " \t\r\n"; - auto start = s.find_first_not_of(whitespace); - if (start == std::string::npos) { s.clear(); return; } - auto end = s.find_last_not_of(whitespace); - s = s.substr(start, end - start + 1); - }; - - auto parse_named_params = [&](const std::string& opt_str) -> bool { - std::stringstream ss(opt_str); - std::string token; - while (std::getline(ss, token, ',')) { - trim(token); - if (token.empty()) continue; - - size_t eq_pos = token.find('='); - if (eq_pos == std::string::npos) { - fprintf(stderr, "error: invalid named parameter '%s', expected key=value\n", token.c_str()); - return false; - } - - std::string key = token.substr(0, eq_pos); - std::string val = token.substr(eq_pos + 1); - trim(key); - trim(val); - - if (key.empty() || val.empty()) { - fprintf(stderr, "error: invalid named parameter '%s'\n", token.c_str()); - return false; - } - - try { - if (key == "threshold") { - if (cache_mode == "easycache" || cache_mode == "ucache") { - cache_params.reuse_threshold = std::stof(val); - } else { - cache_params.residual_diff_threshold = std::stof(val); - } - } else if (key == "start") { - cache_params.start_percent = std::stof(val); - } else if (key == "end") { - cache_params.end_percent = std::stof(val); - } else if (key == "decay") { - cache_params.error_decay_rate = std::stof(val); - } else if (key == "relative") { - cache_params.use_relative_threshold = (std::stof(val) != 0.0f); - } else if (key == "reset") { - cache_params.reset_error_on_compute = (std::stof(val) != 0.0f); - } else if (key == "Fn" || key == "fn") { - cache_params.Fn_compute_blocks = std::stoi(val); - } else if (key == "Bn" || key == "bn") { - cache_params.Bn_compute_blocks = std::stoi(val); - } else if (key == "warmup") { - cache_params.max_warmup_steps = std::stoi(val); - } else { - fprintf(stderr, "error: unknown cache parameter '%s'\n", key.c_str()); - return false; - } - } catch (const std::exception&) { - fprintf(stderr, "error: invalid value '%s' for parameter '%s'\n", val.c_str(), key.c_str()); - return false; - } - } - return true; - }; - - if (cache_mode == "easycache") { - cache_params.mode = SD_CACHE_EASYCACHE; - cache_params.reuse_threshold = 0.2f; - cache_params.start_percent = 0.15f; - cache_params.end_percent = 0.95f; - cache_params.error_decay_rate = 1.0f; - cache_params.use_relative_threshold = true; - cache_params.reset_error_on_compute = true; - } else if (cache_mode == "ucache") { - cache_params.mode = SD_CACHE_UCACHE; - cache_params.reuse_threshold = 1.0f; - cache_params.start_percent = 0.15f; - cache_params.end_percent = 0.95f; - cache_params.error_decay_rate = 1.0f; - cache_params.use_relative_threshold = true; - cache_params.reset_error_on_compute = true; - } else if (cache_mode == "dbcache") { - cache_params.mode = SD_CACHE_DBCACHE; - cache_params.Fn_compute_blocks = 8; - cache_params.Bn_compute_blocks = 0; - cache_params.residual_diff_threshold = 0.08f; - cache_params.max_warmup_steps = 8; - } else if (cache_mode == "taylorseer") { - cache_params.mode = SD_CACHE_TAYLORSEER; - cache_params.Fn_compute_blocks = 8; - cache_params.Bn_compute_blocks = 0; - cache_params.residual_diff_threshold = 0.08f; - cache_params.max_warmup_steps = 8; - } else if (cache_mode == "cache-dit") { - cache_params.mode = SD_CACHE_CACHE_DIT; - cache_params.Fn_compute_blocks = 8; - cache_params.Bn_compute_blocks = 0; - cache_params.residual_diff_threshold = 0.08f; - cache_params.max_warmup_steps = 8; - } else { - fprintf(stderr, "error: unknown cache mode '%s'\n", cache_mode.c_str()); - return false; - } - - if (!cache_option.empty()) { - if (!parse_named_params(cache_option)) { - return false; - } - } - - if (cache_mode == "easycache" || cache_mode == "ucache") { - if (cache_params.reuse_threshold < 0.0f) { - fprintf(stderr, "error: cache threshold must be non-negative\n"); - return false; - } - if (cache_params.start_percent < 0.0f || cache_params.start_percent >= 1.0f || - cache_params.end_percent <= 0.0f || cache_params.end_percent > 1.0f || - cache_params.start_percent >= cache_params.end_percent) { - fprintf(stderr, "error: cache start/end percents must satisfy 0.0 <= start < end <= 1.0\n"); - return false; - } - } - } - - if (cache_params.mode == SD_CACHE_DBCACHE || - cache_params.mode == SD_CACHE_TAYLORSEER || - cache_params.mode == SD_CACHE_CACHE_DIT) { - - if (!cache_preset.empty()) { - cache_params.Fn_compute_blocks = get_preset_Fn(cache_preset); - cache_params.Bn_compute_blocks = get_preset_Bn(cache_preset); - cache_params.residual_diff_threshold = get_preset_threshold(cache_preset); - cache_params.max_warmup_steps = get_preset_warmup(cache_preset); - - if (scm_mask.empty()) { - int total_steps = sample_params.sample_steps; - std::vector mask = get_scm_preset(cache_preset, total_steps); - std::ostringstream oss; - for (size_t i = 0; i < mask.size(); i++) { - if (i > 0) oss << ","; - oss << mask[i]; - } - scm_mask = oss.str(); - } - } - - if (!scm_mask.empty()) { - cache_params.scm_mask = scm_mask.c_str(); - } - cache_params.scm_policy_dynamic = scm_policy_dynamic; - } - - sample_params.guidance.slg.layers = skip_layers.data(); - sample_params.guidance.slg.layer_count = skip_layers.size(); - high_noise_sample_params.guidance.slg.layers = high_noise_skip_layers.data(); - high_noise_sample_params.guidance.slg.layer_count = high_noise_skip_layers.size(); - - if (mode == VID_GEN && video_frames <= 0) { - return false; - } - - if (mode == VID_GEN && fps <= 0) { - return false; - } - - if (sample_params.shifted_timestep < 0 || sample_params.shifted_timestep > 1000) { - return false; - } + }; - if (upscale_repeats < 1) { + bool process_and_check() { + if (output_path.length() == 0) { + LOG_ERROR("error: the following arguments are required: output_path"); return false; } - if (mode == UPSCALE) { - if (init_image_path.length() == 0) { - fprintf(stderr, "error: upscale mode needs an init image (--init-img)\n"); - return false; + if (mode == CONVERT) { + if (output_path == "output.png") { + output_path = "output.gguf"; } } - - if (seed < 0) { - srand((int)time(nullptr)); - seed = rand(); - } - - extract_and_remove_lora(lora_model_dir); - return true; } std::string to_string() const { - char* sample_params_str = sd_sample_params_to_str(&sample_params); - char* high_noise_sample_params_str = sd_sample_params_to_str(&high_noise_sample_params); - - std::ostringstream lora_ss; - lora_ss << "{\n"; - for (auto it = lora_map.begin(); it != lora_map.end(); ++it) { - lora_ss << " \"" << it->first << "\": \"" << it->second << "\""; - if (std::next(it) != lora_map.end()) { - lora_ss << ","; - } - lora_ss << "\n"; - } - lora_ss << " }"; - std::string loras_str = lora_ss.str(); - - lora_ss = std::ostringstream(); - ; - lora_ss << "{\n"; - for (auto it = high_noise_lora_map.begin(); it != high_noise_lora_map.end(); ++it) { - lora_ss << " \"" << it->first << "\": \"" << it->second << "\""; - if (std::next(it) != high_noise_lora_map.end()) { - lora_ss << ","; - } - lora_ss << "\n"; - } - lora_ss << " }"; - std::string high_noise_loras_str = lora_ss.str(); - std::ostringstream oss; - oss << "SDGenerationParams {\n" - << " loras: \"" << loras_str << "\",\n" - << " high_noise_loras: \"" << high_noise_loras_str << "\",\n" - << " prompt: \"" << prompt << "\",\n" - << " negative_prompt: \"" << negative_prompt << "\",\n" - << " clip_skip: " << clip_skip << ",\n" - << " width: " << width << ",\n" - << " height: " << height << ",\n" - << " batch_count: " << batch_count << ",\n" - << " init_image_path: \"" << init_image_path << "\",\n" - << " end_image_path: \"" << end_image_path << "\",\n" - << " mask_image_path: \"" << mask_image_path << "\",\n" - << " control_image_path: \"" << control_image_path << "\",\n" - << " ref_image_paths: " << vec_str_to_string(ref_image_paths) << ",\n" - << " control_video_path: \"" << control_video_path << "\",\n" - << " auto_resize_ref_image: " << (auto_resize_ref_image ? "true" : "false") << ",\n" - << " increase_ref_index: " << (increase_ref_index ? "true" : "false") << ",\n" - << " pm_id_images_dir: \"" << pm_id_images_dir << "\",\n" - << " pm_id_embed_path: \"" << pm_id_embed_path << "\",\n" - << " pm_style_strength: " << pm_style_strength << ",\n" - << " skip_layers: " << vec_to_string(skip_layers) << ",\n" - << " sample_params: " << sample_params_str << ",\n" - << " high_noise_skip_layers: " << vec_to_string(high_noise_skip_layers) << ",\n" - << " high_noise_sample_params: " << high_noise_sample_params_str << ",\n" - << " cache_mode: \"" << cache_mode << "\",\n" - << " cache_option: \"" << cache_option << "\",\n" - << " cache: " - << (cache_params.mode == SD_CACHE_DISABLED ? "disabled" : - (cache_params.mode == SD_CACHE_EASYCACHE ? "easycache" : "ucache")) - << " (threshold=" << cache_params.reuse_threshold - << ", start=" << cache_params.start_percent - << ", end=" << cache_params.end_percent << "),\n" - << " moe_boundary: " << moe_boundary << ",\n" - << " video_frames: " << video_frames << ",\n" - << " fps: " << fps << ",\n" - << " vace_strength: " << vace_strength << ",\n" - << " strength: " << strength << ",\n" - << " control_strength: " << control_strength << ",\n" - << " seed: " << seed << ",\n" - << " upscale_repeats: " << upscale_repeats << ",\n" + oss << "SDCliParams {\n" + << " mode: " << modes_str[mode] << ",\n" + << " output_path: \"" << output_path << "\",\n" + << " verbose: " << (verbose ? "true" : "false") << ",\n" + << " color: " << (color ? "true" : "false") << ",\n" + << " canny_preprocess: " << (canny_preprocess ? "true" : "false") << ",\n" + << " preview_method: " << previews_str[preview_method] << ",\n" + << " preview_interval: " << preview_interval << ",\n" + << " preview_path: \"" << preview_path << "\",\n" + << " preview_fps: " << preview_fps << ",\n" + << " taesd_preview: " << (taesd_preview ? "true" : "false") << ",\n" + << " preview_noisy: " << (preview_noisy ? "true" : "false") << "\n" << "}"; - free(sample_params_str); - free(high_noise_sample_params_str); return oss.str(); } }; -static std::string version_string() { - return std::string("stable-diffusion.cpp version ") + sd_version() + ", commit " + sd_commit(); -} - void print_usage(int argc, const char* argv[], const std::vector& options_list) { std::cout << version_string() << "\n"; std::cout << "Usage: " << argv[0] << " [options]\n\n"; @@ -1906,7 +213,7 @@ void parse_args(int argc, const char** argv, SDCliParams& cli_params, SDContextP } std::string get_image_params(const SDCliParams& cli_params, const SDContextParams& ctx_params, const SDGenerationParams& gen_params, int64_t seed) { - std::string parameter_string = gen_params.prompt + "\n"; + std::string parameter_string = gen_params.prompt_with_lora + "\n"; if (gen_params.negative_prompt.size() != 0) { parameter_string += "Negative prompt: " + gen_params.negative_prompt + "\n"; } @@ -1932,7 +239,15 @@ std::string get_image_params(const SDCliParams& cli_params, const SDContextParam parameter_string += "Sampler RNG: " + std::string(sd_rng_type_name(ctx_params.sampler_rng_type)) + ", "; } parameter_string += "Sampler: " + std::string(sd_sample_method_name(gen_params.sample_params.sample_method)); - if (gen_params.sample_params.scheduler != SCHEDULER_COUNT) { + if (!gen_params.custom_sigmas.empty()) { + parameter_string += ", Custom Sigmas: ["; + for (size_t i = 0; i < gen_params.custom_sigmas.size(); ++i) { + std::ostringstream oss; + oss << std::fixed << std::setprecision(4) << gen_params.custom_sigmas[i]; + parameter_string += oss.str() + (i == gen_params.custom_sigmas.size() - 1 ? "" : ", "); + } + parameter_string += "]"; + } else if (gen_params.sample_params.scheduler != SCHEDULER_COUNT) { // Only show schedule if not using custom sigmas parameter_string += " " + std::string(sd_scheduler_name(gen_params.sample_params.scheduler)); } parameter_string += ", "; @@ -1959,94 +274,6 @@ void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) { log_print(level, log, cli_params->verbose, cli_params->color); } -uint8_t* load_image(const char* image_path, int& width, int& height, int expected_width = 0, int expected_height = 0, int expected_channel = 3) { - int c = 0; - uint8_t* image_buffer = (uint8_t*)stbi_load(image_path, &width, &height, &c, expected_channel); - if (image_buffer == nullptr) { - fprintf(stderr, "load image from '%s' failed\n", image_path); - return nullptr; - } - if (c < expected_channel) { - fprintf(stderr, - "the number of channels for the input image must be >= %d," - "but got %d channels, image_path = %s\n", - expected_channel, - c, - image_path); - free(image_buffer); - return nullptr; - } - if (width <= 0) { - fprintf(stderr, "error: the width of image must be greater than 0, image_path = %s\n", image_path); - free(image_buffer); - return nullptr; - } - if (height <= 0) { - fprintf(stderr, "error: the height of image must be greater than 0, image_path = %s\n", image_path); - free(image_buffer); - return nullptr; - } - - // Resize input image ... - if ((expected_width > 0 && expected_height > 0) && (height != expected_height || width != expected_width)) { - float dst_aspect = (float)expected_width / (float)expected_height; - float src_aspect = (float)width / (float)height; - - int crop_x = 0, crop_y = 0; - int crop_w = width, crop_h = height; - - if (src_aspect > dst_aspect) { - crop_w = (int)(height * dst_aspect); - crop_x = (width - crop_w) / 2; - } else if (src_aspect < dst_aspect) { - crop_h = (int)(width / dst_aspect); - crop_y = (height - crop_h) / 2; - } - - if (crop_x != 0 || crop_y != 0) { - printf("crop input image from %dx%d to %dx%d, image_path = %s\n", width, height, crop_w, crop_h, image_path); - uint8_t* cropped_image_buffer = (uint8_t*)malloc(crop_w * crop_h * expected_channel); - if (cropped_image_buffer == nullptr) { - fprintf(stderr, "error: allocate memory for crop\n"); - free(image_buffer); - return nullptr; - } - for (int row = 0; row < crop_h; row++) { - uint8_t* src = image_buffer + ((crop_y + row) * width + crop_x) * expected_channel; - uint8_t* dst = cropped_image_buffer + (row * crop_w) * expected_channel; - memcpy(dst, src, crop_w * expected_channel); - } - - width = crop_w; - height = crop_h; - free(image_buffer); - image_buffer = cropped_image_buffer; - } - - printf("resize input image from %dx%d to %dx%d\n", width, height, expected_width, expected_height); - int resized_height = expected_height; - int resized_width = expected_width; - - uint8_t* resized_image_buffer = (uint8_t*)malloc(resized_height * resized_width * expected_channel); - if (resized_image_buffer == nullptr) { - fprintf(stderr, "error: allocate memory for resize input image\n"); - free(image_buffer); - return nullptr; - } - stbir_resize(image_buffer, width, height, 0, - resized_image_buffer, resized_width, resized_height, 0, STBIR_TYPE_UINT8, - expected_channel, STBIR_ALPHA_CHANNEL_NONE, 0, - STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP, - STBIR_FILTER_BOX, STBIR_FILTER_BOX, - STBIR_COLORSPACE_SRGB, nullptr); - width = resized_width; - height = resized_height; - free(image_buffer); - image_buffer = resized_image_buffer; - } - return image_buffer; -} - bool load_images_from_dir(const std::string dir, std::vector& images, int expected_width = 0, @@ -2079,7 +306,7 @@ bool load_images_from_dir(const std::string dir, LOG_DEBUG("load image %zu from '%s'", images.size(), path.c_str()); int width = 0; int height = 0; - uint8_t* image_buffer = load_image(path.c_str(), width, height, expected_width, expected_height); + uint8_t* image_buffer = load_image_from_file(path.c_str(), width, height, expected_width, expected_height); if (image_buffer == nullptr) { LOG_ERROR("load image from '%s' failed", path.c_str()); return false; @@ -2212,7 +439,7 @@ int main(int argc, const char* argv[]) { int width = 0; int height = 0; - init_image.data = load_image(gen_params.init_image_path.c_str(), width, height, gen_params.width, gen_params.height); + init_image.data = load_image_from_file(gen_params.init_image_path.c_str(), width, height, gen_params.width, gen_params.height); if (init_image.data == nullptr) { LOG_ERROR("load image from '%s' failed", gen_params.init_image_path.c_str()); release_all_resources(); @@ -2225,7 +452,7 @@ int main(int argc, const char* argv[]) { int width = 0; int height = 0; - end_image.data = load_image(gen_params.end_image_path.c_str(), width, height, gen_params.width, gen_params.height); + end_image.data = load_image_from_file(gen_params.end_image_path.c_str(), width, height, gen_params.width, gen_params.height); if (end_image.data == nullptr) { LOG_ERROR("load image from '%s' failed", gen_params.end_image_path.c_str()); release_all_resources(); @@ -2237,7 +464,7 @@ int main(int argc, const char* argv[]) { int c = 0; int width = 0; int height = 0; - mask_image.data = load_image(gen_params.mask_image_path.c_str(), width, height, gen_params.width, gen_params.height, 1); + mask_image.data = load_image_from_file(gen_params.mask_image_path.c_str(), width, height, gen_params.width, gen_params.height, 1); if (mask_image.data == nullptr) { LOG_ERROR("load image from '%s' failed", gen_params.mask_image_path.c_str()); release_all_resources(); @@ -2256,7 +483,7 @@ int main(int argc, const char* argv[]) { if (gen_params.control_image_path.size() > 0) { int width = 0; int height = 0; - control_image.data = load_image(gen_params.control_image_path.c_str(), width, height, gen_params.width, gen_params.height); + control_image.data = load_image_from_file(gen_params.control_image_path.c_str(), width, height, gen_params.width, gen_params.height); if (control_image.data == nullptr) { LOG_ERROR("load image from '%s' failed", gen_params.control_image_path.c_str()); release_all_resources(); @@ -2277,7 +504,7 @@ int main(int argc, const char* argv[]) { for (auto& path : gen_params.ref_image_paths) { int width = 0; int height = 0; - uint8_t* image_buffer = load_image(path.c_str(), width, height); + uint8_t* image_buffer = load_image_from_file(path.c_str(), width, height); if (image_buffer == nullptr) { LOG_ERROR("load image from '%s' failed", path.c_str()); release_all_resources(); @@ -2428,7 +655,8 @@ int main(int argc, const char* argv[]) { upscaler_ctx_t* upscaler_ctx = new_upscaler_ctx(ctx_params.esrgan_path.c_str(), ctx_params.offload_params_to_cpu, ctx_params.diffusion_conv_direct, - ctx_params.n_threads); + ctx_params.n_threads, + gen_params.upscale_tile_size); if (upscaler_ctx == nullptr) { LOG_ERROR("new_upscaler_ctx failed"); @@ -2524,4 +752,4 @@ int main(int argc, const char* argv[]) { release_all_resources(); return 0; -} +} \ No newline at end of file diff --git a/examples/server/main.cpp b/examples/server/main.cpp index 39359fbbe..5c951c075 100644 --- a/examples/server/main.cpp +++ b/examples/server/main.cpp @@ -432,7 +432,7 @@ int main(int argc, const char** argv) { gen_params.pm_style_strength, }, // pm_params ctx_params.vae_tiling_params, - gen_params.easycache_params, + gen_params.cache_params, }; sd_image_t* results = nullptr; @@ -645,7 +645,7 @@ int main(int argc, const char** argv) { gen_params.pm_style_strength, }, // pm_params ctx_params.vae_tiling_params, - gen_params.easycache_params, + gen_params.cache_params, }; sd_image_t* results = nullptr; From e49d1ab20d272452834180c58a1498f57aa31342 Mon Sep 17 00:00:00 2001 From: rmatif Date: Tue, 16 Dec 2025 17:20:12 +0000 Subject: [PATCH 11/12] add documentation --- docs/caching.md | 126 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 126 insertions(+) create mode 100644 docs/caching.md diff --git a/docs/caching.md b/docs/caching.md new file mode 100644 index 000000000..7b4be3ce0 --- /dev/null +++ b/docs/caching.md @@ -0,0 +1,126 @@ +## Caching + +Caching methods accelerate diffusion inference by reusing intermediate computations when changes between steps are small. + +### Cache Modes + +| Mode | Target | Description | +|------|--------|-------------| +| `ucache` | UNET models | Condition-level caching with error tracking | +| `easycache` | DiT models | Condition-level cache | +| `dbcache` | DiT models | Block-level L1 residual threshold | +| `taylorseer` | DiT models | Taylor series approximation | +| `cache-dit` | DiT models | Combined DBCache + TaylorSeer | + +### UCache (UNET Models) + +UCache caches the residual difference (output - input) and reuses it when input changes are below threshold. + +```bash +sd-cli -m model.safetensors -p "a cat" --cache-mode ucache --cache-option "threshold=1.5" +``` + +#### Parameters + +| Parameter | Description | Default | +|-----------|-------------|---------| +| `threshold` | Error threshold for reuse decision | 1.0 | +| `start` | Start caching at this percent of steps | 0.15 | +| `end` | Stop caching at this percent of steps | 0.95 | +| `decay` | Error decay rate (0-1) | 1.0 | +| `relative` | Scale threshold by output norm (0/1) | 1 | +| `reset` | Reset error after computing (0/1) | 1 | + +#### Reset Parameter + +The `reset` parameter controls error accumulation behavior: + +- `reset=1` (default): Resets accumulated error after each computed step. More aggressive caching, works well with most samplers. +- `reset=0`: Keeps error accumulated. More conservative, recommended for `euler_a` sampler. + +### EasyCache (DiT Models) + +Condition-level caching for DiT models. Caches and reuses outputs when input changes are below threshold. + +```bash +--cache-mode easycache --cache-option "threshold=0.3" +``` + +#### Parameters + +| Parameter | Description | Default | +|-----------|-------------|---------| +| `threshold` | Input change threshold for reuse | 0.2 | +| `start` | Start caching at this percent of steps | 0.15 | +| `end` | Stop caching at this percent of steps | 0.95 | + +### Cache-DIT (DiT Models) + +For DiT models like FLUX and QWEN, use block-level caching modes. + +#### DBCache + +Caches blocks based on L1 residual difference threshold: + +```bash +--cache-mode dbcache --cache-option "threshold=0.25,warmup=4" +``` + +#### TaylorSeer + +Uses Taylor series approximation to predict block outputs: + +```bash +--cache-mode taylorseer +``` + +#### Cache-DIT (Combined) + +Combines DBCache and TaylorSeer: + +```bash +--cache-mode cache-dit --cache-preset fast +``` + +#### Parameters + +| Parameter | Description | Default | +|-----------|-------------|---------| +| `Fn` | Front blocks to always compute | 8 | +| `Bn` | Back blocks to always compute | 0 | +| `threshold` | L1 residual difference threshold | 0.08 | +| `warmup` | Steps before caching starts | 8 | + +#### Presets + +Available presets: `slow`, `medium`, `fast`, `ultra` (or `s`, `m`, `f`, `u`). + +```bash +--cache-mode cache-dit --cache-preset fast +``` + +#### SCM Options + +Steps Computation Mask controls which steps can be cached: + +```bash +--scm-mask "1,1,1,1,0,0,1,0,0,0,1,0,0,0,1,0,0,0,1,1" +``` + +Mask values: `1` = compute, `0` = can cache. + +| Policy | Description | +|--------|-------------| +| `dynamic` | Check threshold before caching | +| `static` | Always cache on cacheable steps | + +```bash +--scm-policy dynamic +``` + +### Performance Tips + +- Start with default thresholds and adjust based on output quality +- Lower threshold = better quality, less speedup +- Higher threshold = more speedup, potential quality loss +- More steps generally means more caching opportunities From d26f06afcc377494dfa93428ca48ba78795d14a8 Mon Sep 17 00:00:00 2001 From: rmatif Date: Tue, 16 Dec 2025 17:27:43 +0000 Subject: [PATCH 12/12] cleanup --- examples/cli/README.md | 6 +++--- examples/common/common.hpp | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/cli/README.md b/examples/cli/README.md index e2561615e..f73dfb20a 100644 --- a/examples/cli/README.md +++ b/examples/cli/README.md @@ -126,11 +126,11 @@ Generation Options: --skip-layers layers to skip for SLG steps (default: [7,8,9]) --high-noise-skip-layers (high noise) layers to skip for SLG steps (default: [7,8,9]) -r, --ref-image reference image for Flux Kontext models (can be used multiple times) - --cache-mode caching method: 'easycache'/'ucache' (legacy), 'dbcache'/'taylorseer'/'cache-dit' (DiT block-level) + --cache-mode caching method: 'easycache' (DiT), 'ucache' (UNET), 'dbcache'/'taylorseer'/'cache-dit' (DiT block-level) --cache-option named cache params (key=value format, comma-separated): - - easycache/ucache: threshold=,start=,end=,decay=,relative= + - easycache/ucache: threshold=,start=,end=,decay=,relative=,reset= - dbcache/taylorseer/cache-dit: Fn=,Bn=,threshold=,warmup= - Examples: "threshold=0.25" or "Fn=12,threshold=0.30,warmup=4" + Examples: "threshold=0.25" or "threshold=1.5,reset=0" --cache-preset cache-dit preset: 'slow'/'s', 'medium'/'m', 'fast'/'f', 'ultra'/'u' --scm-mask SCM steps mask: comma-separated 0/1 (1=compute, 0=can cache) --scm-policy SCM policy: 'dynamic' (default) or 'static' diff --git a/examples/common/common.hpp b/examples/common/common.hpp index 9ea14153b..50d000208 100644 --- a/examples/common/common.hpp +++ b/examples/common/common.hpp @@ -1461,7 +1461,7 @@ struct SDGenerationParams { on_ref_image_arg}, {"", "--cache-mode", - "caching method: 'easycache'/'ucache' (legacy), 'dbcache'/'taylorseer'/'cache-dit' (DiT block-level)", + "caching method: 'easycache' (DiT), 'ucache' (UNET), 'dbcache'/'taylorseer'/'cache-dit' (DiT block-level)", on_cache_mode_arg}, {"", "--cache-option",