Skip to content
291 changes: 212 additions & 79 deletions conditioner.hpp

Large diffs are not rendered by default.

70 changes: 52 additions & 18 deletions flux.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,19 @@ namespace Flux {

public:
SelfAttention(int64_t dim,
int64_t num_heads = 8,
bool qkv_bias = false,
bool proj_bias = true)
int64_t num_heads = 8,
bool qkv_bias = false,
bool proj_bias = true,
bool diffusers_style = false)
: num_heads(num_heads) {
int64_t head_dim = dim / num_heads;
blocks["qkv"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim * 3, qkv_bias));
blocks["norm"] = std::shared_ptr<GGMLBlock>(new QKNorm(head_dim));
blocks["proj"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim, proj_bias));
if (diffusers_style) {
blocks["qkv"] = std::shared_ptr<GGMLBlock>(new SplitLinear(dim, {dim, dim, dim}, qkv_bias));
} else {
blocks["qkv"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim * 3, qkv_bias));
}
blocks["norm"] = std::shared_ptr<GGMLBlock>(new QKNorm(head_dim));
blocks["proj"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim, proj_bias));
}

std::vector<struct ggml_tensor*> pre_attention(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
Expand Down Expand Up @@ -258,15 +263,16 @@ namespace Flux {
bool share_modulation = false,
bool mlp_proj_bias = true,
bool use_yak_mlp = false,
bool use_mlp_silu_act = false)
bool use_mlp_silu_act = false,
bool diffusers_style = false)
: idx(idx), prune_mod(prune_mod) {
int64_t mlp_hidden_dim = hidden_size * mlp_ratio;

if (!prune_mod && !share_modulation) {
blocks["img_mod"] = std::shared_ptr<GGMLBlock>(new Modulation(hidden_size, true));
}
blocks["img_norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
blocks["img_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias, mlp_proj_bias));
blocks["img_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias, mlp_proj_bias, diffusers_style));

blocks["img_norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
if (use_yak_mlp) {
Expand All @@ -279,7 +285,7 @@ namespace Flux {
blocks["txt_mod"] = std::shared_ptr<GGMLBlock>(new Modulation(hidden_size, true));
}
blocks["txt_norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
blocks["txt_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias, mlp_proj_bias));
blocks["txt_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias, mlp_proj_bias, diffusers_style));

blocks["txt_norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
if (use_yak_mlp) {
Expand Down Expand Up @@ -421,6 +427,7 @@ namespace Flux {
bool use_yak_mlp;
bool use_mlp_silu_act;
int64_t mlp_mult_factor;
bool diffusers_style = false;

public:
SingleStreamBlock(int64_t hidden_size,
Expand All @@ -432,7 +439,8 @@ namespace Flux {
bool share_modulation = false,
bool mlp_proj_bias = true,
bool use_yak_mlp = false,
bool use_mlp_silu_act = false)
bool use_mlp_silu_act = false,
bool diffusers_style = false)
: hidden_size(hidden_size), num_heads(num_heads), idx(idx), prune_mod(prune_mod), use_yak_mlp(use_yak_mlp), use_mlp_silu_act(use_mlp_silu_act) {
int64_t head_dim = hidden_size / num_heads;
float scale = qk_scale;
Expand All @@ -444,8 +452,11 @@ namespace Flux {
if (use_yak_mlp || use_mlp_silu_act) {
mlp_mult_factor = 2;
}

blocks["linear1"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, hidden_size * 3 + mlp_hidden_dim * mlp_mult_factor, mlp_proj_bias));
if (diffusers_style) {
blocks["linear1"] = std::shared_ptr<GGMLBlock>(new SplitLinear(hidden_size, {hidden_size, hidden_size, hidden_size, mlp_hidden_dim * mlp_mult_factor}, mlp_proj_bias));
} else {
blocks["linear1"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, hidden_size * 3 + mlp_hidden_dim * mlp_mult_factor, mlp_proj_bias));
}
blocks["linear2"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size + mlp_hidden_dim, hidden_size, mlp_proj_bias));
blocks["norm"] = std::shared_ptr<GGMLBlock>(new QKNorm(head_dim));
blocks["pre_norm"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
Expand Down Expand Up @@ -771,6 +782,7 @@ namespace Flux {
bool use_yak_mlp = false;
bool use_mlp_silu_act = false;
float ref_index_scale = 1.f;
bool diffusers_style = false;
ChromaRadianceParams chroma_radiance_params;
};

Expand Down Expand Up @@ -817,7 +829,8 @@ namespace Flux {
params.share_modulation,
!params.disable_bias,
params.use_yak_mlp,
params.use_mlp_silu_act);
params.use_mlp_silu_act,
params.diffusers_style);
}

for (int i = 0; i < params.depth_single_blocks; i++) {
Expand All @@ -830,7 +843,8 @@ namespace Flux {
params.share_modulation,
!params.disable_bias,
params.use_yak_mlp,
params.use_mlp_silu_act);
params.use_mlp_silu_act,
params.diffusers_style);
}

if (params.version == VERSION_CHROMA_RADIANCE) {
Expand Down Expand Up @@ -877,6 +891,11 @@ namespace Flux {
int64_t C = x->ne[2];
int64_t H = x->ne[1];
int64_t W = x->ne[0];
if (params.patch_size == 1) {
x = ggml_reshape_3d(ctx, x, H * W, C, N); // [N, C, H*W]
x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); // [N, H*W, C]
return x;
}
int64_t p = params.patch_size;
int64_t h = H / params.patch_size;
int64_t w = W / params.patch_size;
Expand Down Expand Up @@ -911,6 +930,12 @@ namespace Flux {
int64_t W = w * params.patch_size;
int64_t p = params.patch_size;

if (params.patch_size == 1) {
x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); // [N, C, H*W]
x = ggml_reshape_4d(ctx, x, W, H, C, N); // [N, C, H, W]
return x;
}

GGML_ASSERT(C * p * p == x->ne[0]);

x = ggml_reshape_4d(ctx, x, p * p, C, w * h, N); // [N, h*w, C, p*p]
Expand Down Expand Up @@ -1281,6 +1306,9 @@ namespace Flux {
flux_params.share_modulation = true;
flux_params.ref_index_scale = 10.f;
flux_params.use_mlp_silu_act = true;
} else if (sd_version_is_longcat(version)) {
flux_params.context_in_dim = 3584;
flux_params.vec_in_dim = 0;
}
for (auto pair : tensor_storage_map) {
std::string tensor_name = pair.first;
Expand All @@ -1290,6 +1318,9 @@ namespace Flux {
// not schnell
flux_params.guidance_embed = true;
}
if (tensor_name.find("model.diffusion_model.single_blocks.0.linear1.weight.1") != std::string::npos) {
flux_params.diffusers_style = true;
}
if (tensor_name.find("distilled_guidance_layer.in_proj.weight") != std::string::npos) {
// Chroma
flux_params.is_chroma = true;
Expand Down Expand Up @@ -1319,6 +1350,10 @@ namespace Flux {
LOG_INFO("Flux guidance is disabled (Schnell mode)");
}

if (flux_params.diffusers_style) {
LOG_INFO("Using diffusers-style attention blocks");
}

flux = Flux(flux_params);
flux.init(params_ctx, tensor_storage_map, prefix);
}
Expand Down Expand Up @@ -1430,7 +1465,6 @@ namespace Flux {
} else if (version == VERSION_OVIS_IMAGE) {
txt_arange_dims = {1, 2};
}

pe_vec = Rope::gen_flux_pe(x->ne[1],
x->ne[0],
flux_params.patch_size,
Expand All @@ -1441,10 +1475,10 @@ namespace Flux {
increase_ref_index,
flux_params.ref_index_scale,
flux_params.theta,
flux_params.axes_dim);
flux_params.axes_dim,
sd_version_is_longcat(version));
int pos_len = pe_vec.size() / flux_params.axes_dim_sum / 2;
// LOG_DEBUG("pos_len %d", pos_len);
auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, flux_params.axes_dim_sum / 2, pos_len);
auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, flux_params.axes_dim_sum / 2, pos_len);
// pe->data = pe_vec.data();
// print_ggml_tensor(pe);
// pe->data = nullptr;
Expand Down
77 changes: 77 additions & 0 deletions ggml_extend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2173,6 +2173,83 @@ class Linear : public UnaryBlock {
}
};

class SplitLinear : public Linear {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this part has no effect, I think we can remove the related code. In fact, even if it does have some effect, additional work is required to handle it when LoRA uses QKV format, so I wouldn’t really recommend this approach.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thi is used when loading Flux diffusion models with diffusers naming convention, which has the qkv matrices split as individual linear layers rather than one big linear layer. For some reason it is not quite working, not sure why.

protected:
int64_t in_features;
std::vector<int64_t> out_features_vec;
bool bias;
bool force_f32;
bool force_prec_f32;
float scale;
std::string prefix;

void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
this->prefix = prefix;
enum ggml_type wtype = get_type(prefix + "weight", tensor_storage_map, GGML_TYPE_F32);
if (in_features % ggml_blck_size(wtype) != 0 || force_f32) {
wtype = GGML_TYPE_F32;
}
params["weight"] = ggml_new_tensor_2d(ctx, wtype, in_features, out_features_vec[0]);
for (int i = 1; i < out_features_vec.size(); i++) {
// most likely same type as the first weight
params["weight." + std::to_string(i)] = ggml_new_tensor_2d(ctx, wtype, in_features, out_features_vec[i]);
}
if (bias) {
enum ggml_type wtype = GGML_TYPE_F32;
params["bias"] = ggml_new_tensor_1d(ctx, wtype, out_features_vec[0]);
for (int i = 1; i < out_features_vec.size(); i++) {
params["bias." + std::to_string(i)] = ggml_new_tensor_1d(ctx, wtype, out_features_vec[i]);
}
}
}

public:
SplitLinear(int64_t in_features,
std::vector<int64_t> out_features_vec,
bool bias = true,
bool force_f32 = false,
bool force_prec_f32 = false,
float scale = 1.f)
: Linear(in_features, out_features_vec[0], bias, force_f32, force_prec_f32, scale),
in_features(in_features),
out_features_vec(out_features_vec),
bias(bias),
force_f32(force_f32),
force_prec_f32(force_prec_f32),
scale(scale) {}

struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
struct ggml_tensor* w = params["weight"];
struct ggml_tensor* b = nullptr;
if (bias) {
b = params["bias"];
}
if (ctx->weight_adapter) {
// concat all weights and biases together so it runs in one linear layer
for (int i = 1; i < out_features_vec.size(); i++) {
w = ggml_concat(ctx->ggml_ctx, w, params["weight." + std::to_string(i)], 1);
if (bias) {
b = ggml_concat(ctx->ggml_ctx, b, params["bias." + std::to_string(i)], 0);
}
}
WeightAdapter::ForwardParams forward_params;
forward_params.op_type = WeightAdapter::ForwardParams::op_type_t::OP_LINEAR;
forward_params.linear.force_prec_f32 = force_prec_f32;
forward_params.linear.scale = scale;
return ctx->weight_adapter->forward_with_lora(ctx->ggml_ctx, x, w, b, prefix, forward_params);
}
auto out = ggml_ext_linear(ctx->ggml_ctx, x, w, b, force_prec_f32, scale);
for (int i = 1; i < out_features_vec.size(); i++) {
auto wi = params["weight." + std::to_string(i)];
auto bi = bias ? params["bias." + std::to_string(i)] : nullptr;
auto curr_out = ggml_ext_linear(ctx->ggml_ctx, x, wi, bi, force_prec_f32, scale);
out = ggml_concat(ctx->ggml_ctx, out, curr_out, 0);
}

return out;
}
};

__STATIC_INLINE__ bool support_get_rows(ggml_type wtype) {
std::set<ggml_type> allow_types = {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0};
if (allow_types.find(wtype) != allow_types.end()) {
Expand Down
29 changes: 18 additions & 11 deletions model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1027,7 +1027,7 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s
}

SDVersion ModelLoader::get_sd_version() {
TensorStorage token_embedding_weight, input_block_weight;
TensorStorage token_embedding_weight, input_block_weight, context_ebedding_weight;

bool has_multiple_encoders = false;
bool is_unet = false;
Expand All @@ -1041,7 +1041,7 @@ SDVersion ModelLoader::get_sd_version() {

for (auto& [name, tensor_storage] : tensor_storage_map) {
if (!(is_xl)) {
if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) {
if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos || tensor_storage.name.find("model.diffusion_model.single_transformer_blocks.") != std::string::npos) {
is_flux = true;
}
if (tensor_storage.name.find("model.diffusion_model.nerf_final_layer_conv.") != std::string::npos) {
Expand Down Expand Up @@ -1108,6 +1108,9 @@ SDVersion ModelLoader::get_sd_version() {
tensor_storage.name == "unet.conv_in.weight") {
input_block_weight = tensor_storage;
}
if (tensor_storage.name == "model.diffusion_model.txt_in.weight" || tensor_storage.name == "model.diffusion_model.context_embedder.weight") {
context_ebedding_weight = tensor_storage;
}
}
if (is_wan) {
LOG_DEBUG("patch_embedding_channels %d", patch_embedding_channels);
Expand Down Expand Up @@ -1135,16 +1138,20 @@ SDVersion ModelLoader::get_sd_version() {
}

if (is_flux) {
if (input_block_weight.ne[0] == 384) {
return VERSION_FLUX_FILL;
}
if (input_block_weight.ne[0] == 128) {
return VERSION_FLUX_CONTROLS;
}
if (input_block_weight.ne[0] == 196) {
return VERSION_FLEX_2;
if (context_ebedding_weight.ne[0] == 3584) {
return VERSION_LONGCAT;
} else {
if (input_block_weight.ne[0] == 384) {
return VERSION_FLUX_FILL;
}
if (input_block_weight.ne[0] == 128) {
return VERSION_FLUX_CONTROLS;
}
if (input_block_weight.ne[0] == 196) {
return VERSION_FLEX_2;
}
return VERSION_FLUX;
}
return VERSION_FLUX;
}

if (token_embedding_weight.ne[0] == 768) {
Expand Down
11 changes: 10 additions & 1 deletion model.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ enum SDVersion {
VERSION_FLUX2,
VERSION_Z_IMAGE,
VERSION_OVIS_IMAGE,
VERSION_LONGCAT,
VERSION_COUNT,
};

Expand Down Expand Up @@ -126,6 +127,13 @@ static inline bool sd_version_is_z_image(SDVersion version) {
return false;
}

static inline bool sd_version_is_longcat(SDVersion version) {
if (version == VERSION_LONGCAT) {
return true;
}
return false;
}

static inline bool sd_version_is_inpaint(SDVersion version) {
if (version == VERSION_SD1_INPAINT ||
version == VERSION_SD2_INPAINT ||
Expand All @@ -143,7 +151,8 @@ static inline bool sd_version_is_dit(SDVersion version) {
sd_version_is_sd3(version) ||
sd_version_is_wan(version) ||
sd_version_is_qwen_image(version) ||
sd_version_is_z_image(version)) {
sd_version_is_z_image(version) ||
sd_version_is_longcat(version)) {
return true;
}
return false;
Expand Down
8 changes: 7 additions & 1 deletion name_conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,12 @@ std::string convert_diffusers_dit_to_original_flux(std::string name) {
static std::unordered_map<std::string, std::string> flux_name_map;

if (flux_name_map.empty()) {
// --- time_embed (longcat) ---
flux_name_map["time_embed.timestep_embedder.linear_1.weight"] = "time_in.in_layer.weight";
flux_name_map["time_embed.timestep_embedder.linear_1.bias"] = "time_in.in_layer.bias";
flux_name_map["time_embed.timestep_embedder.linear_2.weight"] = "time_in.out_layer.weight";
flux_name_map["time_embed.timestep_embedder.linear_2.bias"] = "time_in.out_layer.bias";

// --- time_text_embed ---
flux_name_map["time_text_embed.timestep_embedder.linear_1.weight"] = "time_in.in_layer.weight";
flux_name_map["time_text_embed.timestep_embedder.linear_1.bias"] = "time_in.in_layer.bias";
Expand Down Expand Up @@ -660,7 +666,7 @@ std::string convert_diffusion_model_name(std::string name, std::string prefix, S
name = convert_diffusers_unet_to_original_sdxl(name);
} else if (sd_version_is_sd3(version)) {
name = convert_diffusers_dit_to_original_sd3(name);
} else if (sd_version_is_flux(version) || sd_version_is_flux2(version)) {
} else if (sd_version_is_flux(version) || sd_version_is_flux2(version) || sd_version_is_longcat(version)) {
name = convert_diffusers_dit_to_original_flux(name);
} else if (sd_version_is_z_image(version)) {
name = convert_diffusers_dit_to_original_lumina2(name);
Expand Down
Loading
Loading