Skip to content

Commit 6262e24

Browse files
committed
Qwen2-VL support added
1 parent 335250b commit 6262e24

File tree

5 files changed

+161
-50
lines changed

5 files changed

+161
-50
lines changed

llama_bringup/models/Qwen2-VL.yaml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
use_llava: True
2+
3+
n_ctx: 8192
4+
n_batch: 512
5+
n_gpu_layers: 15
6+
n_threads: -1
7+
n_predict: 8192
8+
9+
model_repo: "bartowski/Qwen2-VL-2B-Instruct-GGUF"
10+
model_filename: "Qwen2-VL-2B-Instruct-IQ2_M.gguf"
11+
12+
mmproj_repo: "bartowski/Qwen2-VL-2B-Instruct-GGUF"
13+
mmproj_filename: "mmproj-Qwen2-VL-2B-Instruct-f16.gguf"
14+
15+
image_prefix: "<|vision_start|>"
16+
image_suffix: "<|vision_end|>"
17+
18+
system_prompt_type: "ChatML"

llama_ros/include/llama_ros/llama.hpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,14 +169,15 @@ using GenerateResponseCallback = std::function<void(struct CompletionOutput)>;
169169
class Llama {
170170

171171
public:
172-
Llama(const struct common_params &params, std::string system_prompt = "");
172+
Llama(const struct common_params &params, std::string system_prompt = "",
173+
bool initial_reset = true);
173174
virtual ~Llama();
174175

175176
std::vector<llama_token> tokenize(const std::string &text, bool add_bos,
176177
bool special = false);
177178
std::string detokenize(const std::vector<llama_token> &tokens);
178179

179-
void reset();
180+
virtual void reset();
180181
void cancel();
181182

182183
std::string format_chat_prompt(std::vector<struct common_chat_msg> chat_msgs,
@@ -266,7 +267,7 @@ class Llama {
266267
virtual bool eval_prompt();
267268
bool eval_prompt(std::vector<llama_token> prompt_tokens);
268269
bool eval_token(llama_token token);
269-
bool eval(std::vector<llama_token> tokens);
270+
virtual bool eval(std::vector<llama_token> tokens);
270271
bool eval(struct llama_batch batch);
271272

272273
std::vector<struct TokenProb> get_probs();

llama_ros/include/llava_ros/llava.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ class Llava : public llama_ros::Llama {
5050
const struct LlavaParams &llava_params, std::string system_prompt = "");
5151
~Llava();
5252

53+
void reset() override;
5354
bool load_image(std::string base64_str);
5455
struct llava_image_embed *
5556
base64_image_to_embed(const std::string &base64_str);
@@ -59,6 +60,7 @@ class Llava : public llama_ros::Llama {
5960
bool add_sfx) override;
6061
bool eval_image(struct llava_image_embed *image_embed);
6162
bool eval_prompt();
63+
bool eval(std::vector<llama_token> tokens) override;
6264

6365
struct llava_image_embed *image_embed;
6466
struct clip_ctx *ctx_clip;
@@ -67,6 +69,7 @@ class Llava : public llama_ros::Llama {
6769
private:
6870
void free_image();
6971
int image_pose;
72+
int st_pos_id;
7073
};
7174

7275
} // namespace llava_ros

llama_ros/src/llama_ros/llama.cpp

Lines changed: 45 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@
3434

3535
using namespace llama_ros;
3636

37-
Llama::Llama(const struct common_params &params, std::string system_prompt)
37+
Llama::Llama(const struct common_params &params, std::string system_prompt,
38+
bool initial_reset)
3839
: params(params), system_prompt(system_prompt) {
3940

4041
print_build_info();
@@ -100,7 +101,9 @@ Llama::Llama(const struct common_params &params, std::string system_prompt)
100101
}
101102

102103
// set inital values
103-
this->reset();
104+
if (initial_reset) {
105+
this->reset();
106+
}
104107

105108
// show info
106109
LLAMA_LOG_INFO("llama.cpp: build = %d, commit = %s", LLAMA_BUILD_NUMBER,
@@ -148,6 +151,38 @@ Llama::~Llama() {
148151
this->threadpool_batch = nullptr;
149152
}
150153

154+
/*
155+
*****************************
156+
* RESET *
157+
* CANCEL *
158+
*****************************
159+
*/
160+
void Llama::reset() {
161+
162+
llama_kv_cache_clear(this->ctx);
163+
164+
if (this->sampler != nullptr) {
165+
common_sampler_reset(this->sampler);
166+
}
167+
168+
this->canceled = false;
169+
this->n_past = 0;
170+
this->n_consumed = 0;
171+
this->ga_i = 0;
172+
173+
this->prompt_tokens.clear();
174+
175+
// load system prompt
176+
if (!this->eval_system_prompt()) {
177+
LLAMA_LOG_ERROR("Failed to eval system prompt");
178+
}
179+
180+
// number of tokens to keep when resetting context
181+
if (this->params.n_keep < 0) {
182+
this->params.n_keep = (int)this->prompt_tokens.size();
183+
}
184+
}
185+
151186
/*
152187
*****************************
153188
* METADATA *
@@ -339,38 +374,6 @@ struct Metadata Llama::get_metadata() {
339374
return metadata;
340375
}
341376

342-
/*
343-
*****************************
344-
* RESET *
345-
* CANCEL *
346-
*****************************
347-
*/
348-
void Llama::reset() {
349-
350-
llama_kv_cache_clear(this->ctx);
351-
352-
if (this->sampler != nullptr) {
353-
common_sampler_reset(this->sampler);
354-
}
355-
356-
this->canceled = false;
357-
this->n_past = 0;
358-
this->n_consumed = 0;
359-
this->ga_i = 0;
360-
361-
this->prompt_tokens.clear();
362-
363-
// load system prompt
364-
if (!this->eval_system_prompt()) {
365-
LLAMA_LOG_ERROR("Failed to eval system prompt");
366-
}
367-
368-
// number of tokens to keep when resetting context
369-
if (this->params.n_keep < 0) {
370-
this->params.n_keep = (int)this->prompt_tokens.size();
371-
}
372-
}
373-
374377
/*
375378
*****************************
376379
* TOKENIZE *
@@ -911,6 +914,7 @@ bool Llama::eval_prompt() { return this->eval_prompt(this->prompt_tokens); }
911914
bool Llama::eval_prompt(std::vector<llama_token> prompt_tokens) {
912915

913916
std::vector<llama_token> batch;
917+
batch.reserve(this->params.n_batch);
914918

915919
while (((int)prompt_tokens.size() > this->n_consumed)) {
916920

@@ -941,13 +945,13 @@ bool Llama::eval(std::vector<llama_token> tokens) {
941945

942946
// create batch
943947
struct llama_batch batch = {
944-
int32_t(tokens.size()),
945-
tokens.data(),
946-
nullptr,
947-
nullptr,
948-
nullptr,
949-
nullptr,
950-
nullptr,
948+
int32_t(tokens.size()), // n_tokens
949+
tokens.data(), // tokens
950+
nullptr, // embd
951+
nullptr, // pos
952+
nullptr, // n_seq_id
953+
nullptr, // seq_id
954+
nullptr, // logits
951955
};
952956

953957
return this->eval(batch);

llama_ros/src/llava_ros/llava.cpp

Lines changed: 91 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,32 @@ using namespace llava_ros;
3434

3535
Llava::Llava(const struct common_params &params,
3636
const struct LlavaParams &llava_params, std::string system_prompt)
37-
: llama_ros::Llama(params, system_prompt), llava_params(llava_params) {
37+
: llama_ros::Llama(params, system_prompt, false),
38+
llava_params(llava_params), image_pose(0), st_pos_id(-1) {
3839

3940
// load clip model
4041
const char *clip_path = this->params.mmproj.c_str();
4142
this->ctx_clip = clip_model_load(clip_path, 1);
4243
this->image_embed = nullptr;
44+
45+
// set inital values
46+
this->reset();
4347
}
4448

4549
Llava::~Llava() {
50+
this->image_pose = 0;
51+
this->st_pos_id = -1;
4652
clip_free(this->ctx_clip);
4753
this->free_image();
4854
}
4955

56+
void Llava::reset() {
57+
this->image_pose = 0;
58+
this->st_pos_id = -1;
59+
this->free_image();
60+
Llama::reset();
61+
}
62+
5063
/*
5164
*****************************
5265
* LOAD IMAGE *
@@ -150,13 +163,40 @@ bool Llava::eval_image(struct llava_image_embed *image_embed) {
150163
int n_embd = this->get_n_embd();
151164
bool succ = true;
152165

153-
for (int i = 0; i < image_embed->n_image_pos; i += this->params.n_batch) {
166+
// for qwen2-vl
167+
auto img_tokens = image_embed->n_image_pos;
168+
169+
std::vector<llama_pos> mrope_pos;
170+
mrope_pos.resize(img_tokens * 4);
171+
172+
std::vector<llama_pos> batch_mrope_pos;
173+
batch_mrope_pos.resize(img_tokens * 4);
154174

155-
int n_eval = image_embed->n_image_pos - i;
175+
// fill mrope if qwen2-vl
176+
if (clip_is_qwen2vl(this->ctx_clip)) {
177+
auto image_size = clip_get_load_image_size(this->ctx_clip);
178+
const int patch_size = 14 * 2;
156179

157-
if (n_eval > this->params.n_batch) {
158-
n_eval = this->params.n_batch;
180+
const int ph =
181+
image_size->height / patch_size + (image_size->height % patch_size > 0);
182+
const int pw =
183+
image_size->width / patch_size + (image_size->width % patch_size > 0);
184+
185+
for (int y = 0; y < ph; y++) {
186+
for (int x = 0; x < pw; x++) {
187+
int i = y * pw + x;
188+
mrope_pos[i] = this->st_pos_id;
189+
mrope_pos[i + img_tokens] = this->st_pos_id + y;
190+
mrope_pos[i + img_tokens * 2] = this->st_pos_id + x;
191+
mrope_pos[i + img_tokens * 3] = 0;
192+
}
159193
}
194+
this->st_pos_id += std::max(pw, ph);
195+
}
196+
197+
for (int i = 0; i < image_embed->n_image_pos; i += this->params.n_batch) {
198+
199+
int n_eval = std::min(this->params.n_batch, image_embed->n_image_pos - i);
160200

161201
struct llama_batch batch = {
162202
int32_t(n_eval), // n_tokens
@@ -168,7 +208,19 @@ bool Llava::eval_image(struct llava_image_embed *image_embed) {
168208
nullptr // logits
169209
};
170210

171-
if (!this->eval(batch)) {
211+
if (clip_is_qwen2vl(this->ctx_clip)) {
212+
std::fill(batch_mrope_pos.begin(), batch_mrope_pos.end(), 0);
213+
memcpy(batch_mrope_pos.data(), &mrope_pos[i], n_eval * sizeof(llama_pos));
214+
memcpy(&batch_mrope_pos[n_eval * 1], &mrope_pos[img_tokens * 1 + i],
215+
n_eval * sizeof(llama_pos));
216+
memcpy(&batch_mrope_pos[n_eval * 2], &mrope_pos[img_tokens * 2 + i],
217+
n_eval * sizeof(llama_pos));
218+
memcpy(&batch_mrope_pos[n_eval * 3], &mrope_pos[img_tokens * 3 + i],
219+
n_eval * sizeof(llama_pos));
220+
batch.pos = batch_mrope_pos.data();
221+
}
222+
223+
if (!Llama::eval(batch)) {
172224
LLAMA_LOG_ERROR("Failed in image eval");
173225
succ = false;
174226
break;
@@ -212,3 +264,36 @@ bool Llava::eval_prompt() {
212264

213265
return true;
214266
}
267+
268+
bool Llava::eval(std::vector<llama_token> tokens) {
269+
270+
std::vector<llama_pos> pos;
271+
272+
// create batch
273+
struct llama_batch batch = {
274+
int32_t(tokens.size()), // n_tokens
275+
tokens.data(), // tokens
276+
nullptr, // embd
277+
nullptr, // pos
278+
nullptr, // n_seq_id
279+
nullptr, // seq_id
280+
nullptr, // logits
281+
};
282+
283+
if (clip_is_qwen2vl(this->ctx_clip)) {
284+
pos.resize(batch.n_tokens * 4);
285+
std::fill(pos.begin(), pos.end(), 0);
286+
for (int j = 0; j < batch.n_tokens * 3; j++) {
287+
pos[j] = this->st_pos_id + (j % batch.n_tokens);
288+
}
289+
batch.pos = pos.data();
290+
}
291+
292+
if (!Llama::eval(batch)) {
293+
return false;
294+
}
295+
296+
this->st_pos_id += batch.n_tokens;
297+
298+
return true;
299+
}

0 commit comments

Comments
 (0)