Skip to content

Commit c7b1e9e

Browse files
committed
Fixes
1 parent 6419fd0 commit c7b1e9e

File tree

1 file changed

+23
-6
lines changed

1 file changed

+23
-6
lines changed

text_generation/causal_lm/cpp/chat_model_lm.cpp

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
#include "group_beam_searcher.hpp"
77
#include "openvino/openvino.hpp"
8+
#include <iostream>
9+
#include <fstream>
810

911
namespace {
1012
std::pair<ov::Tensor, ov::Tensor> tokenize(ov::InferRequest& tokenizer, std::string&& prompt) {
@@ -39,15 +41,27 @@ int main(int argc, char* argv[]) try {
3941
// Compile models
4042
ov::Core core;
4143
core.add_extension(OPENVINO_TOKENIZERS_PATH); // OPENVINO_TOKENIZERS_PATH is defined in CMakeLists.txt
44+
auto tokenizer_model = core.read_model(std::string{argv[1]} + "/openvino_tokenizer.xml");
4245
// tokenizer and detokenizer work on CPU only
4346
ov::InferRequest tokenizer =
44-
core.compile_model(std::string{argv[1]} + "/openvino_tokenizer.xml", "CPU").create_infer_request();
47+
core.compile_model(tokenizer_model, "CPU").create_infer_request();
4548
ov::InferRequest detokenizer =
4649
core.compile_model(std::string{argv[1]} + "/openvino_detokenizer.xml", "CPU").create_infer_request();
4750
// The model can be compiled for GPU as well
4851
ov::InferRequest lm =
4952
core.compile_model(std::string{argv[1]} + "/openvino_model.xml", "CPU").create_infer_request();
5053

54+
// Get the runtime info from the tokenizer model that we read earlier
55+
auto rt_info = tokenizer_model->get_rt_info(); //Get the runtime info for the model
56+
int64_t SPECIAL_EOS_TOKEN;
57+
58+
if (rt_info.count("eos_token_id") > 0) { //check if the runtime information has a valid EOS token ID
59+
SPECIAL_EOS_TOKEN = rt_info["eos_token_id"].as<int64_t>();
60+
61+
} else {
62+
throw std::runtime_error("EOS token ID not found in model's runtime information.");
63+
}
64+
5165
int64_t total_positions = 0;
5266
int32_t global_beam_idx = 0;
5367
std::string prompt;
@@ -84,12 +98,13 @@ int main(int argc, char* argv[]) try {
8498
lm.set_tensor("beam_idx", ov::Tensor{ov::element::i32, {1}, &global_beam_idx});
8599

86100
const int64_t* prompt_data = input_ids.data<const int64_t>();
87-
Parameters parameters{std::vector<int64_t>{prompt_data, prompt_data + input_ids.get_size()}};
101+
Parameters parameters{{{prompt_data, prompt_data + input_ids.get_size()}}, SPECIAL_EOS_TOKEN};
88102
GroupBeamSearcher group_beam_searcher{parameters};
89103
std::vector<int64_t> next_tokens;
90104
std::vector<int32_t> next_beams;
105+
lm.infer();
106+
91107
for (size_t length_count = 0; length_count < parameters.max_new_tokens; ++length_count) {
92-
lm.infer();
93108
std::tie(next_tokens, next_beams) = group_beam_searcher.select_next_tokens(lm.get_tensor("logits"));
94109
if (next_tokens.empty()) {
95110
break;
@@ -105,11 +120,13 @@ int main(int argc, char* argv[]) try {
105120
std::fill_n(attention_mask.data<int64_t>(), ov::shape_size(mask_shape), 1);
106121
lm.get_tensor("position_ids").set_shape({batch_size, 1});
107122
std::fill_n(lm.get_tensor("position_ids").data<int64_t>(), batch_size, total_positions++);
123+
lm.infer();
108124
}
109125

110126
Beam answer;
111-
float highest_score = std::numeric_limits<float>().min();
112-
for (const std::vector<Beam>& group : finalize(std::move(group_beam_searcher))) {
127+
float highest_score = std::numeric_limits<float>().lowest();
128+
auto all_groups = finalize(std::move(group_beam_searcher));
129+
for (const std::vector<Beam>& group : all_groups[0]) {
113130
for (const Beam& beam : group) {
114131
if (beam.score > highest_score) {
115132
highest_score = beam.score;
@@ -119,7 +136,7 @@ int main(int argc, char* argv[]) try {
119136
}
120137

121138
auto answer_str = detokenize(detokenizer, answer.tokens);
122-
answer_str = answer_str.substr(0, answer_str.find("<eos>"));
139+
//answer_str = answer_str.substr(0, answer_str.find("<eos>"));
123140
std::cout << "Answer: " << answer_str << "\n_______\n";
124141
global_beam_idx = answer.global_beam_idx;
125142

0 commit comments

Comments
 (0)