From 8541e996298aec4d2442412502310835e6aee130 Mon Sep 17 00:00:00 2001 From: caitianchi Date: Mon, 27 May 2024 04:27:54 +0800 Subject: [PATCH] better pos_embed in clip --- examples/minicpmv/clip.cpp | 33 +++++++++++++++++++++++---------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/examples/minicpmv/clip.cpp b/examples/minicpmv/clip.cpp index d7ec7e5b1d6ba..6a726fa505c88 100644 --- a/examples/minicpmv/clip.cpp +++ b/examples/minicpmv/clip.cpp @@ -593,7 +593,7 @@ std::vector>> get_2d_sincos_pos_embed_from_grid(i return emb; } -struct ggml_tensor * get_2d_sincos_pos_embed(int embed_dim, const std::pair image_size, struct ggml_context * ctx, struct ggml_tensor * pos_embed) { +std::vector> get_2d_sincos_pos_embed(int embed_dim, const std::pair image_size) { int grid_h_size = image_size.first; int grid_w_size = image_size.second; @@ -632,13 +632,7 @@ struct ggml_tensor * get_2d_sincos_pos_embed(int embed_dim, const std::pair(pos_embed->data); - for(int i=0;i load_image_size = {448, 448}) { @@ -708,8 +702,9 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 int pos_w = image_size_width/patch_size; int pos_h = image_size_height/patch_size; - struct ggml_tensor * pos_embed = get_2d_sincos_pos_embed(4096, std::make_pair(pos_w, pos_h), ctx0, model.mm_model_pos_embed_k); - pos_embed = ggml_view_3d(ctx0, pos_embed, 4096, pos_w * pos_h, 1, pos_embed->nb[1], pos_embed->nb[2], 0); + struct ggml_tensor * pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 4096, pos_w * pos_h, 1); + ggml_set_name(pos_embed, "pos_embed"); + ggml_set_input(pos_embed); // // pre-layernorm // { @@ -2068,6 +2063,24 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima free(positions_data); } + { + struct ggml_tensor * pos_embed = ggml_graph_get_tensor(gf, "pos_embed"); + int pos_w = image_size_width/patch_size; + int pos_h = image_size_height/patch_size; + int embed_dim = 4096; + auto pos_embed_t = get_2d_sincos_pos_embed(embed_dim, std::make_pair(pos_w, pos_h)); + + float * pos_embed_data = (float *)malloc(ggml_nbytes(pos_embed)); + for(int i=0;i