From 4ed7e83eabf576612d3f779b33ccc5c5b092ecea Mon Sep 17 00:00:00 2001 From: Pavel Esir Date: Thu, 26 Sep 2024 15:06:49 +0200 Subject: [PATCH] do not keep ends if --do-not-add-special-tokens --- .../src/make_combine_segments_stateful.cpp | 66 +++---------------- .../src/make_combine_segments_stateful.hpp | 8 +++ src/cpp/src/tokenizer.cpp | 2 +- tests/python_tests/test_chat_generate_api.py | 1 - 4 files changed, 18 insertions(+), 59 deletions(-) diff --git a/src/cpp/src/make_combine_segments_stateful.cpp b/src/cpp/src/make_combine_segments_stateful.cpp index 3c9360b98a..2285c172dc 100644 --- a/src/cpp/src/make_combine_segments_stateful.cpp +++ b/src/cpp/src/make_combine_segments_stateful.cpp @@ -3,52 +3,15 @@ #include "make_combine_segments_stateful.hpp" #include "openvino/op/constant.hpp" -#include "openvino/op/multiply.hpp" #include "openvino/op/select.hpp" #include "openvino/op/read_value.hpp" -#include "openvino/pass/graph_rewrite.hpp" -#include "openvino/pass/pattern/op/wrap_type.hpp" -#include "openvino/pass/pattern/matcher.hpp" +#include "openvino/op/assign.hpp" -namespace { - const std::string ADD_SPECIAL_TOKENS_VAR_ID = "add_special_tokens"; -} using namespace ov; using namespace ov::op; -/* -* For the newly converted models input(1) to CombineSegments stores default mode, -* in that case we just need to insert ReadValue in between DefaultMode -> Select. -* -* +--------------+ +--------+ +------------------+ -* | DefaultMode | | ends | | const value = 0 | -* +--------------+ +--------+ +------------------+ -* \ | / -* \ | / -* v v v -* +--------------+ -* | Select | -* +--------------+ -* | -* v -* +-------------------------+ -* | CombineSegments | -* +-------------------------+ -* -* If IR is old, then default mode to add special tokens is true, and we insert -* the whole new subgraph with Select. -* -* +------------+ -* | ends | -* +------------+ -* | -* v -* +-------------------------+ -* | CombineSegments | -* +-------------------------+ -*/ -bool MakeCombineSegmentsSatateful::run_on_model(const std::shared_ptr& model) { +bool ov::genai::MakeCombineSegmentsSatateful::run_on_model(const std::shared_ptr& model) { std::shared_ptr combine_seg_node; for (auto node: model->get_ordered_ops()) { @@ -61,31 +24,20 @@ bool MakeCombineSegmentsSatateful::run_on_model(const std::shared_ptr } std::shared_ptr input_1_const = std::dynamic_pointer_cast(combine_seg_node->get_input_node_shared_ptr(1)); - std::shared_ptr input_1_select = std::dynamic_pointer_cast(combine_seg_node->get_input_node_shared_ptr(1)); - if (!input_1_const && !input_1_select) { + if (!input_1_const) { return false; } op::util::VariableInfo var_info{ov::Shape{}, ov::element::boolean, ADD_SPECIAL_TOKENS_VAR_ID}; auto variable = std::make_shared(var_info); - std::shared_ptr read_value; - if (input_1_select) { - // Select already exists, need just to insert to input(0) ReadValue - // instead of default mode Const. - read_value = std::make_shared(input_1_select->input_value(0), variable); - input_1_select->input(0).replace_source_output(read_value->output(0)); - } else { - // If there is end then default mode is add_special_tokens. - bool add_special_tokens = true; - auto default_mode_const = std::make_shared(ov::element::boolean, ov::Shape{}, std::vector{add_special_tokens}); - read_value = std::make_shared(default_mode_const, variable); - auto zero_constant = std::make_shared(ov::element::i32, ov::Shape{}, std::vector{0}); - auto select_node = std::make_shared(read_value, input_1_const, zero_constant); - combine_seg_node->input(1).replace_source_output(select_node->output(0)); - } + // Default mode is add_special_tokens. + auto default_mode_const = std::make_shared(ov::element::boolean, ov::Shape{}, std::vector{true}); + auto read_value = std::make_shared(default_mode_const, variable); + auto zero_constant = std::make_shared(ov::element::i32, ov::Shape{}, std::vector{0}); + auto select_node = std::make_shared(read_value, input_1_const, zero_constant); + combine_seg_node->input(1).replace_source_output(select_node->output(0)); - // here need to store. auto assign = std::make_shared(read_value, variable); model->add_sinks({assign}); diff --git a/src/cpp/src/make_combine_segments_stateful.hpp b/src/cpp/src/make_combine_segments_stateful.hpp index dbced771b1..f81f8f08d6 100644 --- a/src/cpp/src/make_combine_segments_stateful.hpp +++ b/src/cpp/src/make_combine_segments_stateful.hpp @@ -4,6 +4,9 @@ #include "openvino/op/constant.hpp" #include "openvino/pass/pass.hpp" +namespace ov { +namespace genai { + /** * @brief This pass modifies tokenizer ov::Model so that special tokens adding will be * enabled or diabled depending on stateful value. @@ -34,3 +37,8 @@ class MakeCombineSegmentsSatateful : public ov::pass::ModelPass { OPENVINO_RTTI("MakeCombineSegmentsSatateful", "0"); bool run_on_model(const std::shared_ptr& model) override; }; + +const std::string ADD_SPECIAL_TOKENS_VAR_ID = "add_special_tokens"; + +} // namespace genai +} // namespace ov diff --git a/src/cpp/src/tokenizer.cpp b/src/cpp/src/tokenizer.cpp index de77ee035b..8563ab26d2 100644 --- a/src/cpp/src/tokenizer.cpp +++ b/src/cpp/src/tokenizer.cpp @@ -97,7 +97,7 @@ class Tokenizer::TokenizerImpl { *add_special_tensor.data() = add_special_tokens; for (auto& state: infer_request_guard.get().query_state()) { - if (state.get_name().find("ADD_SPECIAL_TOKENS") == std::string::npos) { + if (state.get_name().find(ov::genai::ADD_SPECIAL_TOKENS_VAR_ID) == std::string::npos) { // It's not add_special_tokens flag state. continue; } diff --git a/tests/python_tests/test_chat_generate_api.py b/tests/python_tests/test_chat_generate_api.py index 57b72c85c0..b68de6372d 100644 --- a/tests/python_tests/test_chat_generate_api.py +++ b/tests/python_tests/test_chat_generate_api.py @@ -221,5 +221,4 @@ def test_add_special_tokens(add_special_tokens, prompt): # Calling encode with add_special_tokens will set state flag. res_genai = genai_tokenzier.encode(prompt, add_special_tokens).input_ids.data res_hf = hf_tokenizer(prompt, return_tensors="np", add_special_tokens=add_special_tokens)["input_ids"] - # breakpoint() assert np.all(res_genai == res_hf)