Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion clip.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -664,7 +664,7 @@ class CLIPVisionEmbeddings : public GGMLBlock {
// concat(patch_embedding, class_embedding) + position_embedding
struct ggml_tensor* patch_embedding;
int64_t N = pixel_values->ne[3];
patch_embedding = ggml_ext_conv_2d(ctx->ggml_ctx, pixel_values, patch_embed_weight, nullptr, patch_size, patch_size); // [N, embed_dim, image_size // pacht_size, image_size // pacht_size]
patch_embedding = ggml_ext_conv_2d(ctx->ggml_ctx, pixel_values, patch_embed_weight, nullptr, patch_size, patch_size, 0, 0, 1, 1, false, ctx->circular_x_enabled, ctx->circular_y_enabled); // [N, embed_dim, image_size // pacht_size, image_size // pacht_size]
patch_embedding = ggml_reshape_3d(ctx->ggml_ctx, patch_embedding, num_patches, embed_dim, N); // [N, embed_dim, num_patches]
patch_embedding = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, patch_embedding, 1, 0, 2, 3)); // [N, num_patches, embed_dim]
patch_embedding = ggml_reshape_4d(ctx->ggml_ctx, patch_embedding, 1, embed_dim, num_patches, N); // [N, num_patches, embed_dim, 1]
Expand Down
4 changes: 3 additions & 1 deletion common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ class DownSampleBlock : public GGMLBlock {
if (vae_downsample) {
auto conv = std::dynamic_pointer_cast<Conv2d>(blocks["conv"]);

x = ggml_pad(ctx->ggml_ctx, x, 1, 1, 0, 0);
// For VAE downsampling we manually pad by 1 before the stride-2 conv.
// Honor the global circular padding flags here to avoid seams in seamless mode.
x = sd_pad(ctx->ggml_ctx, x, 1, 1, 0, 0, ctx->circular_x_enabled, ctx->circular_y_enabled);
x = conv->forward(ctx, x);
} else {
auto conv = std::dynamic_pointer_cast<Conv2d>(blocks["op"]);
Expand Down
25 changes: 25 additions & 0 deletions diffusion_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ struct DiffusionModel {
virtual void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter){};
virtual int64_t get_adm_in_channels() = 0;
virtual void set_flash_attn_enabled(bool enabled) = 0;
virtual void set_circular_axes(bool circular_x, bool circular_y) = 0;
};

struct UNetModel : public DiffusionModel {
Expand Down Expand Up @@ -87,6 +88,10 @@ struct UNetModel : public DiffusionModel {
unet.set_flash_attention_enabled(enabled);
}

void set_circular_axes(bool circular_x, bool circular_y) override {
unet.set_circular_axes(circular_x, circular_y);
}

bool compute(int n_threads,
DiffusionParams diffusion_params,
struct ggml_tensor** output = nullptr,
Expand Down Expand Up @@ -148,6 +153,10 @@ struct MMDiTModel : public DiffusionModel {
mmdit.set_flash_attention_enabled(enabled);
}

void set_circular_axes(bool circular_x, bool circular_y) override {
mmdit.set_circular_axes(circular_x, circular_y);
}

bool compute(int n_threads,
DiffusionParams diffusion_params,
struct ggml_tensor** output = nullptr,
Expand Down Expand Up @@ -210,6 +219,10 @@ struct FluxModel : public DiffusionModel {
flux.set_flash_attention_enabled(enabled);
}

void set_circular_axes(bool circular_x, bool circular_y) override {
flux.set_circular_axes(circular_x, circular_y);
}

bool compute(int n_threads,
DiffusionParams diffusion_params,
struct ggml_tensor** output = nullptr,
Expand Down Expand Up @@ -277,6 +290,10 @@ struct WanModel : public DiffusionModel {
wan.set_flash_attention_enabled(enabled);
}

void set_circular_axes(bool circular_x, bool circular_y) override {
wan.set_circular_axes(circular_x, circular_y);
}

bool compute(int n_threads,
DiffusionParams diffusion_params,
struct ggml_tensor** output = nullptr,
Expand Down Expand Up @@ -343,6 +360,10 @@ struct QwenImageModel : public DiffusionModel {
qwen_image.set_flash_attention_enabled(enabled);
}

void set_circular_axes(bool circular_x, bool circular_y) override {
qwen_image.set_circular_axes(circular_x, circular_y);
}

bool compute(int n_threads,
DiffusionParams diffusion_params,
struct ggml_tensor** output = nullptr,
Expand Down Expand Up @@ -406,6 +427,10 @@ struct ZImageModel : public DiffusionModel {
z_image.set_flash_attention_enabled(enabled);
}

void set_circular_axes(bool circular_x, bool circular_y) override {
z_image.set_circular_axes(circular_x, circular_y);
}

bool compute(int n_threads,
DiffusionParams diffusion_params,
struct ggml_tensor** output = nullptr,
Expand Down
21 changes: 21 additions & 0 deletions examples/cli/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,9 @@ struct SDContextParams {
bool diffusion_flash_attn = false;
bool diffusion_conv_direct = false;
bool vae_conv_direct = false;
bool circular = false;
bool circular_x = false;
bool circular_y = false;

bool chroma_use_dit_mask = true;
bool chroma_use_t5_mask = false;
Expand Down Expand Up @@ -671,6 +674,18 @@ struct SDContextParams {
"--vae-conv-direct",
"use ggml_conv2d_direct in the vae model",
true, &vae_conv_direct},
{"",
"--circular",
"enable circular padding for convolutions",
true, &circular},
{"",
"--circularx",
"enable circular RoPE wrapping on x-axis (width) only",
true, &circular_x},
{"",
"--circulary",
"enable circular RoPE wrapping on y-axis (height) only",
true, &circular_y},
{"",
"--chroma-disable-dit-mask",
"disable dit mask for chroma",
Expand Down Expand Up @@ -934,6 +949,9 @@ struct SDContextParams {
<< " diffusion_flash_attn: " << (diffusion_flash_attn ? "true" : "false") << ",\n"
<< " diffusion_conv_direct: " << (diffusion_conv_direct ? "true" : "false") << ",\n"
<< " vae_conv_direct: " << (vae_conv_direct ? "true" : "false") << ",\n"
<< " circular: " << (circular ? "true" : "false") << ",\n"
<< " circular_x: " << (circular_x ? "true" : "false") << ",\n"
<< " circular_y: " << (circular_y ? "true" : "false") << ",\n"
<< " chroma_use_dit_mask: " << (chroma_use_dit_mask ? "true" : "false") << ",\n"
<< " chroma_use_t5_mask: " << (chroma_use_t5_mask ? "true" : "false") << ",\n"
<< " chroma_t5_mask_pad: " << chroma_t5_mask_pad << ",\n"
Expand Down Expand Up @@ -995,6 +1013,9 @@ struct SDContextParams {
taesd_preview,
diffusion_conv_direct,
vae_conv_direct,
circular,
circular || circular_x,
circular || circular_y,
force_sdxl_vae_conv_scale,
chroma_use_dit_mask,
chroma_use_t5_mask,
Expand Down
28 changes: 15 additions & 13 deletions flux.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -858,14 +858,14 @@ namespace Flux {
}
}

struct ggml_tensor* pad_to_patch_size(struct ggml_context* ctx,
struct ggml_tensor* pad_to_patch_size(GGMLRunnerContext* ctx,
struct ggml_tensor* x) {
int64_t W = x->ne[0];
int64_t H = x->ne[1];

int pad_h = (params.patch_size - H % params.patch_size) % params.patch_size;
int pad_w = (params.patch_size - W % params.patch_size) % params.patch_size;
x = ggml_pad(ctx, x, pad_w, pad_h, 0, 0); // [N, C, H + pad_h, W + pad_w]
x = sd_pad(ctx->ggml_ctx, x, pad_w, pad_h, 0, 0, ctx->circular_x_enabled, ctx->circular_y_enabled);
return x;
}

Expand All @@ -891,11 +891,11 @@ namespace Flux {
return x;
}

struct ggml_tensor* process_img(struct ggml_context* ctx,
struct ggml_tensor* process_img(GGMLRunnerContext* ctx,
struct ggml_tensor* x) {
// img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
x = pad_to_patch_size(ctx, x);
x = patchify(ctx, x);
x = patchify(ctx->ggml_ctx, x);
return x;
}

Expand Down Expand Up @@ -1065,7 +1065,7 @@ namespace Flux {
int pad_h = (patch_size - H % patch_size) % patch_size;
int pad_w = (patch_size - W % patch_size) % patch_size;

auto img = pad_to_patch_size(ctx->ggml_ctx, x);
auto img = pad_to_patch_size(ctx, x);
auto orig_img = img;

auto img_in_patch = std::dynamic_pointer_cast<Conv2d>(blocks["img_in_patch"]);
Expand Down Expand Up @@ -1128,16 +1128,16 @@ namespace Flux {
int pad_h = (patch_size - H % patch_size) % patch_size;
int pad_w = (patch_size - W % patch_size) % patch_size;

auto img = process_img(ctx->ggml_ctx, x);
auto img = process_img(ctx, x);
uint64_t img_tokens = img->ne[1];

if (params.version == VERSION_FLUX_FILL) {
GGML_ASSERT(c_concat != nullptr);
ggml_tensor* masked = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0);
ggml_tensor* mask = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 8 * 8, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);

masked = process_img(ctx->ggml_ctx, masked);
mask = process_img(ctx->ggml_ctx, mask);
masked = process_img(ctx, masked);
mask = process_img(ctx, mask);

img = ggml_concat(ctx->ggml_ctx, img, ggml_concat(ctx->ggml_ctx, masked, mask, 0), 0);
} else if (params.version == VERSION_FLEX_2) {
Expand All @@ -1146,21 +1146,21 @@ namespace Flux {
ggml_tensor* mask = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 1, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);
ggml_tensor* control = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * (C + 1));

masked = process_img(ctx->ggml_ctx, masked);
mask = process_img(ctx->ggml_ctx, mask);
control = process_img(ctx->ggml_ctx, control);
masked = process_img(ctx, masked);
mask = process_img(ctx, mask);
control = process_img(ctx, control);

img = ggml_concat(ctx->ggml_ctx, img, ggml_concat(ctx->ggml_ctx, ggml_concat(ctx->ggml_ctx, masked, mask, 0), control, 0), 0);
} else if (params.version == VERSION_FLUX_CONTROLS) {
GGML_ASSERT(c_concat != nullptr);

auto control = process_img(ctx->ggml_ctx, c_concat);
auto control = process_img(ctx, c_concat);
img = ggml_concat(ctx->ggml_ctx, img, control, 0);
}

if (ref_latents.size() > 0) {
for (ggml_tensor* ref : ref_latents) {
ref = process_img(ctx->ggml_ctx, ref);
ref = process_img(ctx, ref);
img = ggml_concat(ctx->ggml_ctx, img, ref, 1);
}
}
Expand Down Expand Up @@ -1441,6 +1441,8 @@ namespace Flux {
increase_ref_index,
flux_params.ref_index_scale,
flux_params.theta,
circular_y_enabled,
circular_x_enabled,
flux_params.axes_dim);
int pos_len = pe_vec.size() / flux_params.axes_dim_sum / 2;
// LOG_DEBUG("pos_len %d", pos_len);
Expand Down
2 changes: 1 addition & 1 deletion ggml
Submodule ggml updated 252 files
Loading