diff --git a/cache_dit.hpp b/cache_dit.hpp new file mode 100644 index 000000000..84002da1c --- /dev/null +++ b/cache_dit.hpp @@ -0,0 +1,922 @@ +#ifndef __CACHE_DIT_HPP__ +#define __CACHE_DIT_HPP__ + +#include +#include +#include +#include +#include +#include + +#include "ggml_extend.hpp" + +struct DBCacheConfig { + bool enabled = false; + int Fn_compute_blocks = 8; + int Bn_compute_blocks = 0; + float residual_diff_threshold = 0.08f; + int max_warmup_steps = 8; + int max_cached_steps = -1; + int max_continuous_cached_steps = -1; + float max_accumulated_residual_diff = -1.0f; + std::vector steps_computation_mask; + bool scm_policy_dynamic = true; +}; + +struct TaylorSeerConfig { + bool enabled = false; + int n_derivatives = 1; + int max_warmup_steps = 2; + int skip_interval_steps = 1; +}; + +struct CacheDitConfig { + DBCacheConfig dbcache; + TaylorSeerConfig taylorseer; + int double_Fn_blocks = -1; + int double_Bn_blocks = -1; + int single_Fn_blocks = -1; + int single_Bn_blocks = -1; +}; + +struct TaylorSeerState { + int n_derivatives = 1; + int current_step = -1; + int last_computed_step = -1; + std::vector> dY_prev; + std::vector> dY_current; + + void init(int n_deriv, size_t hidden_size) { + n_derivatives = n_deriv; + int order = n_derivatives + 1; + dY_prev.resize(order); + dY_current.resize(order); + for (int i = 0; i < order; i++) { + dY_prev[i].clear(); + dY_current[i].clear(); + } + current_step = -1; + last_computed_step = -1; + } + + void reset() { + for (auto& v : dY_prev) v.clear(); + for (auto& v : dY_current) v.clear(); + current_step = -1; + last_computed_step = -1; + } + + bool can_approximate() const { + return last_computed_step >= n_derivatives && !dY_prev.empty() && !dY_prev[0].empty(); + } + + void update_derivatives(const float* Y, size_t size, int step) { + int order = n_derivatives + 1; + dY_prev = dY_current; + dY_current[0].resize(size); + for (size_t i = 0; i < size; i++) { + dY_current[0][i] = Y[i]; + } + + int window = step - last_computed_step; + if (window <= 0) window = 1; + + for (int d = 0; d < n_derivatives; d++) { + if (!dY_prev[d].empty() && dY_prev[d].size() == size) { + dY_current[d + 1].resize(size); + for (size_t i = 0; i < size; i++) { + dY_current[d + 1][i] = (dY_current[d][i] - dY_prev[d][i]) / static_cast(window); + } + } else { + dY_current[d + 1].clear(); + } + } + + current_step = step; + last_computed_step = step; + } + + void approximate(float* output, size_t size, int target_step) const { + if (!can_approximate() || dY_prev[0].size() != size) { + return; + } + + int elapsed = target_step - last_computed_step; + if (elapsed <= 0) elapsed = 1; + + std::fill(output, output + size, 0.0f); + float factorial = 1.0f; + int order = static_cast(dY_prev.size()); + + for (int o = 0; o < order; o++) { + if (dY_prev[o].empty() || dY_prev[o].size() != size) continue; + if (o > 0) factorial *= static_cast(o); + float coeff = std::pow(static_cast(elapsed), o) / factorial; + for (size_t i = 0; i < size; i++) { + output[i] += coeff * dY_prev[o][i]; + } + } + } +}; + +struct BlockCacheEntry { + std::vector residual_img; + std::vector residual_txt; + std::vector residual; + std::vector prev_img; + std::vector prev_txt; + std::vector prev_output; + bool has_prev = false; +}; + +struct CacheDitState { + CacheDitConfig config; + bool initialized = false; + + int total_double_blocks = 0; + int total_single_blocks = 0; + size_t hidden_size = 0; + + int current_step = -1; + int total_steps = 0; + int warmup_remaining = 0; + std::vector cached_steps; + int continuous_cached_steps = 0; + float accumulated_residual_diff = 0.0f; + + std::vector double_block_cache; + std::vector single_block_cache; + + std::vector Fn_residual_img; + std::vector Fn_residual_txt; + std::vector prev_Fn_residual_img; + std::vector prev_Fn_residual_txt; + bool has_prev_Fn_residual = false; + + std::vector Bn_buffer_img; + std::vector Bn_buffer_txt; + std::vector Bn_buffer; + bool has_Bn_buffer = false; + + TaylorSeerState taylor_state; + + bool can_cache_this_step = false; + bool is_caching_this_step = false; + + int total_blocks_computed = 0; + int total_blocks_cached = 0; + + void init(const CacheDitConfig& cfg, int num_double_blocks, int num_single_blocks, size_t h_size) { + config = cfg; + total_double_blocks = num_double_blocks; + total_single_blocks = num_single_blocks; + hidden_size = h_size; + + initialized = cfg.dbcache.enabled || cfg.taylorseer.enabled; + + if (!initialized) return; + + warmup_remaining = cfg.dbcache.max_warmup_steps; + double_block_cache.resize(total_double_blocks); + single_block_cache.resize(total_single_blocks); + + if (cfg.taylorseer.enabled) { + taylor_state.init(cfg.taylorseer.n_derivatives, h_size); + } + + reset_runtime(); + } + + void reset_runtime() { + current_step = -1; + total_steps = 0; + warmup_remaining = config.dbcache.max_warmup_steps; + cached_steps.clear(); + continuous_cached_steps = 0; + accumulated_residual_diff = 0.0f; + + for (auto& entry : double_block_cache) { + entry.residual_img.clear(); + entry.residual_txt.clear(); + entry.prev_img.clear(); + entry.prev_txt.clear(); + entry.has_prev = false; + } + + for (auto& entry : single_block_cache) { + entry.residual.clear(); + entry.prev_output.clear(); + entry.has_prev = false; + } + + Fn_residual_img.clear(); + Fn_residual_txt.clear(); + prev_Fn_residual_img.clear(); + prev_Fn_residual_txt.clear(); + has_prev_Fn_residual = false; + + Bn_buffer_img.clear(); + Bn_buffer_txt.clear(); + Bn_buffer.clear(); + has_Bn_buffer = false; + + taylor_state.reset(); + + can_cache_this_step = false; + is_caching_this_step = false; + + total_blocks_computed = 0; + total_blocks_cached = 0; + } + + bool enabled() const { + return initialized && (config.dbcache.enabled || config.taylorseer.enabled); + } + + void begin_step(int step_index, float sigma = 0.0f) { + if (!enabled()) return; + if (step_index == current_step) return; + + current_step = step_index; + total_steps++; + + bool in_warmup = warmup_remaining > 0; + if (in_warmup) { + warmup_remaining--; + } + + bool scm_allows_cache = true; + if (!config.dbcache.steps_computation_mask.empty()) { + if (step_index < static_cast(config.dbcache.steps_computation_mask.size())) { + scm_allows_cache = (config.dbcache.steps_computation_mask[step_index] == 0); + if (!config.dbcache.scm_policy_dynamic && scm_allows_cache) { + can_cache_this_step = true; + is_caching_this_step = false; + return; + } + } + } + + bool max_cached_ok = (config.dbcache.max_cached_steps < 0) || + (static_cast(cached_steps.size()) < config.dbcache.max_cached_steps); + + bool max_cont_ok = (config.dbcache.max_continuous_cached_steps < 0) || + (continuous_cached_steps < config.dbcache.max_continuous_cached_steps); + + bool accum_ok = (config.dbcache.max_accumulated_residual_diff < 0.0f) || + (accumulated_residual_diff < config.dbcache.max_accumulated_residual_diff); + + can_cache_this_step = !in_warmup && scm_allows_cache && max_cached_ok && max_cont_ok && accum_ok && has_prev_Fn_residual; + is_caching_this_step = false; + } + + void end_step(bool was_cached) { + if (was_cached) { + cached_steps.push_back(current_step); + continuous_cached_steps++; + } else { + continuous_cached_steps = 0; + } + } + + static float calculate_residual_diff(const float* prev, const float* curr, size_t size) { + if (size == 0) return 0.0f; + + float sum_diff = 0.0f; + float sum_abs = 0.0f; + + for (size_t i = 0; i < size; i++) { + sum_diff += std::fabs(prev[i] - curr[i]); + sum_abs += std::fabs(prev[i]); + } + + return sum_diff / (sum_abs + 1e-6f); + } + + static float calculate_residual_diff(const std::vector& prev, const std::vector& curr) { + if (prev.size() != curr.size() || prev.empty()) return 1.0f; + return calculate_residual_diff(prev.data(), curr.data(), prev.size()); + } + + int get_double_Fn_blocks() const { + return (config.double_Fn_blocks >= 0) ? config.double_Fn_blocks : config.dbcache.Fn_compute_blocks; + } + + int get_double_Bn_blocks() const { + return (config.double_Bn_blocks >= 0) ? config.double_Bn_blocks : config.dbcache.Bn_compute_blocks; + } + + int get_single_Fn_blocks() const { + return (config.single_Fn_blocks >= 0) ? config.single_Fn_blocks : config.dbcache.Fn_compute_blocks; + } + + int get_single_Bn_blocks() const { + return (config.single_Bn_blocks >= 0) ? config.single_Bn_blocks : config.dbcache.Bn_compute_blocks; + } + + bool is_Fn_double_block(int block_idx) const { + return block_idx < get_double_Fn_blocks(); + } + + bool is_Bn_double_block(int block_idx) const { + int Bn = get_double_Bn_blocks(); + return Bn > 0 && block_idx >= (total_double_blocks - Bn); + } + + bool is_Mn_double_block(int block_idx) const { + return !is_Fn_double_block(block_idx) && !is_Bn_double_block(block_idx); + } + + bool is_Fn_single_block(int block_idx) const { + return block_idx < get_single_Fn_blocks(); + } + + bool is_Bn_single_block(int block_idx) const { + int Bn = get_single_Bn_blocks(); + return Bn > 0 && block_idx >= (total_single_blocks - Bn); + } + + bool is_Mn_single_block(int block_idx) const { + return !is_Fn_single_block(block_idx) && !is_Bn_single_block(block_idx); + } + + void store_Fn_residual(const float* img, const float* txt, size_t img_size, size_t txt_size, + const float* input_img, const float* input_txt) { + Fn_residual_img.resize(img_size); + Fn_residual_txt.resize(txt_size); + + for (size_t i = 0; i < img_size; i++) { + Fn_residual_img[i] = img[i] - input_img[i]; + } + for (size_t i = 0; i < txt_size; i++) { + Fn_residual_txt[i] = txt[i] - input_txt[i]; + } + } + + bool check_cache_decision() { + if (!can_cache_this_step) { + is_caching_this_step = false; + return false; + } + + if (!has_prev_Fn_residual || prev_Fn_residual_img.empty()) { + is_caching_this_step = false; + return false; + } + + float diff_img = calculate_residual_diff(prev_Fn_residual_img, Fn_residual_img); + float diff_txt = calculate_residual_diff(prev_Fn_residual_txt, Fn_residual_txt); + float diff = (diff_img + diff_txt) / 2.0f; + + if (diff < config.dbcache.residual_diff_threshold) { + is_caching_this_step = true; + accumulated_residual_diff += diff; + return true; + } + + is_caching_this_step = false; + return false; + } + + void update_prev_Fn_residual() { + prev_Fn_residual_img = Fn_residual_img; + prev_Fn_residual_txt = Fn_residual_txt; + has_prev_Fn_residual = !prev_Fn_residual_img.empty(); + } + + void store_double_block_residual(int block_idx, const float* img, const float* txt, + size_t img_size, size_t txt_size, + const float* prev_img, const float* prev_txt) { + if (block_idx < 0 || block_idx >= static_cast(double_block_cache.size())) return; + + BlockCacheEntry& entry = double_block_cache[block_idx]; + + entry.residual_img.resize(img_size); + entry.residual_txt.resize(txt_size); + for (size_t i = 0; i < img_size; i++) { + entry.residual_img[i] = img[i] - prev_img[i]; + } + for (size_t i = 0; i < txt_size; i++) { + entry.residual_txt[i] = txt[i] - prev_txt[i]; + } + + entry.prev_img.resize(img_size); + entry.prev_txt.resize(txt_size); + for (size_t i = 0; i < img_size; i++) { + entry.prev_img[i] = img[i]; + } + for (size_t i = 0; i < txt_size; i++) { + entry.prev_txt[i] = txt[i]; + } + entry.has_prev = true; + } + + void apply_double_block_cache(int block_idx, float* img, float* txt, + size_t img_size, size_t txt_size) { + if (block_idx < 0 || block_idx >= static_cast(double_block_cache.size())) return; + + const BlockCacheEntry& entry = double_block_cache[block_idx]; + if (entry.residual_img.size() != img_size || entry.residual_txt.size() != txt_size) return; + + for (size_t i = 0; i < img_size; i++) { + img[i] += entry.residual_img[i]; + } + for (size_t i = 0; i < txt_size; i++) { + txt[i] += entry.residual_txt[i]; + } + + total_blocks_cached++; + } + + void store_single_block_residual(int block_idx, const float* output, size_t size, const float* input) { + if (block_idx < 0 || block_idx >= static_cast(single_block_cache.size())) return; + + BlockCacheEntry& entry = single_block_cache[block_idx]; + + entry.residual.resize(size); + for (size_t i = 0; i < size; i++) { + entry.residual[i] = output[i] - input[i]; + } + + entry.prev_output.resize(size); + for (size_t i = 0; i < size; i++) { + entry.prev_output[i] = output[i]; + } + entry.has_prev = true; + } + + void apply_single_block_cache(int block_idx, float* output, size_t size) { + if (block_idx < 0 || block_idx >= static_cast(single_block_cache.size())) return; + + const BlockCacheEntry& entry = single_block_cache[block_idx]; + if (entry.residual.size() != size) return; + + for (size_t i = 0; i < size; i++) { + output[i] += entry.residual[i]; + } + + total_blocks_cached++; + } + + void store_Bn_buffer(const float* img, const float* txt, size_t img_size, size_t txt_size, + const float* Bn_start_img, const float* Bn_start_txt) { + Bn_buffer_img.resize(img_size); + Bn_buffer_txt.resize(txt_size); + + for (size_t i = 0; i < img_size; i++) { + Bn_buffer_img[i] = img[i] - Bn_start_img[i]; + } + for (size_t i = 0; i < txt_size; i++) { + Bn_buffer_txt[i] = txt[i] - Bn_start_txt[i]; + } + has_Bn_buffer = true; + } + + void apply_Bn_buffer(float* img, float* txt, size_t img_size, size_t txt_size) { + if (!has_Bn_buffer) return; + if (Bn_buffer_img.size() != img_size || Bn_buffer_txt.size() != txt_size) return; + + for (size_t i = 0; i < img_size; i++) { + img[i] += Bn_buffer_img[i]; + } + for (size_t i = 0; i < txt_size; i++) { + txt[i] += Bn_buffer_txt[i]; + } + } + + void taylor_update(const float* hidden_state, size_t size) { + if (!config.taylorseer.enabled) return; + taylor_state.update_derivatives(hidden_state, size, current_step); + } + + bool taylor_can_approximate() const { + return config.taylorseer.enabled && taylor_state.can_approximate(); + } + + void taylor_approximate(float* output, size_t size) { + if (!config.taylorseer.enabled) return; + taylor_state.approximate(output, size, current_step); + } + + bool should_use_taylor_this_step() const { + if (!config.taylorseer.enabled) return false; + if (current_step < config.taylorseer.max_warmup_steps) return false; + + int interval = config.taylorseer.skip_interval_steps; + if (interval <= 0) interval = 1; + + return (current_step % (interval + 1)) != 0; + } + + void log_metrics() const { + if (!enabled()) return; + + int total_blocks = total_blocks_computed + total_blocks_cached; + float cache_ratio = (total_blocks > 0) ? + (static_cast(total_blocks_cached) / total_blocks * 100.0f) : 0.0f; + + float step_cache_ratio = (total_steps > 0) ? + (static_cast(cached_steps.size()) / total_steps * 100.0f) : 0.0f; + + LOG_INFO("CacheDIT: steps_cached=%zu/%d (%.1f%%), blocks_cached=%d/%d (%.1f%%), accum_diff=%.4f", + cached_steps.size(), total_steps, step_cache_ratio, + total_blocks_cached, total_blocks, cache_ratio, + accumulated_residual_diff); + } + + std::string get_summary() const { + char buf[256]; + snprintf(buf, sizeof(buf), + "CacheDIT[thresh=%.2f]: cached %zu/%d steps, %d/%d blocks", + config.dbcache.residual_diff_threshold, + cached_steps.size(), total_steps, + total_blocks_cached, total_blocks_computed + total_blocks_cached); + return std::string(buf); + } +}; + +inline std::vector parse_scm_mask(const std::string& mask_str) { + std::vector mask; + if (mask_str.empty()) return mask; + + size_t pos = 0; + size_t start = 0; + while ((pos = mask_str.find(',', start)) != std::string::npos) { + std::string token = mask_str.substr(start, pos - start); + mask.push_back(std::stoi(token)); + start = pos + 1; + } + if (start < mask_str.length()) { + mask.push_back(std::stoi(mask_str.substr(start))); + } + + return mask; +} + +inline std::vector generate_scm_mask( + const std::vector& compute_bins, + const std::vector& cache_bins, + int total_steps +) { + std::vector mask; + size_t c_idx = 0, cache_idx = 0; + + while (static_cast(mask.size()) < total_steps) { + if (c_idx < compute_bins.size()) { + for (int i = 0; i < compute_bins[c_idx] && static_cast(mask.size()) < total_steps; i++) { + mask.push_back(1); + } + c_idx++; + } + if (cache_idx < cache_bins.size()) { + for (int i = 0; i < cache_bins[cache_idx] && static_cast(mask.size()) < total_steps; i++) { + mask.push_back(0); + } + cache_idx++; + } + if (c_idx >= compute_bins.size() && cache_idx >= cache_bins.size()) break; + } + + if (!mask.empty()) { + mask.back() = 1; + } + + return mask; +} + +inline std::vector get_scm_preset(const std::string& preset, int total_steps) { + struct Preset { + std::vector compute_bins; + std::vector cache_bins; + }; + + Preset slow = {{8, 3, 3, 2, 1, 1}, {1, 2, 2, 2, 3}}; + Preset medium = {{6, 2, 2, 2, 2, 1}, {1, 3, 3, 3, 3}}; + Preset fast = {{6, 1, 1, 1, 1, 1}, {1, 3, 4, 5, 4}}; + Preset ultra = {{4, 1, 1, 1, 1}, {2, 5, 6, 7}}; + + Preset* p = nullptr; + if (preset == "slow" || preset == "s" || preset == "S") p = &slow; + else if (preset == "medium" || preset == "m" || preset == "M") p = &medium; + else if (preset == "fast" || preset == "f" || preset == "F") p = &fast; + else if (preset == "ultra" || preset == "u" || preset == "U") p = &ultra; + else return {}; + + if (total_steps != 28 && total_steps > 0) { + float scale = static_cast(total_steps) / 28.0f; + std::vector scaled_compute, scaled_cache; + + for (int v : p->compute_bins) { + scaled_compute.push_back(std::max(1, static_cast(v * scale + 0.5f))); + } + for (int v : p->cache_bins) { + scaled_cache.push_back(std::max(1, static_cast(v * scale + 0.5f))); + } + + return generate_scm_mask(scaled_compute, scaled_cache, total_steps); + } + + return generate_scm_mask(p->compute_bins, p->cache_bins, total_steps); +} + +inline float get_preset_threshold(const std::string& preset) { + if (preset == "slow" || preset == "s" || preset == "S") return 0.20f; + if (preset == "medium" || preset == "m" || preset == "M") return 0.25f; + if (preset == "fast" || preset == "f" || preset == "F") return 0.30f; + if (preset == "ultra" || preset == "u" || preset == "U") return 0.34f; + return 0.08f; +} + +inline int get_preset_warmup(const std::string& preset) { + if (preset == "slow" || preset == "s" || preset == "S") return 8; + if (preset == "medium" || preset == "m" || preset == "M") return 6; + if (preset == "fast" || preset == "f" || preset == "F") return 6; + if (preset == "ultra" || preset == "u" || preset == "U") return 4; + return 8; +} + +inline int get_preset_Fn(const std::string& preset) { + if (preset == "slow" || preset == "s" || preset == "S") return 8; + if (preset == "medium" || preset == "m" || preset == "M") return 8; + if (preset == "fast" || preset == "f" || preset == "F") return 6; + if (preset == "ultra" || preset == "u" || preset == "U") return 4; + return 8; +} + +inline int get_preset_Bn(const std::string& preset) { + (void)preset; + return 0; +} + +inline void parse_dbcache_options(const std::string& opts, DBCacheConfig& cfg) { + if (opts.empty()) return; + + int Fn = 8, Bn = 0, warmup = 8, max_cached = -1, max_cont = -1; + float thresh = 0.08f; + + sscanf(opts.c_str(), "%d,%d,%f,%d,%d,%d", + &Fn, &Bn, &thresh, &warmup, &max_cached, &max_cont); + + cfg.Fn_compute_blocks = Fn; + cfg.Bn_compute_blocks = Bn; + cfg.residual_diff_threshold = thresh; + cfg.max_warmup_steps = warmup; + cfg.max_cached_steps = max_cached; + cfg.max_continuous_cached_steps = max_cont; +} + +inline void parse_taylorseer_options(const std::string& opts, TaylorSeerConfig& cfg) { + if (opts.empty()) return; + + int n_deriv = 1, warmup = 2, interval = 1; + sscanf(opts.c_str(), "%d,%d,%d", &n_deriv, &warmup, &interval); + + cfg.n_derivatives = n_deriv; + cfg.max_warmup_steps = warmup; + cfg.skip_interval_steps = interval; +} + +struct CacheDitConditionState { + DBCacheConfig config; + TaylorSeerConfig taylor_config; + bool initialized = false; + + int current_step_index = -1; + bool step_active = false; + bool skip_current_step = false; + bool initial_step = true; + int warmup_remaining = 0; + std::vector cached_steps; + int continuous_cached_steps = 0; + float accumulated_residual_diff = 0.0f; + int total_steps_skipped = 0; + + const void* anchor_condition = nullptr; + + struct CacheEntry { + std::vector diff; + std::vector prev_input; + std::vector prev_output; + bool has_prev = false; + }; + std::unordered_map cache_diffs; + + TaylorSeerState taylor_state; + + float start_sigma = std::numeric_limits::max(); + float end_sigma = 0.0f; + + void reset_runtime() { + current_step_index = -1; + step_active = false; + skip_current_step = false; + initial_step = true; + warmup_remaining = config.max_warmup_steps; + cached_steps.clear(); + continuous_cached_steps = 0; + accumulated_residual_diff = 0.0f; + total_steps_skipped = 0; + anchor_condition = nullptr; + cache_diffs.clear(); + taylor_state.reset(); + } + + void init(const DBCacheConfig& dbcfg, const TaylorSeerConfig& tcfg) { + config = dbcfg; + taylor_config = tcfg; + initialized = dbcfg.enabled || tcfg.enabled; + reset_runtime(); + + if (taylor_config.enabled) { + taylor_state.init(taylor_config.n_derivatives, 0); + } + } + + void set_sigmas(const std::vector& sigmas) { + if (!initialized || sigmas.size() < 2) return; + + float start_percent = 0.15f; + float end_percent = 0.95f; + + size_t n_steps = sigmas.size() - 1; + size_t start_step = static_cast(start_percent * n_steps); + size_t end_step = static_cast(end_percent * n_steps); + + if (start_step >= n_steps) start_step = n_steps - 1; + if (end_step >= n_steps) end_step = n_steps - 1; + + start_sigma = sigmas[start_step]; + end_sigma = sigmas[end_step]; + + if (start_sigma < end_sigma) { + std::swap(start_sigma, end_sigma); + } + } + + bool enabled() const { + return initialized && (config.enabled || taylor_config.enabled); + } + + void begin_step(int step_index, float sigma) { + if (!enabled()) return; + if (step_index == current_step_index) return; + + current_step_index = step_index; + skip_current_step = false; + step_active = false; + + if (sigma > start_sigma) return; + if (!(sigma > end_sigma)) return; + + step_active = true; + + if (warmup_remaining > 0) { + warmup_remaining--; + return; + } + + if (!config.steps_computation_mask.empty()) { + if (step_index < static_cast(config.steps_computation_mask.size())) { + if (config.steps_computation_mask[step_index] == 1) { + return; + } + } + } + + if (config.max_cached_steps >= 0 && + static_cast(cached_steps.size()) >= config.max_cached_steps) { + return; + } + + if (config.max_continuous_cached_steps >= 0 && + continuous_cached_steps >= config.max_continuous_cached_steps) { + return; + } + } + + bool step_is_active() const { + return enabled() && step_active; + } + + bool is_step_skipped() const { + return enabled() && step_active && skip_current_step; + } + + bool has_cache(const void* cond) const { + auto it = cache_diffs.find(cond); + return it != cache_diffs.end() && !it->second.diff.empty(); + } + + void update_cache(const void* cond, const float* input, const float* output, size_t size) { + CacheEntry& entry = cache_diffs[cond]; + entry.diff.resize(size); + for (size_t i = 0; i < size; i++) { + entry.diff[i] = output[i] - input[i]; + } + + entry.prev_input.resize(size); + entry.prev_output.resize(size); + for (size_t i = 0; i < size; i++) { + entry.prev_input[i] = input[i]; + entry.prev_output[i] = output[i]; + } + entry.has_prev = true; + } + + void apply_cache(const void* cond, const float* input, float* output, size_t size) { + auto it = cache_diffs.find(cond); + if (it == cache_diffs.end() || it->second.diff.empty()) return; + if (it->second.diff.size() != size) return; + + for (size_t i = 0; i < size; i++) { + output[i] = input[i] + it->second.diff[i]; + } + } + + bool before_condition(const void* cond, struct ggml_tensor* input, struct ggml_tensor* output, float sigma, int step_index) { + if (!enabled() || step_index < 0) return false; + + if (step_index != current_step_index) { + begin_step(step_index, sigma); + } + + if (!step_active) return false; + + if (initial_step) { + anchor_condition = cond; + initial_step = false; + } + + bool is_anchor = (cond == anchor_condition); + + if (skip_current_step) { + if (has_cache(cond)) { + apply_cache(cond, (float*)input->data, (float*)output->data, + static_cast(ggml_nelements(output))); + return true; + } + return false; + } + + if (!is_anchor) return false; + + auto it = cache_diffs.find(cond); + if (it == cache_diffs.end() || !it->second.has_prev) return false; + + size_t ne = static_cast(ggml_nelements(input)); + if (it->second.prev_input.size() != ne) return false; + + float* input_data = (float*)input->data; + float diff = CacheDitState::calculate_residual_diff( + it->second.prev_input.data(), input_data, ne); + + float effective_threshold = config.residual_diff_threshold; + if (config.Fn_compute_blocks > 0) { + float fn_confidence = 1.0f + 0.02f * (config.Fn_compute_blocks - 8); + fn_confidence = std::max(0.5f, std::min(2.0f, fn_confidence)); + effective_threshold *= fn_confidence; + } + if (config.Bn_compute_blocks > 0) { + float bn_quality = 1.0f - 0.03f * config.Bn_compute_blocks; + bn_quality = std::max(0.5f, std::min(1.0f, bn_quality)); + effective_threshold *= bn_quality; + } + + if (diff < effective_threshold) { + skip_current_step = true; + total_steps_skipped++; + cached_steps.push_back(current_step_index); + continuous_cached_steps++; + accumulated_residual_diff += diff; + apply_cache(cond, input_data, (float*)output->data, ne); + return true; + } + + continuous_cached_steps = 0; + return false; + } + + void after_condition(const void* cond, struct ggml_tensor* input, struct ggml_tensor* output) { + if (!step_is_active()) return; + + size_t ne = static_cast(ggml_nelements(output)); + update_cache(cond, (float*)input->data, (float*)output->data, ne); + + if (cond == anchor_condition && taylor_config.enabled) { + taylor_state.update_derivatives((float*)output->data, ne, current_step_index); + } + } + + void log_metrics() const { + if (!enabled()) return; + + LOG_INFO("CacheDIT: steps_skipped=%d/%d (%.1f%%), accum_residual_diff=%.4f", + total_steps_skipped, + current_step_index + 1, + (current_step_index > 0) ? + (100.0f * total_steps_skipped / (current_step_index + 1)) : 0.0f, + accumulated_residual_diff); + } +}; + +#endif diff --git a/docs/caching.md b/docs/caching.md new file mode 100644 index 000000000..7b4be3ce0 --- /dev/null +++ b/docs/caching.md @@ -0,0 +1,126 @@ +## Caching + +Caching methods accelerate diffusion inference by reusing intermediate computations when changes between steps are small. + +### Cache Modes + +| Mode | Target | Description | +|------|--------|-------------| +| `ucache` | UNET models | Condition-level caching with error tracking | +| `easycache` | DiT models | Condition-level cache | +| `dbcache` | DiT models | Block-level L1 residual threshold | +| `taylorseer` | DiT models | Taylor series approximation | +| `cache-dit` | DiT models | Combined DBCache + TaylorSeer | + +### UCache (UNET Models) + +UCache caches the residual difference (output - input) and reuses it when input changes are below threshold. + +```bash +sd-cli -m model.safetensors -p "a cat" --cache-mode ucache --cache-option "threshold=1.5" +``` + +#### Parameters + +| Parameter | Description | Default | +|-----------|-------------|---------| +| `threshold` | Error threshold for reuse decision | 1.0 | +| `start` | Start caching at this percent of steps | 0.15 | +| `end` | Stop caching at this percent of steps | 0.95 | +| `decay` | Error decay rate (0-1) | 1.0 | +| `relative` | Scale threshold by output norm (0/1) | 1 | +| `reset` | Reset error after computing (0/1) | 1 | + +#### Reset Parameter + +The `reset` parameter controls error accumulation behavior: + +- `reset=1` (default): Resets accumulated error after each computed step. More aggressive caching, works well with most samplers. +- `reset=0`: Keeps error accumulated. More conservative, recommended for `euler_a` sampler. + +### EasyCache (DiT Models) + +Condition-level caching for DiT models. Caches and reuses outputs when input changes are below threshold. + +```bash +--cache-mode easycache --cache-option "threshold=0.3" +``` + +#### Parameters + +| Parameter | Description | Default | +|-----------|-------------|---------| +| `threshold` | Input change threshold for reuse | 0.2 | +| `start` | Start caching at this percent of steps | 0.15 | +| `end` | Stop caching at this percent of steps | 0.95 | + +### Cache-DIT (DiT Models) + +For DiT models like FLUX and QWEN, use block-level caching modes. + +#### DBCache + +Caches blocks based on L1 residual difference threshold: + +```bash +--cache-mode dbcache --cache-option "threshold=0.25,warmup=4" +``` + +#### TaylorSeer + +Uses Taylor series approximation to predict block outputs: + +```bash +--cache-mode taylorseer +``` + +#### Cache-DIT (Combined) + +Combines DBCache and TaylorSeer: + +```bash +--cache-mode cache-dit --cache-preset fast +``` + +#### Parameters + +| Parameter | Description | Default | +|-----------|-------------|---------| +| `Fn` | Front blocks to always compute | 8 | +| `Bn` | Back blocks to always compute | 0 | +| `threshold` | L1 residual difference threshold | 0.08 | +| `warmup` | Steps before caching starts | 8 | + +#### Presets + +Available presets: `slow`, `medium`, `fast`, `ultra` (or `s`, `m`, `f`, `u`). + +```bash +--cache-mode cache-dit --cache-preset fast +``` + +#### SCM Options + +Steps Computation Mask controls which steps can be cached: + +```bash +--scm-mask "1,1,1,1,0,0,1,0,0,0,1,0,0,0,1,0,0,0,1,1" +``` + +Mask values: `1` = compute, `0` = can cache. + +| Policy | Description | +|--------|-------------| +| `dynamic` | Check threshold before caching | +| `static` | Always cache on cacheable steps | + +```bash +--scm-policy dynamic +``` + +### Performance Tips + +- Start with default thresholds and adjust based on output quality +- Lower threshold = better quality, less speedup +- Higher threshold = more speedup, potential quality loss +- More steps generally means more caching opportunities diff --git a/examples/cli/README.md b/examples/cli/README.md index 8531b2aed..f73dfb20a 100644 --- a/examples/cli/README.md +++ b/examples/cli/README.md @@ -126,5 +126,12 @@ Generation Options: --skip-layers layers to skip for SLG steps (default: [7,8,9]) --high-noise-skip-layers (high noise) layers to skip for SLG steps (default: [7,8,9]) -r, --ref-image reference image for Flux Kontext models (can be used multiple times) - --easycache enable EasyCache for DiT models with optional "threshold,start_percent,end_percent" (default: 0.2,0.15,0.95) + --cache-mode caching method: 'easycache' (DiT), 'ucache' (UNET), 'dbcache'/'taylorseer'/'cache-dit' (DiT block-level) + --cache-option named cache params (key=value format, comma-separated): + - easycache/ucache: threshold=,start=,end=,decay=,relative=,reset= + - dbcache/taylorseer/cache-dit: Fn=,Bn=,threshold=,warmup= + Examples: "threshold=0.25" or "threshold=1.5,reset=0" + --cache-preset cache-dit preset: 'slow'/'s', 'medium'/'m', 'fast'/'f', 'ultra'/'u' + --scm-mask SCM steps mask: comma-separated 0/1 (1=compute, 0=can cache) + --scm-policy SCM policy: 'dynamic' (default) or 'static' ``` diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 42b909e4f..889cabc5d 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -610,7 +610,7 @@ int main(int argc, const char* argv[]) { gen_params.pm_style_strength, }, // pm_params ctx_params.vae_tiling_params, - gen_params.easycache_params, + gen_params.cache_params, }; results = generate_image(sd_ctx, &img_gen_params); @@ -635,7 +635,7 @@ int main(int argc, const char* argv[]) { gen_params.seed, gen_params.video_frames, gen_params.vace_strength, - gen_params.easycache_params, + gen_params.cache_params, }; results = generate_video(sd_ctx, &vid_gen_params, &num_results); diff --git a/examples/common/common.hpp b/examples/common/common.hpp index f3a561367..50d000208 100644 --- a/examples/common/common.hpp +++ b/examples/common/common.hpp @@ -997,8 +997,12 @@ struct SDGenerationParams { std::vector custom_sigmas; - std::string easycache_option; - sd_easycache_params_t easycache_params; + std::string cache_mode; + std::string cache_option; + std::string cache_preset; + std::string scm_mask; + bool scm_policy_dynamic = true; + sd_cache_params_t cache_params{}; float moe_boundary = 0.875f; int video_frames = 1; @@ -1360,36 +1364,64 @@ struct SDGenerationParams { return 1; }; - auto on_easycache_arg = [&](int argc, const char** argv, int index) { - const std::string default_values = "0.2,0.15,0.95"; - auto looks_like_value = [](const std::string& token) { - if (token.empty()) { - return false; - } - if (token[0] != '-') { - return true; - } - if (token.size() == 1) { - return false; - } - unsigned char next = static_cast(token[1]); - return std::isdigit(next) || token[1] == '.'; - }; - - std::string option_value; - int consumed = 0; - if (index + 1 < argc) { - std::string next_arg = argv[index + 1]; - if (looks_like_value(next_arg)) { - option_value = argv_to_utf8(index + 1, argv); - consumed = 1; - } + auto on_cache_mode_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; } - if (option_value.empty()) { - option_value = default_values; + cache_mode = argv_to_utf8(index, argv); + if (cache_mode != "easycache" && cache_mode != "ucache" && + cache_mode != "dbcache" && cache_mode != "taylorseer" && cache_mode != "cache-dit") { + fprintf(stderr, "error: invalid cache mode '%s', must be 'easycache', 'ucache', 'dbcache', 'taylorseer', or 'cache-dit'\n", cache_mode.c_str()); + return -1; } - easycache_option = option_value; - return consumed; + return 1; + }; + + auto on_cache_option_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; + } + cache_option = argv_to_utf8(index, argv); + return 1; + }; + + auto on_scm_mask_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; + } + scm_mask = argv_to_utf8(index, argv); + return 1; + }; + + auto on_scm_policy_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; + } + std::string policy = argv_to_utf8(index, argv); + if (policy == "dynamic") { + scm_policy_dynamic = true; + } else if (policy == "static") { + scm_policy_dynamic = false; + } else { + fprintf(stderr, "error: invalid scm policy '%s', must be 'dynamic' or 'static'\n", policy.c_str()); + return -1; + } + return 1; + }; + + auto on_cache_preset_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; + } + cache_preset = argv_to_utf8(index, argv); + if (cache_preset != "slow" && cache_preset != "s" && cache_preset != "S" && + cache_preset != "medium" && cache_preset != "m" && cache_preset != "M" && + cache_preset != "fast" && cache_preset != "f" && cache_preset != "F" && + cache_preset != "ultra" && cache_preset != "u" && cache_preset != "U") { + fprintf(stderr, "error: invalid cache preset '%s', must be 'slow'/'s', 'medium'/'m', 'fast'/'f', or 'ultra'/'u'\n", cache_preset.c_str()); + return -1; + } + return 1; }; options.manual_options = { @@ -1428,9 +1460,25 @@ struct SDGenerationParams { "reference image for Flux Kontext models (can be used multiple times)", on_ref_image_arg}, {"", - "--easycache", - "enable EasyCache for DiT models with optional \"threshold,start_percent,end_percent\" (default: 0.2,0.15,0.95)", - on_easycache_arg}, + "--cache-mode", + "caching method: 'easycache' (DiT), 'ucache' (UNET), 'dbcache'/'taylorseer'/'cache-dit' (DiT block-level)", + on_cache_mode_arg}, + {"", + "--cache-option", + "named cache params (key=value format, comma-separated):\n - easycache/ucache: threshold=,start=,end=,decay=,relative=,reset=\n - dbcache/taylorseer/cache-dit: Fn=,Bn=,threshold=,warmup=\n Examples: \"threshold=0.25\" or \"threshold=1.5,reset=0\"", + on_cache_option_arg}, + {"", + "--cache-preset", + "cache-dit preset: 'slow'/'s', 'medium'/'m', 'fast'/'f', 'ultra'/'u'", + on_cache_preset_arg}, + {"", + "--scm-mask", + "SCM steps mask for cache-dit: comma-separated 0/1 (e.g., \"1,1,1,0,0,1,0,0,1,0\") - 1=compute, 0=can cache", + on_scm_mask_arg}, + {"", + "--scm-policy", + "SCM policy: 'dynamic' (default) or 'static'", + on_scm_policy_arg}, }; @@ -1473,7 +1521,10 @@ struct SDGenerationParams { load_if_exists("prompt", prompt); load_if_exists("negative_prompt", negative_prompt); - load_if_exists("easycache_option", easycache_option); + load_if_exists("cache_mode", cache_mode); + load_if_exists("cache_option", cache_option); + load_if_exists("cache_preset", cache_preset); + load_if_exists("scm_mask", scm_mask); load_if_exists("clip_skip", clip_skip); load_if_exists("width", width); @@ -1613,57 +1664,118 @@ struct SDGenerationParams { return false; } - if (!easycache_option.empty()) { - float values[3] = {0.0f, 0.0f, 0.0f}; - std::stringstream ss(easycache_option); + sd_cache_params_init(&cache_params); + + auto parse_named_params = [&](const std::string& opt_str) -> bool { + std::stringstream ss(opt_str); std::string token; - int idx = 0; while (std::getline(ss, token, ',')) { - auto trim = [](std::string& s) { - const char* whitespace = " \t\r\n"; - auto start = s.find_first_not_of(whitespace); - if (start == std::string::npos) { - s.clear(); - return; - } - auto end = s.find_last_not_of(whitespace); - s = s.substr(start, end - start + 1); - }; - trim(token); - if (token.empty()) { - LOG_ERROR("error: invalid easycache option '%s'", easycache_option.c_str()); - return false; - } - if (idx >= 3) { - LOG_ERROR("error: easycache expects exactly 3 comma-separated values (threshold,start,end)\n"); + size_t eq_pos = token.find('='); + if (eq_pos == std::string::npos) { + LOG_ERROR("error: cache option '%s' missing '=' separator", token.c_str()); return false; } + std::string key = token.substr(0, eq_pos); + std::string val = token.substr(eq_pos + 1); try { - values[idx] = std::stof(token); + if (key == "threshold") { + if (cache_mode == "easycache" || cache_mode == "ucache") { + cache_params.reuse_threshold = std::stof(val); + } else { + cache_params.residual_diff_threshold = std::stof(val); + } + } else if (key == "start") { + cache_params.start_percent = std::stof(val); + } else if (key == "end") { + cache_params.end_percent = std::stof(val); + } else if (key == "decay") { + cache_params.error_decay_rate = std::stof(val); + } else if (key == "relative") { + cache_params.use_relative_threshold = (std::stof(val) != 0.0f); + } else if (key == "reset") { + cache_params.reset_error_on_compute = (std::stof(val) != 0.0f); + } else if (key == "Fn" || key == "fn") { + cache_params.Fn_compute_blocks = std::stoi(val); + } else if (key == "Bn" || key == "bn") { + cache_params.Bn_compute_blocks = std::stoi(val); + } else if (key == "warmup") { + cache_params.max_warmup_steps = std::stoi(val); + } else { + LOG_ERROR("error: unknown cache parameter '%s'", key.c_str()); + return false; + } } catch (const std::exception&) { - LOG_ERROR("error: invalid easycache value '%s'", token.c_str()); + LOG_ERROR("error: invalid value '%s' for parameter '%s'", val.c_str(), key.c_str()); return false; } - idx++; } - if (idx != 3) { - LOG_ERROR("error: easycache expects exactly 3 comma-separated values (threshold,start,end)\n"); - return false; + return true; + }; + + if (!cache_mode.empty()) { + if (cache_mode == "easycache") { + cache_params.mode = SD_CACHE_EASYCACHE; + cache_params.reuse_threshold = 0.2f; + cache_params.start_percent = 0.15f; + cache_params.end_percent = 0.95f; + cache_params.error_decay_rate = 1.0f; + cache_params.use_relative_threshold = true; + cache_params.reset_error_on_compute = true; + } else if (cache_mode == "ucache") { + cache_params.mode = SD_CACHE_UCACHE; + cache_params.reuse_threshold = 1.0f; + cache_params.start_percent = 0.15f; + cache_params.end_percent = 0.95f; + cache_params.error_decay_rate = 1.0f; + cache_params.use_relative_threshold = true; + cache_params.reset_error_on_compute = true; + } else if (cache_mode == "dbcache") { + cache_params.mode = SD_CACHE_DBCACHE; + cache_params.Fn_compute_blocks = 8; + cache_params.Bn_compute_blocks = 0; + cache_params.residual_diff_threshold = 0.08f; + cache_params.max_warmup_steps = 8; + } else if (cache_mode == "taylorseer") { + cache_params.mode = SD_CACHE_TAYLORSEER; + cache_params.Fn_compute_blocks = 8; + cache_params.Bn_compute_blocks = 0; + cache_params.residual_diff_threshold = 0.08f; + cache_params.max_warmup_steps = 8; + } else if (cache_mode == "cache-dit") { + cache_params.mode = SD_CACHE_CACHE_DIT; + cache_params.Fn_compute_blocks = 8; + cache_params.Bn_compute_blocks = 0; + cache_params.residual_diff_threshold = 0.08f; + cache_params.max_warmup_steps = 8; } - if (values[0] < 0.0f) { - LOG_ERROR("error: easycache threshold must be non-negative\n"); - return false; + + if (!cache_option.empty()) { + if (!parse_named_params(cache_option)) { + return false; + } } - if (values[1] < 0.0f || values[1] >= 1.0f || values[2] <= 0.0f || values[2] > 1.0f || values[1] >= values[2]) { - LOG_ERROR("error: easycache start/end percents must satisfy 0.0 <= start < end <= 1.0\n"); - return false; + + if (cache_mode == "easycache" || cache_mode == "ucache") { + if (cache_params.reuse_threshold < 0.0f) { + LOG_ERROR("error: cache threshold must be non-negative"); + return false; + } + if (cache_params.start_percent < 0.0f || cache_params.start_percent >= 1.0f || + cache_params.end_percent <= 0.0f || cache_params.end_percent > 1.0f || + cache_params.start_percent >= cache_params.end_percent) { + LOG_ERROR("error: cache start/end percents must satisfy 0.0 <= start < end <= 1.0"); + return false; + } + } + } + + if (cache_params.mode == SD_CACHE_DBCACHE || + cache_params.mode == SD_CACHE_TAYLORSEER || + cache_params.mode == SD_CACHE_CACHE_DIT) { + if (!scm_mask.empty()) { + cache_params.scm_mask = scm_mask.c_str(); } - easycache_params.enabled = true; - easycache_params.reuse_threshold = values[0]; - easycache_params.start_percent = values[1]; - easycache_params.end_percent = values[2]; - } else { - easycache_params.enabled = false; + cache_params.scm_policy_dynamic = scm_policy_dynamic; } sample_params.guidance.slg.layers = skip_layers.data(); @@ -1765,12 +1877,13 @@ struct SDGenerationParams { << " high_noise_skip_layers: " << vec_to_string(high_noise_skip_layers) << ",\n" << " high_noise_sample_params: " << high_noise_sample_params_str << ",\n" << " custom_sigmas: " << vec_to_string(custom_sigmas) << ",\n" - << " easycache_option: \"" << easycache_option << "\",\n" - << " easycache: " - << (easycache_params.enabled ? "enabled" : "disabled") - << " (threshold=" << easycache_params.reuse_threshold - << ", start=" << easycache_params.start_percent - << ", end=" << easycache_params.end_percent << "),\n" + << " cache_mode: \"" << cache_mode << "\",\n" + << " cache_option: \"" << cache_option << "\",\n" + << " cache: " + << (cache_params.mode != SD_CACHE_DISABLED ? "enabled" : "disabled") + << " (threshold=" << cache_params.reuse_threshold + << ", start=" << cache_params.start_percent + << ", end=" << cache_params.end_percent << "),\n" << " moe_boundary: " << moe_boundary << ",\n" << " video_frames: " << video_frames << ",\n" << " fps: " << fps << ",\n" diff --git a/examples/server/main.cpp b/examples/server/main.cpp index 39359fbbe..5c951c075 100644 --- a/examples/server/main.cpp +++ b/examples/server/main.cpp @@ -432,7 +432,7 @@ int main(int argc, const char** argv) { gen_params.pm_style_strength, }, // pm_params ctx_params.vae_tiling_params, - gen_params.easycache_params, + gen_params.cache_params, }; sd_image_t* results = nullptr; @@ -645,7 +645,7 @@ int main(int argc, const char** argv) { gen_params.pm_style_strength, }, // pm_params ctx_params.vae_tiling_params, - gen_params.easycache_params, + gen_params.cache_params, }; sd_image_t* results = nullptr; diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 44bd3ccac..d7b4fe72b 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -13,6 +13,8 @@ #include "diffusion_model.hpp" #include "easycache.hpp" #include "esrgan.hpp" +#include "ucache.hpp" +#include "cache_dit.hpp" #include "lora.hpp" #include "pmid.hpp" #include "tae.hpp" @@ -1486,7 +1488,7 @@ class StableDiffusionGGML { ggml_tensor* denoise_mask = nullptr, ggml_tensor* vace_context = nullptr, float vace_strength = 1.f, - const sd_easycache_params_t* easycache_params = nullptr) { + const sd_cache_params_t* cache_params = nullptr) { if (shifted_timestep > 0 && !sd_version_is_sdxl(version)) { LOG_WARN("timestep shifting is only supported for SDXL models!"); shifted_timestep = 0; @@ -1503,31 +1505,40 @@ class StableDiffusionGGML { } EasyCacheState easycache_state; + UCacheState ucache_state; + CacheDitConditionState cachedit_state; bool easycache_enabled = false; - if (easycache_params != nullptr && easycache_params->enabled) { - bool easycache_supported = sd_version_is_dit(version); - if (!easycache_supported) { - LOG_WARN("EasyCache requested but not supported for this model type"); - } else { - EasyCacheConfig easycache_config; - easycache_config.enabled = true; - easycache_config.reuse_threshold = std::max(0.0f, easycache_params->reuse_threshold); - easycache_config.start_percent = easycache_params->start_percent; - easycache_config.end_percent = easycache_params->end_percent; - bool percent_valid = easycache_config.start_percent >= 0.0f && - easycache_config.start_percent < 1.0f && - easycache_config.end_percent > 0.0f && - easycache_config.end_percent <= 1.0f && - easycache_config.start_percent < easycache_config.end_percent; - if (!percent_valid) { - LOG_WARN("EasyCache disabled due to invalid percent range (start=%.3f, end=%.3f)", - easycache_config.start_percent, - easycache_config.end_percent); + bool ucache_enabled = false; + bool cachedit_enabled = false; + + if (cache_params != nullptr && cache_params->mode != SD_CACHE_DISABLED) { + bool percent_valid = true; + if (cache_params->mode == SD_CACHE_EASYCACHE || cache_params->mode == SD_CACHE_UCACHE) { + percent_valid = cache_params->start_percent >= 0.0f && + cache_params->start_percent < 1.0f && + cache_params->end_percent > 0.0f && + cache_params->end_percent <= 1.0f && + cache_params->start_percent < cache_params->end_percent; + } + + if (!percent_valid) { + LOG_WARN("Cache disabled due to invalid percent range (start=%.3f, end=%.3f)", + cache_params->start_percent, + cache_params->end_percent); + } else if (cache_params->mode == SD_CACHE_EASYCACHE) { + bool easycache_supported = sd_version_is_dit(version); + if (!easycache_supported) { + LOG_WARN("EasyCache requested but not supported for this model type"); } else { + EasyCacheConfig easycache_config; + easycache_config.enabled = true; + easycache_config.reuse_threshold = std::max(0.0f, cache_params->reuse_threshold); + easycache_config.start_percent = cache_params->start_percent; + easycache_config.end_percent = cache_params->end_percent; easycache_state.init(easycache_config, denoiser.get()); if (easycache_state.enabled()) { easycache_enabled = true; - LOG_INFO("EasyCache enabled - threshold: %.3f, start_percent: %.2f, end_percent: %.2f", + LOG_INFO("EasyCache enabled - threshold: %.3f, start: %.2f, end: %.2f", easycache_config.reuse_threshold, easycache_config.start_percent, easycache_config.end_percent); @@ -1535,9 +1546,85 @@ class StableDiffusionGGML { LOG_WARN("EasyCache requested but could not be initialized for this run"); } } + } else if (cache_params->mode == SD_CACHE_UCACHE) { + bool ucache_supported = sd_version_is_unet(version); + if (!ucache_supported) { + LOG_WARN("UCache requested but not supported for this model type (only UNET models)"); + } else { + UCacheConfig ucache_config; + ucache_config.enabled = true; + ucache_config.reuse_threshold = std::max(0.0f, cache_params->reuse_threshold); + ucache_config.start_percent = cache_params->start_percent; + ucache_config.end_percent = cache_params->end_percent; + ucache_config.error_decay_rate = std::max(0.0f, std::min(1.0f, cache_params->error_decay_rate)); + ucache_config.use_relative_threshold = cache_params->use_relative_threshold; + ucache_config.reset_error_on_compute = cache_params->reset_error_on_compute; + ucache_state.init(ucache_config, denoiser.get()); + if (ucache_state.enabled()) { + ucache_enabled = true; + LOG_INFO("UCache enabled - threshold: %.3f, start: %.2f, end: %.2f, decay: %.2f, relative: %s, reset: %s", + ucache_config.reuse_threshold, + ucache_config.start_percent, + ucache_config.end_percent, + ucache_config.error_decay_rate, + ucache_config.use_relative_threshold ? "true" : "false", + ucache_config.reset_error_on_compute ? "true" : "false"); + } else { + LOG_WARN("UCache requested but could not be initialized for this run"); + } + } + } else if (cache_params->mode == SD_CACHE_DBCACHE || + cache_params->mode == SD_CACHE_TAYLORSEER || + cache_params->mode == SD_CACHE_CACHE_DIT) { + bool cachedit_supported = sd_version_is_dit(version); + if (!cachedit_supported) { + LOG_WARN("CacheDIT requested but not supported for this model type (only DiT models)"); + } else { + DBCacheConfig dbcfg; + dbcfg.enabled = (cache_params->mode == SD_CACHE_DBCACHE || + cache_params->mode == SD_CACHE_CACHE_DIT); + dbcfg.Fn_compute_blocks = cache_params->Fn_compute_blocks; + dbcfg.Bn_compute_blocks = cache_params->Bn_compute_blocks; + dbcfg.residual_diff_threshold = cache_params->residual_diff_threshold; + dbcfg.max_warmup_steps = cache_params->max_warmup_steps; + dbcfg.max_cached_steps = cache_params->max_cached_steps; + dbcfg.max_continuous_cached_steps = cache_params->max_continuous_cached_steps; + if (cache_params->scm_mask != nullptr && strlen(cache_params->scm_mask) > 0) { + dbcfg.steps_computation_mask = parse_scm_mask(cache_params->scm_mask); + } + dbcfg.scm_policy_dynamic = cache_params->scm_policy_dynamic; + + TaylorSeerConfig tcfg; + tcfg.enabled = (cache_params->mode == SD_CACHE_TAYLORSEER || + cache_params->mode == SD_CACHE_CACHE_DIT); + tcfg.n_derivatives = cache_params->taylorseer_n_derivatives; + tcfg.skip_interval_steps = cache_params->taylorseer_skip_interval; + + cachedit_state.init(dbcfg, tcfg); + if (cachedit_state.enabled()) { + cachedit_enabled = true; + LOG_INFO("CacheDIT enabled - mode: %s, Fn: %d, Bn: %d, threshold: %.3f, warmup: %d", + cache_params->mode == SD_CACHE_CACHE_DIT ? "DBCache+TaylorSeer" : + (cache_params->mode == SD_CACHE_DBCACHE ? "DBCache" : "TaylorSeer"), + dbcfg.Fn_compute_blocks, + dbcfg.Bn_compute_blocks, + dbcfg.residual_diff_threshold, + dbcfg.max_warmup_steps); + } else { + LOG_WARN("CacheDIT requested but could not be initialized for this run"); + } + } } } + if (ucache_enabled) { + ucache_state.set_sigmas(sigmas); + } + + if (cachedit_enabled) { + cachedit_state.set_sigmas(sigmas); + } + size_t steps = sigmas.size() - 1; struct ggml_tensor* x = ggml_dup_tensor(work_ctx, init_latent); copy_ggml_tensor(x, init_latent); @@ -1641,6 +1728,91 @@ class StableDiffusionGGML { return easycache_step_active && easycache_state.is_step_skipped(); }; + const bool ucache_step_active = ucache_enabled && step > 0; + int ucache_step_index = ucache_step_active ? (step - 1) : -1; + if (ucache_step_active) { + ucache_state.begin_step(ucache_step_index, sigma); + } + + auto ucache_before_condition = [&](const SDCondition* condition, struct ggml_tensor* output_tensor) -> bool { + if (!ucache_step_active || condition == nullptr || output_tensor == nullptr) { + return false; + } + return ucache_state.before_condition(condition, + diffusion_params.x, + output_tensor, + sigma, + ucache_step_index); + }; + + auto ucache_after_condition = [&](const SDCondition* condition, struct ggml_tensor* output_tensor) { + if (!ucache_step_active || condition == nullptr || output_tensor == nullptr) { + return; + } + ucache_state.after_condition(condition, + diffusion_params.x, + output_tensor); + }; + + auto ucache_step_is_skipped = [&]() { + return ucache_step_active && ucache_state.is_step_skipped(); + }; + + const bool cachedit_step_active = cachedit_enabled && step > 0; + int cachedit_step_index = cachedit_step_active ? (step - 1) : -1; + if (cachedit_step_active) { + cachedit_state.begin_step(cachedit_step_index, sigma); + } + + auto cachedit_before_condition = [&](const SDCondition* condition, struct ggml_tensor* output_tensor) -> bool { + if (!cachedit_step_active || condition == nullptr || output_tensor == nullptr) { + return false; + } + return cachedit_state.before_condition(condition, + diffusion_params.x, + output_tensor, + sigma, + cachedit_step_index); + }; + + auto cachedit_after_condition = [&](const SDCondition* condition, struct ggml_tensor* output_tensor) { + if (!cachedit_step_active || condition == nullptr || output_tensor == nullptr) { + return; + } + cachedit_state.after_condition(condition, + diffusion_params.x, + output_tensor); + }; + + auto cachedit_step_is_skipped = [&]() { + return cachedit_step_active && cachedit_state.is_step_skipped(); + }; + + auto cache_before_condition = [&](const SDCondition* condition, struct ggml_tensor* output_tensor) -> bool { + if (easycache_step_active) { + return easycache_before_condition(condition, output_tensor); + } else if (ucache_step_active) { + return ucache_before_condition(condition, output_tensor); + } else if (cachedit_step_active) { + return cachedit_before_condition(condition, output_tensor); + } + return false; + }; + + auto cache_after_condition = [&](const SDCondition* condition, struct ggml_tensor* output_tensor) { + if (easycache_step_active) { + easycache_after_condition(condition, output_tensor); + } else if (ucache_step_active) { + ucache_after_condition(condition, output_tensor); + } else if (cachedit_step_active) { + cachedit_after_condition(condition, output_tensor); + } + }; + + auto cache_step_is_skipped = [&]() { + return easycache_step_is_skipped() || ucache_step_is_skipped() || cachedit_step_is_skipped(); + }; + std::vector scaling = denoiser->get_scalings(sigma); GGML_ASSERT(scaling.size() == 3); float c_skip = scaling[0]; @@ -1716,7 +1888,7 @@ class StableDiffusionGGML { active_condition = &id_cond; } - bool skip_model = easycache_before_condition(active_condition, *active_output); + bool skip_model = cache_before_condition(active_condition, *active_output); if (!skip_model) { if (!work_diffusion_model->compute(n_threads, diffusion_params, @@ -1724,10 +1896,10 @@ class StableDiffusionGGML { LOG_ERROR("diffusion model compute failed"); return nullptr; } - easycache_after_condition(active_condition, *active_output); + cache_after_condition(active_condition, *active_output); } - bool current_step_skipped = easycache_step_is_skipped(); + bool current_step_skipped = cache_step_is_skipped(); float* negative_data = nullptr; if (has_unconditioned) { @@ -1739,12 +1911,12 @@ class StableDiffusionGGML { LOG_ERROR("controlnet compute failed"); } } - current_step_skipped = easycache_step_is_skipped(); + current_step_skipped = cache_step_is_skipped(); diffusion_params.controls = controls; diffusion_params.context = uncond.c_crossattn; diffusion_params.c_concat = uncond.c_concat; diffusion_params.y = uncond.c_vector; - bool skip_uncond = easycache_before_condition(&uncond, out_uncond); + bool skip_uncond = cache_before_condition(&uncond, out_uncond); if (!skip_uncond) { if (!work_diffusion_model->compute(n_threads, diffusion_params, @@ -1752,7 +1924,7 @@ class StableDiffusionGGML { LOG_ERROR("diffusion model compute failed"); return nullptr; } - easycache_after_condition(&uncond, out_uncond); + cache_after_condition(&uncond, out_uncond); } negative_data = (float*)out_uncond->data; } @@ -1762,7 +1934,7 @@ class StableDiffusionGGML { diffusion_params.context = img_cond.c_crossattn; diffusion_params.c_concat = img_cond.c_concat; diffusion_params.y = img_cond.c_vector; - bool skip_img_cond = easycache_before_condition(&img_cond, out_img_cond); + bool skip_img_cond = cache_before_condition(&img_cond, out_img_cond); if (!skip_img_cond) { if (!work_diffusion_model->compute(n_threads, diffusion_params, @@ -1770,7 +1942,7 @@ class StableDiffusionGGML { LOG_ERROR("diffusion model compute failed"); return nullptr; } - easycache_after_condition(&img_cond, out_img_cond); + cache_after_condition(&img_cond, out_img_cond); } img_cond_data = (float*)out_img_cond->data; } @@ -1780,7 +1952,7 @@ class StableDiffusionGGML { float* skip_layer_data = has_skiplayer ? (float*)out_skip->data : nullptr; if (is_skiplayer_step) { LOG_DEBUG("Skipping layers at step %d\n", step); - if (!easycache_step_is_skipped()) { + if (!cache_step_is_skipped()) { // skip layer (same as conditioned) diffusion_params.context = cond.c_crossattn; diffusion_params.c_concat = cond.c_concat; @@ -1884,6 +2056,48 @@ class StableDiffusionGGML { } } + if (ucache_enabled) { + size_t total_steps = sigmas.size() > 0 ? sigmas.size() - 1 : 0; + if (ucache_state.total_steps_skipped > 0 && total_steps > 0) { + if (ucache_state.total_steps_skipped < static_cast(total_steps)) { + double speedup = static_cast(total_steps) / + static_cast(total_steps - ucache_state.total_steps_skipped); + LOG_INFO("UCache skipped %d/%zu steps (%.2fx estimated speedup)", + ucache_state.total_steps_skipped, + total_steps, + speedup); + } else { + LOG_INFO("UCache skipped %d/%zu steps", + ucache_state.total_steps_skipped, + total_steps); + } + } else if (total_steps > 0) { + LOG_INFO("UCache completed without skipping steps"); + } + } + + if (cachedit_enabled) { + size_t total_steps = sigmas.size() > 0 ? sigmas.size() - 1 : 0; + if (cachedit_state.total_steps_skipped > 0 && total_steps > 0) { + if (cachedit_state.total_steps_skipped < static_cast(total_steps)) { + double speedup = static_cast(total_steps) / + static_cast(total_steps - cachedit_state.total_steps_skipped); + LOG_INFO("CacheDIT skipped %d/%zu steps (%.2fx estimated speedup), accum_diff: %.4f", + cachedit_state.total_steps_skipped, + total_steps, + speedup, + cachedit_state.accumulated_residual_diff); + } else { + LOG_INFO("CacheDIT skipped %d/%zu steps, accum_diff: %.4f", + cachedit_state.total_steps_skipped, + total_steps, + cachedit_state.accumulated_residual_diff); + } + } else if (total_steps > 0) { + LOG_INFO("CacheDIT completed without skipping steps"); + } + } + if (inverse_noise_scaling) { x = denoiser->inverse_noise_scaling(sigmas[sigmas.size() - 1], x); } @@ -2498,12 +2712,25 @@ enum lora_apply_mode_t str_to_lora_apply_mode(const char* str) { return LORA_APPLY_MODE_COUNT; } -void sd_easycache_params_init(sd_easycache_params_t* easycache_params) { - *easycache_params = {}; - easycache_params->enabled = false; - easycache_params->reuse_threshold = 0.2f; - easycache_params->start_percent = 0.15f; - easycache_params->end_percent = 0.95f; +void sd_cache_params_init(sd_cache_params_t* cache_params) { + *cache_params = {}; + cache_params->mode = SD_CACHE_DISABLED; + cache_params->reuse_threshold = 1.0f; + cache_params->start_percent = 0.15f; + cache_params->end_percent = 0.95f; + cache_params->error_decay_rate = 1.0f; + cache_params->use_relative_threshold = true; + cache_params->reset_error_on_compute = true; + cache_params->Fn_compute_blocks = 8; + cache_params->Bn_compute_blocks = 0; + cache_params->residual_diff_threshold = 0.08f; + cache_params->max_warmup_steps = 8; + cache_params->max_cached_steps = -1; + cache_params->max_continuous_cached_steps = -1; + cache_params->taylorseer_n_derivatives = 1; + cache_params->taylorseer_skip_interval = 1; + cache_params->scm_mask = nullptr; + cache_params->scm_policy_dynamic = true; } void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) { @@ -2662,7 +2889,7 @@ void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params) { sd_img_gen_params->control_strength = 0.9f; sd_img_gen_params->pm_params = {nullptr, 0, nullptr, 20.f}; sd_img_gen_params->vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f}; - sd_easycache_params_init(&sd_img_gen_params->easycache); + sd_cache_params_init(&sd_img_gen_params->cache); } char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) { @@ -2706,12 +2933,18 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) { sd_img_gen_params->pm_params.id_images_count, SAFE_STR(sd_img_gen_params->pm_params.id_embed_path), BOOL_STR(sd_img_gen_params->vae_tiling_params.enabled)); + const char* cache_mode_str = "disabled"; + if (sd_img_gen_params->cache.mode == SD_CACHE_EASYCACHE) { + cache_mode_str = "easycache"; + } else if (sd_img_gen_params->cache.mode == SD_CACHE_UCACHE) { + cache_mode_str = "ucache"; + } snprintf(buf + strlen(buf), 4096 - strlen(buf), - "easycache: %s (threshold=%.3f, start=%.2f, end=%.2f)\n", - sd_img_gen_params->easycache.enabled ? "enabled" : "disabled", - sd_img_gen_params->easycache.reuse_threshold, - sd_img_gen_params->easycache.start_percent, - sd_img_gen_params->easycache.end_percent); + "cache: %s (threshold=%.3f, start=%.2f, end=%.2f)\n", + cache_mode_str, + sd_img_gen_params->cache.reuse_threshold, + sd_img_gen_params->cache.start_percent, + sd_img_gen_params->cache.end_percent); free(sample_params_str); return buf; } @@ -2728,7 +2961,7 @@ void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params) { sd_vid_gen_params->video_frames = 6; sd_vid_gen_params->moe_boundary = 0.875f; sd_vid_gen_params->vace_strength = 1.f; - sd_easycache_params_init(&sd_vid_gen_params->easycache); + sd_cache_params_init(&sd_vid_gen_params->cache); } struct sd_ctx_t { @@ -2806,7 +3039,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, bool increase_ref_index, ggml_tensor* concat_latent = nullptr, ggml_tensor* denoise_mask = nullptr, - const sd_easycache_params_t* easycache_params = nullptr) { + const sd_cache_params_t* cache_params = nullptr) { if (seed < 0) { // Generally, when using the provided command line, the seed is always >0. // However, to prevent potential issues if 'stable-diffusion.cpp' is invoked as a library @@ -3095,7 +3328,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, denoise_mask, nullptr, 1.0f, - easycache_params); + cache_params); int64_t sampling_end = ggml_time_ms(); if (x_0 != nullptr) { // print_ggml_tensor(x_0); @@ -3429,7 +3662,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g sd_img_gen_params->increase_ref_index, concat_latent, denoise_mask, - &sd_img_gen_params->easycache); + &sd_img_gen_params->cache); size_t t2 = ggml_time_ms(); @@ -3796,7 +4029,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s denoise_mask, vace_context, sd_vid_gen_params->vace_strength, - &sd_vid_gen_params->easycache); + &sd_vid_gen_params->cache); int64_t sampling_end = ggml_time_ms(); LOG_INFO("sampling(high noise) completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000); @@ -3833,7 +4066,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s denoise_mask, vace_context, sd_vid_gen_params->vace_strength, - &sd_vid_gen_params->easycache); + &sd_vid_gen_params->cache); int64_t sampling_end = ggml_time_ms(); LOG_INFO("sampling completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000); diff --git a/stable-diffusion.h b/stable-diffusion.h index 9266ba437..de6485f5e 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -235,12 +235,34 @@ typedef struct { float style_strength; } sd_pm_params_t; // photo maker +enum sd_cache_mode_t { + SD_CACHE_DISABLED = 0, + SD_CACHE_EASYCACHE, + SD_CACHE_UCACHE, + SD_CACHE_DBCACHE, + SD_CACHE_TAYLORSEER, + SD_CACHE_CACHE_DIT, +}; + typedef struct { - bool enabled; + enum sd_cache_mode_t mode; float reuse_threshold; float start_percent; float end_percent; -} sd_easycache_params_t; + float error_decay_rate; + bool use_relative_threshold; + bool reset_error_on_compute; + int Fn_compute_blocks; + int Bn_compute_blocks; + float residual_diff_threshold; + int max_warmup_steps; + int max_cached_steps; + int max_continuous_cached_steps; + int taylorseer_n_derivatives; + int taylorseer_skip_interval; + const char* scm_mask; + bool scm_policy_dynamic; +} sd_cache_params_t; typedef struct { bool is_high_noise; @@ -270,7 +292,7 @@ typedef struct { float control_strength; sd_pm_params_t pm_params; sd_tiling_params_t vae_tiling_params; - sd_easycache_params_t easycache; + sd_cache_params_t cache; } sd_img_gen_params_t; typedef struct { @@ -292,7 +314,7 @@ typedef struct { int64_t seed; int video_frames; float vace_strength; - sd_easycache_params_t easycache; + sd_cache_params_t cache; } sd_vid_gen_params_t; typedef struct sd_ctx_t sd_ctx_t; @@ -322,7 +344,7 @@ SD_API enum preview_t str_to_preview(const char* str); SD_API const char* sd_lora_apply_mode_name(enum lora_apply_mode_t mode); SD_API enum lora_apply_mode_t str_to_lora_apply_mode(const char* str); -SD_API void sd_easycache_params_init(sd_easycache_params_t* easycache_params); +SD_API void sd_cache_params_init(sd_cache_params_t* cache_params); SD_API void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params); SD_API char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params); diff --git a/ucache.hpp b/ucache.hpp new file mode 100644 index 000000000..9a39557d1 --- /dev/null +++ b/ucache.hpp @@ -0,0 +1,402 @@ +#ifndef __UCACHE_HPP__ +#define __UCACHE_HPP__ + +#include +#include +#include +#include + +#include "denoiser.hpp" +#include "ggml_extend.hpp" + +struct UCacheConfig { + bool enabled = false; + float reuse_threshold = 1.0f; + float start_percent = 0.15f; + float end_percent = 0.95f; + float error_decay_rate = 1.0f; + bool use_relative_threshold = true; + bool adaptive_threshold = true; + float early_step_multiplier = 0.5f; + float late_step_multiplier = 1.5f; + bool reset_error_on_compute = true; +}; + +struct UCacheCacheEntry { + std::vector diff; +}; + +struct UCacheState { + UCacheConfig config; + Denoiser* denoiser = nullptr; + float start_sigma = std::numeric_limits::max(); + float end_sigma = 0.0f; + bool initialized = false; + bool initial_step = true; + bool skip_current_step = false; + bool step_active = false; + const SDCondition* anchor_condition = nullptr; + std::unordered_map cache_diffs; + std::vector prev_input; + std::vector prev_output; + float output_prev_norm = 0.0f; + bool has_prev_input = false; + bool has_prev_output = false; + bool has_output_prev_norm = false; + bool has_relative_transformation_rate = false; + float relative_transformation_rate = 0.0f; + float cumulative_change_rate = 0.0f; + float last_input_change = 0.0f; + bool has_last_input_change = false; + int total_steps_skipped = 0; + int current_step_index = -1; + int steps_computed_since_active = 0; + float accumulated_error = 0.0f; + float reference_output_norm = 0.0f; + + struct BlockMetrics { + float sum_transformation_rate = 0.0f; + float sum_output_norm = 0.0f; + int sample_count = 0; + float min_change_rate = std::numeric_limits::max(); + float max_change_rate = 0.0f; + + void reset() { + sum_transformation_rate = 0.0f; + sum_output_norm = 0.0f; + sample_count = 0; + min_change_rate = std::numeric_limits::max(); + max_change_rate = 0.0f; + } + + void record(float change_rate, float output_norm) { + if (std::isfinite(change_rate) && change_rate > 0.0f) { + sum_transformation_rate += change_rate; + sum_output_norm += output_norm; + sample_count++; + if (change_rate < min_change_rate) min_change_rate = change_rate; + if (change_rate > max_change_rate) max_change_rate = change_rate; + } + } + + float avg_transformation_rate() const { + return (sample_count > 0) ? (sum_transformation_rate / sample_count) : 0.0f; + } + + float avg_output_norm() const { + return (sample_count > 0) ? (sum_output_norm / sample_count) : 0.0f; + } + }; + BlockMetrics block_metrics; + int total_active_steps = 0; + + void reset_runtime() { + initial_step = true; + skip_current_step = false; + step_active = false; + anchor_condition = nullptr; + cache_diffs.clear(); + prev_input.clear(); + prev_output.clear(); + output_prev_norm = 0.0f; + has_prev_input = false; + has_prev_output = false; + has_output_prev_norm = false; + has_relative_transformation_rate = false; + relative_transformation_rate = 0.0f; + cumulative_change_rate = 0.0f; + last_input_change = 0.0f; + has_last_input_change = false; + total_steps_skipped = 0; + current_step_index = -1; + steps_computed_since_active = 0; + accumulated_error = 0.0f; + reference_output_norm = 0.0f; + block_metrics.reset(); + total_active_steps = 0; + } + + void init(const UCacheConfig& cfg, Denoiser* d) { + config = cfg; + denoiser = d; + initialized = cfg.enabled && d != nullptr; + reset_runtime(); + if (initialized) { + start_sigma = percent_to_sigma(config.start_percent); + end_sigma = percent_to_sigma(config.end_percent); + } + } + + void set_sigmas(const std::vector& sigmas) { + if (!initialized || sigmas.size() < 2) { + return; + } + size_t n_steps = sigmas.size() - 1; + + size_t start_step = static_cast(config.start_percent * n_steps); + size_t end_step = static_cast(config.end_percent * n_steps); + + if (start_step >= n_steps) start_step = n_steps - 1; + if (end_step >= n_steps) end_step = n_steps - 1; + + start_sigma = sigmas[start_step]; + end_sigma = sigmas[end_step]; + + if (start_sigma < end_sigma) { + std::swap(start_sigma, end_sigma); + } + } + + bool enabled() const { + return initialized && config.enabled; + } + + float percent_to_sigma(float percent) const { + if (!denoiser) { + return 0.0f; + } + if (percent <= 0.0f) { + return std::numeric_limits::max(); + } + if (percent >= 1.0f) { + return 0.0f; + } + float t = (1.0f - percent) * (TIMESTEPS - 1); + return denoiser->t_to_sigma(t); + } + + void begin_step(int step_index, float sigma) { + if (!enabled()) { + return; + } + if (step_index == current_step_index) { + return; + } + current_step_index = step_index; + skip_current_step = false; + has_last_input_change = false; + step_active = false; + + if (sigma > start_sigma) { + return; + } + if (!(sigma > end_sigma)) { + return; + } + step_active = true; + total_active_steps++; + } + + bool step_is_active() const { + return enabled() && step_active; + } + + bool is_step_skipped() const { + return enabled() && step_active && skip_current_step; + } + + float get_adaptive_threshold(int estimated_total_steps = 0) const { + float base_threshold = config.reuse_threshold; + + if (!config.adaptive_threshold) { + return base_threshold; + } + + int effective_total = estimated_total_steps; + if (effective_total <= 0) { + effective_total = std::max(20, steps_computed_since_active * 2); + } + + float progress = (effective_total > 0) ? + (static_cast(steps_computed_since_active) / effective_total) : 0.0f; + + float multiplier = 1.0f; + if (progress < 0.2f) { + multiplier = config.early_step_multiplier; + } else if (progress > 0.8f) { + multiplier = config.late_step_multiplier; + } + + return base_threshold * multiplier; + } + + bool has_cache(const SDCondition* cond) const { + auto it = cache_diffs.find(cond); + return it != cache_diffs.end() && !it->second.diff.empty(); + } + + void update_cache(const SDCondition* cond, ggml_tensor* input, ggml_tensor* output) { + UCacheCacheEntry& entry = cache_diffs[cond]; + size_t ne = static_cast(ggml_nelements(output)); + entry.diff.resize(ne); + float* out_data = (float*)output->data; + float* in_data = (float*)input->data; + + for (size_t i = 0; i < ne; ++i) { + entry.diff[i] = out_data[i] - in_data[i]; + } + } + + void apply_cache(const SDCondition* cond, ggml_tensor* input, ggml_tensor* output) { + auto it = cache_diffs.find(cond); + if (it == cache_diffs.end() || it->second.diff.empty()) { + return; + } + + copy_ggml_tensor(output, input); + float* out_data = (float*)output->data; + const std::vector& diff = it->second.diff; + for (size_t i = 0; i < diff.size(); ++i) { + out_data[i] += diff[i]; + } + } + + bool before_condition(const SDCondition* cond, + ggml_tensor* input, + ggml_tensor* output, + float sigma, + int step_index) { + if (!enabled() || step_index < 0) { + return false; + } + if (step_index != current_step_index) { + begin_step(step_index, sigma); + } + if (!step_active) { + return false; + } + + if (initial_step) { + anchor_condition = cond; + initial_step = false; + } + + bool is_anchor = (cond == anchor_condition); + + if (skip_current_step) { + if (has_cache(cond)) { + apply_cache(cond, input, output); + return true; + } + return false; + } + + if (!is_anchor) { + return false; + } + + if (!has_prev_input || !has_prev_output || !has_cache(cond)) { + return false; + } + + size_t ne = static_cast(ggml_nelements(input)); + if (prev_input.size() != ne) { + return false; + } + + float* input_data = (float*)input->data; + last_input_change = 0.0f; + for (size_t i = 0; i < ne; ++i) { + last_input_change += std::fabs(input_data[i] - prev_input[i]); + } + if (ne > 0) { + last_input_change /= static_cast(ne); + } + has_last_input_change = true; + + if (has_output_prev_norm && has_relative_transformation_rate && + last_input_change > 0.0f && output_prev_norm > 0.0f) { + + float approx_output_change_rate = (relative_transformation_rate * last_input_change) / output_prev_norm; + accumulated_error = accumulated_error * config.error_decay_rate + approx_output_change_rate; + + float effective_threshold = get_adaptive_threshold(); + if (config.use_relative_threshold && reference_output_norm > 0.0f) { + effective_threshold = effective_threshold * reference_output_norm; + } + + if (accumulated_error < effective_threshold) { + skip_current_step = true; + total_steps_skipped++; + apply_cache(cond, input, output); + return true; + } else if (config.reset_error_on_compute) { + accumulated_error = 0.0f; + } + } + + return false; + } + + void after_condition(const SDCondition* cond, ggml_tensor* input, ggml_tensor* output) { + if (!step_is_active()) { + return; + } + + update_cache(cond, input, output); + + if (cond != anchor_condition) { + return; + } + + size_t ne = static_cast(ggml_nelements(input)); + float* in_data = (float*)input->data; + prev_input.resize(ne); + for (size_t i = 0; i < ne; ++i) { + prev_input[i] = in_data[i]; + } + has_prev_input = true; + + float* out_data = (float*)output->data; + float output_change = 0.0f; + if (has_prev_output && prev_output.size() == ne) { + for (size_t i = 0; i < ne; ++i) { + output_change += std::fabs(out_data[i] - prev_output[i]); + } + if (ne > 0) { + output_change /= static_cast(ne); + } + } + + prev_output.resize(ne); + for (size_t i = 0; i < ne; ++i) { + prev_output[i] = out_data[i]; + } + has_prev_output = true; + + float mean_abs = 0.0f; + for (size_t i = 0; i < ne; ++i) { + mean_abs += std::fabs(out_data[i]); + } + output_prev_norm = (ne > 0) ? (mean_abs / static_cast(ne)) : 0.0f; + has_output_prev_norm = output_prev_norm > 0.0f; + + if (reference_output_norm == 0.0f) { + reference_output_norm = output_prev_norm; + } + + if (has_last_input_change && last_input_change > 0.0f && output_change > 0.0f) { + float rate = output_change / last_input_change; + if (std::isfinite(rate)) { + relative_transformation_rate = rate; + has_relative_transformation_rate = true; + block_metrics.record(rate, output_prev_norm); + } + } + + has_last_input_change = false; + } + + void log_block_metrics() const { + if (block_metrics.sample_count > 0) { + LOG_INFO("UCacheBlockMetrics: samples=%d, avg_rate=%.4f, min=%.4f, max=%.4f, avg_norm=%.4f", + block_metrics.sample_count, + block_metrics.avg_transformation_rate(), + block_metrics.min_change_rate, + block_metrics.max_change_rate, + block_metrics.avg_output_norm()); + } + } +}; + +#endif // __UCACHE_HPP__