Skip to content

Commit

Permalink
add --skip-layers
Browse files Browse the repository at this point in the history
  • Loading branch information
zkh2016 committed Sep 18, 2024
1 parent 2f6f1a6 commit 1c445e4
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 2 deletions.
9 changes: 9 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -856,6 +856,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.skip_model = argv[i];
return true;
}
if (arg == "--skip-layers") {
if (++i >= argc) {
invalid_param = true;
return true;
}
params.skip_layers = std::stoi(argv[i]);
return true;
}
if (arg == "-i" || arg == "--interactive") {
params.interactive = true;
return true;
Expand Down Expand Up @@ -1489,6 +1497,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA. see examples/llava/README.md\n");
printf(" --image IMAGE_FILE path to an image file. use with multimodal models\n");
printf(" --skip-model SKIP_MODEL path to a skip model. see examples/llava/README.md\n");
printf(" --skip-layers SKIP_LAYERS the layers to skip. see examples/llava/README.md\n");
if (llama_supports_mlock()) {
printf(" --mlock force system to keep model in RAM rather than swapping or compressing\n");
}
Expand Down
1 change: 1 addition & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ struct gpt_params {
std::string mmproj = ""; // path to multimodal projector
std::string image = ""; // path to an image file
std::string skip_model = ""; //the skip model path
int skip_layers = 0; //the layers to skip
};

bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params);
Expand Down
1 change: 1 addition & 0 deletions examples/llava/minicpmv-cli.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ static struct llava_context * llava_init_context(gpt_params * params) {
//llama_model * model2 = llama_load_model_from_file(params->model.c_str(), model_params);
//llama_model * model2 = llama_load_model_from_file("/Users/zkh/Downloads/last_16/ggml-model-Q4_0.gguf", model_params);
llama_model * model2 = llama_load_model_from_file(params->skip_model.c_str(), model_params);
llama_set_model_skip_layers(model2, params->skip_layers);
llama_set_model2(ctx_llama, model2);
}

Expand Down
15 changes: 13 additions & 2 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2165,6 +2165,7 @@ struct llama_model {
llama_split_mode split_mode;
int main_gpu;
int n_gpu_layers;
int skip_layers;

// gguf metadata
std::unordered_map<std::string, std::string> gguf_kv;
Expand Down Expand Up @@ -6752,15 +6753,18 @@ struct llm_build_context {

// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
const int skip_layers = model2 == nullptr ? 0 : model2->skip_layers;
int skip_idx = 0;

for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * inpSA = inpL;

const llama_model *m = &model;
int local_il = il;
if(il >= 16 && model2 != nullptr){//TODO: && is_vit
if(il >= n_layer - skip_layers && model2 != nullptr){//TODO: && is_vit
m = model2;
// local_il = il - 16;
local_il = skip_idx;
skip_idx += 1;
}

// norm
Expand Down Expand Up @@ -14969,6 +14973,13 @@ int64_t llama_time_us(void) {
return ggml_time_us();
}

void llama_set_model_skip_layers(
struct llama_model* model,
int skip_layers
){
model->skip_layers = skip_layers;
}

struct llama_model * llama_load_model_from_file(
const char * path_model,
struct llama_model_params params) {
Expand Down
4 changes: 4 additions & 0 deletions llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,10 @@ extern "C" {
const char * path_model,
struct llama_model_params params);

LLAMA_API void llama_set_model_skip_layers(
struct llama_model * model,
int skip_layers);

LLAMA_API void llama_free_model(struct llama_model * model);

LLAMA_API struct llama_context * llama_new_context_with_model(
Expand Down

0 comments on commit 1c445e4

Please sign in to comment.