Skip to content

Commit 9d41787

Browse files
committed
Add multiple EOS support
Ticket 157352
1 parent 158f662 commit 9d41787

File tree

4 files changed

+39
-10
lines changed

4 files changed

+39
-10
lines changed

src/cpp/include/openvino/genai/generation_config.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ enum class StopCriteria { EARLY, HEURISTIC, NEVER };
3838
* @param eos_token_id token_id of <eos> (end of sentence)
3939
* @param min_new_tokens set 0 probability for eos_token_id for the first eos_token_id generated tokens. Ignored for non continuous batching.
4040
*
41-
* @param stop_strings vector of strings that will cause pipeline to stop generating further tokens. Ignored for non continuous batching.
41+
* @param stop_strings A set of strings that will cause pipeline to stop generating further tokens.
4242
* @param include_stop_str_in_output if set to true stop string that matched generation will be included in generation output (default: false)
43-
* @param stop_token_ids vector of tokens that will cause pipeline to stop generating further tokens. Ignored for non continuous batching.
43+
* @param stop_token_ids A set of tokens that will cause pipeline to stop generating further tokens.
4444
* @param echo if set to true, output will include user prompt (default: false).
4545
* @param logprobs number of top logprobs computed for each position, if set to 0, logprobs are not computed and value 0.0 is returned.
4646
* Currently only single top logprob can be returned, so any logprobs > 1 is treated as logprobs == 1. (default: 0).
@@ -154,9 +154,9 @@ static constexpr ov::Property<size_t> max_new_tokens{"max_new_tokens"};
154154
static constexpr ov::Property<size_t> max_length{"max_length"};
155155
static constexpr ov::Property<bool> ignore_eos{"ignore_eos"};
156156
static constexpr ov::Property<size_t> min_new_tokens{"min_new_tokens"};
157-
static constexpr ov::Property<std::vector<std::string>> stop_strings{"stop_strings"};
157+
static constexpr ov::Property<std::set<std::string>> stop_strings{"stop_strings"};
158158
static constexpr ov::Property<bool> include_stop_str_in_output{"include_stop_str_in_output"};
159-
static constexpr ov::Property<std::vector<std::vector<int64_t>>> stop_token_ids{"stop_token_ids"};
159+
static constexpr ov::Property<std::set<int64_t>> stop_token_ids{"stop_token_ids"};
160160

161161
static constexpr ov::Property<size_t> num_beam_groups{"num_beam_groups"};
162162
static constexpr ov::Property<size_t> num_beams{"num_beams"};

src/python/py_generation_config.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,9 @@ char generation_config_docstring[] = R"(
4141
ignore_eos: if set to true, then generation will not stop even if <eos> token is met.
4242
eos_token_id: token_id of <eos> (end of sentence)
4343
min_new_tokens: set 0 probability for eos_token_id for the first eos_token_id generated tokens. Ignored for non continuous batching.
44-
stop_strings: list of strings that will cause pipeline to stop generating further tokens. Ignored for non continuous batching.
44+
stop_strings: a set of strings that will cause pipeline to stop generating further tokens.
4545
include_stop_str_in_output: if set to true stop string that matched generation will be included in generation output (default: false)
46-
stop_token_ids: list of tokens that will cause pipeline to stop generating further tokens. Ignored for non continuous batching.
46+
stop_token_ids: a set of tokens that will cause pipeline to stop generating further tokens.
4747
echo: if set to true, the model will echo the prompt in the output.
4848
logprobs: number of top logprobs computed for each position, if set to 0, logprobs are not computed and value 0.0 is returned.
4949
Currently only single top logprob can be returned, so any logprobs > 1 is treated as logprobs == 1. (default: 0).
@@ -87,6 +87,9 @@ void init_generation_config(py::module_& m) {
8787
.def_readwrite("max_length", &GenerationConfig::max_length)
8888
.def_readwrite("ignore_eos", &GenerationConfig::ignore_eos)
8989
.def_readwrite("min_new_tokens", &GenerationConfig::min_new_tokens)
90+
.def_readwrite("stop_strings", &GenerationConfig::stop_strings)
91+
.def_readwrite("include_stop_str_in_output", &GenerationConfig::include_stop_str_in_output)
92+
.def_readwrite("stop_token_ids", &GenerationConfig::stop_token_ids)
9093
.def_readwrite("num_beam_groups", &GenerationConfig::num_beam_groups)
9194
.def_readwrite("num_beams", &GenerationConfig::num_beams)
9295
.def_readwrite("diversity_penalty", &GenerationConfig::diversity_penalty)

src/python/py_utils.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ ov::genai::OptionalGenerationConfig update_config_from_kwargs(const ov::genai::O
248248
res_config.stop_strings = py::cast<std::set<std::string>>(value);
249249
} else if (key == "include_stop_str_in_output") {
250250
res_config.include_stop_str_in_output = py::cast<bool>(value);
251-
} else if (key == "include_stop_str_in_output") {
251+
} else if (key == "stop_token_ids") {
252252
res_config.stop_token_ids = py::cast<std::set<int64_t>>(value);
253253
} else if (key == "max_length") {
254254
res_config.max_length = py::cast<int>(item.second);
@@ -311,11 +311,11 @@ bool generation_config_param_to_property(std::string key, py::object value, ov::
311311
} else if (key == "min_new_tokens") {
312312
map.insert(ov::genai::min_new_tokens(py::cast<int>(value)));
313313
} else if (key == "stop_strings") {
314-
map.insert(ov::genai::stop_strings(py::cast<std::vector<std::string>>(value)));
314+
map.insert(ov::genai::stop_strings(py::cast<std::set<std::string>>(value)));
315315
} else if (key == "include_stop_str_in_output") {
316316
map.insert(ov::genai::include_stop_str_in_output(py::cast<bool>(value)));
317-
} else if (key == "include_stop_str_in_output") {
318-
map.insert(ov::genai::stop_token_ids(py::cast<std::vector<std::vector<int64_t>>>(value)));
317+
} else if (key == "stop_token_ids") {
318+
map.insert(ov::genai::stop_token_ids(py::cast<std::set<int64_t>>(value)));
319319
} else if (key == "num_beam_groups") {
320320
map.insert(ov::genai::num_beam_groups(py::cast<int>(value)));
321321
} else if (key == "num_beams") {

tests/python_tests/test_generate_api.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -848,3 +848,29 @@ def test_batch_switch():
848848
pipe = read_model(('katuni4ka/tiny-random-phi3', Path('tiny-random-phi3')))[4]
849849
pipe.generate(["a"], max_new_tokens=2)
850850
pipe.generate(["1", "2"], max_new_tokens=2)
851+
852+
853+
@pytest.mark.precommit
854+
@pytest.mark.nightly
855+
def test_stop_token_ids():
856+
pipe = read_model(('katuni4ka/tiny-random-phi3', Path('tiny-random-phi3')))[4]
857+
res = pipe.generate(
858+
ov.Tensor([(1,)]),
859+
max_new_tokens=3,
860+
stop_token_ids={-1, 9935},
861+
include_stop_str_in_output=False
862+
)
863+
assert 2 == len(res.tokens[0])
864+
assert 9935 in res.tokens[0]
865+
866+
867+
@pytest.mark.precommit
868+
@pytest.mark.nightly
869+
def test_stop_strings():
870+
pipe = read_model(('katuni4ka/tiny-random-phi3', Path('tiny-random-phi3')))[4]
871+
res = pipe.generate(
872+
"",
873+
max_new_tokens=5,
874+
stop_strings={"ignored", "боль"}
875+
)
876+
assert "боль" not in res

0 commit comments

Comments
 (0)