diff --git a/clip.hpp b/clip.hpp index 24c94f1bb..c5d7a19c6 100644 --- a/clip.hpp +++ b/clip.hpp @@ -664,7 +664,7 @@ class CLIPVisionEmbeddings : public GGMLBlock { // concat(patch_embedding, class_embedding) + position_embedding struct ggml_tensor* patch_embedding; int64_t N = pixel_values->ne[3]; - patch_embedding = ggml_ext_conv_2d(ctx->ggml_ctx, pixel_values, patch_embed_weight, nullptr, patch_size, patch_size); // [N, embed_dim, image_size // pacht_size, image_size // pacht_size] + patch_embedding = ggml_ext_conv_2d(ctx->ggml_ctx, pixel_values, patch_embed_weight, nullptr, patch_size, patch_size, 0, 0, 1, 1, false, ctx->circular_x_enabled, ctx->circular_y_enabled); // [N, embed_dim, image_size // pacht_size, image_size // pacht_size] patch_embedding = ggml_reshape_3d(ctx->ggml_ctx, patch_embedding, num_patches, embed_dim, N); // [N, embed_dim, num_patches] patch_embedding = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, patch_embedding, 1, 0, 2, 3)); // [N, num_patches, embed_dim] patch_embedding = ggml_reshape_4d(ctx->ggml_ctx, patch_embedding, 1, embed_dim, num_patches, N); // [N, num_patches, embed_dim, 1] diff --git a/common.hpp b/common.hpp index 33d499fb1..3741e975a 100644 --- a/common.hpp +++ b/common.hpp @@ -28,7 +28,9 @@ class DownSampleBlock : public GGMLBlock { if (vae_downsample) { auto conv = std::dynamic_pointer_cast(blocks["conv"]); - x = ggml_pad(ctx->ggml_ctx, x, 1, 1, 0, 0); + // For VAE downsampling we manually pad by 1 before the stride-2 conv. + // Honor the global circular padding flags here to avoid seams in seamless mode. + x = sd_pad(ctx->ggml_ctx, x, 1, 1, 0, 0, ctx->circular_x_enabled, ctx->circular_y_enabled); x = conv->forward(ctx, x); } else { auto conv = std::dynamic_pointer_cast(blocks["op"]); diff --git a/diffusion_model.hpp b/diffusion_model.hpp index 8c741fdc4..0724cc938 100644 --- a/diffusion_model.hpp +++ b/diffusion_model.hpp @@ -39,6 +39,7 @@ struct DiffusionModel { virtual void set_weight_adapter(const std::shared_ptr& adapter){}; virtual int64_t get_adm_in_channels() = 0; virtual void set_flash_attn_enabled(bool enabled) = 0; + virtual void set_circular_axes(bool circular_x, bool circular_y) = 0; }; struct UNetModel : public DiffusionModel { @@ -87,6 +88,10 @@ struct UNetModel : public DiffusionModel { unet.set_flash_attention_enabled(enabled); } + void set_circular_axes(bool circular_x, bool circular_y) override { + unet.set_circular_axes(circular_x, circular_y); + } + bool compute(int n_threads, DiffusionParams diffusion_params, struct ggml_tensor** output = nullptr, @@ -148,6 +153,10 @@ struct MMDiTModel : public DiffusionModel { mmdit.set_flash_attention_enabled(enabled); } + void set_circular_axes(bool circular_x, bool circular_y) override { + mmdit.set_circular_axes(circular_x, circular_y); + } + bool compute(int n_threads, DiffusionParams diffusion_params, struct ggml_tensor** output = nullptr, @@ -210,6 +219,10 @@ struct FluxModel : public DiffusionModel { flux.set_flash_attention_enabled(enabled); } + void set_circular_axes(bool circular_x, bool circular_y) override { + flux.set_circular_axes(circular_x, circular_y); + } + bool compute(int n_threads, DiffusionParams diffusion_params, struct ggml_tensor** output = nullptr, @@ -277,6 +290,10 @@ struct WanModel : public DiffusionModel { wan.set_flash_attention_enabled(enabled); } + void set_circular_axes(bool circular_x, bool circular_y) override { + wan.set_circular_axes(circular_x, circular_y); + } + bool compute(int n_threads, DiffusionParams diffusion_params, struct ggml_tensor** output = nullptr, @@ -343,6 +360,10 @@ struct QwenImageModel : public DiffusionModel { qwen_image.set_flash_attention_enabled(enabled); } + void set_circular_axes(bool circular_x, bool circular_y) override { + qwen_image.set_circular_axes(circular_x, circular_y); + } + bool compute(int n_threads, DiffusionParams diffusion_params, struct ggml_tensor** output = nullptr, @@ -406,6 +427,10 @@ struct ZImageModel : public DiffusionModel { z_image.set_flash_attention_enabled(enabled); } + void set_circular_axes(bool circular_x, bool circular_y) override { + z_image.set_circular_axes(circular_x, circular_y); + } + bool compute(int n_threads, DiffusionParams diffusion_params, struct ggml_tensor** output = nullptr, diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 49b202fda..e472ca2e6 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -518,6 +518,9 @@ struct SDContextParams { bool diffusion_flash_attn = false; bool diffusion_conv_direct = false; bool vae_conv_direct = false; + bool circular = false; + bool circular_x = false; + bool circular_y = false; bool chroma_use_dit_mask = true; bool chroma_use_t5_mask = false; @@ -671,6 +674,18 @@ struct SDContextParams { "--vae-conv-direct", "use ggml_conv2d_direct in the vae model", true, &vae_conv_direct}, + {"", + "--circular", + "enable circular padding for convolutions", + true, &circular}, + {"", + "--circularx", + "enable circular RoPE wrapping on x-axis (width) only", + true, &circular_x}, + {"", + "--circulary", + "enable circular RoPE wrapping on y-axis (height) only", + true, &circular_y}, {"", "--chroma-disable-dit-mask", "disable dit mask for chroma", @@ -934,6 +949,9 @@ struct SDContextParams { << " diffusion_flash_attn: " << (diffusion_flash_attn ? "true" : "false") << ",\n" << " diffusion_conv_direct: " << (diffusion_conv_direct ? "true" : "false") << ",\n" << " vae_conv_direct: " << (vae_conv_direct ? "true" : "false") << ",\n" + << " circular: " << (circular ? "true" : "false") << ",\n" + << " circular_x: " << (circular_x ? "true" : "false") << ",\n" + << " circular_y: " << (circular_y ? "true" : "false") << ",\n" << " chroma_use_dit_mask: " << (chroma_use_dit_mask ? "true" : "false") << ",\n" << " chroma_use_t5_mask: " << (chroma_use_t5_mask ? "true" : "false") << ",\n" << " chroma_t5_mask_pad: " << chroma_t5_mask_pad << ",\n" @@ -995,6 +1013,9 @@ struct SDContextParams { taesd_preview, diffusion_conv_direct, vae_conv_direct, + circular, + circular || circular_x, + circular || circular_y, force_sdxl_vae_conv_scale, chroma_use_dit_mask, chroma_use_t5_mask, diff --git a/flux.hpp b/flux.hpp index 1df2874ae..2038fe152 100644 --- a/flux.hpp +++ b/flux.hpp @@ -858,14 +858,14 @@ namespace Flux { } } - struct ggml_tensor* pad_to_patch_size(struct ggml_context* ctx, + struct ggml_tensor* pad_to_patch_size(GGMLRunnerContext* ctx, struct ggml_tensor* x) { int64_t W = x->ne[0]; int64_t H = x->ne[1]; int pad_h = (params.patch_size - H % params.patch_size) % params.patch_size; int pad_w = (params.patch_size - W % params.patch_size) % params.patch_size; - x = ggml_pad(ctx, x, pad_w, pad_h, 0, 0); // [N, C, H + pad_h, W + pad_w] + x = sd_pad(ctx->ggml_ctx, x, pad_w, pad_h, 0, 0, ctx->circular_x_enabled, ctx->circular_y_enabled); return x; } @@ -891,11 +891,11 @@ namespace Flux { return x; } - struct ggml_tensor* process_img(struct ggml_context* ctx, + struct ggml_tensor* process_img(GGMLRunnerContext* ctx, struct ggml_tensor* x) { // img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) x = pad_to_patch_size(ctx, x); - x = patchify(ctx, x); + x = patchify(ctx->ggml_ctx, x); return x; } @@ -1065,7 +1065,7 @@ namespace Flux { int pad_h = (patch_size - H % patch_size) % patch_size; int pad_w = (patch_size - W % patch_size) % patch_size; - auto img = pad_to_patch_size(ctx->ggml_ctx, x); + auto img = pad_to_patch_size(ctx, x); auto orig_img = img; auto img_in_patch = std::dynamic_pointer_cast(blocks["img_in_patch"]); @@ -1128,7 +1128,7 @@ namespace Flux { int pad_h = (patch_size - H % patch_size) % patch_size; int pad_w = (patch_size - W % patch_size) % patch_size; - auto img = process_img(ctx->ggml_ctx, x); + auto img = process_img(ctx, x); uint64_t img_tokens = img->ne[1]; if (params.version == VERSION_FLUX_FILL) { @@ -1136,8 +1136,8 @@ namespace Flux { ggml_tensor* masked = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0); ggml_tensor* mask = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 8 * 8, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C); - masked = process_img(ctx->ggml_ctx, masked); - mask = process_img(ctx->ggml_ctx, mask); + masked = process_img(ctx, masked); + mask = process_img(ctx, mask); img = ggml_concat(ctx->ggml_ctx, img, ggml_concat(ctx->ggml_ctx, masked, mask, 0), 0); } else if (params.version == VERSION_FLEX_2) { @@ -1146,21 +1146,21 @@ namespace Flux { ggml_tensor* mask = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 1, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C); ggml_tensor* control = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * (C + 1)); - masked = process_img(ctx->ggml_ctx, masked); - mask = process_img(ctx->ggml_ctx, mask); - control = process_img(ctx->ggml_ctx, control); + masked = process_img(ctx, masked); + mask = process_img(ctx, mask); + control = process_img(ctx, control); img = ggml_concat(ctx->ggml_ctx, img, ggml_concat(ctx->ggml_ctx, ggml_concat(ctx->ggml_ctx, masked, mask, 0), control, 0), 0); } else if (params.version == VERSION_FLUX_CONTROLS) { GGML_ASSERT(c_concat != nullptr); - auto control = process_img(ctx->ggml_ctx, c_concat); + auto control = process_img(ctx, c_concat); img = ggml_concat(ctx->ggml_ctx, img, control, 0); } if (ref_latents.size() > 0) { for (ggml_tensor* ref : ref_latents) { - ref = process_img(ctx->ggml_ctx, ref); + ref = process_img(ctx, ref); img = ggml_concat(ctx->ggml_ctx, img, ref, 1); } } @@ -1441,6 +1441,8 @@ namespace Flux { increase_ref_index, flux_params.ref_index_scale, flux_params.theta, + circular_y_enabled, + circular_x_enabled, flux_params.axes_dim); int pos_len = pe_vec.size() / flux_params.axes_dim_sum / 2; // LOG_DEBUG("pos_len %d", pos_len); diff --git a/ggml b/ggml index 2d3876d55..d80bac55f 160000 --- a/ggml +++ b/ggml @@ -1 +1 @@ -Subproject commit 2d3876d554551d35c06dccc5852be50d5fd2a275 +Subproject commit d80bac55f6d0c57e57143f80cbb6e3155dec1cc7 diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 07b9bfbf0..663012d5b 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -19,6 +19,11 @@ #include #include #include +#include + +#ifndef GGML_KQ_MASK_PAD +#define GGML_KQ_MASK_PAD 1 +#endif #include "ggml-alloc.h" #include "ggml-backend.h" @@ -1007,6 +1012,50 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_linear(struct ggml_context* ctx, return x; } +__STATIC_INLINE__ struct ggml_tensor* sd_pad_ext(struct ggml_context* ctx, + struct ggml_tensor* x, + int lp0, + int rp0, + int lp1, + int rp1, + int lp2, + int rp2, + int lp3, + int rp3, + bool circular_x = false, + bool circular_y = false) { + if ((circular_x && circular_y) || (!circular_x && !circular_y)) { + return circular_x && circular_y ? ggml_pad_ext_circular(ctx, x, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3) + : ggml_pad_ext(ctx, x, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3); + } + + if (circular_x && (lp0 != 0 || rp0 != 0)) { + x = ggml_pad_ext_circular(ctx, x, lp0, rp0, 0, 0, 0, 0, 0, 0); + lp0 = rp0 = 0; + } + if (circular_y && (lp1 != 0 || rp1 != 0)) { + x = ggml_pad_ext_circular(ctx, x, 0, 0, lp1, rp1, 0, 0, 0, 0); + lp1 = rp1 = 0; + } + + if (lp0 != 0 || rp0 != 0 || lp1 != 0 || rp1 != 0 || lp2 != 0 || rp2 != 0 || lp3 != 0 || rp3 != 0) { + x = ggml_pad_ext(ctx, x, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3); + } + return x; +} + +__STATIC_INLINE__ struct ggml_tensor* sd_pad(struct ggml_context* ctx, + struct ggml_tensor* x, + int pad_w, + int pad_h, + int pad_t = 0, + int pad_d = 0, + bool circular_x = false, + bool circular_y = false) { + + return sd_pad_ext(ctx, x, pad_w, pad_w, pad_h, pad_h, pad_t, pad_t, pad_d, pad_d, circular_x, circular_y); +} + // w: [OC,IC, KH, KW] // x: [N, IC, IH, IW] // b: [OC,] @@ -1015,20 +1064,31 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_conv_2d(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* w, struct ggml_tensor* b, - int s0 = 1, - int s1 = 1, - int p0 = 0, - int p1 = 0, - int d0 = 1, - int d1 = 1, - bool direct = false, - float scale = 1.f) { + int s0 = 1, + int s1 = 1, + int p0 = 0, + int p1 = 0, + int d0 = 1, + int d1 = 1, + bool direct = false, + bool circular_x = false, + bool circular_y = false, + float scale = 1.f) { if (scale != 1.f) { x = ggml_scale(ctx, x, scale); } if (w->ne[2] != x->ne[2] && ggml_n_dims(w) == 2) { w = ggml_reshape_4d(ctx, w, 1, 1, w->ne[0], w->ne[1]); } + + // use circular padding (on a torus, x and y wrap around) for seamless textures + // see https://github.com/leejet/stable-diffusion.cpp/pull/914 + if ((p0 != 0 || p1 != 0) && (circular_x || circular_y)) { + x = sd_pad(ctx, x, p0, p1, 0, 0, circular_x, circular_y); + p0 = 0; + p1 = 0; + } + if (direct) { x = ggml_conv_2d_direct(ctx, w, x, s0, s1, p0, p1, d0, d1); } else { @@ -1538,7 +1598,9 @@ struct WeightAdapter { int d0 = 1; int d1 = 1; bool direct = false; - float scale = 1.f; + bool circular_x = false; + bool circular_y = false; + float scale = 1.f; } conv2d; }; virtual ggml_tensor* patch_weight(ggml_context* ctx, ggml_tensor* weight, const std::string& weight_name) = 0; @@ -1556,6 +1618,8 @@ struct GGMLRunnerContext { ggml_context* ggml_ctx = nullptr; bool flash_attn_enabled = false; bool conv2d_direct_enabled = false; + bool circular_x_enabled = false; + bool circular_y_enabled = false; std::shared_ptr weight_adapter = nullptr; }; @@ -1592,6 +1656,8 @@ struct GGMLRunner { bool flash_attn_enabled = false; bool conv2d_direct_enabled = false; + bool circular_x_enabled = false; + bool circular_y_enabled = false; void alloc_params_ctx() { struct ggml_init_params params; @@ -1869,6 +1935,8 @@ struct GGMLRunner { runner_ctx.backend = runtime_backend; runner_ctx.flash_attn_enabled = flash_attn_enabled; runner_ctx.conv2d_direct_enabled = conv2d_direct_enabled; + runner_ctx.circular_x_enabled = circular_x_enabled; + runner_ctx.circular_y_enabled = circular_y_enabled; runner_ctx.weight_adapter = weight_adapter; return runner_ctx; } @@ -2013,6 +2081,11 @@ struct GGMLRunner { conv2d_direct_enabled = enabled; } + void set_circular_axes(bool circular_x, bool circular_y) { + circular_x_enabled = circular_x; + circular_y_enabled = circular_y; + } + void set_weight_adapter(const std::shared_ptr& adapter) { weight_adapter = adapter; } @@ -2284,7 +2357,9 @@ class Conv2d : public UnaryBlock { forward_params.conv2d.d0 = dilation.second; forward_params.conv2d.d1 = dilation.first; forward_params.conv2d.direct = ctx->conv2d_direct_enabled; - forward_params.conv2d.scale = scale; + forward_params.conv2d.circular_x = ctx->circular_x_enabled; + forward_params.conv2d.circular_y = ctx->circular_y_enabled; + forward_params.conv2d.scale = scale; return ctx->weight_adapter->forward_with_lora(ctx->ggml_ctx, x, w, b, prefix, forward_params); } return ggml_ext_conv_2d(ctx->ggml_ctx, @@ -2298,6 +2373,8 @@ class Conv2d : public UnaryBlock { dilation.second, dilation.first, ctx->conv2d_direct_enabled, + ctx->circular_x_enabled, + ctx->circular_y_enabled, scale); } }; diff --git a/lora.hpp b/lora.hpp index b847f044c..7d83ec5cd 100644 --- a/lora.hpp +++ b/lora.hpp @@ -599,6 +599,8 @@ struct LoraModel : public GGMLRunner { forward_params.conv2d.d0, forward_params.conv2d.d1, forward_params.conv2d.direct, + forward_params.conv2d.circular_x, + forward_params.conv2d.circular_y, forward_params.conv2d.scale); if (lora_mid) { lx = ggml_ext_conv_2d(ctx, @@ -612,6 +614,8 @@ struct LoraModel : public GGMLRunner { 1, 1, forward_params.conv2d.direct, + forward_params.conv2d.circular_x, + forward_params.conv2d.circular_y, forward_params.conv2d.scale); } lx = ggml_ext_conv_2d(ctx, @@ -625,6 +629,8 @@ struct LoraModel : public GGMLRunner { 1, 1, forward_params.conv2d.direct, + forward_params.conv2d.circular_x, + forward_params.conv2d.circular_y, forward_params.conv2d.scale); } @@ -779,6 +785,8 @@ struct MultiLoraAdapter : public WeightAdapter { forward_params.conv2d.d0, forward_params.conv2d.d1, forward_params.conv2d.direct, + forward_params.conv2d.circular_x, + forward_params.conv2d.circular_y, forward_params.conv2d.scale); } for (auto& lora_model : lora_models) { diff --git a/mmdit.hpp b/mmdit.hpp index 38bdc2e74..eeb74a268 100644 --- a/mmdit.hpp +++ b/mmdit.hpp @@ -983,4 +983,4 @@ struct MMDiTRunner : public GGMLRunner { } }; -#endif \ No newline at end of file +#endif diff --git a/qwen_image.hpp b/qwen_image.hpp index eeb823d50..847f61171 100644 --- a/qwen_image.hpp +++ b/qwen_image.hpp @@ -354,14 +354,14 @@ namespace Qwen { blocks["proj_out"] = std::shared_ptr(new Linear(inner_dim, params.patch_size * params.patch_size * params.out_channels)); } - struct ggml_tensor* pad_to_patch_size(struct ggml_context* ctx, + struct ggml_tensor* pad_to_patch_size(GGMLRunnerContext* ctx, struct ggml_tensor* x) { int64_t W = x->ne[0]; int64_t H = x->ne[1]; int pad_h = (params.patch_size - H % params.patch_size) % params.patch_size; int pad_w = (params.patch_size - W % params.patch_size) % params.patch_size; - x = ggml_pad(ctx, x, pad_w, pad_h, 0, 0); // [N, C, H + pad_h, W + pad_w] + x = sd_pad(ctx->ggml_ctx, x, pad_w, pad_h, 0, 0, ctx->circular_x_enabled, ctx->circular_y_enabled); return x; } @@ -387,10 +387,10 @@ namespace Qwen { return x; } - struct ggml_tensor* process_img(struct ggml_context* ctx, + struct ggml_tensor* process_img(GGMLRunnerContext* ctx, struct ggml_tensor* x) { x = pad_to_patch_size(ctx, x); - x = patchify(ctx, x); + x = patchify(ctx->ggml_ctx, x); return x; } @@ -466,12 +466,12 @@ namespace Qwen { int64_t C = x->ne[2]; int64_t N = x->ne[3]; - auto img = process_img(ctx->ggml_ctx, x); + auto img = process_img(ctx, x); uint64_t img_tokens = img->ne[1]; if (ref_latents.size() > 0) { for (ggml_tensor* ref : ref_latents) { - ref = process_img(ctx->ggml_ctx, ref); + ref = process_img(ctx, ref); img = ggml_concat(ctx->ggml_ctx, img, ref, 1); } } @@ -565,6 +565,8 @@ namespace Qwen { ref_latents, increase_ref_index, qwen_image_params.theta, + circular_y_enabled, + circular_x_enabled, qwen_image_params.axes_dim); int pos_len = pe_vec.size() / qwen_image_params.axes_dim_sum / 2; // LOG_DEBUG("pos_len %d", pos_len); @@ -684,4 +686,4 @@ namespace Qwen { } // namespace name -#endif // __QWEN_IMAGE_HPP__ \ No newline at end of file +#endif // __QWEN_IMAGE_HPP__ diff --git a/rope.hpp b/rope.hpp index 4abc51469..682cb641a 100644 --- a/rope.hpp +++ b/rope.hpp @@ -1,6 +1,8 @@ #ifndef __ROPE_HPP__ #define __ROPE_HPP__ +#include +#include #include #include "ggml_extend.hpp" @@ -39,32 +41,51 @@ namespace Rope { return flat_vec; } - __STATIC_INLINE__ std::vector> rope(const std::vector& pos, int dim, int theta) { + __STATIC_INLINE__ std::vector> rope(const std::vector& pos, + int dim, + int theta, + const std::vector* wrap_dims = nullptr) { assert(dim % 2 == 0); int half_dim = dim / 2; + std::vector> result(pos.size(), std::vector(half_dim * 4)); + std::vector scale = linspace(0.f, (dim * 1.f - 2) / dim, half_dim); std::vector omega(half_dim); for (int i = 0; i < half_dim; ++i) { - omega[i] = 1.0 / std::pow(theta, scale[i]); + omega[i] = 1.0f / std::pow(theta, scale[i]); } - int pos_size = pos.size(); - std::vector> out(pos_size, std::vector(half_dim)); - for (int i = 0; i < pos_size; ++i) { + for (size_t i = 0; i < pos.size(); ++i) { + float position = pos[i]; for (int j = 0; j < half_dim; ++j) { - out[i][j] = pos[i] * omega[j]; - } - } - - std::vector> result(pos_size, std::vector(half_dim * 4)); - for (int i = 0; i < pos_size; ++i) { - for (int j = 0; j < half_dim; ++j) { - result[i][4 * j] = std::cos(out[i][j]); - result[i][4 * j + 1] = -std::sin(out[i][j]); - result[i][4 * j + 2] = std::sin(out[i][j]); - result[i][4 * j + 3] = std::cos(out[i][j]); + float omega_val = omega[j]; + float original_angle = position * omega_val; + float angle = original_angle; + int wrap_dim = 0; + if (wrap_dims != nullptr && !wrap_dims->empty()) { + size_t wrap_size = wrap_dims->size(); + // mod batch size since we only store this for one item in the batch + size_t wrap_idx = wrap_size > 0 ? (i % wrap_size) : 0; + wrap_dim = (*wrap_dims)[wrap_idx]; + } + if (wrap_dim > 0) { + constexpr float TWO_PI = 6.28318530717958647692f; + float wrap_f = static_cast(wrap_dim); + float cycles = omega_val * wrap_f / TWO_PI; + // closest periodic harmonic, necessary to ensure things neatly tile + // without this round, things don't tile at the boundaries and you end up + // with the model knowing what is "center" + float rounded = std::round(cycles); + angle = position * TWO_PI * rounded / wrap_f; + } + float sin_val = std::sin(angle); + float cos_val = std::cos(angle); + result[i][4 * j] = cos_val; + result[i][4 * j + 1] = -sin_val; + result[i][4 * j + 2] = sin_val; + result[i][4 * j + 3] = cos_val; } } @@ -137,7 +158,8 @@ namespace Rope { __STATIC_INLINE__ std::vector embed_nd(const std::vector>& ids, int bs, int theta, - const std::vector& axes_dim) { + const std::vector& axes_dim, + const std::vector>* wrap_dims = nullptr) { std::vector> trans_ids = transpose(ids); size_t pos_len = ids.size() / bs; int num_axes = axes_dim.size(); @@ -152,7 +174,12 @@ namespace Rope { std::vector> emb(bs * pos_len, std::vector(emb_dim * 2 * 2, 0.0)); int offset = 0; for (int i = 0; i < num_axes; ++i) { - std::vector> rope_emb = rope(trans_ids[i], axes_dim[i], theta); // [bs*pos_len, axes_dim[i]/2 * 2 * 2] + const std::vector* axis_wrap_dims = nullptr; + if (wrap_dims != nullptr && i < (int)wrap_dims->size()) { + axis_wrap_dims = &(*wrap_dims)[i]; + } + std::vector> rope_emb = + rope(trans_ids[i], axes_dim[i], theta, axis_wrap_dims); // [bs*pos_len, axes_dim[i]/2 * 2 * 2] for (int b = 0; b < bs; ++b) { for (int j = 0; j < pos_len; ++j) { for (int k = 0; k < rope_emb[0].size(); ++k) { @@ -239,6 +266,8 @@ namespace Rope { bool increase_ref_index, float ref_index_scale, int theta, + bool circular_h, + bool circular_w, const std::vector& axes_dim) { std::vector> ids = gen_flux_ids(h, w, @@ -250,7 +279,48 @@ namespace Rope { ref_latents, increase_ref_index, ref_index_scale); - return embed_nd(ids, bs, theta, axes_dim); + std::vector> wrap_dims; + if ((circular_h || circular_w) && bs > 0 && axes_dim.size() >= 3) { + int h_len = (h + (patch_size / 2)) / patch_size; + int w_len = (w + (patch_size / 2)) / patch_size; + if (h_len > 0 && w_len > 0) { + size_t pos_len = ids.size() / bs; + wrap_dims.assign(axes_dim.size(), std::vector(pos_len, 0)); + size_t cursor = context_len; // text first + const size_t img_tokens = static_cast(h_len) * static_cast(w_len); + for (size_t token_i = 0; token_i < img_tokens; ++token_i) { + if (circular_h) { + wrap_dims[1][cursor + token_i] = h_len; + } + if (circular_w) { + wrap_dims[2][cursor + token_i] = w_len; + } + } + cursor += img_tokens; + // reference latents + for (ggml_tensor* ref : ref_latents) { + if (ref == nullptr) { + continue; + } + int ref_h = static_cast(ref->ne[1]); + int ref_w = static_cast(ref->ne[0]); + int ref_h_l = (ref_h + (patch_size / 2)) / patch_size; + int ref_w_l = (ref_w + (patch_size / 2)) / patch_size; + size_t ref_tokens = static_cast(ref_h_l) * static_cast(ref_w_l); + for (size_t token_i = 0; token_i < ref_tokens; ++token_i) { + if (circular_h) { + wrap_dims[1][cursor + token_i] = ref_h_l; + } + if (circular_w) { + wrap_dims[2][cursor + token_i] = ref_w_l; + } + } + cursor += ref_tokens; + } + } + } + const std::vector>* wraps_ptr = wrap_dims.empty() ? nullptr : &wrap_dims; + return embed_nd(ids, bs, theta, axes_dim, wraps_ptr); } __STATIC_INLINE__ std::vector> gen_qwen_image_ids(int h, @@ -289,9 +359,58 @@ namespace Rope { const std::vector& ref_latents, bool increase_ref_index, int theta, + bool circular_h, + bool circular_w, const std::vector& axes_dim) { std::vector> ids = gen_qwen_image_ids(h, w, patch_size, bs, context_len, ref_latents, increase_ref_index); - return embed_nd(ids, bs, theta, axes_dim); + std::vector> wrap_dims; + // This logic simply stores the (pad and patch_adjusted) sizes of images so we can make sure rope correctly tiles + if ((circular_h || circular_w) && bs > 0 && axes_dim.size() >= 3) { + int pad_h = (patch_size - (h % patch_size)) % patch_size; + int pad_w = (patch_size - (w % patch_size)) % patch_size; + int h_len = (h + pad_h) / patch_size; + int w_len = (w + pad_w) / patch_size; + if (h_len > 0 && w_len > 0) { + const size_t total_tokens = ids.size(); + // Track per-token wrap lengths for the row/column axes so only spatial tokens become periodic. + wrap_dims.assign(axes_dim.size(), std::vector(total_tokens / bs, 0)); + size_t cursor = context_len; // ignore text tokens + const size_t img_tokens = static_cast(h_len) * static_cast(w_len); + for (size_t token_i = 0; token_i < img_tokens; ++token_i) { + if (circular_h) { + wrap_dims[1][cursor + token_i] = h_len; + } + if (circular_w) { + wrap_dims[2][cursor + token_i] = w_len; + } + } + cursor += img_tokens; + // For each reference image, store wrap sizes as well + for (ggml_tensor* ref : ref_latents) { + if (ref == nullptr) { + continue; + } + int ref_h = static_cast(ref->ne[1]); + int ref_w = static_cast(ref->ne[0]); + int ref_pad_h = (patch_size - (ref_h % patch_size)) % patch_size; + int ref_pad_w = (patch_size - (ref_w % patch_size)) % patch_size; + int ref_h_len = (ref_h + ref_pad_h) / patch_size; + int ref_w_len = (ref_w + ref_pad_w) / patch_size; + size_t ref_n_tokens = static_cast(ref_h_len) * static_cast(ref_w_len); + for (size_t token_i = 0; token_i < ref_n_tokens; ++token_i) { + if (circular_h) { + wrap_dims[1][cursor + token_i] = ref_h_len; + } + if (circular_w) { + wrap_dims[2][cursor + token_i] = ref_w_len; + } + } + cursor += ref_n_tokens; + } + } + } + const std::vector>* wraps_ptr = wrap_dims.empty() ? nullptr : &wrap_dims; + return embed_nd(ids, bs, theta, axes_dim, wraps_ptr); } __STATIC_INLINE__ std::vector> gen_vid_ids(int t, @@ -428,9 +547,34 @@ namespace Rope { const std::vector& ref_latents, bool increase_ref_index, int theta, + bool circular_h, + bool circular_w, const std::vector& axes_dim) { std::vector> ids = gen_z_image_ids(h, w, patch_size, bs, context_len, seq_multi_of, ref_latents, increase_ref_index); - return embed_nd(ids, bs, theta, axes_dim); + std::vector> wrap_dims; + if ((circular_h || circular_w) && bs > 0 && axes_dim.size() >= 3) { + int pad_h = (patch_size - (h % patch_size)) % patch_size; + int pad_w = (patch_size - (w % patch_size)) % patch_size; + int h_len = (h + pad_h) / patch_size; + int w_len = (w + pad_w) / patch_size; + if (h_len > 0 && w_len > 0) { + size_t pos_len = ids.size() / bs; + wrap_dims.assign(axes_dim.size(), std::vector(pos_len, 0)); + size_t cursor = context_len + bound_mod(context_len, seq_multi_of); // skip text (and its padding) + size_t img_tokens = static_cast(h_len) * static_cast(w_len); + for (size_t token_i = 0; token_i < img_tokens; ++token_i) { + if (circular_h) { + wrap_dims[1][cursor + token_i] = h_len; + } + if (circular_w) { + wrap_dims[2][cursor + token_i] = w_len; + } + } + } + } + + const std::vector>* wraps_ptr = wrap_dims.empty() ? nullptr : &wrap_dims; + return embed_nd(ids, bs, theta, axes_dim, wraps_ptr); } __STATIC_INLINE__ struct ggml_tensor* apply_rope(struct ggml_context* ctx, diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 1ef851247..d94134602 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -127,6 +127,9 @@ class StableDiffusionGGML { bool use_tiny_autoencoder = false; sd_tiling_params_t vae_tiling_params = {false, 0, 0, 0.5f, 0, 0}; bool offload_params_to_cpu = false; + bool circular = false; + bool circular_x = false; + bool circular_y = false; bool stacked_id = false; bool is_using_v_parameterization = false; @@ -210,6 +213,10 @@ class StableDiffusionGGML { taesd_path = SAFE_STR(sd_ctx_params->taesd_path); use_tiny_autoencoder = taesd_path.size() > 0; offload_params_to_cpu = sd_ctx_params->offload_params_to_cpu; + circular = sd_ctx_params->circular; + circular_x = sd_ctx_params->circular_x || circular; + circular_y = sd_ctx_params->circular_y || circular; + bool circular_any = circular || circular_x || circular_y; rng = get_rng(sd_ctx_params->rng_type); if (sd_ctx_params->sampler_rng_type != RNG_TYPE_COUNT && sd_ctx_params->sampler_rng_type != sd_ctx_params->rng_type) { @@ -387,6 +394,10 @@ class StableDiffusionGGML { vae_decode_only = false; } + if (circular_any) { + LOG_INFO("Using circular padding for convolutions"); + } + bool clip_on_cpu = sd_ctx_params->keep_clip_on_cpu; { @@ -402,6 +413,7 @@ class StableDiffusionGGML { diffusion_model = std::make_shared(backend, offload_params_to_cpu, tensor_storage_map); + diffusion_model->set_circular_axes(circular_x, circular_y); } else if (sd_version_is_flux(version)) { bool is_chroma = false; for (auto pair : tensor_storage_map) { @@ -442,6 +454,7 @@ class StableDiffusionGGML { tensor_storage_map, version, sd_ctx_params->chroma_use_dit_mask); + diffusion_model->set_circular_axes(circular_x, circular_y); } else if (sd_version_is_flux2(version)) { bool is_chroma = false; cond_stage_model = std::make_shared(clip_backend, @@ -449,10 +462,11 @@ class StableDiffusionGGML { tensor_storage_map, version); diffusion_model = std::make_shared(backend, - offload_params_to_cpu, - tensor_storage_map, - version, - sd_ctx_params->chroma_use_dit_mask); + offload_params_to_cpu, + tensor_storage_map, + version, + sd_ctx_params->chroma_use_dit_mask); + diffusion_model->set_circular_axes(circular_x, circular_y); } else if (sd_version_is_wan(version)) { cond_stage_model = std::make_shared(clip_backend, offload_params_to_cpu, @@ -465,12 +479,14 @@ class StableDiffusionGGML { tensor_storage_map, "model.diffusion_model", version); + diffusion_model->set_circular_axes(circular_x, circular_y); if (strlen(SAFE_STR(sd_ctx_params->high_noise_diffusion_model_path)) > 0) { high_noise_diffusion_model = std::make_shared(backend, offload_params_to_cpu, tensor_storage_map, "model.high_noise_diffusion_model", version); + high_noise_diffusion_model->set_circular_axes(circular_x, circular_y); } if (diffusion_model->get_desc() == "Wan2.1-I2V-14B" || diffusion_model->get_desc() == "Wan2.1-FLF2V-14B" || @@ -497,6 +513,7 @@ class StableDiffusionGGML { tensor_storage_map, "model.diffusion_model", version); + diffusion_model->set_circular_axes(circular_x, circular_y); } else if (sd_version_is_z_image(version)) { cond_stage_model = std::make_shared(clip_backend, offload_params_to_cpu, @@ -507,6 +524,7 @@ class StableDiffusionGGML { tensor_storage_map, "model.diffusion_model", version); + diffusion_model->set_circular_axes(circular_x, circular_y); } else { // SD1.x SD2.x SDXL std::map embbeding_map; for (int i = 0; i < sd_ctx_params->embedding_count; i++) { @@ -526,15 +544,17 @@ class StableDiffusionGGML { embbeding_map, version); } - diffusion_model = std::make_shared(backend, - offload_params_to_cpu, - tensor_storage_map, - version); - if (sd_ctx_params->diffusion_conv_direct) { - LOG_INFO("Using Conv2d direct in the diffusion model"); - std::dynamic_pointer_cast(diffusion_model)->unet.set_conv2d_direct_enabled(true); - } - } + diffusion_model = std::make_shared(backend, + offload_params_to_cpu, + tensor_storage_map, + version); + if (sd_ctx_params->diffusion_conv_direct) { + LOG_INFO("Using Conv2d direct in the diffusion model"); + std::dynamic_pointer_cast(diffusion_model)->unet.set_conv2d_direct_enabled(true); + } + diffusion_model->set_circular_axes(circular_x, circular_y); + std::dynamic_pointer_cast(diffusion_model)->unet.set_circular_axes(circular_x, circular_y); + } if (sd_ctx_params->diffusion_flash_attn) { LOG_INFO("Using flash attention in the diffusion model"); @@ -570,6 +590,7 @@ class StableDiffusionGGML { "first_stage_model", vae_decode_only, version); + first_stage_model->set_circular_axes(circular_x, circular_y); first_stage_model->alloc_params_buffer(); first_stage_model->get_param_tensors(tensors, "first_stage_model"); } else if (version == VERSION_CHROMA_RADIANCE) { @@ -596,6 +617,7 @@ class StableDiffusionGGML { vae_conv_2d_scale); first_stage_model->set_conv2d_scale(vae_conv_2d_scale); } + first_stage_model->set_circular_axes(circular_x, circular_y); first_stage_model->alloc_params_buffer(); first_stage_model->get_param_tensors(tensors, "first_stage_model"); } @@ -610,6 +632,7 @@ class StableDiffusionGGML { LOG_INFO("Using Conv2d direct in the tae model"); tae_first_stage->set_conv2d_direct_enabled(true); } + tae_first_stage->set_circular_axes(circular_x, circular_y); } // first_stage_model->get_param_tensors(tensors, "first_stage_model."); @@ -629,6 +652,7 @@ class StableDiffusionGGML { LOG_INFO("Using Conv2d direct in the control net"); control_net->set_conv2d_direct_enabled(true); } + control_net->set_circular_axes(circular_x, circular_y); } if (strstr(SAFE_STR(sd_ctx_params->photo_maker_path), "v2")) { @@ -2511,6 +2535,9 @@ void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) { sd_ctx_params->keep_control_net_on_cpu = false; sd_ctx_params->keep_vae_on_cpu = false; sd_ctx_params->diffusion_flash_attn = false; + sd_ctx_params->circular = false; + sd_ctx_params->circular_x = false; + sd_ctx_params->circular_y = false; sd_ctx_params->chroma_use_dit_mask = true; sd_ctx_params->chroma_use_t5_mask = false; sd_ctx_params->chroma_t5_mask_pad = 1; @@ -2551,6 +2578,9 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) { "keep_control_net_on_cpu: %s\n" "keep_vae_on_cpu: %s\n" "diffusion_flash_attn: %s\n" + "circular: %s\n" + "circular_x: %s\n" + "circular_y: %s\n" "chroma_use_dit_mask: %s\n" "chroma_use_t5_mask: %s\n" "chroma_t5_mask_pad: %d\n", @@ -2581,6 +2611,9 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) { BOOL_STR(sd_ctx_params->keep_control_net_on_cpu), BOOL_STR(sd_ctx_params->keep_vae_on_cpu), BOOL_STR(sd_ctx_params->diffusion_flash_attn), + BOOL_STR(sd_ctx_params->circular), + BOOL_STR(sd_ctx_params->circular_x), + BOOL_STR(sd_ctx_params->circular_y), BOOL_STR(sd_ctx_params->chroma_use_dit_mask), BOOL_STR(sd_ctx_params->chroma_use_t5_mask), sd_ctx_params->chroma_t5_mask_pad); diff --git a/stable-diffusion.h b/stable-diffusion.h index 2da70bd77..4ef3799b0 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -189,6 +189,9 @@ typedef struct { bool tae_preview_only; bool diffusion_conv_direct; bool vae_conv_direct; + bool circular; + bool circular_x; + bool circular_y; bool force_sdxl_vae_conv_scale; bool chroma_use_dit_mask; bool chroma_use_t5_mask; diff --git a/wan.hpp b/wan.hpp index 75333bfe1..8e5984622 100644 --- a/wan.hpp +++ b/wan.hpp @@ -75,7 +75,7 @@ namespace WAN { lp2 -= (int)cache_x->ne[2]; } - x = ggml_pad_ext(ctx->ggml_ctx, x, lp0, rp0, lp1, rp1, lp2, rp2, 0, 0); + x = sd_pad_ext(ctx->ggml_ctx, x, lp0, rp0, lp1, rp1, lp2, rp2, 0, 0, ctx->circular_x_enabled, ctx->circular_y_enabled); return ggml_ext_conv_3d(ctx->ggml_ctx, x, w, b, in_channels, std::get<2>(stride), std::get<1>(stride), std::get<0>(stride), 0, 0, 0, @@ -206,9 +206,9 @@ namespace WAN { } else if (mode == "upsample3d") { x = ggml_upscale(ctx->ggml_ctx, x, 2, GGML_SCALE_MODE_NEAREST); } else if (mode == "downsample2d") { - x = ggml_pad(ctx->ggml_ctx, x, 1, 1, 0, 0); + x = sd_pad(ctx->ggml_ctx, x, 1, 1, 0, 0, ctx->circular_x_enabled, ctx->circular_y_enabled); } else if (mode == "downsample3d") { - x = ggml_pad(ctx->ggml_ctx, x, 1, 1, 0, 0); + x = sd_pad(ctx->ggml_ctx, x, 1, 1, 0, 0, ctx->circular_x_enabled, ctx->circular_y_enabled); } x = resample_1->forward(ctx, x); x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 0, 1, 3, 2)); // (c, t, h, w) @@ -1826,7 +1826,7 @@ namespace WAN { } } - struct ggml_tensor* pad_to_patch_size(struct ggml_context* ctx, + struct ggml_tensor* pad_to_patch_size(GGMLRunnerContext* ctx, struct ggml_tensor* x) { int64_t W = x->ne[0]; int64_t H = x->ne[1]; @@ -1835,8 +1835,7 @@ namespace WAN { int pad_t = (std::get<0>(params.patch_size) - T % std::get<0>(params.patch_size)) % std::get<0>(params.patch_size); int pad_h = (std::get<1>(params.patch_size) - H % std::get<1>(params.patch_size)) % std::get<1>(params.patch_size); int pad_w = (std::get<2>(params.patch_size) - W % std::get<2>(params.patch_size)) % std::get<2>(params.patch_size); - x = ggml_pad(ctx, x, pad_w, pad_h, pad_t, 0); // [N*C, T + pad_t, H + pad_h, W + pad_w] - + sd_pad(ctx->ggml_ctx, x, pad_w, pad_h, pad_t, 0, ctx->circular_x_enabled, ctx->circular_y_enabled); return x; } @@ -1986,14 +1985,14 @@ namespace WAN { int64_t T = x->ne[2]; int64_t C = x->ne[3]; - x = pad_to_patch_size(ctx->ggml_ctx, x); + x = pad_to_patch_size(ctx, x); int64_t t_len = ((T + (std::get<0>(params.patch_size) / 2)) / std::get<0>(params.patch_size)); int64_t h_len = ((H + (std::get<1>(params.patch_size) / 2)) / std::get<1>(params.patch_size)); int64_t w_len = ((W + (std::get<2>(params.patch_size) / 2)) / std::get<2>(params.patch_size)); if (time_dim_concat != nullptr) { - time_dim_concat = pad_to_patch_size(ctx->ggml_ctx, time_dim_concat); + time_dim_concat = pad_to_patch_size(ctx, time_dim_concat); x = ggml_concat(ctx->ggml_ctx, x, time_dim_concat, 2); // [N*C, (T+pad_t) + (T2+pad_t2), H + pad_h, W + pad_w] t_len = ((x->ne[2] + (std::get<0>(params.patch_size) / 2)) / std::get<0>(params.patch_size)); } diff --git a/z_image.hpp b/z_image.hpp index bc554f177..5a53fe675 100644 --- a/z_image.hpp +++ b/z_image.hpp @@ -324,14 +324,14 @@ namespace ZImage { blocks["final_layer"] = std::make_shared(z_image_params.hidden_size, z_image_params.patch_size, z_image_params.out_channels); } - struct ggml_tensor* pad_to_patch_size(struct ggml_context* ctx, + struct ggml_tensor* pad_to_patch_size(GGMLRunnerContext* ctx, struct ggml_tensor* x) { int64_t W = x->ne[0]; int64_t H = x->ne[1]; int pad_h = (z_image_params.patch_size - H % z_image_params.patch_size) % z_image_params.patch_size; int pad_w = (z_image_params.patch_size - W % z_image_params.patch_size) % z_image_params.patch_size; - x = ggml_pad(ctx, x, pad_w, pad_h, 0, 0); // [N, C, H + pad_h, W + pad_w] + x = sd_pad(ctx->ggml_ctx, x, pad_w, pad_h, 0, 0, ctx->circular_x_enabled, ctx->circular_y_enabled); return x; } @@ -357,10 +357,10 @@ namespace ZImage { return x; } - struct ggml_tensor* process_img(struct ggml_context* ctx, + struct ggml_tensor* process_img(GGMLRunnerContext* ctx, struct ggml_tensor* x) { x = pad_to_patch_size(ctx, x); - x = patchify(ctx, x); + x = patchify(ctx->ggml_ctx, x); return x; } @@ -473,12 +473,12 @@ namespace ZImage { int64_t C = x->ne[2]; int64_t N = x->ne[3]; - auto img = process_img(ctx->ggml_ctx, x); + auto img = process_img(ctx, x); uint64_t n_img_token = img->ne[1]; if (ref_latents.size() > 0) { for (ggml_tensor* ref : ref_latents) { - ref = process_img(ctx->ggml_ctx, ref); + ref = process_img(ctx, ref); img = ggml_concat(ctx->ggml_ctx, img, ref, 1); } } @@ -552,6 +552,8 @@ namespace ZImage { ref_latents, increase_ref_index, z_image_params.theta, + circular_y_enabled, + circular_x_enabled, z_image_params.axes_dim); int pos_len = pe_vec.size() / z_image_params.axes_dim_sum / 2; // LOG_DEBUG("pos_len %d", pos_len);