Skip to content

Commit

Permalink
do not keep ends if --do-not-add-special-tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
pavel-esir committed Sep 26, 2024
1 parent 33f2b3c commit 4ed7e83
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 59 deletions.
66 changes: 9 additions & 57 deletions src/cpp/src/make_combine_segments_stateful.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,52 +3,15 @@

#include "make_combine_segments_stateful.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/select.hpp"
#include "openvino/op/read_value.hpp"
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "openvino/pass/pattern/matcher.hpp"
#include "openvino/op/assign.hpp"

namespace {
const std::string ADD_SPECIAL_TOKENS_VAR_ID = "add_special_tokens";
}

using namespace ov;
using namespace ov::op;

/*
* For the newly converted models input(1) to CombineSegments stores default mode,
* in that case we just need to insert ReadValue in between DefaultMode -> Select.
*
* +--------------+ +--------+ +------------------+
* | DefaultMode | | ends | | const value = 0 |
* +--------------+ +--------+ +------------------+
* \ | /
* \ | /
* v v v
* +--------------+
* | Select |
* +--------------+
* |
* v
* +-------------------------+
* | CombineSegments |
* +-------------------------+
*
* If IR is old, then default mode to add special tokens is true, and we insert
* the whole new subgraph with Select.
*
* +------------+
* | ends |
* +------------+
* |
* v
* +-------------------------+
* | CombineSegments |
* +-------------------------+
*/
bool MakeCombineSegmentsSatateful::run_on_model(const std::shared_ptr<ov::Model>& model) {
bool ov::genai::MakeCombineSegmentsSatateful::run_on_model(const std::shared_ptr<ov::Model>& model) {

std::shared_ptr<ov::Node> combine_seg_node;
for (auto node: model->get_ordered_ops()) {
Expand All @@ -61,31 +24,20 @@ bool MakeCombineSegmentsSatateful::run_on_model(const std::shared_ptr<ov::Model>
}

std::shared_ptr<v0::Constant> input_1_const = std::dynamic_pointer_cast<v0::Constant>(combine_seg_node->get_input_node_shared_ptr(1));
std::shared_ptr<v1::Select> input_1_select = std::dynamic_pointer_cast<v1::Select>(combine_seg_node->get_input_node_shared_ptr(1));
if (!input_1_const && !input_1_select) {
if (!input_1_const) {
return false;
}

op::util::VariableInfo var_info{ov::Shape{}, ov::element::boolean, ADD_SPECIAL_TOKENS_VAR_ID};
auto variable = std::make_shared<op::util::Variable>(var_info);

std::shared_ptr<v6::ReadValue> read_value;
if (input_1_select) {
// Select already exists, need just to insert to input(0) ReadValue
// instead of default mode Const.
read_value = std::make_shared<v6::ReadValue>(input_1_select->input_value(0), variable);
input_1_select->input(0).replace_source_output(read_value->output(0));
} else {
// If there is end then default mode is add_special_tokens.
bool add_special_tokens = true;
auto default_mode_const = std::make_shared<v0::Constant>(ov::element::boolean, ov::Shape{}, std::vector{add_special_tokens});
read_value = std::make_shared<v6::ReadValue>(default_mode_const, variable);
auto zero_constant = std::make_shared<v0::Constant>(ov::element::i32, ov::Shape{}, std::vector{0});
auto select_node = std::make_shared<v1::Select>(read_value, input_1_const, zero_constant);
combine_seg_node->input(1).replace_source_output(select_node->output(0));
}
// Default mode is add_special_tokens.
auto default_mode_const = std::make_shared<v0::Constant>(ov::element::boolean, ov::Shape{}, std::vector{true});
auto read_value = std::make_shared<v6::ReadValue>(default_mode_const, variable);
auto zero_constant = std::make_shared<v0::Constant>(ov::element::i32, ov::Shape{}, std::vector{0});
auto select_node = std::make_shared<v1::Select>(read_value, input_1_const, zero_constant);
combine_seg_node->input(1).replace_source_output(select_node->output(0));

// here need to store.
auto assign = std::make_shared<v6::Assign>(read_value, variable);

model->add_sinks({assign});
Expand Down
8 changes: 8 additions & 0 deletions src/cpp/src/make_combine_segments_stateful.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
#include "openvino/op/constant.hpp"
#include "openvino/pass/pass.hpp"

namespace ov {
namespace genai {

/**
* @brief This pass modifies tokenizer ov::Model so that special tokens adding will be
* enabled or diabled depending on stateful value.
Expand Down Expand Up @@ -34,3 +37,8 @@ class MakeCombineSegmentsSatateful : public ov::pass::ModelPass {
OPENVINO_RTTI("MakeCombineSegmentsSatateful", "0");
bool run_on_model(const std::shared_ptr<ov::Model>& model) override;
};

const std::string ADD_SPECIAL_TOKENS_VAR_ID = "add_special_tokens";

} // namespace genai
} // namespace ov
2 changes: 1 addition & 1 deletion src/cpp/src/tokenizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class Tokenizer::TokenizerImpl {
*add_special_tensor.data<bool>() = add_special_tokens;

for (auto& state: infer_request_guard.get().query_state()) {
if (state.get_name().find("ADD_SPECIAL_TOKENS") == std::string::npos) {
if (state.get_name().find(ov::genai::ADD_SPECIAL_TOKENS_VAR_ID) == std::string::npos) {
// It's not add_special_tokens flag state.
continue;
}
Expand Down
1 change: 0 additions & 1 deletion tests/python_tests/test_chat_generate_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,5 +221,4 @@ def test_add_special_tokens(add_special_tokens, prompt):
# Calling encode with add_special_tokens will set state flag.
res_genai = genai_tokenzier.encode(prompt, add_special_tokens).input_ids.data
res_hf = hf_tokenizer(prompt, return_tensors="np", add_special_tokens=add_special_tokens)["input_ids"]
# breakpoint()
assert np.all(res_genai == res_hf)

0 comments on commit 4ed7e83

Please sign in to comment.