From ca7cdcc2652844381181ccdd3e1e8a5aca2aa0a8 Mon Sep 17 00:00:00 2001 From: Yixin Dong Date: Tue, 30 Apr 2024 07:55:24 -0700 Subject: [PATCH] [Tokenizer] Support ByteLevel BPE in tokenizer token table (#2248) --- cpp/serve/engine.cc | 20 ++++- cpp/tokenizers.cc | 105 +++++++++++++++++++++---- cpp/tokenizers.h | 21 ++++- python/mlc_llm/interface/gen_config.py | 74 ++++++++++++++++- 4 files changed, 198 insertions(+), 22 deletions(-) diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index 9703dda472..755af998cd 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -56,9 +56,7 @@ class EngineImpl : public Engine { } this->request_stream_callback_ = std::move(request_stream_callback); this->trace_recorder_ = trace_recorder; - this->tokenizer_ = Tokenizer::FromPath(engine_config->model); - this->token_table_ = tokenizer_->TokenTable(); - this->grammar_init_context_storage_ = GrammarInitContextStorage(this->token_table_); + // Step 2. Initialize each model independently. // Create the logit processor and sampler. this->models_.clear(); @@ -100,6 +98,21 @@ class EngineImpl : public Engine { engine_config->additional_model_lib_paths[i], /*model_index=*/i + 1); } + // Step 3. Initialize tokenizer and grammar + this->tokenizer_ = Tokenizer::FromPath(engine_config->model); + std::string token_table_postproc_method; + if (model_configs[0].count("token_table_postproc_method") == 0) { + // Backward compatibility: use "byte-fallback" by default + token_table_postproc_method = "byte-fallback"; + } else { + token_table_postproc_method = + model_configs[0].at("token_table_postproc_method").get(); + } + this->token_table_ = + Tokenizer::PostProcessTokenTable(tokenizer_->TokenTable(), token_table_postproc_method); + this->grammar_init_context_storage_ = GrammarInitContextStorage(this->token_table_); + + // Step 4. Initialize engine actions that represent state transitions. int max_num_tokens = engine_config->max_num_sequence; DraftTokenWorkspaceManager draft_token_workspace_manager{nullptr}; if (engine_config->speculative_mode != SpeculativeMode::kDisable) { @@ -113,7 +126,6 @@ class EngineImpl : public Engine { this->models_[0]->CreateLogitProcessor(max_num_tokens, trace_recorder); Sampler sampler = this->models_[0]->CreateSampler( max_num_tokens, static_cast(this->models_.size()), trace_recorder); - // Step 3. Initialize engine actions that represent state transitions. if (engine_config->speculative_mode != SpeculativeMode::kDisable) { // Speculative decoding is only possible for more than one model. ICHECK_GT(this->models_.size(), 1U); diff --git a/cpp/tokenizers.cc b/cpp/tokenizers.cc index ef866f3bfc..6fe9217520 100644 --- a/cpp/tokenizers.cc +++ b/cpp/tokenizers.cc @@ -9,10 +9,12 @@ #include #include +#include #include #include #include +#include "./support/encoding.h" #include "./support/load_bytes_from_file.h" namespace mlc { @@ -91,13 +93,8 @@ Tokenizer Tokenizer::FromPath(const String& _path) { LOG(FATAL) << "Cannot find any tokenizer under: " << _path; } -/*! - * \brief Post-process a raw token (which may be a raw byte or contain lower - * one eights block) to the actual token. - * We do this in order to conform with the tokenizers' setup. - */ -inline std::string PostProcessToken(std::string token) { - // 1. The token represents a byte. +/*! \brief ByteFallback decoder: transform tokens like <0x1B> to hex char byte 1B */ +inline std::string ByteFallbackDecoder(const std::string& token) { if (token.length() == 6 && token.substr(0, 3) == "<0x" && token.back() == '>') { int byte = 0; for (int i = 0; i < 2; ++i) { @@ -108,15 +105,82 @@ inline std::string PostProcessToken(std::string token) { ICHECK(byte >= 0 && byte < 256); return std::string(/*n=*/1, static_cast(byte)); } + return token; +} - // 2. The token contains "\u2581" which means space. - static const std::string& lower_one_eighth_block = "\u2581"; - size_t pos = token.find(lower_one_eighth_block); - while (pos != std::string::npos) { - token.replace(pos, /*n=*/lower_one_eighth_block.length(), /*str=*/" "); - pos = token.find(lower_one_eighth_block); +/*! \brief SpaceReplacer decoder: transform "\u2581" back to space */ +inline std::string SpaceReplacerDecoder(const std::string& token) { + // \u2581 is the unicode for "lower one eighth block" + // UTF8 encoding for \u2581 is 0xE2 0x96 0x81 + std::string result; + for (size_t i = 0; i < token.size(); ++i) { + if (i + 2 < token.size() && token[i] == char(0xE2) && token[i + 1] == char(0x96) && + token[i + 2] == char(0x81)) { + result += ' '; + i += 2; + } else { + result += token[i]; + } + } + return result; +} + +/*! \brief ByteLevel decoder: inverses the bytes-to-unicode transformation in the encoding + * process as in + * https://github.com/huggingface/transformers/blob/87be06ca77166e6a6215eee5a990ab9f07238a18/src/transformers/models/gpt2/tokenization_gpt2.py#L38-L59 + */ +inline std::string ByteLevelDecoder(const std::string& token) { + // clang-format off + // The inverse map of bytes_to_unicode. -1 means there is no mapping to this unicode. + static const std::array unicode_to_byte_map = { + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, + 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, + 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, + 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, + 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, -1, + 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, + 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, + 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, + 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, + 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 127, 128, + 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, + 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 173 + }; + // clang-format on + + auto unicode_codepoints = ParseUTF8(token.c_str()); + std::string decoded; + + for (auto unicode_codepoint : unicode_codepoints) { + ICHECK(unicode_codepoint >= 0 && + unicode_codepoint < static_cast(unicode_to_byte_map.size())); + int byte = unicode_to_byte_map[unicode_codepoint]; + if (byte == -1) { + // If there is no mapping, add the codepoint itself to the result string + // Some tokenizer like Phi-2 have raw tokens like \t\t + decoded += static_cast(unicode_codepoint); + } else { + decoded += static_cast(byte); + } + } + return decoded; +} + +/*! + * \brief Post-process a raw token to the actual token with the given post-processing method. + */ +inline std::string PostProcessToken(const std::string& token, const std::string& postproc_method) { + if (postproc_method == "byte_fallback") { + return SpaceReplacerDecoder(ByteFallbackDecoder(token)); + } else if (postproc_method == "byte_level") { + return ByteLevelDecoder(token); + } else { + LOG(FATAL) << "Unknown post-processing method: " << postproc_method; } - return token; } const std::vector& TokenizerObj::TokenTable() { @@ -127,12 +191,21 @@ const std::vector& TokenizerObj::TokenTable() { int vocab_size = tokenizer->GetVocabSize(); token_table_.reserve(vocab_size); for (int32_t token_id = 0; token_id < vocab_size; ++token_id) { - std::string token = tokenizer->IdToToken(token_id); - token_table_.push_back(PostProcessToken(token)); + token_table_.push_back(tokenizer->IdToToken(token_id)); } return token_table_; } +std::vector Tokenizer::PostProcessTokenTable( + const std::vector& token_table, const std::string& postproc_method) { + std::vector postprocessed_token_table; + postprocessed_token_table.reserve(token_table.size()); + for (const std::string& token : token_table) { + postprocessed_token_table.push_back(PostProcessToken(token, postproc_method)); + } + return postprocessed_token_table; +} + TVM_REGISTER_GLOBAL("mlc.Tokenizer").set_body_typed([](const String& path) { return Tokenizer::FromPath(path); }); diff --git a/cpp/tokenizers.h b/cpp/tokenizers.h index 16d9ba456b..36fc0c23db 100644 --- a/cpp/tokenizers.h +++ b/cpp/tokenizers.h @@ -30,7 +30,7 @@ class TokenizerObj : public Object { std::vector Encode(const std::string& text) const; /*! \brief Decode token ids into text. */ std::string Decode(const std::vector& token_ids) const; - /*! \brief Return the token table of the tokenizer. */ + /*! \brief Return the token table of the tokenizer. Special tokens are included. */ const std::vector& TokenTable(); /*! @@ -64,6 +64,25 @@ class Tokenizer : public ObjectRef { /*! \brief Create a tokenizer from a directory path on disk. */ MLC_LLM_DLL static Tokenizer FromPath(const String& path); + /*! + * \brief Convert raw tokens provided by the tokenizer to their original string to simplify + * later processing. E.g. For LLaMA-2, convert "▁of" to " of". + * + * \param token_table The raw token table. + * \param postproc_method The postprocessing method to use. Now we only support "byte-fallback" + * and "byte-level", which refers to the type of the decoder of the tokenizer. + * - "byte-fallback": Use the decoding method in the byte-fallback BPE tokenizer. This is used + * by LLaMA-2, Mixtral-7b, etc. This method: 1) transform tokens like <0x1B> to hex char + * byte 1B. (known as the byte-fallback method); 2) transform \\u2581 to space. + * - "byte-level": Use the decoding method in the byte-level BPE tokenizer. This is used by + * LLaMA-3, GPT-2, Phi-2, etc. This method inverses the bytes-to-unicode transformation in + * the encoding process as in + * https://github.com/huggingface/transformers/blob/87be06ca77166e6a6215eee5a990ab9f07238a18/src/transformers/models/gpt2/tokenization_gpt2.py#L38-L59 + * \returns The postprocessed token table containing the original strings. + */ + static std::vector PostProcessTokenTable(const std::vector& token_table, + const std::string& postproc_method); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Tokenizer, ObjectRef, TokenizerObj); private: diff --git a/python/mlc_llm/interface/gen_config.py b/python/mlc_llm/interface/gen_config.py index 8e617fc3d2..13f0e1215f 100644 --- a/python/mlc_llm/interface/gen_config.py +++ b/python/mlc_llm/interface/gen_config.py @@ -5,7 +5,7 @@ import re import shutil from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Literal, Optional, Union from mlc_llm.conversation_template import ConvTemplateRegistry from mlc_llm.model import Model @@ -51,7 +51,11 @@ class MLCChatConfig: # pylint: disable=too-many-instance-attributes pad_token_id: int = None bos_token_id: int = None eos_token_id: int = None + # Tokenizer configuration tokenizer_files: List[str] = dataclasses.field(default_factory=list) + # The method to post-process the token table. See + # cpp/tokenizers.h::Tokenizer::PostProcessTokenTable for details + token_table_postproc_method: Literal["byte_fallback", "byte_level"] = None # Version control version: str = VERSION @@ -129,6 +133,70 @@ def json2rwkv_tokenizer(vocab: Path, out: Path) -> None: msgpack.pack(idx2token, f) +def detect_token_table_postproc_method(output_path: Path) -> Literal["byte_fallback", "byte_level"]: + """Detect the token table postprocessing method from tokenizer.json that is found under + output_path. If not detected, use ByteFallback as default. + + Check the decoder field of the tokenizer. If it uses ByteFallback decoder, return + "byte_fallback". If it uses ByteLevel decoder, return "byte_level". Otherwise, use + ByteFallback as default. + + See also cpp/tokenizers.h::Tokenizer::PostProcessTokenTable. + """ + output_tokenizer_path = output_path / "tokenizer.json" + if not output_tokenizer_path.exists(): + logger.warning( + "Tokenizer token table postprocessing method is not detected as tokenizer.json " + "is not found, use ByteFallback (the same as LLaMA/LLaMA2) by default" + ) + return "byte_fallback" + + with output_tokenizer_path.open("r", encoding="utf-8") as in_file: + tokenizer_json = json.load(in_file) + + # Find all decoders in tokenizer.json + decoders = [] + + if "decoder" not in tokenizer_json: + logger.warning( + "Decoder field is not found in tokenizer.json, use ByteFallback (the same as " + "LLaMA/LLaMA2) as the token table postprocessing method by default" + ) + return "byte_fallback" + + decoders_json = tokenizer_json["decoder"] + assert "type" in decoders_json, "Decoder type is not specified in tokenizer.json" + if decoders_json["type"] == "Sequence": + assert "decoders" in decoders_json + decoders = decoders_json["decoders"] + else: + decoders = [decoders_json] + + is_byte_level = False + is_byte_fallback = False + + for decoder in decoders: + if decoder["type"] == "ByteLevel": + is_byte_level = True + if decoder["type"] == "ByteFallback": + is_byte_fallback = True + assert not ( + is_byte_level and is_byte_fallback + ), "Tokenizer decoder cannot have both type ByteLevel and type ByteFallback" + + if is_byte_level: + return "byte_level" + if is_byte_fallback: + return "byte_fallback" + + logger.warning( + "Neither ByteLevel nor ByteFallback decoder is detected in tokenizer.json, use " + "ByteFallback (the same as LLaMA/LLaMA2) as the token table postprocessing method " + "by default" + ) + return "byte_fallback" + + def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-branches,too-many-statements config: Path, model: Model, @@ -255,6 +323,10 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b except Exception: # pylint: disable=broad-exception-caught logger.exception("%s with the exception below. Skipping", FAILED) + # 3.4. Find the token table postprocessing method from tokenizer.json if it exists. If not + # detected, use "byte_fallback" as default. + mlc_chat_config.token_table_postproc_method = detect_token_table_postproc_method(output) + # Step 4. Load system default value mlc_chat_config.apply_defaults() # Step 5. Dump the configuration file to output directory