diff --git a/3rdparty/tvm b/3rdparty/tvm index d694451c58..ced07e8878 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit d694451c580a931116a2c93571f21f7d791c7fa0 +Subproject commit ced07e88781c0d6416e276d9cd084bb46aaf3da5 diff --git a/android/library/prepare_libs.sh b/android/library/prepare_libs.sh index a06e9f067d..c089927d09 100755 --- a/android/library/prepare_libs.sh +++ b/android/library/prepare_libs.sh @@ -27,6 +27,7 @@ cmake .. \ -DMLC_LLM_INSTALL_STATIC_LIB=ON \ -DCMAKE_SKIP_INSTALL_ALL_DEPENDENCY=ON \ -DUSE_OPENCL=ON \ + -DUSE_OPENCL_ENABLE_HOST_PTR=ON \ -DUSE_CUSTOM_LOGGING=ON \ cmake --build . --target tvm4j_runtime_packed --config release diff --git a/cpp/json_ffi/conv_template.cc b/cpp/json_ffi/config.cc similarity index 85% rename from cpp/json_ffi/conv_template.cc rename to cpp/json_ffi/config.cc index 02e0b3bdbd..8f5c0e1062 100644 --- a/cpp/json_ffi/conv_template.cc +++ b/cpp/json_ffi/config.cc @@ -1,4 +1,6 @@ -#include "conv_template.h" +#include "config.h" + +#include #include "../metadata/json_parser.h" @@ -8,6 +10,29 @@ namespace json_ffi { using namespace mlc::llm; +/****************** Model-defined generation config ******************/ + +TVM_REGISTER_OBJECT_TYPE(ModelDefinedGenerationConfigNode); + +ModelDefinedGenerationConfig::ModelDefinedGenerationConfig(double temperature, double top_p, + double frequency_penalty, + double presence_penalty) { + ObjectPtr n = make_object(); + n->temperature = temperature; + n->top_p = top_p; + n->frequency_penalty = frequency_penalty; + n->presence_penalty = presence_penalty; + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("mlc.json_ffi.ModelDefinedGenerationConfig") + .set_body_typed([](double temperature, double top_p, double frequency_penalty, + double presence_penalty) { + return ModelDefinedGenerationConfig(temperature, top_p, frequency_penalty, presence_penalty); + }); + +/****************** Conversation template ******************/ + std::map PLACEHOLDERS = { {MessagePlaceholders::SYSTEM, "{system_message}"}, {MessagePlaceholders::USER, "{user_message}"}, @@ -308,6 +333,25 @@ std::optional Conversation::FromJSON(const std::string& json_str, } return Conversation::FromJSON(json_obj.value(), err); } + +/****************** JSON FFI engine config ******************/ + +TVM_REGISTER_OBJECT_TYPE(JSONFFIEngineConfigNode); + +JSONFFIEngineConfig::JSONFFIEngineConfig( + String conv_template, Map model_generation_cfgs) { + ObjectPtr n = make_object(); + n->conv_template = conv_template; + n->model_generation_cfgs = model_generation_cfgs; + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("mlc.json_ffi.JSONFFIEngineConfig") + .set_body_typed([](String conv_template, + Map model_generation_cfgs) { + return JSONFFIEngineConfig(std::move(conv_template), std::move(model_generation_cfgs)); + }); + } // namespace json_ffi } // namespace llm } // namespace mlc diff --git a/cpp/json_ffi/conv_template.h b/cpp/json_ffi/config.h similarity index 67% rename from cpp/json_ffi/conv_template.h rename to cpp/json_ffi/config.h index d3a1d1de2f..fe5e4e42e2 100644 --- a/cpp/json_ffi/conv_template.h +++ b/cpp/json_ffi/config.h @@ -1,5 +1,9 @@ -#ifndef MLC_LLM_JSON_FFI_CONV_TEMPLATE_H -#define MLC_LLM_JSON_FFI_CONV_TEMPLATE_H +#ifndef MLC_LLM_JSON_FFI_CONFIG_H +#define MLC_LLM_JSON_FFI_CONFIG_H + +#include +#include +#include #include #include @@ -18,6 +22,32 @@ namespace mlc { namespace llm { namespace json_ffi { +/****************** Model-defined generation config ******************/ + +class ModelDefinedGenerationConfigNode : public Object { + public: + double temperature; + double top_p; + double frequency_penalty; + double presence_penalty; + + static constexpr const char* _type_key = "mlc.json_ffi.ModelDefinedGenerationConfig"; + static constexpr const bool _type_has_method_sequal_reduce = false; + static constexpr const bool _type_has_method_shash_reduce = false; + TVM_DECLARE_BASE_OBJECT_INFO(ModelDefinedGenerationConfigNode, Object); +}; + +class ModelDefinedGenerationConfig : public ObjectRef { + public: + explicit ModelDefinedGenerationConfig(double temperature, double top_p, double frequency_penalty, + double presence_penalty); + + TVM_DEFINE_OBJECT_REF_METHODS(ModelDefinedGenerationConfig, ObjectRef, + ModelDefinedGenerationConfigNode); +}; + +/****************** Conversation template ******************/ + enum class MessagePlaceholders { SYSTEM, USER, ASSISTANT, TOOL, FUNCTION }; MessagePlaceholders messagePlaceholderFromString(const std::string& role); @@ -114,6 +144,27 @@ struct Conversation { static std::optional FromJSON(const std::string& json_str, std::string* err); }; +/****************** JSON FFI engine config ******************/ + +class JSONFFIEngineConfigNode : public Object { + public: + String conv_template; + Map model_generation_cfgs; + + static constexpr const char* _type_key = "mlc.json_ffi.JSONFFIEngineConfig"; + static constexpr const bool _type_has_method_sequal_reduce = false; + static constexpr const bool _type_has_method_shash_reduce = false; + TVM_DECLARE_BASE_OBJECT_INFO(JSONFFIEngineConfigNode, Object); +}; + +class JSONFFIEngineConfig : public ObjectRef { + public: + explicit JSONFFIEngineConfig(String conv_template, + Map model_generation_cfgs); + + TVM_DEFINE_OBJECT_REF_METHODS(JSONFFIEngineConfig, ObjectRef, JSONFFIEngineConfigNode); +}; + } // namespace json_ffi } // namespace llm } // namespace mlc diff --git a/cpp/json_ffi/json_ffi_engine.cc b/cpp/json_ffi/json_ffi_engine.cc index 0e21735e2f..d5fc53b8fa 100644 --- a/cpp/json_ffi/json_ffi_engine.cc +++ b/cpp/json_ffi/json_ffi_engine.cc @@ -83,8 +83,8 @@ bool JSONFFIEngine::AddRequest(std::string request_json_str, std::string request Array inputs = inputs_obj.value(); // generation_cfg - Optional generation_cfg = - GenerationConfig::FromJSON(request_json_str, &err_, conv_template); + Optional generation_cfg = GenerationConfig::Create( + request_json_str, &err_, conv_template, this->model_generation_cfgs[request.model]); if (!generation_cfg.defined()) { return false; } @@ -122,14 +122,16 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode { TVM_MODULE_VTABLE_ENTRY("exit_background_loop", &JSONFFIEngineImpl::ExitBackgroundLoop); TVM_MODULE_VTABLE_END(); - void InitBackgroundEngine(std::string conv_template_str, EngineConfig engine_config, - Optional request_stream_callback, + void InitBackgroundEngine(JSONFFIEngineConfig json_ffi_engine_config, EngineConfig engine_config, + Device device, Optional request_stream_callback, Optional trace_recorder) { - std::optional conv_template = Conversation::FromJSON(conv_template_str, &err_); + std::optional conv_template = + Conversation::FromJSON(json_ffi_engine_config->conv_template, &err_); if (!conv_template.has_value()) { LOG(FATAL) << "Invalid conversation template JSON: " << err_; } this->conv_template_ = conv_template.value(); + this->model_generation_cfgs = json_ffi_engine_config->model_generation_cfgs; // Todo(mlc-team): decouple InitBackgroundEngine into two functions // by removing `engine_config` from arguments, after properly handling @@ -148,7 +150,7 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode { }; request_stream_callback = PackedFunc(frequest_stream_callback_wrapper); - this->engine_->InitBackgroundEngine(std::move(request_stream_callback), + this->engine_->InitBackgroundEngine(device, std::move(request_stream_callback), std::move(trace_recorder)); this->engine_->Reload(std::move(engine_config)); } diff --git a/cpp/json_ffi/json_ffi_engine.h b/cpp/json_ffi/json_ffi_engine.h index 2c7501c337..d57384abb5 100644 --- a/cpp/json_ffi/json_ffi_engine.h +++ b/cpp/json_ffi/json_ffi_engine.h @@ -12,7 +12,7 @@ #include "../serve/threaded_engine.h" #include "../streamer.h" -#include "conv_template.h" +#include "config.h" #include "openai_api_protocol.h" namespace mlc { @@ -49,6 +49,7 @@ class JSONFFIEngine { PackedFunc request_stream_callback_; TextStreamer streamer_; // TODO: Support "n", and support different streamers for each request Conversation conv_template_; + Map model_generation_cfgs; }; } // namespace json_ffi diff --git a/cpp/json_ffi/openai_api_protocol.h b/cpp/json_ffi/openai_api_protocol.h index bed225d3d0..429050da3c 100644 --- a/cpp/json_ffi/openai_api_protocol.h +++ b/cpp/json_ffi/openai_api_protocol.h @@ -13,7 +13,7 @@ #include #include -#include "conv_template.h" +#include "config.h" #include "picojson.h" namespace mlc { @@ -90,8 +90,8 @@ class ChatCompletionRequest { public: std::vector messages; std::string model; - double frequency_penalty = 0.0; - double presence_penalty = 0.0; + std::optional frequency_penalty = std::nullopt; + std::optional presence_penalty = std::nullopt; bool logprobs = false; int top_logprobs = 0; std::optional> logit_bias = std::nullopt; @@ -100,8 +100,8 @@ class ChatCompletionRequest { std::optional seed = std::nullopt; std::optional> stop = std::nullopt; bool stream = false; - double temperature = 1.0; - double top_p = 1.0; + std::optional temperature = std::nullopt; + std::optional top_p = std::nullopt; std::optional> tools = std::nullopt; std::optional tool_choice = std::nullopt; std::optional user = std::nullopt; diff --git a/cpp/metadata/json_parser.h b/cpp/metadata/json_parser.h index f6ff10e1ac..99a284fc42 100644 --- a/cpp/metadata/json_parser.h +++ b/cpp/metadata/json_parser.h @@ -149,6 +149,22 @@ inline ValueType Lookup(const picojson::object& json, const std::string& key) { return it->second.get(); } +template +inline ValueType LookupOrDefault(const picojson::object& json, const std::string& key, + const ValueType& default_value) { + auto it = json.find(key); + if (it == json.end()) { + return default_value; + } + + if (it->second.is()) { + return default_value; + } + + CHECK(it->second.is()) << "ValueError: key `" << key << "` has unexpected type"; + return it->second.get(); +} + template inline ValueType Lookup(const picojson::array& json, int index) { CHECK(index < json.size()) << "IndexError: json::array index out of range"; diff --git a/cpp/serve/config.cc b/cpp/serve/config.cc index 7379bad7ed..3bb809ad67 100644 --- a/cpp/serve/config.cc +++ b/cpp/serve/config.cc @@ -161,15 +161,26 @@ GenerationConfig::GenerationConfig(String config_json_str) { data_ = std::move(n); } -Optional GenerationConfig::FromJSON(const std::string& json_str, std::string* err, - const Conversation& conv_template) { - std::optional json_obj = json::LoadJSONFromString(json_str, err); - if (!err->empty() || !json_obj.has_value()) { +Optional GenerationConfig::Create( + const std::string& json_str, std::string* err, const Conversation& conv_template, + const ModelDefinedGenerationConfig& model_defined_gen_config) { + std::optional optional_json_obj = json::LoadJSONFromString(json_str, err); + if (!err->empty() || !optional_json_obj.has_value()) { return NullOpt; } + picojson::object& json_obj = optional_json_obj.value(); ObjectPtr n = make_object(); - // TODO(mlc-team): Pass the parameters from `json_obj` to `n`. + n->temperature = + json::LookupOrDefault(json_obj, "temperature", model_defined_gen_config->temperature); + n->top_p = json::LookupOrDefault(json_obj, "top_p", model_defined_gen_config->top_p); + n->frequency_penalty = json::LookupOrDefault(json_obj, "frequency_penalty", + model_defined_gen_config->frequency_penalty); + n->presence_penalty = json::LookupOrDefault(json_obj, "presence_penalty", + model_defined_gen_config->presence_penalty); + n->logprobs = json::LookupOrDefault(json_obj, "logprobs", false); + n->top_logprobs = static_cast(json::LookupOrDefault(json_obj, "top_logprobs", 0)); + n->ignore_eos = json::LookupOrDefault(json_obj, "ignore_eos", false); // Copy stop str from conversation template to generation config for (auto& stop_str : conv_template.stop_str) { @@ -179,9 +190,6 @@ Optional GenerationConfig::FromJSON(const std::string& json_st n->stop_token_ids.push_back(stop_token_id); } - if (!err->empty()) { - return NullOpt; - } GenerationConfig gen_config; gen_config.data_ = std::move(n); return gen_config; @@ -236,37 +244,85 @@ String GenerationConfigNode::AsJSONString() const { TVM_REGISTER_OBJECT_TYPE(EngineConfigNode); EngineConfig::EngineConfig(String model, String model_lib_path, Array additional_models, - Array additional_model_lib_paths, DLDevice device, - int kv_cache_page_size, int max_num_sequence, - int max_total_sequence_length, int max_single_sequence_length, - int prefill_chunk_size, SpeculativeMode speculative_mode, - int spec_draft_length) { + Array additional_model_lib_paths, int kv_cache_page_size, + int max_num_sequence, int max_total_sequence_length, + int max_single_sequence_length, int prefill_chunk_size, + int max_history_size, KVStateKind kv_state_kind, + SpeculativeMode speculative_mode, int spec_draft_length) { ObjectPtr n = make_object(); n->model = std::move(model); n->model_lib_path = std::move(model_lib_path); n->additional_models = std::move(additional_models); n->additional_model_lib_paths = std::move(additional_model_lib_paths); - n->device = device; n->kv_cache_page_size = kv_cache_page_size; n->max_num_sequence = max_num_sequence; n->max_total_sequence_length = max_total_sequence_length; n->max_single_sequence_length = max_single_sequence_length; n->prefill_chunk_size = prefill_chunk_size; + n->max_history_size = max_history_size; + n->kv_state_kind = kv_state_kind; n->spec_draft_length = spec_draft_length; n->speculative_mode = speculative_mode; data_ = std::move(n); } +EngineConfig EngineConfig::FromJSONString(const std::string& json_str) { + picojson::value config_json; + std::string err = picojson::parse(config_json, json_str); + if (!err.empty()) { + LOG(FATAL) << err; + } + + // Get json fields. + picojson::object config = config_json.get(); + String model = json::Lookup(config, "model"); + String model_lib_path = json::Lookup(config, "model_lib_path"); + std::vector additional_models; + std::vector additional_model_lib_paths; + int kv_cache_page_size = json::Lookup(config, "kv_cache_page_size"); + int max_num_sequence = json::Lookup(config, "max_num_sequence"); + int max_total_sequence_length = json::Lookup(config, "max_total_sequence_length"); + int max_single_sequence_length = json::Lookup(config, "max_single_sequence_length"); + int prefill_chunk_size = json::Lookup(config, "prefill_chunk_size"); + int max_history_size = json::Lookup(config, "max_history_size"); + KVStateKind kv_state_kind = + static_cast(json::Lookup(config, "kv_state_kind")); + SpeculativeMode speculative_mode = + static_cast(json::Lookup(config, "speculative_mode")); + int spec_draft_length = json::Lookup(config, "spec_draft_length"); + + picojson::array additional_models_arr = + json::Lookup(config, "additional_models"); + picojson::array additional_model_lib_paths_arr = + json::Lookup(config, "additional_model_lib_paths"); + CHECK_EQ(additional_models_arr.size(), additional_model_lib_paths_arr.size()) + << "The number of additional model lib paths does not match the number of additional models"; + int num_additional_models = additional_models_arr.size(); + additional_models.reserve(num_additional_models); + additional_model_lib_paths.reserve(num_additional_models); + for (int i = 0; i < num_additional_models; ++i) { + additional_models.push_back(json::Lookup(additional_models_arr, i)); + additional_model_lib_paths.push_back( + json::Lookup(additional_model_lib_paths_arr, i)); + } + + return EngineConfig(std::move(model), std::move(model_lib_path), additional_models, + additional_model_lib_paths, kv_cache_page_size, max_num_sequence, + max_total_sequence_length, max_single_sequence_length, prefill_chunk_size, + max_history_size, kv_state_kind, speculative_mode, spec_draft_length); +} + TVM_REGISTER_GLOBAL("mlc.serve.EngineConfig") .set_body_typed([](String model, String model_lib_path, Array additional_models, - Array additional_model_lib_paths, DLDevice device, - int kv_cache_page_size, int max_num_sequence, int max_total_sequence_length, - int max_single_sequence_length, int prefill_chunk_size, int speculative_mode, - int spec_draft_length) { + Array additional_model_lib_paths, int kv_cache_page_size, + int max_num_sequence, int max_total_sequence_length, + int max_single_sequence_length, int prefill_chunk_size, int max_history_size, + int kv_state_kind, int speculative_mode, int spec_draft_length) { return EngineConfig(std::move(model), std::move(model_lib_path), std::move(additional_models), - std::move(additional_model_lib_paths), device, kv_cache_page_size, + std::move(additional_model_lib_paths), kv_cache_page_size, max_num_sequence, max_total_sequence_length, max_single_sequence_length, - prefill_chunk_size, SpeculativeMode(speculative_mode), spec_draft_length); + prefill_chunk_size, max_history_size, KVStateKind(kv_state_kind), + SpeculativeMode(speculative_mode), spec_draft_length); }); } // namespace serve diff --git a/cpp/serve/config.h b/cpp/serve/config.h index 41ddb3c6e4..fd76dd49f0 100644 --- a/cpp/serve/config.h +++ b/cpp/serve/config.h @@ -11,7 +11,7 @@ #include -#include "../json_ffi/conv_template.h" +#include "../json_ffi/config.h" namespace mlc { namespace llm { @@ -63,11 +63,13 @@ class GenerationConfig : public ObjectRef { explicit GenerationConfig(String config_json_str); /*! - * \brief Parse the generation config from the given JSON string. - * When parsing fails, errors are dumped to the input error string, and NullOpt is returned. + * \brief Create a generation config from a ChatCompletionRequest. + * If the request does not contain a generation config, the model-defined + * generation config will be used. */ - static Optional FromJSON(const std::string& json_str, std::string* err, - const Conversation& conv_template); + static Optional Create( + const std::string& json_str, std::string* err, const Conversation& conv_template, + const ModelDefinedGenerationConfig& model_defined_gen_config); TVM_DEFINE_OBJECT_REF_METHODS(GenerationConfig, ObjectRef, GenerationConfigNode); }; @@ -84,6 +86,12 @@ enum class SpeculativeMode : int { kEagle = 2, }; +/*! \brief The kind of cache. */ +enum KVStateKind { + kAttention = 0, + kRNNState = 1, +}; + /*! \brief The configuration of engine execution config. */ class EngineConfigNode : public Object { public: @@ -98,11 +106,6 @@ class EngineConfigNode : public Object { /*! \brief The path to the additional models' libraries. */ Array additional_model_lib_paths; - /*************** Device ***************/ - - /*! \brief The device where the models run. */ - DLDevice device; - /*************** KV cache config and engine capacities ***************/ /*! \brief The number of consecutive tokens handled in each page in paged KV cache. */ @@ -121,6 +124,10 @@ class EngineConfigNode : public Object { int max_single_sequence_length; /*! \brief The maximum total sequence length in a prefill. */ int prefill_chunk_size; + /*! \brief The maximum history size for RNN state. KV cache does not need this. */ + int max_history_size; + /*! \brief The kind of cache. Whether it's KV cache or RNN state. */ + KVStateKind kv_state_kind; /*************** Speculative decoding ***************/ @@ -140,11 +147,15 @@ class EngineConfigNode : public Object { class EngineConfig : public ObjectRef { public: explicit EngineConfig(String model, String model_lib_path, Array additional_models, - Array additional_model_lib_paths, DLDevice device, - int kv_cache_page_size, int max_num_sequence, int max_total_sequence_length, + Array additional_model_lib_paths, int kv_cache_page_size, + int max_num_sequence, int max_total_sequence_length, int max_single_sequence_length, int prefill_chunk_size, + int max_history_size, KVStateKind kv_state_kind, SpeculativeMode speculative_mode, int spec_draft_length); + /*! \brief Create EngineConfig from JSON string. */ + static EngineConfig FromJSONString(const std::string& json_str); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(EngineConfig, ObjectRef, EngineConfigNode); }; diff --git a/cpp/serve/draft_token_workspace_manager.cc b/cpp/serve/draft_token_workspace_manager.cc new file mode 100644 index 0000000000..185b899e14 --- /dev/null +++ b/cpp/serve/draft_token_workspace_manager.cc @@ -0,0 +1,54 @@ +/*! + * Copyright (c) 2024 by Contributors + * \file serve/draft_token_workspace_manager.cc + */ + +#include "draft_token_workspace_manager.h" + +#include "model.h" + +namespace mlc { +namespace llm { +namespace serve { + +DraftTokenWorkspaceManagerObj::DraftTokenWorkspaceManagerObj(int max_num_tokens, int vocab_size, + int hidden_size, + DLDataType hidden_states_dtype, + DLDevice device, + const FunctionTable& ft) + : max_num_tokens_(max_num_tokens), + vocab_size_(vocab_size), + hidden_size_(hidden_size), + hidden_states_dtype_(hidden_states_dtype), + device_(device), + ft_(ft) { + free_slots_.resize(max_num_tokens); + std::iota(free_slots_.begin(), free_slots_.end(), 0); +} + +void DraftTokenWorkspaceManagerObj::AllocSlots(int num_slots, std::vector* result) { + ICHECK_LE(num_slots, free_slots_.size()); + result->assign(free_slots_.rbegin(), free_slots_.rbegin() + num_slots); + std::vector allocated(free_slots_.begin(), free_slots_.begin() + num_slots); + free_slots_.resize(free_slots_.size() - num_slots); +} + +void DraftTokenWorkspaceManagerObj::FreeSlots(const std::vector& slots) { + std::copy(slots.begin(), slots.end(), std::back_inserter(free_slots_)); +} + +void DraftTokenWorkspaceManagerObj::AllocWorkspace(ModelWorkspace* workspace, + bool require_hidden_states) { + workspace->draft_probs = + NDArray::Empty({max_num_tokens_, vocab_size_}, DataType::Float(32), device_); + workspace->draft_probs_storage = + NDArray::Empty({max_num_tokens_, vocab_size_}, DataType::Float(32), device_); + if (require_hidden_states) { + workspace->draft_hidden_states_storage = + NDArray::Empty({max_num_tokens_, hidden_size_}, hidden_states_dtype_, device_); + } +} + +} // namespace serve +} // namespace llm +} // namespace mlc diff --git a/cpp/serve/draft_token_workspace_manager.h b/cpp/serve/draft_token_workspace_manager.h new file mode 100644 index 0000000000..1a1dfbc8e0 --- /dev/null +++ b/cpp/serve/draft_token_workspace_manager.h @@ -0,0 +1,95 @@ +/*! + * Copyright (c) 2024 by Contributors + * \file serve/draft_token_workspace_manager.h + */ + +#ifndef MLC_LLM_SERVE_DRAFT_TOKEN_WORKSPACE_MANAGER_H_ +#define MLC_LLM_SERVE_DRAFT_TOKEN_WORKSPACE_MANAGER_H_ +#include + +#include +#include +#include + +#include "data.h" +#include "function_table.h" +namespace mlc { +namespace llm { +namespace serve { + +using tvm::Device; +using namespace tvm::runtime; + +struct ModelWorkspace; + +/*! + * \brief Managing the workspace for draft token generation. + * + * The workspace is used to store the associated states for each draft token, including the + * probability distribution of the draft token, the hidden states, etc. The workspace manager + * maintains a pool of slots for the draft tokens to store the states. + */ +class DraftTokenWorkspaceManagerObj : public Object { + public: + /*! + * \brief Constructor + * \param max_num_tokens The maximum number of draft tokens that can be stored in the workspace. + * \param vocab_size The size of the vocabulary. + * \param hidden_size The size of the hidden states. + * \param hidden_states_dtype The data type of the hidden states. + * \param device The device running the model. + * \param ft The function table. + */ + DraftTokenWorkspaceManagerObj(int max_num_tokens, int vocab_size, int hidden_size, + DLDataType hidden_states_dtype, DLDevice device, + const FunctionTable& ft); + + /*! + * \brief Allocate the workspace for draft tokens and update `ModelWorkspace` data structure. + * \param workspace The object to stored the allocated draft token workspace. + * \param require_hidden_states Whether to allocate workspace for the hidden states. + */ + void AllocWorkspace(ModelWorkspace* workspace, bool require_hidden_states); + + /*! + * \brief Allocate slots for the draft tokens. + * \param num_slots The number of slots to allocate. + * \param result The vector to store the allocated slots. + */ + void AllocSlots(int num_slots, std::vector* result); + + /*! + * \brief Free the slots. + * \param slots The slots to free. + */ + void FreeSlots(const std::vector& slots); + + static constexpr const char* _type_key = "mlc.serve.DraftTokenWorkspaceManager"; + + private: + std::vector free_slots_; + int max_num_tokens_; + int vocab_size_; + int hidden_size_; + DataType hidden_states_dtype_; + DLDevice device_; + const FunctionTable& ft_; +}; + +class DraftTokenWorkspaceManager : public ObjectRef { + public: + DraftTokenWorkspaceManager(int max_num_tokens, int vocab_size, int hidden_size, + DLDataType hidden_states_dtype, DLDevice device, + const FunctionTable& ft) { + data_ = make_object(max_num_tokens, vocab_size, hidden_size, + hidden_states_dtype, device, ft); + } + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(DraftTokenWorkspaceManager, ObjectRef, + DraftTokenWorkspaceManagerObj); +}; + +} // namespace serve +} // namespace llm +} // namespace mlc + +#endif // MLC_LLM_SERVE_DRAFT_TOKEN_WORKSPACE_MANAGER_H_ diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index 8568c6ce94..755af998cd 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -12,6 +12,7 @@ #include #include +#include #include #include #include @@ -44,7 +45,8 @@ class EngineImpl : public Engine { public: /********************** Engine Management **********************/ - explicit EngineImpl(EngineConfig engine_config, Optional request_stream_callback, + explicit EngineImpl(EngineConfig engine_config, DLDevice device, + Optional request_stream_callback, Optional trace_recorder) { // Step 1. Initialize metadata and singleton states inside the engine this->estate_->Reset(); @@ -54,22 +56,30 @@ class EngineImpl : public Engine { } this->request_stream_callback_ = std::move(request_stream_callback); this->trace_recorder_ = trace_recorder; - this->tokenizer_ = Tokenizer::FromPath(engine_config->model); - this->token_table_ = tokenizer_->TokenTable(); - this->grammar_init_context_storage_ = GrammarInitContextStorage(this->token_table_); + // Step 2. Initialize each model independently. // Create the logit processor and sampler. this->models_.clear(); this->model_workspaces_.clear(); - auto f_create_model = [this, &engine_config, &trace_recorder](const String& model_path, - const String& model_lib_path) { - Model model = Model::Create(model_lib_path, std::move(model_path), engine_config->device, - engine_config->max_num_sequence, + std::vector model_configs; + model_configs.push_back(Model::LoadModelConfig(engine_config->model)); + for (const auto& model_path : engine_config->additional_models) { + model_configs.push_back(Model::LoadModelConfig(model_path)); + } + + Optional session = CreateDiscoSession(model_configs, device); + + auto f_create_model = [this, &engine_config, &device, &trace_recorder, &model_configs, + &session](const String& model_path, const String& model_lib_path, + int model_index) { + Model model = Model::Create(model_lib_path, std::move(model_path), model_configs[model_index], + device, engine_config->max_num_sequence, session, /*trace_enabled=*/trace_recorder.defined()); model->CreateKVCache(engine_config->kv_cache_page_size, engine_config->max_num_sequence, engine_config->max_total_sequence_length, - engine_config->prefill_chunk_size); + engine_config->prefill_chunk_size, engine_config->max_history_size, + engine_config->kv_state_kind); CHECK_GE(model->GetMaxWindowSize(), engine_config->max_single_sequence_length) << "The window size of the model, " << model->GetMaxWindowSize() << ", is smaller than the pre-defined max single sequence length, " @@ -79,53 +89,78 @@ class EngineImpl : public Engine { ModelWorkspace{model->AllocEmbeddingTensor(), model->AllocHiddenStatesTensor()}); }; - f_create_model(engine_config->model, engine_config->model_lib_path); + f_create_model(engine_config->model, engine_config->model_lib_path, /*model_index=*/0); CHECK_EQ(engine_config->additional_models.size(), engine_config->additional_model_lib_paths.size()) << "The additional model and lib path list has mismatched size."; for (int i = 0; i < static_cast(engine_config->additional_models.size()); ++i) { f_create_model(engine_config->additional_models[i], - engine_config->additional_model_lib_paths[i]); + engine_config->additional_model_lib_paths[i], /*model_index=*/i + 1); + } + + // Step 3. Initialize tokenizer and grammar + this->tokenizer_ = Tokenizer::FromPath(engine_config->model); + std::string token_table_postproc_method; + if (model_configs[0].count("token_table_postproc_method") == 0) { + // Backward compatibility: use "byte-fallback" by default + token_table_postproc_method = "byte-fallback"; + } else { + token_table_postproc_method = + model_configs[0].at("token_table_postproc_method").get(); } + this->token_table_ = + Tokenizer::PostProcessTokenTable(tokenizer_->TokenTable(), token_table_postproc_method); + this->grammar_init_context_storage_ = GrammarInitContextStorage(this->token_table_); + // Step 4. Initialize engine actions that represent state transitions. int max_num_tokens = engine_config->max_num_sequence; + DraftTokenWorkspaceManager draft_token_workspace_manager{nullptr}; if (engine_config->speculative_mode != SpeculativeMode::kDisable) { max_num_tokens *= engine_config->spec_draft_length + 1; + draft_token_workspace_manager = models_[0]->CreateDraftTokenWorkspaceManager(max_num_tokens); + draft_token_workspace_manager->AllocWorkspace( + &model_workspaces_[0], + /*require_hidden_states=*/engine_config->speculative_mode == SpeculativeMode::kEagle); } LogitProcessor logit_processor = this->models_[0]->CreateLogitProcessor(max_num_tokens, trace_recorder); Sampler sampler = this->models_[0]->CreateSampler( max_num_tokens, static_cast(this->models_.size()), trace_recorder); - // Step 3. Initialize engine actions that represent state transitions. if (engine_config->speculative_mode != SpeculativeMode::kDisable) { // Speculative decoding is only possible for more than one model. ICHECK_GT(this->models_.size(), 1U); switch (engine_config->speculative_mode) { case SpeculativeMode::kEagle: - this->actions_ = {EngineAction::EagleNewRequestPrefill(this->models_, // - logit_processor, // - sampler, // - this->model_workspaces_, // - engine_config, // - this->trace_recorder_), - EngineAction::EagleBatchDraft( - this->models_, logit_processor, sampler, this->model_workspaces_, - this->trace_recorder_, engine_config->spec_draft_length), - EngineAction::EagleBatchVerify(this->models_, logit_processor, sampler, - this->model_workspaces_, engine_config, - this->trace_recorder_)}; + this->actions_ = { + EngineAction::EagleNewRequestPrefill(this->models_, // + logit_processor, // + sampler, // + this->model_workspaces_, // + draft_token_workspace_manager, // + engine_config, // + this->trace_recorder_), + EngineAction::EagleBatchDraft(this->models_, logit_processor, sampler, + this->model_workspaces_, draft_token_workspace_manager, + this->trace_recorder_, + engine_config->spec_draft_length), + EngineAction::EagleBatchVerify(this->models_, logit_processor, sampler, + this->model_workspaces_, draft_token_workspace_manager, + engine_config, this->trace_recorder_)}; break; default: - this->actions_ = {EngineAction::NewRequestPrefill(this->models_, // - logit_processor, // - sampler, // - this->model_workspaces_, // - engine_config, // - this->trace_recorder_), - EngineAction::BatchDraft(this->models_, logit_processor, sampler, - this->trace_recorder_), - EngineAction::BatchVerify(this->models_, logit_processor, sampler, - engine_config, this->trace_recorder_)}; + this->actions_ = { + EngineAction::NewRequestPrefill(this->models_, // + logit_processor, // + sampler, // + this->model_workspaces_, // + engine_config, // + this->trace_recorder_), + EngineAction::BatchDraft(this->models_, logit_processor, sampler, + this->model_workspaces_, draft_token_workspace_manager, + this->trace_recorder_), + EngineAction::BatchVerify(this->models_, logit_processor, sampler, + this->model_workspaces_, draft_token_workspace_manager, + engine_config, this->trace_recorder_)}; } } else { this->actions_ = {EngineAction::NewRequestPrefill(this->models_, // @@ -285,6 +320,51 @@ class EngineImpl : public Engine { "action (e.g. prefill, decode, etc.) but it does not."; } + /************** Utility Functions **************/ + Optional CreateDiscoSession(std::vector model_configs, Device device) { + const auto& base_model_config = model_configs[0]; + + auto f_get_num_shards = [](const picojson::object& model_config) -> int { + constexpr auto kNumShardsKey = "tensor_parallel_shards"; + if (model_config.count(kNumShardsKey)) { + const auto& val = model_config.at(kNumShardsKey); + CHECK(val.is()); + return static_cast(val.get()); + } else { + LOG(FATAL) << "Key \"tensor_parallel_shards\" not found."; + } + throw; + }; + + int num_shards = std::transform_reduce( + model_configs.begin(), model_configs.end(), 1, [](int a, int b) { return std::max(a, b); }, + f_get_num_shards); + Optional session = NullOpt; + if (num_shards > 1) { + constexpr const char* f_create_process_pool = "runtime.disco.create_process_pool"; + if (Registry::Get(f_create_process_pool) == nullptr) { + LOG(FATAL) << "Cannot find process launcher `" << f_create_process_pool << "`. " + << "Multi-GPU inference depends on MLC LLM Python API to launch process."; + } + std::string ccl; + if (device.device_type == kDLCUDA) { + ccl = "nccl"; + } else if (device.device_type == kDLROCM) { + ccl = "rccl"; + } else { + LOG(FATAL) << "ValueError: Multi-GPU on device " << DLDeviceType2Str(device.device_type) + << " is not supported. Currently, only NCCL and RCCL are integrated."; + } + std::vector device_ids(num_shards); + for (int i = 0; i < num_shards; ++i) { + device_ids[i] = i; + } + session = Session::ProcessSession(num_shards, f_create_process_pool, "mlc_llm.cli.worker"); + session.value()->InitCCL(ccl, ShapeTuple(device_ids)); + } + return session; + } + /************** Debug/Profile **************/ void DebugCallFuncOnAllAllWorker(const String& func_name) final { @@ -338,10 +418,11 @@ class EngineImpl : public Engine { Optional trace_recorder_; }; -std::unique_ptr Engine::Create(EngineConfig engine_config, +std::unique_ptr Engine::Create(EngineConfig engine_config, Device device, Optional request_stream_callback, Optional trace_recorder) { - return std::make_unique(std::move(engine_config), std::move(request_stream_callback), + return std::make_unique(std::move(engine_config), device, + std::move(request_stream_callback), std::move(trace_recorder)); } @@ -367,10 +448,10 @@ class EngineModule : public ModuleNode { TVM_MODULE_VTABLE_END(); /*! \brief Initialize the engine with config and other fields. */ - void Init(EngineConfig engine_config, Optional request_stream_callback, + void Init(EngineConfig engine_config, Device device, Optional request_stream_callback, Optional trace_recorder) { - this->engine_ = Engine::Create(std::move(engine_config), std::move(request_stream_callback), - std::move(trace_recorder)); + this->engine_ = Engine::Create(std::move(engine_config), device, + std::move(request_stream_callback), std::move(trace_recorder)); } /*! \brief Construct an EngineModule. */ static tvm::runtime::Module Create() { return Module(make_object()); } diff --git a/cpp/serve/engine.h b/cpp/serve/engine.h index bcc1b80988..2fc0a4d730 100644 --- a/cpp/serve/engine.h +++ b/cpp/serve/engine.h @@ -51,11 +51,12 @@ class Engine { /*! * \brief Create an engine in unique pointer. * \param engine_config The engine config. + * \param device The device where the run models. * \param request_stream_callback The request stream callback function to. * \param trace_recorder Event trace recorder for requests. * \return The created Engine in pointer. */ - static std::unique_ptr Create(EngineConfig engine_config, + static std::unique_ptr Create(EngineConfig engine_config, Device device, Optional request_stream_callback, Optional trace_recorder); diff --git a/cpp/serve/engine_actions/action.h b/cpp/serve/engine_actions/action.h index 79359c5741..c69c508810 100644 --- a/cpp/serve/engine_actions/action.h +++ b/cpp/serve/engine_actions/action.h @@ -8,6 +8,7 @@ #define MLC_LLM_SERVE_ENGINE_ACTIONS_ACTION_H_ #include "../config.h" +#include "../draft_token_workspace_manager.h" #include "../engine_state.h" #include "../event_trace_recorder.h" #include "../model.h" @@ -72,15 +73,16 @@ class EngineAction : public ObjectRef { * \param logit_processor The logit processor. * \param sampler The sampler to sample new tokens. * \param model_workspaces The workspace of each model. + * \param draft_token_workspace_manager The draft token workspace manager. * \param engine_config The engine config. * \param trace_recorder The event trace recorder for requests. * \return The created action object. */ - static EngineAction EagleNewRequestPrefill(Array models, LogitProcessor logit_processor, - Sampler sampler, - std::vector model_workspaces, - EngineConfig engine_config, - Optional trace_recorder); + static EngineAction EagleNewRequestPrefill( + Array models, LogitProcessor logit_processor, Sampler sampler, + std::vector model_workspaces, + DraftTokenWorkspaceManager draft_token_workspace_manager, EngineConfig engine_config, + Optional trace_recorder); /*! * \brief Create the action that runs one-step decode for requests in the * `running_queue` of engine state. Preempt low-priority requests @@ -104,13 +106,16 @@ class EngineAction : public ObjectRef { * \param models The model to run decode in. When there are multiple * models, the `Step` function of the created action will not take effect. * \param sampler The sampler to sample new tokens. + * \param model_workspaces The workspace of each model. + * \param draft_token_workspace_manager The draft token workspace manager. * \param trace_recorder The event trace recorder for requests. * \param draft_length The number of draft proposal rounds. * \return The created action object. */ static EngineAction BatchDraft(Array models, LogitProcessor logit_processor, - Sampler sampler, Optional trace_recorder, - int draft_length = 4); + Sampler sampler, std::vector model_workspaces, + DraftTokenWorkspaceManager draft_token_workspace_manager, + Optional trace_recorder, int draft_length = 4); /*! * \brief Create the action that runs one-step speculative draft proposal for @@ -120,12 +125,14 @@ class EngineAction : public ObjectRef { * models, the `Step` function of the created action will not take effect. * \param sampler The sampler to sample new tokens. * \param model_workspaces The workspace of each model. + * \param draft_token_workspace_manager The draft token workspace manager. * \param trace_recorder The event trace recorder for requests. * \param draft_length The number of draft proposal rounds. * \return The created action object. */ static EngineAction EagleBatchDraft(Array models, LogitProcessor logit_processor, Sampler sampler, std::vector model_workspaces, + DraftTokenWorkspaceManager draft_token_workspace_manager, Optional trace_recorder, int draft_length = 4); @@ -135,13 +142,17 @@ class EngineAction : public ObjectRef { * accordingly when it is impossible to decode all the running requests. * \param models The model to run decode in. When there are multiple * models, the `Step` function of the created action will not take effect. + * \param model_workspaces The workspace of each model. + * \param draft_token_workspace_manager The draft token workspace manager. * \param sampler The sampler to sample new tokens. * \param engine_config The engine config. * \param trace_recorder The event trace recorder for requests. * \return The created action object. */ static EngineAction BatchVerify(Array models, LogitProcessor logit_processor, - Sampler sampler, EngineConfig engine_config, + Sampler sampler, std::vector model_workspaces, + DraftTokenWorkspaceManager draft_token_workspace_manager, + EngineConfig engine_config, Optional trace_recorder); /*! @@ -152,6 +163,7 @@ class EngineAction : public ObjectRef { * models, the `Step` function of the created action will not take effect. * \param sampler The sampler to sample new tokens. * \param model_workspaces The workspace of each model. + * \param draft_token_workspace_manager The draft token workspace manager. * \param engine_config The engine config. * \param trace_recorder The event trace recorder for requests. * \return The created action object. @@ -159,6 +171,7 @@ class EngineAction : public ObjectRef { static EngineAction EagleBatchVerify(Array models, LogitProcessor logit_processor, Sampler sampler, std::vector model_workspaces, + DraftTokenWorkspaceManager draft_token_workspace_manager, EngineConfig engine_config, Optional trace_recorder); diff --git a/cpp/serve/engine_actions/action_commons.cc b/cpp/serve/engine_actions/action_commons.cc index 6eb7a3d84a..af0dfe978d 100644 --- a/cpp/serve/engine_actions/action_commons.cc +++ b/cpp/serve/engine_actions/action_commons.cc @@ -142,9 +142,10 @@ void ActionStepPostProcess(Array requests, EngineState estate, Array& models, - Optional trace_recorder) { +RequestStateEntry PreemptLastRunningRequestStateEntry( + EngineState estate, const Array& models, + Optional draft_token_workspace_manager, + Optional trace_recorder) { ICHECK(!estate->running_queue.empty()); Request request = estate->running_queue.back(); @@ -168,8 +169,12 @@ RequestStateEntry PreemptLastRunningRequestStateEntry(EngineState estate, // - Update `inputs` for future prefill. RECORD_EVENT(trace_recorder, rsentry->request->id, "preempt"); rsentry->status = RequestStateStatus::kPending; + std::vector draft_token_slots; for (RequestModelState mstate : rsentry->mstates) { - mstate->RemoveAllDraftTokens(); + if (draft_token_workspace_manager.defined()) { + mstate->RemoveAllDraftTokens(&draft_token_slots); + draft_token_workspace_manager.value()->FreeSlots(draft_token_slots); + } std::vector committed_token_ids; committed_token_ids.reserve(mstate->committed_tokens.size()); for (const SampleResult& committed_token : mstate->committed_tokens) { diff --git a/cpp/serve/engine_actions/action_commons.h b/cpp/serve/engine_actions/action_commons.h index 78e3937d0b..07bef2d2d9 100644 --- a/cpp/serve/engine_actions/action_commons.h +++ b/cpp/serve/engine_actions/action_commons.h @@ -7,6 +7,7 @@ #define MLC_LLM_SERVE_ENGINE_ACTIONS_ACTION_COMMONS_H_ #include "../../tokenizers.h" +#include "../draft_token_workspace_manager.h" #include "../engine.h" #include "../engine_state.h" #include "../event_trace_recorder.h" @@ -52,12 +53,14 @@ void ActionStepPostProcess(Array requests, EngineState estate, Array& models, - Optional trace_recorder); +RequestStateEntry PreemptLastRunningRequestStateEntry( + EngineState estate, const Array& models, + Optional draft_token_workspace_manager, + Optional trace_recorder); /*! \brief Get the running request entries from the engine state. */ inline std::vector GetRunningRequestStateEntries(const EngineState& estate) { diff --git a/cpp/serve/engine_actions/batch_decode.cc b/cpp/serve/engine_actions/batch_decode.cc index 36acc6b06e..ecff914baa 100644 --- a/cpp/serve/engine_actions/batch_decode.cc +++ b/cpp/serve/engine_actions/batch_decode.cc @@ -48,7 +48,7 @@ class BatchDecodeActionObj : public EngineActionObj { running_rsentries = GetRunningRequestStateEntries(estate); while (!CanDecode(running_rsentries.size())) { RequestStateEntry preempted = - PreemptLastRunningRequestStateEntry(estate, models_, trace_recorder_); + PreemptLastRunningRequestStateEntry(estate, models_, NullOpt, trace_recorder_); if (preempted.same_as(running_rsentries.back())) { running_rsentries.pop_back(); } diff --git a/cpp/serve/engine_actions/batch_draft.cc b/cpp/serve/engine_actions/batch_draft.cc index c1ddeb6e4e..513a0fe447 100644 --- a/cpp/serve/engine_actions/batch_draft.cc +++ b/cpp/serve/engine_actions/batch_draft.cc @@ -23,10 +23,14 @@ namespace serve { class BatchDraftActionObj : public EngineActionObj { public: explicit BatchDraftActionObj(Array models, LogitProcessor logit_processor, Sampler sampler, + std::vector model_workspaces, + DraftTokenWorkspaceManager draft_token_workspace_manager, Optional trace_recorder, int draft_length) : models_(std::move(models)), logit_processor_(std::move(logit_processor)), sampler_(std::move(sampler)), + model_workspaces_(std::move(model_workspaces)), + draft_token_workspace_manager_(std::move(draft_token_workspace_manager)), trace_recorder_(std::move(trace_recorder)), draft_length_(draft_length) { ICHECK_GT(draft_length_, 0); @@ -41,8 +45,8 @@ class BatchDraftActionObj : public EngineActionObj { // Preempt request state entries when decode cannot apply. std::vector running_rsentries = GetRunningRequestStateEntries(estate); while (!CanDecode(running_rsentries.size())) { - RequestStateEntry preempted = - PreemptLastRunningRequestStateEntry(estate, models_, trace_recorder_); + RequestStateEntry preempted = PreemptLastRunningRequestStateEntry( + estate, models_, draft_token_workspace_manager_, trace_recorder_); if (preempted.same_as(running_rsentries.back())) { running_rsentries.pop_back(); } @@ -123,8 +127,11 @@ class BatchDraftActionObj : public EngineActionObj { ICHECK_EQ(sample_results.size(), num_rsentries); // - Add draft token to the state. + draft_token_workspace_manager_->AllocSlots(num_rsentries, &draft_token_slots_); + models_[model_id]->ScatterDraftProbs(probs_on_device, draft_token_slots_, + &model_workspaces_[0].draft_probs_storage); for (int i = 0; i < num_rsentries; ++i) { - mstates[i]->AddDraftToken(sample_results[i], prob_dist[i]); + mstates[i]->AddDraftToken(sample_results[i], draft_token_slots_[i]); estate->stats.total_draft_length += 1; } } @@ -156,18 +163,27 @@ class BatchDraftActionObj : public EngineActionObj { LogitProcessor logit_processor_; /*! \brief The sampler to sample new tokens. */ Sampler sampler_; + /*! \brief The model workspaces. */ + std::vector model_workspaces_; + /*! \brief The draft token workspace manager. */ + DraftTokenWorkspaceManager draft_token_workspace_manager_; /*! \brief Event trace recorder. */ Optional trace_recorder_; /*! \brief Draft proposal length */ int draft_length_; + /*! \brief Temporary buffer to store the slots of the current draft tokens */ + std::vector draft_token_slots_; }; EngineAction EngineAction::BatchDraft(Array models, LogitProcessor logit_processor, - Sampler sampler, Optional trace_recorder, + Sampler sampler, std::vector model_workspaces, + DraftTokenWorkspaceManager draft_token_workspace_manager, + Optional trace_recorder, int draft_length) { return EngineAction(make_object( - std::move(models), std::move(logit_processor), std::move(sampler), std::move(trace_recorder), - draft_length)); + std::move(models), std::move(logit_processor), std::move(sampler), + std::move(model_workspaces), std::move(draft_token_workspace_manager), + std::move(trace_recorder), draft_length)); } } // namespace serve diff --git a/cpp/serve/engine_actions/batch_verify.cc b/cpp/serve/engine_actions/batch_verify.cc index 42c9bbe018..6f27a50394 100644 --- a/cpp/serve/engine_actions/batch_verify.cc +++ b/cpp/serve/engine_actions/batch_verify.cc @@ -28,11 +28,15 @@ namespace serve { class BatchVerifyActionObj : public EngineActionObj { public: explicit BatchVerifyActionObj(Array models, LogitProcessor logit_processor, - Sampler sampler, EngineConfig engine_config, + Sampler sampler, std::vector model_workspaces, + DraftTokenWorkspaceManager draft_token_workspace_manager, + EngineConfig engine_config, Optional trace_recorder) : models_(std::move(models)), logit_processor_(std::move(logit_processor)), sampler_(std::move(sampler)), + model_workspaces_(std::move(model_workspaces)), + draft_token_workspace_manager_(std::move(draft_token_workspace_manager)), engine_config_(std::move(engine_config)), trace_recorder_(std::move(trace_recorder)), rng_(RandomGenerator::GetInstance()) {} @@ -61,14 +65,13 @@ class BatchVerifyActionObj : public EngineActionObj { Array generation_cfg; std::vector rngs; std::vector> draft_output_tokens; - std::vector> draft_output_prob_dist; request_internal_ids.reserve(num_rsentries); all_tokens_to_verify.reserve(total_verify_length); verify_request_mstates.reserve(num_rsentries); rngs.reserve(num_rsentries); generation_cfg.reserve(num_rsentries); draft_output_tokens.reserve(num_rsentries); - draft_output_prob_dist.reserve(num_rsentries); + draft_token_slots_.clear(); for (int i = 0; i < num_rsentries; ++i) { RequestModelState verify_mstate = rsentries[i]->mstates[verify_model_id_]; @@ -76,18 +79,22 @@ class BatchVerifyActionObj : public EngineActionObj { request_internal_ids.push_back(verify_mstate->internal_id); ICHECK(!verify_lengths.empty()); ICHECK_EQ(verify_lengths[i], draft_mstate->draft_output_tokens.size() + 1); - ICHECK_EQ(verify_lengths[i], draft_mstate->draft_output_prob_dist.size() + 1); + ICHECK_EQ(verify_lengths[i], draft_mstate->draft_token_slots.size() + 1); // the last committed token + all the draft tokens. + draft_token_slots_.push_back(0); // placeholder for the last committed token all_tokens_to_verify.push_back(draft_mstate->committed_tokens.back().sampled_token_id.first); for (int j = 0; j < static_cast(draft_mstate->draft_output_tokens.size()); ++j) { all_tokens_to_verify.push_back(draft_mstate->draft_output_tokens[j].sampled_token_id.first); + draft_token_slots_.push_back(draft_mstate->draft_token_slots[j]); } verify_request_mstates.push_back(verify_mstate); generation_cfg.push_back(rsentries[i]->request->generation_cfg); rngs.push_back(&rsentries[i]->rng); draft_output_tokens.push_back(draft_mstate->draft_output_tokens); - draft_output_prob_dist.push_back(draft_mstate->draft_output_prob_dist); } + NDArray draft_probs_on_device = models_[draft_model_id_]->GatherDraftProbs( + model_workspaces_[verify_model_id_].draft_probs_storage, draft_token_slots_, + &model_workspaces_[verify_model_id_].draft_probs); RECORD_EVENT(trace_recorder_, request_ids, "start verify embedding"); ObjectRef embeddings = models_[verify_model_id_]->TokenEmbed( @@ -123,7 +130,7 @@ class BatchVerifyActionObj : public EngineActionObj { std::vector> sample_results_arr = sampler_->BatchVerifyDraftTokensWithProbAfterTopP( renormalized_probs, request_ids, cum_verify_lengths, generation_cfg, rngs, - draft_output_tokens, draft_output_prob_dist); + draft_output_tokens, draft_probs_on_device); ICHECK_EQ(sample_results_arr.size(), num_rsentries); for (int i = 0; i < num_rsentries; ++i) { @@ -149,7 +156,8 @@ class BatchVerifyActionObj : public EngineActionObj { // clear the draft model state entries for (int i = 0; i < num_rsentries; ++i) { - rsentries[i]->mstates[draft_model_id_]->RemoveAllDraftTokens(); + rsentries[i]->mstates[draft_model_id_]->RemoveAllDraftTokens(&draft_token_slots_); + draft_token_workspace_manager_->FreeSlots(draft_token_slots_); } auto tend = std::chrono::high_resolution_clock::now(); @@ -194,8 +202,8 @@ class BatchVerifyActionObj : public EngineActionObj { total_required_pages += num_require_pages; } while (!CanVerify(total_required_pages)) { - RequestStateEntry preempted = - PreemptLastRunningRequestStateEntry(estate, models_, trace_recorder_); + RequestStateEntry preempted = PreemptLastRunningRequestStateEntry( + estate, models_, draft_token_workspace_manager_, trace_recorder_); if (preempted.same_as(running_rsentries.back())) { total_verify_length -= verify_lengths.back(); total_required_pages -= num_page_requirement.back(); @@ -222,6 +230,10 @@ class BatchVerifyActionObj : public EngineActionObj { LogitProcessor logit_processor_; /*! \brief The sampler to sample new tokens. */ Sampler sampler_; + /*! \brief The model workspaces. */ + std::vector model_workspaces_; + /*! \brief The draft token workspace manager. */ + DraftTokenWorkspaceManager draft_token_workspace_manager_; /*! \brief The engine config. */ EngineConfig engine_config_; /*! \brief Event trace recorder. */ @@ -232,14 +244,20 @@ class BatchVerifyActionObj : public EngineActionObj { const int verify_model_id_ = 0; const int draft_model_id_ = 1; const float eps_ = 1e-5; + /*! \brief Temporary buffer to store the slots of the current draft tokens */ + std::vector draft_token_slots_; }; EngineAction EngineAction::BatchVerify(Array models, LogitProcessor logit_processor, - Sampler sampler, EngineConfig engine_config, + Sampler sampler, + std::vector model_workspaces, + DraftTokenWorkspaceManager draft_token_workspace_manager, + EngineConfig engine_config, Optional trace_recorder) { return EngineAction(make_object( - std::move(models), std::move(logit_processor), std::move(sampler), std::move(engine_config), - std::move(trace_recorder))); + std::move(models), std::move(logit_processor), std::move(sampler), + std::move(model_workspaces), std::move(draft_token_workspace_manager), + std::move(engine_config), std::move(trace_recorder))); } } // namespace serve diff --git a/cpp/serve/engine_actions/eagle_batch_draft.cc b/cpp/serve/engine_actions/eagle_batch_draft.cc index fde314a5c5..7ad66a045c 100644 --- a/cpp/serve/engine_actions/eagle_batch_draft.cc +++ b/cpp/serve/engine_actions/eagle_batch_draft.cc @@ -24,11 +24,13 @@ class EagleBatchDraftActionObj : public EngineActionObj { public: explicit EagleBatchDraftActionObj(Array models, LogitProcessor logit_processor, Sampler sampler, std::vector model_workspaces, + DraftTokenWorkspaceManager draft_token_workspace_manager, Optional trace_recorder, int draft_length) : models_(std::move(models)), logit_processor_(std::move(logit_processor)), sampler_(std::move(sampler)), model_workspaces_(std::move(model_workspaces)), + draft_token_workspace_manager_(std::move(draft_token_workspace_manager)), trace_recorder_(std::move(trace_recorder)), draft_length_(draft_length) { ICHECK_GT(draft_length_, 0); @@ -43,8 +45,8 @@ class EagleBatchDraftActionObj : public EngineActionObj { // Preempt request state entries when decode cannot apply. std::vector running_rsentries = GetRunningRequestStateEntries(estate); while (!CanDecode(running_rsentries.size())) { - RequestStateEntry preempted = - PreemptLastRunningRequestStateEntry(estate, models_, trace_recorder_); + RequestStateEntry preempted = PreemptLastRunningRequestStateEntry( + estate, models_, draft_token_workspace_manager_, trace_recorder_); if (preempted.same_as(running_rsentries.back())) { running_rsentries.pop_back(); } @@ -81,21 +83,20 @@ class EagleBatchDraftActionObj : public EngineActionObj { mstates.push_back(rsentry->mstates[model_id]); } // draft_length_ rounds of draft proposal. - NDArray hidden_states_nd{nullptr}; ObjectRef last_hidden_states{nullptr}; - ObjectRef hidden_states = model_workspaces_[model_id].hidden_states; + NDArray hidden_states = Downcast(model_workspaces_[model_id].hidden_states); // Concat last hidden_states - std::vector previous_hidden_on_device; - for (int i = 0; i < num_rsentries; ++i) { - previous_hidden_on_device.push_back(mstates[i]->draft_last_hidden_on_device.back()); + draft_token_slots_.clear(); + if (draft_length_ > 1) { + for (int i = 0; i < num_rsentries; ++i) { + draft_token_slots_.push_back(mstates[i]->draft_token_slots.back()); + } + hidden_states = Downcast(models_[model_id]->GatherHiddenStates( + model_workspaces_[0].draft_hidden_states_storage, draft_token_slots_, &hidden_states)); + ICHECK(hidden_states->ndim == 2); + last_hidden_states = hidden_states.CreateView( + {hidden_states->shape[0], 1, hidden_states->shape[1]}, hidden_states->dtype); } - hidden_states_nd = - models_[model_id]->ConcatLastHidden(previous_hidden_on_device, &hidden_states); - ICHECK_EQ(hidden_states_nd->ndim, 2); - ICHECK_EQ(hidden_states_nd->shape[0], num_rsentries); - hidden_states_nd = hidden_states_nd.CreateView( - {hidden_states_nd->shape[0], 1, hidden_states_nd->shape[1]}, hidden_states_nd->dtype); - last_hidden_states = hidden_states_nd; // The first draft token has been generated in prefill/verify stage for (int draft_id = 1; draft_id < draft_length_; ++draft_id) { // prepare new input tokens @@ -115,17 +116,17 @@ class EagleBatchDraftActionObj : public EngineActionObj { RECORD_EVENT(trace_recorder_, request_ids, "start proposal decode"); ObjectRef fused_hidden_states = models_[model_id]->FuseEmbedHidden( embeddings, last_hidden_states, /*batch_size*/ num_rsentries, /*seq_len*/ 1); - hidden_states_nd = + hidden_states = models_[model_id]->BatchDecodeToLastHidden(fused_hidden_states, request_internal_ids); - last_hidden_states = hidden_states_nd; + last_hidden_states = hidden_states; NDArray logits; if (models_[model_id]->CanGetLogits()) { - logits = models_[model_id]->GetLogits(hidden_states_nd, /*batch_size*/ num_rsentries, + logits = models_[model_id]->GetLogits(hidden_states, /*batch_size*/ num_rsentries, /*seq_len*/ 1); } else { // - Use base model's head. logits = - models_[0]->GetLogits(hidden_states_nd, /*batch_size*/ num_rsentries, /*seq_len*/ 1); + models_[0]->GetLogits(hidden_states, /*batch_size*/ num_rsentries, /*seq_len*/ 1); } RECORD_EVENT(trace_recorder_, request_ids, "finish proposal decode"); ICHECK_EQ(logits->ndim, 3); @@ -152,12 +153,12 @@ class EagleBatchDraftActionObj : public EngineActionObj { ICHECK_EQ(sample_results.size(), num_rsentries); // - Add draft token to the state. + draft_token_workspace_manager_->AllocSlots(num_rsentries, &draft_token_slots_); + models_[model_id]->ScatterDraftProbs(probs_on_device, draft_token_slots_, + &model_workspaces_[0].draft_probs_storage); + // No need to save hidden states as they are not used by subsequent engine actions for (int i = 0; i < num_rsentries; ++i) { - // - Slice hidden_states_for_sample - NDArray last_hidden_on_device = GetTokenHidden(hidden_states_nd, i); - CHECK(i < static_cast(prob_dist.size())); - CHECK(prob_dist[i].defined()); - mstates[i]->AddDraftToken(sample_results[i], prob_dist[i], last_hidden_on_device); + mstates[i]->AddDraftToken(sample_results[i], draft_token_slots_[i]); estate->stats.total_draft_length += 1; } } @@ -183,26 +184,6 @@ class EagleBatchDraftActionObj : public EngineActionObj { return true; } - /*! - * \brief Get one item from a hidden_states array, which corresponds to the last token. - * \param hidden_states The hidden_states of all the tokens. - * \param token_pos The desired token position in the sequence. - * \return The desired token's hidden_states - */ - NDArray GetTokenHidden(NDArray hidden_states, int token_pos) { - ICHECK_EQ(hidden_states->ndim, 3); - NDArray last_hidden_on_device = - NDArray::Empty({hidden_states->shape[2]}, hidden_states->dtype, hidden_states->device); - - int64_t ndata = hidden_states->shape[2]; - const int16_t* __restrict p_hidden = - static_cast(__builtin_assume_aligned(hidden_states->data, 2)) + - (token_pos * ndata); - - last_hidden_on_device.CopyFromBytes(p_hidden, ndata * sizeof(int16_t)); - return last_hidden_on_device; - } - /*! \brief The model to run draft generation in speculative decoding. */ Array models_; /*! \brief The logit processor. */ @@ -211,20 +192,26 @@ class EagleBatchDraftActionObj : public EngineActionObj { Sampler sampler_; /*! \brief Workspace of each model. */ std::vector model_workspaces_; + /*! \brief The draft token workspace manager. */ + DraftTokenWorkspaceManager draft_token_workspace_manager_; /*! \brief Event trace recorder. */ Optional trace_recorder_; /*! \brief Draft proposal length */ int draft_length_; + /*! \brief Temporary buffer to store the slots of the current draft tokens */ + std::vector draft_token_slots_; }; EngineAction EngineAction::EagleBatchDraft(Array models, LogitProcessor logit_processor, Sampler sampler, std::vector model_workspaces, + DraftTokenWorkspaceManager draft_token_workspace_manager, Optional trace_recorder, int draft_length) { return EngineAction(make_object( std::move(models), std::move(logit_processor), std::move(sampler), - std::move(model_workspaces), std::move(trace_recorder), draft_length)); + std::move(model_workspaces), std::move(draft_token_workspace_manager), + std::move(trace_recorder), draft_length)); } } // namespace serve diff --git a/cpp/serve/engine_actions/eagle_batch_verify.cc b/cpp/serve/engine_actions/eagle_batch_verify.cc index b259417050..d52f60d5c7 100644 --- a/cpp/serve/engine_actions/eagle_batch_verify.cc +++ b/cpp/serve/engine_actions/eagle_batch_verify.cc @@ -29,12 +29,14 @@ class EagleBatchVerifyActionObj : public EngineActionObj { public: explicit EagleBatchVerifyActionObj(Array models, LogitProcessor logit_processor, Sampler sampler, std::vector model_workspaces, + DraftTokenWorkspaceManager draft_token_workspace_manager, EngineConfig engine_config, Optional trace_recorder) : models_(std::move(models)), logit_processor_(std::move(logit_processor)), sampler_(std::move(sampler)), model_workspaces_(std::move(model_workspaces)), + draft_token_workspace_manager_(std::move(draft_token_workspace_manager)), engine_config_(std::move(engine_config)), trace_recorder_(std::move(trace_recorder)), rng_(RandomGenerator::GetInstance()) {} @@ -70,7 +72,7 @@ class EagleBatchVerifyActionObj : public EngineActionObj { rngs.reserve(num_rsentries); generation_cfg.reserve(num_rsentries); draft_output_tokens.reserve(num_rsentries); - draft_output_prob_dist.reserve(num_rsentries); + draft_token_slots_.clear(); for (int i = 0; i < num_rsentries; ++i) { RequestModelState verify_mstate = rsentries[i]->mstates[verify_model_id_]; @@ -78,19 +80,24 @@ class EagleBatchVerifyActionObj : public EngineActionObj { request_internal_ids.push_back(verify_mstate->internal_id); ICHECK(!draft_lengths.empty()); ICHECK_EQ(draft_lengths[i], draft_mstate->draft_output_tokens.size()); - ICHECK_EQ(draft_lengths[i], draft_mstate->draft_output_prob_dist.size()); + ICHECK_EQ(draft_lengths[i], draft_mstate->draft_token_slots.size()); // the last committed token + all the draft tokens but the last one. all_tokens_to_verify.push_back(draft_mstate->committed_tokens.back().sampled_token_id.first); + draft_token_slots_.push_back(0); // placeholder for the last committed token for (int j = 0; j < static_cast(draft_mstate->draft_output_tokens.size()); ++j) { all_tokens_to_verify.push_back(draft_mstate->draft_output_tokens[j].sampled_token_id.first); + draft_token_slots_.push_back(draft_mstate->draft_token_slots[j]); } verify_request_mstates.push_back(verify_mstate); generation_cfg.push_back(rsentries[i]->request->generation_cfg); rngs.push_back(&rsentries[i]->rng); draft_output_tokens.push_back(draft_mstate->draft_output_tokens); - draft_output_prob_dist.push_back(draft_mstate->draft_output_prob_dist); } + NDArray draft_probs_on_device = models_[draft_model_id_]->GatherDraftProbs( + model_workspaces_[verify_model_id_].draft_probs_storage, draft_token_slots_, + &model_workspaces_[verify_model_id_].draft_probs); + std::vector cum_verify_lengths = {0}; cum_verify_lengths.reserve(num_rsentries + 1); std::vector verify_lengths; @@ -135,10 +142,11 @@ class EagleBatchVerifyActionObj : public EngineActionObj { std::vector> sample_results_arr = sampler_->BatchVerifyDraftTokensWithProbAfterTopP( renormalized_probs, request_ids, cum_verify_lengths, generation_cfg, rngs, - draft_output_tokens, draft_output_prob_dist); + draft_output_tokens, draft_probs_on_device); ICHECK_EQ(sample_results_arr.size(), num_rsentries); - std::vector last_hidden_states; + std::vector last_accepted_hidden_positions; + last_accepted_hidden_positions.reserve(num_rsentries); for (int i = 0; i < num_rsentries; ++i) { const std::vector& sample_results = sample_results_arr[i]; int accept_length = sample_results.size(); @@ -163,24 +171,24 @@ class EagleBatchVerifyActionObj : public EngineActionObj { rsentries[i]->mstates[draft_model_id_]->internal_id, rollback_length - 1); } // clear the draft model state entries - rsentries[i]->mstates[draft_model_id_]->RemoveAllDraftTokens(); - // - Slice hidden_states_for_sample - NDArray last_hidden_on_device = - GetTokenHidden(hidden_states, (cum_verify_lengths[i] + accept_length - 1)); - last_hidden_states.push_back(last_hidden_on_device); + rsentries[i]->mstates[draft_model_id_]->RemoveAllDraftTokens(&draft_token_slots_); + draft_token_workspace_manager_->FreeSlots(draft_token_slots_); + // - Slice and save hidden_states_for_sample + last_accepted_hidden_positions.push_back(cum_verify_lengths[i] + accept_length - 1); } { // One step draft for the following steps - NDArray hidden_states_nd{nullptr}; - ObjectRef next_hidden_states = model_workspaces_[draft_model_id_].hidden_states; - // Concat last hidden_states - hidden_states_nd = - models_[draft_model_id_]->ConcatLastHidden(last_hidden_states, &next_hidden_states); - ICHECK_EQ(hidden_states_nd->ndim, 2); - ICHECK_EQ(hidden_states_nd->shape[0], num_rsentries); - hidden_states_nd = hidden_states_nd.CreateView( - {hidden_states_nd->shape[0], 1, hidden_states_nd->shape[1]}, hidden_states_nd->dtype); + NDArray last_hidden_states_nd = hidden_states.CreateView( + {hidden_states->shape[0] * hidden_states->shape[1], hidden_states->shape[2]}, + hidden_states->dtype); + + hidden_states = Downcast(models_[draft_model_id_]->GatherHiddenStates( + last_hidden_states_nd, last_accepted_hidden_positions, + &model_workspaces_[draft_model_id_].hidden_states)); + ICHECK(hidden_states->ndim == 2); + hidden_states = hidden_states.CreateView( + {hidden_states->shape[0], 1, hidden_states->shape[1]}, hidden_states->dtype); std::vector input_tokens; Array mstates; @@ -203,17 +211,16 @@ class EagleBatchVerifyActionObj : public EngineActionObj { // - Invoke model decode. RECORD_EVENT(trace_recorder_, request_ids, "start proposal decode"); ObjectRef fused_hidden_states = models_[draft_model_id_]->FuseEmbedHidden( - embeddings, hidden_states_nd, /*batch_size*/ num_rsentries, /*seq_len*/ 1); - hidden_states_nd = models_[draft_model_id_]->BatchDecodeToLastHidden(fused_hidden_states, - request_internal_ids); + embeddings, hidden_states, /*batch_size*/ num_rsentries, /*seq_len*/ 1); + hidden_states = models_[draft_model_id_]->BatchDecodeToLastHidden(fused_hidden_states, + request_internal_ids); if (models_[draft_model_id_]->CanGetLogits()) { - logits = models_[draft_model_id_]->GetLogits(hidden_states_nd, /*batch_size*/ num_rsentries, + logits = models_[draft_model_id_]->GetLogits(hidden_states, /*batch_size*/ num_rsentries, /*seq_len*/ 1); } else { // - Use base model's head. - logits = - models_[0]->GetLogits(hidden_states_nd, /*batch_size*/ num_rsentries, /*seq_len*/ 1); + logits = models_[0]->GetLogits(hidden_states, /*batch_size*/ num_rsentries, /*seq_len*/ 1); } RECORD_EVENT(trace_recorder_, request_ids, "finish proposal decode"); ICHECK_EQ(logits->ndim, 3); @@ -239,13 +246,21 @@ class EagleBatchVerifyActionObj : public EngineActionObj { renormalized_probs, sample_indices, request_ids, generation_cfg, rngs, &prob_dist); ICHECK_EQ(sample_results.size(), num_rsentries); + // - Slice and save hidden_states_for_sample + draft_token_workspace_manager_->AllocSlots(num_rsentries, &draft_token_slots_); + models_[draft_model_id_]->ScatterDraftProbs( + renormalized_probs, draft_token_slots_, + &model_workspaces_[verify_model_id_].draft_probs_storage); + ICHECK(hidden_states->ndim == 3); + hidden_states = hidden_states.CreateView( + {hidden_states->shape[0] * hidden_states->shape[1], hidden_states->shape[2]}, + hidden_states->dtype); + models_[draft_model_id_]->ScatterHiddenStates( + hidden_states, draft_token_slots_, + &model_workspaces_[verify_model_id_].draft_hidden_states_storage); // - Add draft token to the state. for (int i = 0; i < num_rsentries; ++i) { - // - Slice hidden_states_for_sample - NDArray last_hidden_on_device = GetTokenHidden(hidden_states_nd, i); - CHECK(i < static_cast(prob_dist.size())); - CHECK(prob_dist[i].defined()); - mstates[i]->AddDraftToken(sample_results[i], prob_dist[i], last_hidden_on_device); + mstates[i]->AddDraftToken(sample_results[i], draft_token_slots_[i]); estate->stats.total_draft_length += 1; } } @@ -292,8 +307,8 @@ class EagleBatchVerifyActionObj : public EngineActionObj { total_required_pages += num_require_pages; } while (!CanVerify(total_required_pages)) { - RequestStateEntry preempted = - PreemptLastRunningRequestStateEntry(estate, models_, trace_recorder_); + RequestStateEntry preempted = PreemptLastRunningRequestStateEntry( + estate, models_, draft_token_workspace_manager_, trace_recorder_); if (preempted.same_as(running_rsentries.back())) { total_draft_length -= draft_lengths.back(); total_required_pages -= num_page_requirement.back(); @@ -342,6 +357,8 @@ class EagleBatchVerifyActionObj : public EngineActionObj { Sampler sampler_; /*! \brief Workspace of each model. */ std::vector model_workspaces_; + /*! \brief The draft token workspace manager. */ + DraftTokenWorkspaceManager draft_token_workspace_manager_; /*! \brief The engine config. */ EngineConfig engine_config_; /*! \brief Event trace recorder. */ @@ -352,16 +369,19 @@ class EagleBatchVerifyActionObj : public EngineActionObj { const int verify_model_id_ = 0; const int draft_model_id_ = 1; const float eps_ = 1e-5; + /*! \brief Temporary buffer to store the slots of the current draft tokens */ + std::vector draft_token_slots_; }; -EngineAction EngineAction::EagleBatchVerify(Array models, LogitProcessor logit_processor, - Sampler sampler, - std::vector model_workspaces, - EngineConfig engine_config, - Optional trace_recorder) { +EngineAction EngineAction::EagleBatchVerify( + Array models, LogitProcessor logit_processor, Sampler sampler, + std::vector model_workspaces, + DraftTokenWorkspaceManager draft_token_workspace_manager, EngineConfig engine_config, + Optional trace_recorder) { return EngineAction(make_object( std::move(models), std::move(logit_processor), std::move(sampler), - std::move(model_workspaces), std::move(engine_config), std::move(trace_recorder))); + std::move(model_workspaces), std::move(draft_token_workspace_manager), + std::move(engine_config), std::move(trace_recorder))); } } // namespace serve diff --git a/cpp/serve/engine_actions/eagle_new_request_prefill.cc b/cpp/serve/engine_actions/eagle_new_request_prefill.cc index a687e7eb7f..57310f7986 100644 --- a/cpp/serve/engine_actions/eagle_new_request_prefill.cc +++ b/cpp/serve/engine_actions/eagle_new_request_prefill.cc @@ -24,12 +24,14 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { explicit EagleNewRequestPrefillActionObj(Array models, LogitProcessor logit_processor, Sampler sampler, std::vector model_workspaces, + DraftTokenWorkspaceManager draft_token_workspace_manager, EngineConfig engine_config, Optional trace_recorder) : models_(std::move(models)), logit_processor_(std::move(logit_processor)), sampler_(std::move(sampler)), model_workspaces_(std::move(model_workspaces)), + draft_token_workspace_manager_(std::move(draft_token_workspace_manager)), engine_config_(std::move(engine_config)), trace_recorder_(std::move(trace_recorder)) {} @@ -107,7 +109,7 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { } ICHECK(mstate->draft_output_tokens.empty()); - ICHECK(mstate->draft_output_prob_dist.empty()); + ICHECK(mstate->draft_token_slots.empty()); if (status_before_prefill[i] == RequestStateStatus::kPending) { // Add the sequence to the model, or fork the sequence from its parent. if (rsentry->parent_idx == -1) { @@ -286,8 +288,8 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { // - Update the committed tokens of states. // - If a request is first-time prefilled, set the prefill finish time. auto tnow = std::chrono::high_resolution_clock::now(); - for (int i = 0; i < static_cast(rsentries_for_sample.size()); ++i) { - if (model_id == 0) { + if (model_id == 0) { + for (int i = 0; i < static_cast(rsentries_for_sample.size()); ++i) { for (int mid = 0; mid < static_cast(models_.size()); ++mid) { rsentries_for_sample[i]->mstates[mid]->CommitToken(sample_results[i]); if (!rsentry_activated[i]) { @@ -301,13 +303,24 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { if (rsentries_for_sample[i]->mstates[0]->committed_tokens.size() == 1) { rsentries_for_sample[i]->tprefill_finish = tnow; } - } else { - // - Slice hidden_states_for_sample - NDArray last_hidden_on_device = GetTokenHidden(hidden_states_for_sample, i); - CHECK(i < static_cast(prob_dist.size())); - CHECK(prob_dist[i].defined()); - rsentries_for_sample[i]->mstates[model_id]->AddDraftToken(sample_results[i], prob_dist[i], - last_hidden_on_device); + } + } else { + // - Slice and save hidden_states_for_sample + draft_token_workspace_manager_->AllocSlots(rsentries_for_sample.size(), + &draft_token_slots_); + models_[model_id]->ScatterDraftProbs(renormalized_probs, draft_token_slots_, + &model_workspaces_[0].draft_probs_storage); + if (engine_config_->spec_draft_length > 1) { + hidden_states_for_sample = hidden_states_for_sample.CreateView( + {hidden_states_for_sample->shape[0] * hidden_states_for_sample->shape[1], + hidden_states_for_sample->shape[2]}, + hidden_states_for_sample->dtype); + models_[model_id]->ScatterHiddenStates(hidden_states_for_sample, draft_token_slots_, + &model_workspaces_[0].draft_hidden_states_storage); + } + for (int i = 0; i < static_cast(rsentries_for_sample.size()); ++i) { + rsentries_for_sample[i]->mstates[model_id]->AddDraftToken(sample_results[i], + draft_token_slots_[i]); estate->stats.total_draft_length += 1; } } @@ -582,20 +595,25 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { Sampler sampler_; /*! \brief Workspace of each model. */ std::vector model_workspaces_; + /*! \brief The draft token workspace manager. */ + DraftTokenWorkspaceManager draft_token_workspace_manager_; /*! \brief The engine config. */ EngineConfig engine_config_; /*! \brief Event trace recorder. */ Optional trace_recorder_; + /*! \brief Temporary buffer to store the slots of the current draft tokens */ + std::vector draft_token_slots_; }; -EngineAction EngineAction::EagleNewRequestPrefill(Array models, - LogitProcessor logit_processor, Sampler sampler, - std::vector model_workspaces, - EngineConfig engine_config, - Optional trace_recorder) { +EngineAction EngineAction::EagleNewRequestPrefill( + Array models, LogitProcessor logit_processor, Sampler sampler, + std::vector model_workspaces, + DraftTokenWorkspaceManager draft_token_workspace_manager, EngineConfig engine_config, + Optional trace_recorder) { return EngineAction(make_object( std::move(models), std::move(logit_processor), std::move(sampler), - std::move(model_workspaces), std::move(engine_config), std::move(trace_recorder))); + std::move(model_workspaces), std::move(draft_token_workspace_manager), + std::move(engine_config), std::move(trace_recorder))); } } // namespace serve diff --git a/cpp/serve/engine_actions/new_request_prefill.cc b/cpp/serve/engine_actions/new_request_prefill.cc index c80c5e0ede..f801b1e282 100644 --- a/cpp/serve/engine_actions/new_request_prefill.cc +++ b/cpp/serve/engine_actions/new_request_prefill.cc @@ -100,7 +100,7 @@ class NewRequestPrefillActionObj : public EngineActionObj { } ICHECK(mstate->draft_output_tokens.empty()); - ICHECK(mstate->draft_output_prob_dist.empty()); + ICHECK(mstate->draft_token_slots.empty()); if (status_before_prefill[i] == RequestStateStatus::kPending) { // Add the sequence to the model, or fork the sequence from its parent. if (rsentry->parent_idx == -1) { @@ -396,6 +396,11 @@ class NewRequestPrefillActionObj : public EngineActionObj { int num_running_rsentries) { ICHECK_LE(num_running_rsentries, engine_config_->max_num_sequence); + // For RNN State, it can prefill as long as it can be instantiated. + if (engine_config_->kv_state_kind == KVStateKind::kRNNState) { + return true; + } + // No exceeding of the maximum allowed requests that can // run simultaneously. int spec_factor = engine_config_->speculative_mode != SpeculativeMode::kDisable diff --git a/cpp/serve/event_trace_recorder.cc b/cpp/serve/event_trace_recorder.cc index 8a930002fe..e0311716fd 100644 --- a/cpp/serve/event_trace_recorder.cc +++ b/cpp/serve/event_trace_recorder.cc @@ -51,7 +51,7 @@ class EventTraceRecorderImpl : public EventTraceRecorderObj { void AddEvent(const Array& request_ids, const std::string& event) final { double event_time = std::chrono::duration_cast>( std::chrono::system_clock::now().time_since_epoch()) - .count(); + .count(); // in seconds { std::lock_guard lock(mutex_); @@ -96,16 +96,16 @@ class EventTraceRecorderImpl : public EventTraceRecorderObj { name = event; phase = "i"; } - int64_t event_time_in_ms = static_cast(event_time * 1e6); + int64_t event_time_in_us = static_cast(event_time * 1e6); picojson::object event_json; event_json["name"] = picojson::value(name); event_json["ph"] = picojson::value(phase); - event_json["ts"] = picojson::value(event_time_in_ms); + event_json["ts"] = picojson::value(event_time_in_us); event_json["pid"] = picojson::value(static_cast(1)); event_json["tid"] = picojson::value(request_id); - events_to_sort.push_back({event_time_in_ms, picojson::value(event_json)}); + events_to_sort.push_back({event_time_in_us, picojson::value(event_json)}); } std::sort(events_to_sort.begin(), events_to_sort.end(), fcmp_events); for (auto [timestamp, event] : events_to_sort) { diff --git a/cpp/serve/function_table.cc b/cpp/serve/function_table.cc index b33d3709e8..4e0301eb2d 100644 --- a/cpp/serve/function_table.cc +++ b/cpp/serve/function_table.cc @@ -69,7 +69,8 @@ PackedFunc FunctionTable::SessionFuncAsPackedFunc(Session sess, DRef sess_func, }); } -void FunctionTable::Init(String reload_lib_path, Device device, picojson::object model_config) { +void FunctionTable::Init(String reload_lib_path, Device device, picojson::object model_config, + Optional session) { local_gpu_device = device; Device null_device{DLDeviceType(0), 0}; int num_shards; @@ -85,33 +86,14 @@ void FunctionTable::Init(String reload_lib_path, Device device, picojson::object this->cached_buffers = Map(); if (num_shards > 1) { - constexpr const char* f_create_process_pool = "runtime.disco.create_process_pool"; - if (Registry::Get(f_create_process_pool) == nullptr) { - LOG(FATAL) << "Cannot find process launcher `" << f_create_process_pool << "`. " - << "Multi-GPU inference depends on MLC LLM Python API to launch process."; - } - std::string ccl; - if (device.device_type == kDLCUDA) { - ccl = "nccl"; - } else if (device.device_type == kDLROCM) { - ccl = "rccl"; - } else { - LOG(FATAL) << "ValueError: Multi-GPU on device " << DLDeviceType2Str(device.device_type) - << " is not supported. Currently, only NCCL and RCCL are integrated."; - } - std::vector device_ids(num_shards); - for (int i = 0; i < num_shards; ++i) { - device_ids[i] = i; - } + this->sess = session.value(); this->use_disco = true; - this->sess = Session::ProcessSession(num_shards, f_create_process_pool, "mlc_llm.cli.worker"); - this->sess->InitCCL(ccl, ShapeTuple(device_ids)); this->disco_mod = sess->CallPacked(sess->GetGlobalFunc("runtime.disco.load_vm_module"), reload_lib_path, null_device); this->mod_get_func = [this, fmodule_get_function = sess->GetGlobalFunc("runtime.ModuleGetFunction")]( const std::string& name) -> PackedFunc { - DRef func = sess->CallPacked(fmodule_get_function, this->disco_mod, name, false); + DRef func = sess->CallPacked(fmodule_get_function, this->disco_mod, name, true); bool exists = (func->DebugGetFromRemote(0).operator PackedFunc()) != nullptr; if (!exists) { return PackedFunc(nullptr); @@ -244,7 +226,12 @@ void FunctionTable::_InitFunctions() { this->alloc_embedding_tensor_func_ = mod_get_func("alloc_embedding_tensor"); this->create_kv_cache_func_ = mod_get_func("create_flashinfer_paged_kv_cache"); if (!this->create_kv_cache_func_.defined()) { - this->create_kv_cache_func_ = mod_get_func("create_tir_paged_kv_cache"); + PackedFunc f_create_rnn_state = mod_get_func("create_rnn_state"); + if (f_create_rnn_state.defined()) { + this->create_kv_cache_func_ = f_create_rnn_state; + } else { + this->create_kv_cache_func_ = mod_get_func("create_tir_paged_kv_cache"); + } ICHECK(this->create_kv_cache_func_.defined()); } this->reset_kv_cache_func_ = get_global_func("vm.builtin.kv_state_clear"); @@ -272,6 +259,11 @@ void FunctionTable::_InitFunctions() { this->nd_get_shape_func_ = get_global_func("vm.builtin.shape_of"); this->nd_copy_embedding_to_offset_func_ = get_global_func("mlc.copy_embedding_to_offset"); support_backtracking_kv_ = true; + + this->gather_probs_func_ = mod->GetFunction("gather_probs", true); + this->scatter_probs_func_ = mod->GetFunction("scatter_probs", true); + this->gather_hidden_states_func_ = mod->GetFunction("gather_hidden_states", true); + this->scatter_hidden_states_func_ = mod->GetFunction("scatter_hidden_states", true); } ObjectRef FunctionTable::Empty(ShapeTuple shape, DataType dtype, Device device) const { @@ -285,8 +277,8 @@ ObjectRef FunctionTable::Empty(ShapeTuple shape, DataType dtype, Device device) } ObjectRef FunctionTable::CopyToWorker0(const NDArray& host_array, String buffer_cache_key, - ShapeTuple max_reserved_shape) { - if (this->use_disco) { + ShapeTuple max_reserved_shape, bool local_only) { + if (this->use_disco && !local_only) { Device null_device{DLDeviceType(0), 0}; DRef buffer(nullptr); auto it = this->cached_buffers.find(buffer_cache_key); diff --git a/cpp/serve/function_table.h b/cpp/serve/function_table.h index b6ea3287ad..e368edcb9c 100644 --- a/cpp/serve/function_table.h +++ b/cpp/serve/function_table.h @@ -41,7 +41,8 @@ using namespace tvm::runtime; struct FunctionTable { static PackedFunc SessionFuncAsPackedFunc(Session sess, DRef sess_func, String name); - void Init(String reload_lib_path, Device device, picojson::object model_config); + void Init(String reload_lib_path, Device device, picojson::object model_config, + Optional session); ObjectRef LoadParams(const std::string& model_path, Device device); @@ -49,8 +50,18 @@ struct FunctionTable { ObjectRef Empty(ShapeTuple shape, DataType dtype, Device device) const; + /*! + * \brief Copy a host array to the worker or local gpu. + * \param host_array The host array to be copied. + * \param buffer_cache_key The key to the buffer cache. + * \param max_reserved_shape The maximum shape to be reserved in the buffer cache. + * \param local_only Whether to copy the array to the local gpu only. If true, the use_disco + * flag will be ignored. This can be useful for functions that run only on the + * local gpu when disco is enabled. + * \return The array on the worker or local gpu. + */ ObjectRef CopyToWorker0(const NDArray& host_array, String buffer_cache_key, - ShapeTuple max_reserved_shape); + ShapeTuple max_reserved_shape, bool local_only = false); void DebugCallFuncOnAllAllWorker(const String& func_name) const; @@ -109,6 +120,11 @@ struct FunctionTable { PackedFunc nd_view_func_; PackedFunc nd_get_shape_func_; PackedFunc nd_copy_embedding_to_offset_func_; + // Auxiliary functions for speculative decoding. + PackedFunc gather_probs_func_; + PackedFunc scatter_probs_func_; + PackedFunc gather_hidden_states_func_; + PackedFunc scatter_hidden_states_func_; }; } // namespace serve diff --git a/cpp/serve/grammar/grammar_parser.cc b/cpp/serve/grammar/grammar_parser.cc index 1ece99099e..55ab0a1dff 100644 --- a/cpp/serve/grammar/grammar_parser.cc +++ b/cpp/serve/grammar/grammar_parser.cc @@ -156,14 +156,14 @@ int32_t EBNFParserImpl::ParseCharacterClass() { continue; } - auto [codepoint, len] = Utf8OrEscapeToCodepoint(cur_, kCustomEscapeMap); + auto [codepoint, new_cur] = ParseNextUTF8OrEscaped(cur_, kCustomEscapeMap); if (codepoint == static_cast(CharHandlingError::kInvalidUtf8)) { - ThrowParseError("Invalid utf8 sequence"); + ThrowParseError("Invalid UTF8 sequence"); } if (codepoint == static_cast(CharHandlingError::kInvalidEscape)) { ThrowParseError("Invalid escape sequence"); } - Consume(len); + Consume(new_cur - cur_); if (past_is_hyphen) { ICHECK(!elements.empty()); if (elements.back().lower > codepoint) { @@ -194,14 +194,15 @@ int32_t EBNFParserImpl::ParseString() { if (Peek() == '\r' || Peek() == '\n') { ThrowParseError("There should be no newline character in a string literal"); } - auto [codepoint, len] = Utf8OrEscapeToCodepoint(cur_); + + auto [codepoint, new_cur] = ParseNextUTF8OrEscaped(cur_); if (codepoint == static_cast(CharHandlingError::kInvalidUtf8)) { ThrowParseError("Invalid utf8 sequence"); } if (codepoint == static_cast(CharHandlingError::kInvalidEscape)) { ThrowParseError("Invalid escape sequence"); } - Consume(len); + Consume(new_cur - cur_); character_classes.push_back(builder_.AddCharacterClass({{codepoint, codepoint}})); } if (character_classes.empty()) { diff --git a/cpp/serve/grammar/grammar_serializer.cc b/cpp/serve/grammar/grammar_serializer.cc index fd41517863..c3c2c88baa 100644 --- a/cpp/serve/grammar/grammar_serializer.cc +++ b/cpp/serve/grammar/grammar_serializer.cc @@ -59,12 +59,12 @@ std::string BNFGrammarPrinter::PrintCharacterClass(const RuleExpr& rule_expr) { result += "^"; } for (auto i = 0; i < rule_expr.data_len; i += 2) { - result += CodepointToPrintable(rule_expr[i], kCustomEscapeMap); + result += PrintAsEscaped(rule_expr[i], kCustomEscapeMap); if (rule_expr[i] == rule_expr[i + 1]) { continue; } result += "-"; - result += CodepointToPrintable(rule_expr[i + 1], kCustomEscapeMap); + result += PrintAsEscaped(rule_expr[i + 1], kCustomEscapeMap); } result += "]"; return result; diff --git a/cpp/serve/grammar/grammar_state_matcher.cc b/cpp/serve/grammar/grammar_state_matcher.cc index 5c4ef98efe..451127e746 100644 --- a/cpp/serve/grammar/grammar_state_matcher.cc +++ b/cpp/serve/grammar/grammar_state_matcher.cc @@ -510,7 +510,7 @@ TVM_REGISTER_GLOBAL("mlc.serve.GrammarStateMatcherResetState") bool MatchCompleteString(GrammarStateMatcher matcher, String str) { auto mutable_node = const_cast(matcher.as()); - auto codepoints = Utf8StringToCodepoints(str.c_str()); + auto codepoints = ParseUTF8(str.c_str()); int accepted_cnt = 0; for (auto codepoint : codepoints) { if (!mutable_node->AcceptCodepoint(codepoint, false)) { @@ -553,9 +553,9 @@ void PrintAcceptedRejectedTokens( // First cast to unsigned, then cast to int std::cerr << static_cast(static_cast(token[0])); } else { - auto codepoints = Utf8StringToCodepoints(token.c_str()); + auto codepoints = ParseUTF8(token.c_str()); for (auto c : codepoints) { - std::cerr << CodepointToPrintable(c); + std::cerr << PrintAsEscaped(c); } } std::cerr << "> "; @@ -571,9 +571,9 @@ void PrintAcceptedRejectedTokens( if (token.size() == 1 && ((unsigned char)token[0] >= 128 || token[0] == 0)) { std::cerr << (int)(unsigned char)token[0]; } else { - auto codepoints = Utf8StringToCodepoints(token.c_str()); + auto codepoints = ParseUTF8(token.c_str()); for (auto c : codepoints) { - std::cerr << CodepointToPrintable(c); + std::cerr << PrintAsEscaped(c); } } std::cerr << "> "; diff --git a/cpp/serve/grammar/grammar_state_matcher_base.h b/cpp/serve/grammar/grammar_state_matcher_base.h index 55c986bb10..5b774d33a4 100644 --- a/cpp/serve/grammar/grammar_state_matcher_base.h +++ b/cpp/serve/grammar/grammar_state_matcher_base.h @@ -156,15 +156,15 @@ inline bool GrammarStateMatcherBase::AcceptCodepoint(TCodepoint codepoint, bool } if (tmp_new_stack_tops_.empty()) { if (verbose) { - std::cout << "Codepoint: " << codepoint << " \"" << CodepointToPrintable(codepoint) - << "\" Rejected" << std::endl; + std::cout << "Codepoint: " << codepoint << " \"" << PrintAsEscaped(codepoint) << "\" Rejected" + << std::endl; } return false; } stack_tops_history_.PushHistory(tmp_new_stack_tops_); if (verbose) { - std::cout << "Codepoint: " << codepoint << " \"" << CodepointToPrintable(codepoint) - << "\" Accepted" << std::endl; + std::cout << "Codepoint: " << codepoint << " \"" << PrintAsEscaped(codepoint) << "\" Accepted" + << std::endl; std::cout << "Stack after accepting: " << PrintStackState() << std::endl; } #if TVM_LOG_DEBUG diff --git a/cpp/serve/grammar/grammar_state_matcher_preproc.h b/cpp/serve/grammar/grammar_state_matcher_preproc.h index c853ac7e04..f63eee2c5c 100644 --- a/cpp/serve/grammar/grammar_state_matcher_preproc.h +++ b/cpp/serve/grammar/grammar_state_matcher_preproc.h @@ -268,7 +268,7 @@ inline std::shared_ptr GrammarStateMatcher::CreateInitC ptr->special_token_ids.push_back(i); } else { // First replace the special underscore with space. - auto codepoints = Utf8StringToCodepoints(token.c_str()); + auto codepoints = ParseUTF8(token.c_str()); DCHECK(!codepoints.empty() && codepoints[0] != static_cast(CharHandlingError::kInvalidUtf8)) << "Invalid token: " << token; diff --git a/cpp/serve/logit_processor.cc b/cpp/serve/logit_processor.cc index f7190d50ac..7ce70a0d26 100644 --- a/cpp/serve/logit_processor.cc +++ b/cpp/serve/logit_processor.cc @@ -289,7 +289,7 @@ class LogitProcessorImpl : public LogitProcessorObj { p_penalties[num_token_for_penalty * 3 + 2] = generation_cfg[i]->repetition_penalty; ++num_token_for_penalty; if (j > 0) { - mstates[i]->AddDraftToken(draft_tokens->at(i)[j - 1], NDArray(), NDArray()); + mstates[i]->AddDraftToken(draft_tokens->at(i)[j - 1], /*draft_token_slot=*/-1); } } if (num_token_to_process != 1) { @@ -368,7 +368,7 @@ class LogitProcessorImpl : public LogitProcessorObj { p_seq_ids[token_start_offset + j] = 1; } if (j > 0) { - mstates[i]->AddDraftToken(draft_tokens->at(i)[j - 1], NDArray(), NDArray()); + mstates[i]->AddDraftToken(draft_tokens->at(i)[j - 1], /*draft_token_slot=*/-1); } } if (token_number != 1) { diff --git a/cpp/serve/model.cc b/cpp/serve/model.cc index 3583b5d84b..8918cecdc4 100644 --- a/cpp/serve/model.cc +++ b/cpp/serve/model.cc @@ -13,6 +13,7 @@ #include +#include "config.h" #include "logit_processor.h" namespace mlc { @@ -25,10 +26,27 @@ class ModelImpl; TVM_REGISTER_OBJECT_TYPE(ModelObj); -Model Model::Create(String reload_lib_path, String model_path, DLDevice device, - int max_num_sequence, bool trace_enabled) { - return Model( - make_object(reload_lib_path, model_path, device, max_num_sequence, trace_enabled)); +Model Model::Create(String reload_lib_path, String model_path, const picojson::object& model_config, + DLDevice device, int max_num_sequence, const Optional& session, + bool trace_enabled) { + return Model(make_object(reload_lib_path, model_path, model_config, device, + max_num_sequence, session, trace_enabled)); +} + +picojson::object Model::LoadModelConfig(const String& model_path) { + picojson::object model_config; + std::ifstream config_istream((model_path + "/mlc-chat-config.json").c_str()); + std::ostringstream config_ostream; + ICHECK(config_istream); + config_ostream << config_istream.rdbuf(); + std::string config_str = config_ostream.str(); + picojson::value config_json; + std::string err = picojson::parse(config_json, config_str); + if (!err.empty()) { + LOG(FATAL) << err; + } + picojson::object config = config_json.get(); + return config; } class ModelImpl : public ModelObj { @@ -37,23 +55,16 @@ class ModelImpl : public ModelObj { * \brief Constructor of ModelImpl. * \sa Model::Create */ - explicit ModelImpl(String reload_lib_path, String model_path, DLDevice device, - int max_num_sequence, bool trace_enabled) + explicit ModelImpl(String reload_lib_path, String model_path, picojson::object model_config, + DLDevice device, int max_num_sequence, const Optional& session, + bool trace_enabled) : device_(device) { // Step 1. Process model config json string. - picojson::object model_config; - { - std::ifstream config_istream((model_path + "/mlc-chat-config.json").c_str()); - std::ostringstream config_ostream; - ICHECK(config_istream); - config_ostream << config_istream.rdbuf(); - std::string config_str = config_ostream.str(); - model_config = LoadModelConfigJSON(config_str); - } + LoadModelConfigJSON(model_config); // Step 2. Initialize vm, we use the packed function mechanism // so there is no explicit abi dependency on these extra // classes other than basic tvm runtime. - this->ft_.Init(reload_lib_path, device_, model_config); + this->ft_.Init(reload_lib_path, device_, model_config, session); // Step 3. Load params in nd-array cache. this->params_ = ft_.LoadParams(model_path, device_); // Step 4. Set max_num_sequence @@ -68,6 +79,12 @@ class ModelImpl : public ModelObj { token_ids_storage_ = memory::Storage( allocator->Alloc(device_host, {prefill_chunk_size_}, DataType::Int(32)), allocator); this->logit_pos_arr_ = NDArray::Empty({max_num_sequence}, DataType::Int(32), device_host); + // Step 7. Set model type + if (model_config["model_type"].get().find("rwkv") != std::string::npos) { + this->kind = KVStateKind::kRNNState; + } else { + this->kind = KVStateKind::kAttention; + } } /*********************** Model Computation ***********************/ @@ -229,14 +246,8 @@ class ModelImpl : public ModelObj { } NDArray logit_pos_nd = logit_pos_arr_.CreateView({num_sequences}, DataType::Int(32)); - // This step runs on the engine thread. - // By temporarily turning off the disco flag, this copies the logit_pos_nd to the cached device - // tensor without actually copying to the worker. - bool use_disco = ft_.use_disco; - ft_.use_disco = false; - ObjectRef logit_pos_dref_or_nd = - ft_.CopyToWorker0(logit_pos_nd, "logit_pos", {max_num_sequence_}); - ft_.use_disco = use_disco; + ObjectRef logit_pos_dref_or_nd = ft_.CopyToWorker0(logit_pos_nd, "logit_pos_local", + {max_num_sequence_}, /*local_only=*/true); CHECK(ft_.batch_select_last_hidden_func_.defined()) << "`batch_select_last_hidden_states` function is not found in the model."; @@ -739,16 +750,26 @@ class ModelImpl : public ModelObj { /*********************** KV Cache Management ***********************/ void CreateKVCache(int page_size, int max_num_sequence, int max_total_sequence_length, - int prefill_chunk_size) final { - IntTuple max_num_sequence_tuple{max_num_sequence}; - IntTuple max_total_sequence_length_tuple{max_total_sequence_length}; - IntTuple prefill_chunk_size_tuple{prefill_chunk_size}; - IntTuple page_size_tuple{page_size}; - IntTuple support_sliding_window{sliding_window_size_ != -1}; - kv_cache_ = ft_.create_kv_cache_func_(max_num_sequence_tuple, max_total_sequence_length_tuple, - prefill_chunk_size_tuple, page_size_tuple, - support_sliding_window); - local_kv_cache_ = ft_.use_disco ? Downcast(kv_cache_)->DebugGetFromRemote(0) : kv_cache_; + int prefill_chunk_size, int max_history_size, + KVStateKind kv_state_kind) final { + if (kv_state_kind == KVStateKind::kAttention) { + IntTuple max_num_sequence_tuple{max_num_sequence}; + IntTuple max_total_sequence_length_tuple{max_total_sequence_length}; + IntTuple prefill_chunk_size_tuple{prefill_chunk_size}; + IntTuple page_size_tuple{page_size}; + IntTuple support_sliding_window{sliding_window_size_ != -1}; + kv_cache_ = ft_.create_kv_cache_func_(max_num_sequence_tuple, max_total_sequence_length_tuple, + prefill_chunk_size_tuple, page_size_tuple, + support_sliding_window); + local_kv_cache_ = + ft_.use_disco ? Downcast(kv_cache_)->DebugGetFromRemote(0) : kv_cache_; + } else { + IntTuple max_num_sequence_tuple{max_num_sequence}; + IntTuple max_history_size_tuple = {std::max(max_history_size, 1)}; + kv_cache_ = ft_.create_kv_cache_func_(max_num_sequence_tuple, max_history_size_tuple); + local_kv_cache_ = + ft_.use_disco ? Downcast(kv_cache_)->DebugGetFromRemote(0) : kv_cache_; + } } void AddNewSequence(int64_t seq_id) final { ft_.kv_cache_add_sequence_func_(kv_cache_, seq_id); } @@ -775,11 +796,21 @@ class ModelImpl : public ModelObj { /************** Raw Info Query **************/ int GetNumAvailablePages() const final { - return ft_.kv_cache_get_num_available_pages_func_(local_kv_cache_); + if (this->kind == KVStateKind::kRNNState) { + // RNNState does not introduce new page at runtime + return std::numeric_limits::max(); + } else { + return ft_.kv_cache_get_num_available_pages_func_(local_kv_cache_); + } } int GetCurrentTotalSequenceLength() const final { - return ft_.kv_cache_get_total_sequence_length_func_(local_kv_cache_); + if (this->kind == KVStateKind::kRNNState) { + // RNNState does not have a total sequence length limit + return 0; + } else { + return ft_.kv_cache_get_total_sequence_length_func_(local_kv_cache_); + } } /*********************** Utilities ***********************/ @@ -833,20 +864,21 @@ class ModelImpl : public ModelObj { // Allocate the hidden_states tensor. // Use the same function as embeddings. ObjectRef hidden_states = ft_.alloc_embedding_tensor_func_(); + NDArray hidden_states_nd{nullptr}; // Get the shape of the hidden_states tensor for hidden size. - ShapeTuple hidden_states_shape; if (ft_.use_disco) { ICHECK(hidden_states->IsInstance()); - ObjectRef shape_ref = ft_.nd_get_shape_func_(hidden_states); - hidden_states_shape = Downcast(shape_ref)->DebugGetFromRemote(0); + hidden_states_nd = Downcast(hidden_states)->DebugGetFromRemote(0); } else { - NDArray hidden_states_nd = Downcast(hidden_states); - hidden_states_shape = hidden_states_nd.Shape(); + hidden_states_nd = Downcast(hidden_states); } + ShapeTuple hidden_states_shape = hidden_states_nd.Shape(); ICHECK_EQ(hidden_states_shape.size(), 2); ICHECK_EQ(hidden_states_shape[0], prefill_chunk_size_); this->hidden_size_ = hidden_states_shape[1]; - return hidden_states; + this->hidden_states_dtype_ = hidden_states_nd->dtype; + // TODO(wuwei): We can keep hidden_states on the worker after refactor + return hidden_states_nd; } void Reset() final { @@ -856,6 +888,59 @@ class ModelImpl : public ModelObj { } } + /********************** Utilities for speculative decoding **********************/ + + DraftTokenWorkspaceManager CreateDraftTokenWorkspaceManager(int max_num_tokens) { + return DraftTokenWorkspaceManager(max_num_tokens, vocab_size_, hidden_size_, + hidden_states_dtype_, device_, ft_); + } + + ObjectRef GatherHiddenStates(const ObjectRef& input, const std::vector& indices, + ObjectRef* dst) final { + NDArray dst_view = Downcast(*dst).CreateView( + {static_cast(indices.size()), hidden_size_}, hidden_states_dtype_); + NDArray indices_nd = + logit_pos_arr_.CreateView({static_cast(indices.size())}, DataType::Int(32)); + indices_nd.CopyFromBytes(indices.data(), indices.size() * sizeof(int)); + ObjectRef indices_device = + ft_.CopyToWorker0(indices_nd, "logit_pos_local", {max_num_sequence_}, /*local_only=*/true); + ft_.gather_hidden_states_func_(input, indices_device, dst_view); + return dst_view; + } + + void ScatterHiddenStates(const ObjectRef& input, const std::vector& indices, + ObjectRef* dst) final { + NDArray indices_nd = + logit_pos_arr_.CreateView({static_cast(indices.size())}, DataType::Int(32)); + indices_nd.CopyFromBytes(indices.data(), indices.size() * sizeof(int)); + ObjectRef indices_device = + ft_.CopyToWorker0(indices_nd, "logit_pos_local", {max_num_sequence_}, /*local_only=*/true); + ft_.scatter_hidden_states_func_(input, indices_device, *dst); + } + + NDArray GatherDraftProbs(const NDArray& input, const std::vector& indices, + NDArray* dst) final { + NDArray dst_view = + dst->CreateView({static_cast(indices.size()), vocab_size_}, DataType::Float(32)); + NDArray indices_nd = + logit_pos_arr_.CreateView({static_cast(indices.size())}, DataType::Int(32)); + indices_nd.CopyFromBytes(indices.data(), indices.size() * sizeof(int)); + ObjectRef indices_device = + ft_.CopyToWorker0(indices_nd, "logit_pos_local", {max_num_sequence_}, /*local_only=*/true); + ft_.gather_probs_func_(input, indices_device, dst_view); + return dst_view; + } + + void ScatterDraftProbs(const NDArray& input, const std::vector& indices, + NDArray* dst) final { + NDArray indices_nd = + logit_pos_arr_.CreateView({static_cast(indices.size())}, DataType::Int(32)); + indices_nd.CopyFromBytes(indices.data(), indices.size() * sizeof(int)); + ObjectRef indices_device = + ft_.CopyToWorker0(indices_nd, "logit_pos_local", {max_num_sequence_}, /*local_only=*/true); + ft_.scatter_probs_func_(input, indices_device, *dst); + } + /************** Debug/Profile **************/ void DebugCallFuncOnAllAllWorker(const String& func_name) final { @@ -864,15 +949,7 @@ class ModelImpl : public ModelObj { private: /*! \brief Load model configuration from JSON. */ - picojson::object LoadModelConfigJSON(const std::string& config_str) { - picojson::value config_json; - std::string err = picojson::parse(config_json, config_str); - if (!err.empty()) { - LOG(FATAL) << err; - } - - // Get json fields. - picojson::object config = config_json.get(); + picojson::object LoadModelConfigJSON(picojson::object config) { if (config.count("context_window_size")) { CHECK(config["context_window_size"].is()); this->max_window_size_ = config["context_window_size"].get(); @@ -922,6 +999,7 @@ class ModelImpl : public ModelObj { int max_num_sequence_ = -1; int prefill_chunk_size_ = -1; int hidden_size_ = -1; + DLDataType hidden_states_dtype_; int vocab_size_ = -1; int image_embed_size_ = -1; //---------------------------- @@ -946,6 +1024,8 @@ class ModelImpl : public ModelObj { NDArray logit_pos_arr_{nullptr}; // A boolean indicating if tracing is enabled. bool trace_enabled_; + // An enum indicating whether it's RNN-based. + KVStateKind kind; }; TVM_REGISTER_GLOBAL("mlc.copy_embedding_to_offset") diff --git a/cpp/serve/model.h b/cpp/serve/model.h index da532f83e8..d672739581 100644 --- a/cpp/serve/model.h +++ b/cpp/serve/model.h @@ -12,6 +12,7 @@ #include "../base.h" #include "config.h" +#include "draft_token_workspace_manager.h" #include "event_trace_recorder.h" #include "function_table.h" #include "logit_processor.h" @@ -40,10 +41,26 @@ struct ModelWorkspace { */ ObjectRef embeddings{nullptr}; /*! - * \brief The hidden_states tensor. It can be either an NDArray when tensor + * \brief The hidden_states tensor for the current batch. It can be either an NDArray when tensor * model parallelism is not enabled, or a DRef when using tensor model parallelism. */ ObjectRef hidden_states{nullptr}; + + /*! + * \brief The draft token probabilities tensor for the current batch. + */ + NDArray draft_probs{nullptr}; + + /*! + * \brief The hidden_states tensor storing the hidden_states of draft tokens of all requests. + */ + ObjectRef draft_hidden_states_storage{nullptr}; + + /*! + * \brief The draft token probabilities tensor storing the probabilities of draft tokens of all + * requests. + */ + NDArray draft_probs_storage{nullptr}; }; /*! @@ -234,9 +251,13 @@ class ModelObj : public Object { * in the engine. * \param prefill_chunk_size The maximum total number of tokens whose KV data * are allowed to exist in the KV cache at any time. + * \param max_history_size The maximum history size for RNN state to roll back. + * The KV cache does not need this. + * \param kv_state_kind The kind of cache. It can be KV cache or RNN state. */ virtual void CreateKVCache(int page_size, int max_num_sequence, int max_total_sequence_length, - int prefill_chunk_size) = 0; + int prefill_chunk_size, int max_history_size, + KVStateKind kv_state_kind) = 0; /*! \brief Add a new sequence with the given sequence id to the KV cache. */ virtual void AddNewSequence(int64_t seq_id) = 0; @@ -298,6 +319,27 @@ class ModelObj : public Object { /*! \brief Reset the model KV cache and other statistics. */ virtual void Reset() = 0; + /*********************** Utilities for speculative decoding. ***********************/ + + virtual DraftTokenWorkspaceManager CreateDraftTokenWorkspaceManager(int max_num_token) = 0; + + /*! \brief Gather the hidden_states of the given indices and in-place update the dst tensor. */ + virtual ObjectRef GatherHiddenStates(const ObjectRef& input, const std::vector& indices, + ObjectRef* dst) = 0; + + /*! \brief Scatter the hidden_states of the given indices to the dst tensor. */ + virtual void ScatterHiddenStates(const ObjectRef& input, const std::vector& indices, + ObjectRef* dst) = 0; + + /*! \brief Gather the draft token probabilities of the given indices and in-place update the dst + * tensor. */ + virtual NDArray GatherDraftProbs(const NDArray& input, const std::vector& indices, + NDArray* dst) = 0; + + /*! \brief Scatter the draft token probabilities of the given indices to the dst tensor. */ + virtual void ScatterDraftProbs(const NDArray& input, const std::vector& indices, + NDArray* dst) = 0; + /************** Debug/Profile **************/ /*! \brief Call the given global function on all workers. Only for debug purpose. */ @@ -315,13 +357,24 @@ class Model : public ObjectRef { * \brief Create the runtime module for LLM functions. * \param reload_lib_path The model library path. * \param model_path The path to the model weight parameters. + * \param model_config The model config json object. * \param device The device to run the model on. * \param max_num_sequence The maximum number of sequences to be processed + * \param session The session to run the model on. * \param trace_enabled A boolean indicating whether tracing is enabled. * \return The created runtime module. */ - TVM_DLL static Model Create(String reload_lib_path, String model_path, DLDevice device, - int max_num_sequence, bool trace_enabled); + TVM_DLL static Model Create(String reload_lib_path, String model_path, + const picojson::object& model_config, DLDevice device, + int max_num_sequence, const Optional& session, + bool trace_enabled); + + /*! + * Load the model config from the given model path. + * \param model_path The path to the model weight parameters. + * \return The model config json object. + */ + static picojson::object LoadModelConfig(const String& model_path); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Model, ObjectRef, ModelObj); }; diff --git a/cpp/serve/request_state.cc b/cpp/serve/request_state.cc index b1f5ae27a2..4c59ae52a2 100644 --- a/cpp/serve/request_state.cc +++ b/cpp/serve/request_state.cc @@ -59,11 +59,9 @@ void RequestModelStateNode::CommitToken(SampleResult sampled_token) { } } -void RequestModelStateNode::AddDraftToken(SampleResult sampled_token, NDArray prob_dist, - NDArray last_hidden_on_device) { +void RequestModelStateNode::AddDraftToken(SampleResult sampled_token, int draft_token_slot) { draft_output_tokens.push_back(std::move(sampled_token)); - draft_output_prob_dist.push_back(std::move(prob_dist)); - draft_last_hidden_on_device.push_back(std::move(last_hidden_on_device)); + draft_token_slots.push_back(draft_token_slot); appeared_token_ids[sampled_token.sampled_token_id.first] += 1; } @@ -71,14 +69,17 @@ void RequestModelStateNode::RemoveLastDraftToken() { ICHECK(!draft_output_tokens.empty()); auto it = appeared_token_ids.find(draft_output_tokens.back().sampled_token_id.first); draft_output_tokens.pop_back(); - draft_output_prob_dist.pop_back(); CHECK(it != appeared_token_ids.end()); if (--it->second == 0) { appeared_token_ids.erase(it); } } -void RequestModelStateNode::RemoveAllDraftTokens() { +void RequestModelStateNode::RemoveAllDraftTokens(std::vector* removed_draft_token_slots) { + if (removed_draft_token_slots != nullptr) { + removed_draft_token_slots->assign(draft_token_slots.begin(), draft_token_slots.end()); + } + draft_token_slots.clear(); while (!draft_output_tokens.empty()) { RemoveLastDraftToken(); } diff --git a/cpp/serve/request_state.h b/cpp/serve/request_state.h index 950bb6e290..79abcb1a24 100644 --- a/cpp/serve/request_state.h +++ b/cpp/serve/request_state.h @@ -62,20 +62,8 @@ class RequestModelStateNode : public Object { * result of speculation. */ std::vector draft_output_tokens; - /*! - * \brief The probability distribution on each position in the - * draft. We keep the distributions for stochastic sampling when merging - * speculations from multiple models. - * \note We only need this value when we have multiple parallel small models - * and draft outputs in speculative inference settings. - */ - std::vector draft_output_prob_dist; - /*! - * \brief The last hidden_states used to get probs in drafting. - * \note We only need this value when we have multiple parallel small models - * and draft outputs in speculative inference settings. - */ - std::vector draft_last_hidden_on_device; + /*! \brief The storage slots for the associated states of draft tokens. */ + std::vector draft_token_slots; /*! \brief The appeared committed and draft tokens and their occurrence times. */ std::unordered_map appeared_token_ids; @@ -101,17 +89,18 @@ class RequestModelStateNode : public Object { /*! \brief Commit a new token into committed_tokens. Update appeared_token_ids. */ void CommitToken(SampleResult sampled_token); /*! \brief Add a draft token into draft_output_tokens. Update appeared_token_ids. */ - void AddDraftToken(SampleResult sampled_token, NDArray prob_dist, - NDArray draft_last_hidden_on_device = NDArray()); - /*! \brief Remove the last token from draft_output_tokens. Update appeared_token_ids. */ - void RemoveLastDraftToken(); + void AddDraftToken(SampleResult sampled_token, int draft_token_slot); /*! \brief Remove all draft tokens from draft_output_tokens. Update appeared_token_ids. */ - void RemoveAllDraftTokens(); + void RemoveAllDraftTokens(std::vector* removed_draft_token_slots = nullptr); static constexpr const char* _type_key = "mlc.serve.RequestModelState"; static constexpr const bool _type_has_method_sequal_reduce = false; static constexpr const bool _type_has_method_shash_reduce = false; TVM_DECLARE_BASE_OBJECT_INFO(RequestModelStateNode, Object); + + private: + /*! \brief Remove the last token from draft_output_tokens. Update appeared_token_ids. */ + void RemoveLastDraftToken(); }; class RequestModelState : public ObjectRef { diff --git a/cpp/serve/sampler/cpu_sampler.cc b/cpp/serve/sampler/cpu_sampler.cc index 98080c979d..196a6dd695 100644 --- a/cpp/serve/sampler/cpu_sampler.cc +++ b/cpp/serve/sampler/cpu_sampler.cc @@ -430,7 +430,7 @@ class CPUSampler : public SamplerObj { const std::vector& cum_verify_lengths, const Array& generation_cfg, const std::vector& rngs, const std::vector>& draft_output_tokens, - const std::vector>& draft_output_prob_dist) final { + NDArray draft_probs_on_device) final { // probs_on_host: (n, v) RECORD_EVENT(trace_recorder_, request_ids, "start draft verification"); CHECK_EQ(probs_on_host->ndim, 2); @@ -438,8 +438,8 @@ class CPUSampler : public SamplerObj { int num_sequence = static_cast(cum_verify_lengths.size()) - 1; CHECK_EQ(rngs.size(), num_sequence); CHECK_EQ(draft_output_tokens.size(), num_sequence); - CHECK_EQ(draft_output_prob_dist.size(), num_sequence); + NDArray draft_probs_on_host = draft_probs_on_device.CopyTo(DLDevice{kDLCPU, 0}); std::vector> sample_results; sample_results.resize(num_sequence); @@ -451,6 +451,7 @@ class CPUSampler : public SamplerObj { [&](int i) { int verify_start = cum_verify_lengths[i]; int verify_end = cum_verify_lengths[i + 1]; + int cur_token_idx = 0; // Sub 1 to ignore the last prediction. for (; cur_token_idx < verify_end - verify_start - 1; ++cur_token_idx) { @@ -477,12 +478,9 @@ class CPUSampler : public SamplerObj { // normalize a new probability distribution double sum_v = 0.0; - NDArray q_dist = draft_output_prob_dist[i][cur_token_idx]; - ICHECK(q_dist->device.device_type == kDLCPU); - ICHECK(q_dist->ndim == 1); - ICHECK(vocab_size == q_dist->shape[q_dist->ndim - 1]); const float* __restrict p_qdist = - static_cast(__builtin_assume_aligned(q_dist->data, 4)); + static_cast(__builtin_assume_aligned(draft_probs_on_host->data, 4)) + + (verify_start + cur_token_idx + 1) * vocab_size; for (int j = 0; j < vocab_size; ++j) { p_probs[j] = std::max(p_probs[j] - p_qdist[j], 0.0f); diff --git a/cpp/serve/sampler/gpu_sampler.cc b/cpp/serve/sampler/gpu_sampler.cc index c80a846b19..c6f463eb32 100644 --- a/cpp/serve/sampler/gpu_sampler.cc +++ b/cpp/serve/sampler/gpu_sampler.cc @@ -51,6 +51,9 @@ class GPUSampler : public SamplerObj { ICHECK(gpu_sample_with_top_p_func_.defined()); ICHECK(gpu_sampler_take_probs_func_.defined()); + flashinfer_multinomial_sample_func_ = + Registry::Get("flashinfer.sampling.parallel_sampling_from_prob"); + DLDevice device_cpu{DLDeviceType::kDLCPU, /*device_id=*/0}; // We support at most 5 top prob results for each sequence. // Initialize auxiliary arrays on CPU. @@ -76,6 +79,7 @@ class GPUSampler : public SamplerObj { token_tree_first_child_device_ = NDArray::Empty({max_num_sample}, dtype_i32_, device); token_tree_next_sibling_device_ = NDArray::Empty({max_num_sample}, dtype_i32_, device); token_tree_parent_ptr_device_ = NDArray::Empty({max_num_sample}, dtype_i32_, device); + sampled_token_ids_device_ = NDArray::Empty({max_num_sample}, dtype_i32_, device); // If the device is CUDA/ROCm, we create a standalone copy stream, in // purpose to hide the latency of auxiliary stream copy. @@ -163,7 +167,7 @@ class GPUSampler : public SamplerObj { const std::vector& cum_verify_lengths, const Array& generation_cfg, const std::vector& rngs, const std::vector>& draft_output_tokens, - const std::vector>& draft_output_prob_dist) final { + NDArray draft_probs_on_device) final { NVTXScopedRange nvtx_scope("BatchVerifyDraftTokensWithProbAfterTopP"); std::vector> sample_results; // probs_on_device: (n, v) @@ -173,38 +177,27 @@ class GPUSampler : public SamplerObj { int num_sequence = static_cast(cum_verify_lengths.size()) - 1; CHECK_EQ(rngs.size(), num_sequence); CHECK_EQ(draft_output_tokens.size(), num_sequence); - CHECK_EQ(draft_output_prob_dist.size(), num_sequence); sample_results.resize(num_sequence); int num_nodes = cum_verify_lengths.back(); + CHECK_EQ(draft_probs_on_device->shape[0], num_nodes); NDArray uniform_samples_host = uniform_samples_host_.CreateView({num_nodes}, dtype_f32_); NDArray uniform_samples_device = uniform_samples_device_.CreateView({num_nodes}, dtype_f32_); - NDArray draft_probs_device = - draft_probs_device_.CreateView({num_nodes, vocab_size_}, dtype_f32_); NDArray draft_tokens_host = draft_tokens_host_.CreateView({num_nodes}, dtype_i32_); NDArray draft_tokens_device = draft_tokens_device_.CreateView({num_nodes}, dtype_i32_); - // Concat draft prob distributions to a ragged tensor (num_nodes, vocab_size) + // Copy draft tokens to GPU + int* p_draft_tokens_host = static_cast(draft_tokens_host->data); for (int i = 0; i < num_sequence; i++) { const std::vector& draft_output_tokens_i = draft_output_tokens[i]; - const std::vector& draft_output_prob_dist_i = draft_output_prob_dist[i]; int start = cum_verify_lengths[i]; int end = cum_verify_lengths[i + 1]; // start/end is the range of the sequence i in probs_on_device, which includes the prob dist // of the draft tokens and the last committed token ICHECK_EQ(draft_output_tokens_i.size() + 1, end - start); - ICHECK_EQ(draft_output_prob_dist_i.size() + 1, end - start); for (int j = 0; j < end - start - 1; j++) { - // Copy prob dist - ICHECK_EQ(draft_probs_device->dtype.bits, 32); - float* p_draft_probs = - static_cast(draft_probs_device->data) + - (j + start + 1) * - vocab_size_; // shift by one, q of the last committed token is undefined // Copy sampled token id - draft_output_prob_dist_i[j].CopyToBytes(p_draft_probs, vocab_size_ * sizeof(float)); - *(static_cast(draft_tokens_host->data) + j + start + 1) = - draft_output_tokens_i[j].sampled_token_id.first; + p_draft_tokens_host[start + j + 1] = draft_output_tokens_i[j].sampled_token_id.first; } } CopyArray(draft_tokens_host, draft_tokens_device, copy_stream_); @@ -258,7 +251,7 @@ class GPUSampler : public SamplerObj { SyncCopyStream(device_, compute_stream_, copy_stream_); - gpu_verify_draft_tokens_func_(draft_probs_device, draft_tokens_device, probs_on_device, + gpu_verify_draft_tokens_func_(draft_probs_on_device, draft_tokens_device, probs_on_device, token_tree_first_child_device, token_tree_next_sibling_device, uniform_samples_device, token_tree_parent_ptr_device); @@ -311,14 +304,20 @@ class GPUSampler : public SamplerObj { int vocab_size = probs_on_device->shape[1]; if (output_prob_dist != nullptr) { ICHECK(output_prob_dist->empty()); - output_prob_dist->reserve(num_probs); - for (int i = 0; i < num_probs; ++i) { + output_prob_dist->reserve(num_samples); + for (int i = 0; i < num_samples; ++i) { NDArray prob_dist = NDArray::Empty({vocab_size}, dtype_f32_, device_); - float* p_prob = static_cast(probs_on_device->data) + i * vocab_size; + float* p_prob = static_cast(probs_on_device->data) + sample_indices[i] * vocab_size; prob_dist.CopyFromBytes(p_prob, vocab_size * sizeof(float)); output_prob_dist->push_back(std::move(prob_dist)); } } + if (num_samples == 0) { + // This synchronization is necessary for making sure that this round + // of model forward is finished. + TVMSynchronize(device_.device_type, device_.device_id, compute_stream_); + return {}; + } ICHECK_EQ(request_ids.size(), num_samples); ICHECK_EQ(generation_cfg.size(), num_samples); ICHECK_EQ(rngs.size(), num_samples); @@ -489,8 +488,15 @@ class GPUSampler : public SamplerObj { if (!need_top_p && !need_prob_values) { // - Short path: If top_p and prob values are not needed, we directly sample from multinomial. SyncCopyStream(device_, compute_stream_, copy_stream_); - sampled_token_ids_device = gpu_multinomial_from_uniform_func_( - probs_on_device, uniform_samples_device, sample_indices_device); + if (flashinfer_multinomial_sample_func_ != nullptr) { + sampled_token_ids_device = + sampled_token_ids_device_.CreateView({sample_indices_device->shape[0]}, dtype_i32_); + (*flashinfer_multinomial_sample_func_)(probs_on_device, uniform_samples_device, + sample_indices_device, sampled_token_ids_device); + } else { + sampled_token_ids_device = gpu_multinomial_from_uniform_func_( + probs_on_device, uniform_samples_device, sample_indices_device); + } return {sampled_token_ids_device, sampled_probs_device, top_prob_probs_device, top_prob_indices_device}; } @@ -525,8 +531,15 @@ class GPUSampler : public SamplerObj { uniform_samples_device, sample_indices_device, top_p_device); } else { // - Sample without top_p. - sampled_token_ids_device = gpu_multinomial_from_uniform_func_( - probs_on_device, uniform_samples_device, sample_indices_device); + if (flashinfer_multinomial_sample_func_ != nullptr) { + sampled_token_ids_device = + sampled_token_ids_device_.CreateView({sample_indices_device->shape[0]}, dtype_i32_); + (*flashinfer_multinomial_sample_func_)(probs_on_device, uniform_samples_device, + sample_indices_device, sampled_token_ids_device); + } else { + sampled_token_ids_device = gpu_multinomial_from_uniform_func_( + probs_on_device, uniform_samples_device, sample_indices_device); + } } if (need_prob_values) { @@ -580,7 +593,7 @@ class GPUSampler : public SamplerObj { } // Synchronize for CPU to get the correct array results. - TVMSynchronize(device_.device_type, device_.device_id, nullptr); + TVMSynchronize(device_.device_type, device_.device_id, compute_stream_); return {sampled_token_ids_host, sampled_probs_host, top_prob_probs_host, top_prob_indices_host}; } @@ -598,6 +611,7 @@ class GPUSampler : public SamplerObj { PackedFunc gpu_sampler_take_probs_func_; PackedFunc gpu_verify_draft_tokens_func_; PackedFunc gpu_renormalize_by_top_p_func_; + const PackedFunc* flashinfer_multinomial_sample_func_; // Auxiliary NDArrays on CPU NDArray uniform_samples_host_; NDArray sample_indices_host_; @@ -621,6 +635,7 @@ class GPUSampler : public SamplerObj { NDArray token_tree_first_child_device_; NDArray token_tree_next_sibling_device_; NDArray token_tree_parent_ptr_device_; + NDArray sampled_token_ids_device_; // The event trace recorder for requests. */ Optional trace_recorder_; // The device stream for the default computation operations. diff --git a/cpp/serve/sampler/sampler.h b/cpp/serve/sampler/sampler.h index 7943231e55..59e433ac47 100644 --- a/cpp/serve/sampler/sampler.h +++ b/cpp/serve/sampler/sampler.h @@ -108,15 +108,16 @@ class SamplerObj : public Object { * \param rngs The random number generator of each sequence. * \param draft_output_tokens The draft tokens generated by the small model for * each sequence. - * \param draft_output_prob_dist The probability distribution computed from the - * small model for each sequence. + * \param draft_probs_on_device The probability distribution computed from the + * small model for each sequence. Concatenated tensor of shape (total_verify_length, vocab_size). + * It includes the slot for the last committed token that has undefined probablity value. * \return The list of accepted tokens for each request. */ virtual std::vector> BatchVerifyDraftTokensWithProbAfterTopP( NDArray probs, const Array& request_ids, const std::vector& cum_verify_lengths, const Array& generation_cfg, const std::vector& rngs, const std::vector>& draft_output_tokens, - const std::vector>& draft_output_prob_dist) = 0; + NDArray draft_probs_on_device) = 0; static constexpr const char* _type_key = "mlc.serve.Sampler"; static constexpr const bool _type_has_method_sequal_reduce = false; diff --git a/cpp/serve/threaded_engine.cc b/cpp/serve/threaded_engine.cc index f234dfbbc3..2f6f77a3a0 100644 --- a/cpp/serve/threaded_engine.cc +++ b/cpp/serve/threaded_engine.cc @@ -36,8 +36,9 @@ enum class InstructionKind : int { /*! \brief The implementation of ThreadedEngine. */ class ThreadedEngineImpl : public ThreadedEngine { public: - void InitBackgroundEngine(Optional request_stream_callback, + void InitBackgroundEngine(Device device, Optional request_stream_callback, Optional trace_recorder) final { + device_ = device; CHECK(request_stream_callback.defined()) << "ThreadedEngine requires request stream callback function, but it is not given."; request_stream_callback_ = request_stream_callback.value(); @@ -231,7 +232,7 @@ class ThreadedEngineImpl : public ThreadedEngine { }; Optional request_stream_callback = PackedFunc(frequest_stream_callback_wrapper); - background_engine_ = Engine::Create(std::move(engine_config), + background_engine_ = Engine::Create(std::move(engine_config), device_, std::move(request_stream_callback), trace_recorder_); } @@ -247,6 +248,8 @@ class ThreadedEngineImpl : public ThreadedEngine { } } + /*! \brief The device to run models on. */ + Device device_; /*! \brief The background normal engine for request processing. */ std::unique_ptr background_engine_; /*! \brief The request stream callback. */ diff --git a/cpp/serve/threaded_engine.h b/cpp/serve/threaded_engine.h index f3d9c2b70c..49ba8f2175 100644 --- a/cpp/serve/threaded_engine.h +++ b/cpp/serve/threaded_engine.h @@ -35,10 +35,11 @@ class ThreadedEngine { /*! * \brief Initialize the threaded engine from packed arguments in TVMArgs. + * \param device The device where to run models. * \param request_stream_callback The request stream callback function to. * \param trace_recorder Event trace recorder for requests. */ - virtual void InitBackgroundEngine(Optional request_stream_callback, + virtual void InitBackgroundEngine(Device device, Optional request_stream_callback, Optional trace_recorder) = 0; /*! diff --git a/cpp/support/encoding.cc b/cpp/support/encoding.cc index 0509c1eb2a..d9420bbbd5 100644 --- a/cpp/support/encoding.cc +++ b/cpp/support/encoding.cc @@ -11,7 +11,7 @@ namespace mlc { namespace llm { -std::string CodepointToUtf8(TCodepoint codepoint) { +std::string PrintAsUTF8(TCodepoint codepoint) { ICHECK(codepoint <= 0x10FFFF) << "Invalid codepoint: " << codepoint; std::string utf8; if (codepoint <= 0x7F) { @@ -36,8 +36,8 @@ std::string CodepointToUtf8(TCodepoint codepoint) { return utf8; } -std::string CodepointToPrintable( - TCodepoint codepoint, const std::unordered_map& custom_escape_map) { +std::string PrintAsEscaped(TCodepoint codepoint, + const std::unordered_map& custom_escape_map) { static const std::unordered_map kCodepointToEscape = { {'\'', "\\\'"}, {'\"', "\\\""}, {'\?', "\\\?"}, {'\\', "\\\\"}, {'\a', "\\a"}, {'\b', "\\b"}, {'\f', "\\f"}, {'\n', "\\n"}, {'\r', "\\r"}, {'\t', "\\t"}, @@ -63,10 +63,10 @@ std::string CodepointToPrintable( return codepoint <= 0xFFFF ? "\\u" + hex : "\\U" + hex; } -std::pair Utf8ToCodepoint(const char* utf8) { - const std::array kFirstByteMask = {0x00, 0x7F, 0x1F, 0x0F, 0x07}; +std::pair ParseNextUTF8(const char* utf8) { + static const std::array kFirstByteMask = {0x00, 0x7F, 0x1F, 0x0F, 0x07}; // clang-format off - const std::array kUtf8Bytes = { + static const std::array kUtf8Bytes = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, @@ -89,7 +89,7 @@ std::pair Utf8ToCodepoint(const char* utf8) { auto bytes = kUtf8Bytes[static_cast(utf8[0])]; if (bytes == -1) { // invalid utf8 - return {static_cast(CharHandlingError::kInvalidUtf8), 0}; + return {static_cast(CharHandlingError::kInvalidUtf8), utf8}; } TCodepoint res = static_cast(utf8[0]) & kFirstByteMask[bytes]; @@ -100,23 +100,23 @@ std::pair Utf8ToCodepoint(const char* utf8) { } res = (res << 6) | (static_cast(utf8[i]) & 0x3F); } - return {res, bytes}; + return {res, utf8 + bytes}; } -std::vector Utf8StringToCodepoints(const char* utf8) { +std::vector ParseUTF8(const char* utf8) { std::vector codepoints; while (*utf8 != 0) { - auto [codepoint, bytes] = Utf8ToCodepoint(utf8); + TCodepoint codepoint; + std::tie(codepoint, utf8) = ParseNextUTF8(utf8); if (codepoint == static_cast(CharHandlingError::kInvalidUtf8)) { return {codepoint}; } codepoints.push_back(codepoint); - utf8 += bytes; } return codepoints; } -int HexCharToInt(char c) { +inline int HexCharToInt(char c) { if (c >= '0' && c <= '9') { return c - '0'; } else if (c >= 'a' && c <= 'f') { @@ -128,22 +128,22 @@ int HexCharToInt(char c) { } } -std::pair Utf8OrEscapeToCodepoint( +std::pair ParseNextUTF8OrEscaped( const char* utf8, const std::unordered_map& custom_escape_map) { static const std::unordered_map kEscapeToCodepoint = { {"\\\'", '\''}, {"\\\"", '\"'}, {"\\\?", '\?'}, {"\\\\", '\\'}, {"\\a", '\a'}, {"\\b", '\b'}, {"\\f", '\f'}, {"\\n", '\n'}, {"\\r", '\r'}, {"\\t", '\t'}, {"\\v", '\v'}, {"\\0", '\0'}, {"\\e", '\x1B'}}; if (utf8[0] != '\\') { - return Utf8ToCodepoint(utf8); + return ParseNextUTF8(utf8); } auto escape_sequence = std::string(utf8, 2); if (auto it = custom_escape_map.find(escape_sequence); it != custom_escape_map.end()) { - return {it->second, 2}; + return {it->second, utf8 + 2}; } if (auto it = kEscapeToCodepoint.find(escape_sequence); it != kEscapeToCodepoint.end()) { - return {it->second, 2}; + return {it->second, utf8 + 2}; } if (utf8[1] == 'x') { @@ -159,9 +159,9 @@ std::pair Utf8OrEscapeToCodepoint( ++len; } if (len == 0) { - return {static_cast(CharHandlingError::kInvalidEscape), 0}; + return {static_cast(CharHandlingError::kInvalidEscape), utf8}; } - return {codepoint, len + 2}; + return {codepoint, utf8 + len + 2}; } else if (utf8[1] == 'u' || utf8[1] == 'U') { // 4- or 8-digit hex int len = utf8[1] == 'u' ? 4 : 8; @@ -170,13 +170,13 @@ std::pair Utf8OrEscapeToCodepoint( for (int i = 0; i < len; ++i) { auto digit = HexCharToInt(utf8[i + 2]); if (digit == -1) { - return {static_cast(CharHandlingError::kInvalidEscape), 0}; + return {static_cast(CharHandlingError::kInvalidEscape), utf8}; } codepoint = codepoint * 16 + digit; } - return {codepoint, len + 2}; + return {codepoint, utf8 + len + 2}; } else { - return {static_cast(CharHandlingError::kInvalidEscape), 0}; + return {static_cast(CharHandlingError::kInvalidEscape), utf8}; } } diff --git a/cpp/support/encoding.h b/cpp/support/encoding.h index f28aae6d74..790040e97e 100644 --- a/cpp/support/encoding.h +++ b/cpp/support/encoding.h @@ -21,7 +21,7 @@ using TCodepoint = int32_t; * \param codepoint The codepoint. * \return The UTF-8 string. */ -std::string CodepointToUtf8(TCodepoint codepoint); +std::string PrintAsUTF8(TCodepoint codepoint); /*! * \brief Convert a codepoint to a printable string. If the codepoint is not printable, it will be @@ -29,10 +29,10 @@ std::string CodepointToUtf8(TCodepoint codepoint); * specify more escape sequences using custom_escape_map. * \param codepoint The codepoint. * \param custom_escape_map A map from codepoint to escape sequence. If the codepoint is in the map, - * it will be escaped using the corresponding escape sequence. e.g. {'-', "\\-"}. + * it will be escaped using the corresponding escape sequence. e.g. {{'-', "\\-"}}. * \return The printable string. */ -std::string CodepointToPrintable( +std::string PrintAsEscaped( TCodepoint codepoint, const std::unordered_map& custom_escape_map = {}); @@ -53,9 +53,9 @@ enum class CharHandlingError : TCodepoint { * \return The codepoint and the number of bytes consumed. If the UTF-8 string is invalid, the * function returns (CharHandlingError::kInvalidUtf8, 0). */ -std::pair Utf8ToCodepoint(const char* utf8); +std::pair ParseNextUTF8(const char* utf8); -std::vector Utf8StringToCodepoints(const char* utf8); +std::vector ParseUTF8(const char* utf8); /*! * \brief Convert a UTF-8 string or an escape sequence to a codepoint. By default the function @@ -63,12 +63,12 @@ std::vector Utf8StringToCodepoints(const char* utf8); * using custom_escape_map. * \param utf8 The UTF-8 string or the escape sequence. * \param custom_escape_map A map from escape sequence to codepoint. If the escape sequence is in - * the map, it will be converted to the corresponding codepoint. e.g. {"\\-", '-'}. + * the map, it will be converted to the corresponding codepoint. e.g. {{"\\-", '-'}}. * \return The codepoint and the number of bytes consumed. If the UTF-8 string or the escape * sequence is invalid, the function returns * (CharHandlingError::kInvalidUtf8 or CharHandlingError::kInvalidEscape, 0). */ -std::pair Utf8OrEscapeToCodepoint( +std::pair ParseNextUTF8OrEscaped( const char* utf8, const std::unordered_map& custom_escape_map = {}); } // namespace llm diff --git a/cpp/tokenizers.cc b/cpp/tokenizers.cc index ef866f3bfc..6fe9217520 100644 --- a/cpp/tokenizers.cc +++ b/cpp/tokenizers.cc @@ -9,10 +9,12 @@ #include #include +#include #include #include #include +#include "./support/encoding.h" #include "./support/load_bytes_from_file.h" namespace mlc { @@ -91,13 +93,8 @@ Tokenizer Tokenizer::FromPath(const String& _path) { LOG(FATAL) << "Cannot find any tokenizer under: " << _path; } -/*! - * \brief Post-process a raw token (which may be a raw byte or contain lower - * one eights block) to the actual token. - * We do this in order to conform with the tokenizers' setup. - */ -inline std::string PostProcessToken(std::string token) { - // 1. The token represents a byte. +/*! \brief ByteFallback decoder: transform tokens like <0x1B> to hex char byte 1B */ +inline std::string ByteFallbackDecoder(const std::string& token) { if (token.length() == 6 && token.substr(0, 3) == "<0x" && token.back() == '>') { int byte = 0; for (int i = 0; i < 2; ++i) { @@ -108,15 +105,82 @@ inline std::string PostProcessToken(std::string token) { ICHECK(byte >= 0 && byte < 256); return std::string(/*n=*/1, static_cast(byte)); } + return token; +} - // 2. The token contains "\u2581" which means space. - static const std::string& lower_one_eighth_block = "\u2581"; - size_t pos = token.find(lower_one_eighth_block); - while (pos != std::string::npos) { - token.replace(pos, /*n=*/lower_one_eighth_block.length(), /*str=*/" "); - pos = token.find(lower_one_eighth_block); +/*! \brief SpaceReplacer decoder: transform "\u2581" back to space */ +inline std::string SpaceReplacerDecoder(const std::string& token) { + // \u2581 is the unicode for "lower one eighth block" + // UTF8 encoding for \u2581 is 0xE2 0x96 0x81 + std::string result; + for (size_t i = 0; i < token.size(); ++i) { + if (i + 2 < token.size() && token[i] == char(0xE2) && token[i + 1] == char(0x96) && + token[i + 2] == char(0x81)) { + result += ' '; + i += 2; + } else { + result += token[i]; + } + } + return result; +} + +/*! \brief ByteLevel decoder: inverses the bytes-to-unicode transformation in the encoding + * process as in + * https://github.com/huggingface/transformers/blob/87be06ca77166e6a6215eee5a990ab9f07238a18/src/transformers/models/gpt2/tokenization_gpt2.py#L38-L59 + */ +inline std::string ByteLevelDecoder(const std::string& token) { + // clang-format off + // The inverse map of bytes_to_unicode. -1 means there is no mapping to this unicode. + static const std::array unicode_to_byte_map = { + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, + 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, + 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, + 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, + 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, -1, + 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, + 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, + 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, + 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, + 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 127, 128, + 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, + 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 173 + }; + // clang-format on + + auto unicode_codepoints = ParseUTF8(token.c_str()); + std::string decoded; + + for (auto unicode_codepoint : unicode_codepoints) { + ICHECK(unicode_codepoint >= 0 && + unicode_codepoint < static_cast(unicode_to_byte_map.size())); + int byte = unicode_to_byte_map[unicode_codepoint]; + if (byte == -1) { + // If there is no mapping, add the codepoint itself to the result string + // Some tokenizer like Phi-2 have raw tokens like \t\t + decoded += static_cast(unicode_codepoint); + } else { + decoded += static_cast(byte); + } + } + return decoded; +} + +/*! + * \brief Post-process a raw token to the actual token with the given post-processing method. + */ +inline std::string PostProcessToken(const std::string& token, const std::string& postproc_method) { + if (postproc_method == "byte_fallback") { + return SpaceReplacerDecoder(ByteFallbackDecoder(token)); + } else if (postproc_method == "byte_level") { + return ByteLevelDecoder(token); + } else { + LOG(FATAL) << "Unknown post-processing method: " << postproc_method; } - return token; } const std::vector& TokenizerObj::TokenTable() { @@ -127,12 +191,21 @@ const std::vector& TokenizerObj::TokenTable() { int vocab_size = tokenizer->GetVocabSize(); token_table_.reserve(vocab_size); for (int32_t token_id = 0; token_id < vocab_size; ++token_id) { - std::string token = tokenizer->IdToToken(token_id); - token_table_.push_back(PostProcessToken(token)); + token_table_.push_back(tokenizer->IdToToken(token_id)); } return token_table_; } +std::vector Tokenizer::PostProcessTokenTable( + const std::vector& token_table, const std::string& postproc_method) { + std::vector postprocessed_token_table; + postprocessed_token_table.reserve(token_table.size()); + for (const std::string& token : token_table) { + postprocessed_token_table.push_back(PostProcessToken(token, postproc_method)); + } + return postprocessed_token_table; +} + TVM_REGISTER_GLOBAL("mlc.Tokenizer").set_body_typed([](const String& path) { return Tokenizer::FromPath(path); }); diff --git a/cpp/tokenizers.h b/cpp/tokenizers.h index 16d9ba456b..36fc0c23db 100644 --- a/cpp/tokenizers.h +++ b/cpp/tokenizers.h @@ -30,7 +30,7 @@ class TokenizerObj : public Object { std::vector Encode(const std::string& text) const; /*! \brief Decode token ids into text. */ std::string Decode(const std::vector& token_ids) const; - /*! \brief Return the token table of the tokenizer. */ + /*! \brief Return the token table of the tokenizer. Special tokens are included. */ const std::vector& TokenTable(); /*! @@ -64,6 +64,25 @@ class Tokenizer : public ObjectRef { /*! \brief Create a tokenizer from a directory path on disk. */ MLC_LLM_DLL static Tokenizer FromPath(const String& path); + /*! + * \brief Convert raw tokens provided by the tokenizer to their original string to simplify + * later processing. E.g. For LLaMA-2, convert "▁of" to " of". + * + * \param token_table The raw token table. + * \param postproc_method The postprocessing method to use. Now we only support "byte-fallback" + * and "byte-level", which refers to the type of the decoder of the tokenizer. + * - "byte-fallback": Use the decoding method in the byte-fallback BPE tokenizer. This is used + * by LLaMA-2, Mixtral-7b, etc. This method: 1) transform tokens like <0x1B> to hex char + * byte 1B. (known as the byte-fallback method); 2) transform \\u2581 to space. + * - "byte-level": Use the decoding method in the byte-level BPE tokenizer. This is used by + * LLaMA-3, GPT-2, Phi-2, etc. This method inverses the bytes-to-unicode transformation in + * the encoding process as in + * https://github.com/huggingface/transformers/blob/87be06ca77166e6a6215eee5a990ab9f07238a18/src/transformers/models/gpt2/tokenization_gpt2.py#L38-L59 + * \returns The postprocessed token table containing the original strings. + */ + static std::vector PostProcessTokenTable(const std::vector& token_table, + const std::string& postproc_method); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Tokenizer, ObjectRef, TokenizerObj); private: diff --git a/docs/deploy/ios.rst b/docs/deploy/ios.rst index c0217db9e9..75a5cdbdc7 100644 --- a/docs/deploy/ios.rst +++ b/docs/deploy/ios.rst @@ -341,10 +341,24 @@ All these knobs are specified in ``mlc-chat-config.json`` generated by ``gen_con mlc_llm gen_config ./dist/models/phi-2/ \ --quantization q4f16_1 --conv-template phi-2 \ -o dist/phi-2-q4f16_1-MLC/ - # 2. compile: compile model library with specification in mlc-chat-config.json + # 2. mkdir: create a directory to store the compiled model library + mkdir -p dist/libs + # 3. compile: compile model library with specification in mlc-chat-config.json mlc_llm compile ./dist/phi-2-q4f16_1-MLC/mlc-chat-config.json \ --device iphone -o dist/libs/phi-2-q4f16_1-iphone.tar +Given the compiled library, it is possible to calculate an upper bound for the VRAM +usage during runtime. This useful to better understand if a model is able to fit particular +hardware. +That information will be displayed at the end of the console log when the ``compile`` is executed. +It might look something like this: + +.. code:: shell + + [2024-04-25 03:19:56] INFO model_metadata.py:96: Total memory usage: 1625.73 MB (Parameters: 1492.45 MB. KVCache: 0.00 MB. Temporary buffer: 133.28 MB) + [2024-04-25 03:19:56] INFO model_metadata.py:105: To reduce memory usage, tweak `prefill_chunk_size`, `context_window_size` and `sliding_window_size` + [2024-04-25 03:19:56] INFO compile.py:198: Generated: dist/libs/phi-2-q4f16_1-iphone.tar + .. note:: When compiling larger models like ``Llama-2-7B``, you may want to add a lower chunk size while prefilling prompts ``--prefill_chunk_size 128`` or even lower ``context_window_size``\ @@ -388,21 +402,7 @@ This would result in something like `phi-2-q4f16_1-MLC `_. -**Step 4. Calculate estimated VRAM usage** - -Given the compiled library, it is possible to calculate an upper bound for the VRAM -usage during runtime. This useful to better understand if a model is able to fit particular -hardware. We can calculate this estimate using the following command: - -.. code:: shell - - ~/mlc-llm > python -m mlc_llm.cli.model_metadata ./dist/libs/phi-2-q4f16_1-iphone.tar \ - > --memory-only --mlc-chat-config ./dist/phi-2-q4f16_1-MLC/mlc-chat-config.json - INFO model_metadata.py:90: Total memory usage: 3042.96 MB (Parameters: 1492.45 MB. KVCache: 640.00 MB. Temporary buffer: 910.51 MB) - INFO model_metadata.py:99: To reduce memory usage, tweak `prefill_chunk_size`, `context_window_size` and `sliding_window_size` - - -**Step 5. Register as a ModelRecord** +**Step 4. Register as a ModelRecord** Finally, we update the code snippet for `app-config.json `__ diff --git a/docs/install/mlc_llm.rst b/docs/install/mlc_llm.rst index 7b64dce9fb..ce15616957 100644 --- a/docs/install/mlc_llm.rst +++ b/docs/install/mlc_llm.rst @@ -214,7 +214,9 @@ There are two ways to do so: .. code-tab :: bash Install via environment variable - export PYTHONPATH=/path-to-mlc-llm/python:$PYTHONPATH + export MLC_LLM_HOME=/path-to-mlc-llm + export PYTHONPATH=$MLC_LLM_HOME/python:$PYTHONPATH + alias mlc_llm="python -m mlc_llm" .. code-tab :: bash Install via pip local project diff --git a/python/mlc_llm/cli/lib_delivery.py b/python/mlc_llm/cli/lib_delivery.py new file mode 100644 index 0000000000..a5d678fbe2 --- /dev/null +++ b/python/mlc_llm/cli/lib_delivery.py @@ -0,0 +1,200 @@ +"""Continuous model delivery for MLC LLM models.""" + +import argparse +import dataclasses +import json +import os +import shutil +import subprocess +import sys +import tempfile +from pathlib import Path +from typing import Any, Callable, Dict, List + +from mlc_llm.support import logging +from mlc_llm.support.argparse import ArgumentParser +from mlc_llm.support.constants import MLC_TEMP_DIR +from mlc_llm.support.style import bold, green, red + +logging.enable_logging() +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class ModelInfo: # pylint: disable=too-many-instance-attributes + """Necessary information for the model delivery""" + + model_id: str + model: Path + quantization: str + device: str + # overrides the `context_window_size`, `prefill_chunk_size`, + # `sliding_window_size`, `attention_sink_size`, `max_batch_size` + # and `tensor_parallel_shards in mlc-chat-config.json + overrides: Dict[str, int] + + +class DeferredScope: + """A context manager that defers execution of functions until exiting the scope.""" + + def __init__(self): + self.deferred_functions = [] + + def add(self, func: Callable[[], None]): + """Add a function to be executed when exiting the scope.""" + self.deferred_functions.append(func) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + for func in reversed(self.deferred_functions): + func() + return False + + def create_temp_dir(self) -> Path: + """Create a temporary directory that will be deleted when exiting the scope.""" + temp_dir = tempfile.mkdtemp(dir=MLC_TEMP_DIR) + self.add(lambda: shutil.rmtree(temp_dir, ignore_errors=True)) + return Path(temp_dir) + + +def _run_compilation(model_info: ModelInfo, repo_dir: Path) -> bool: + """Run the compilation of the model library.""" + + def get_lib_ext(device: str) -> str: + if device in ["cuda", "vulkan", "metal"]: + return ".so" + if device in ["android", "ios"]: + return ".tar" + if device in ["webgpu"]: + return ".wasm" + + return "" + + succeeded = True + with tempfile.TemporaryDirectory(dir=MLC_TEMP_DIR) as temp_dir: + log_path = Path(temp_dir) / "logs.txt" + model_lib_name = f"{model_info.model_id}-{model_info.quantization}-{model_info.device}" + lib_ext = get_lib_ext(model_info.device) + if lib_ext == "": + raise ValueError(f"Unsupported device: {model_info.device}") + model_lib_name += lib_ext + with log_path.open("a", encoding="utf-8") as log_file: + overrides = ";".join(f"{key}={value}" for key, value in model_info.overrides.items()) + cmd = [ + sys.executable, + "-m", + "mlc_llm", + "compile", + str(model_info.model), + "--device", + model_info.device, + "--quantization", + model_info.quantization, + "--overrides", + overrides, + "--output", + os.path.join(temp_dir, model_lib_name), + ] + print(" ".join(cmd), file=log_file, flush=True) + subprocess.run(cmd, check=True, stdout=log_file, stderr=subprocess.STDOUT) + logger.info("[MLC] Compilation Complete!") + if not (Path(temp_dir) / model_lib_name).exists(): + logger.error( + "[%s] Model %s. Device %s. No compiled library found.", + red("FAILED"), + model_info.model_id, + model_info.device, + ) + succeeded = False + return succeeded + + # overwrite git repo file with the compiled library + repo_filepath = repo_dir / model_info.model_id / model_lib_name + if not repo_filepath.parent.exists(): + repo_filepath.parent.mkdir(parents=True, exist_ok=True) + # copy lib from Path(temp_dir) / model_lib_name to repo_filepath + shutil.copy(Path(temp_dir) / model_lib_name, repo_filepath) + logger.info("Saved library %s at %s", model_lib_name, repo_filepath) + return succeeded + + +def _main( # pylint: disable=too-many-locals + spec: Dict[str, Any], +): + """Compile the model libs in the spec and save them to the binary_libs_dir.""" + failed_cases: List[Any] = [] + for task_index, task in enumerate(spec["tasks"], 1): + logger.info( + bold("[{task_index}/{total_tasks}] Processing model: ").format( + task_index=task_index, + total_tasks=len(spec["tasks"]), + ) + + green(task["model_id"]) + ) + model_info = { + "model_id": task["model_id"], + "model": task["model"], + } + for compile_opt in spec["default_compile_options"] + task.get("compile_options", []): + for quantization in spec["default_quantization"] + task.get("quantization", []): + model_info["quantization"] = quantization + model_info["device"] = compile_opt["device"] + model_info["overrides"] = compile_opt.get("overrides", {}) + logger.info( + "[Config] " + + bold("model_id: ") + + model_info["model_id"] + + bold(", quantization: ") + + model_info["quantization"] + + bold(", device: ") + + model_info["device"] + + bold(", overrides: ") + + json.dumps(model_info["overrides"]) + ) + + result = _run_compilation( + ModelInfo(**model_info), + repo_dir=Path(spec["binary_libs_dir"]), + ) + if not result: + failed_cases.append(model_info) + + if failed_cases: + logger.info("Total %s %s:", len(failed_cases), red("failures")) + for case in failed_cases: + logger.info( + "model_id %s, quantization %s, device %s, overrides %s", + case["model_id"], + case["quantization"], + case["device"], + json.dumps(case["overrides"]), + ) + + +def main(): + """Entry point.""" + + def _load_spec(path_spec: str) -> Dict[str, Any]: + path = Path(path_spec) + if not path.exists(): + raise argparse.ArgumentTypeError(f"Spec file does not exist: {path}") + with path.open("r", encoding="utf-8") as i_f: + return json.load(i_f) + + parser = ArgumentParser("MLC LLM continuous library delivery") + parser.add_argument( + "--spec", + type=_load_spec, + required=True, + help="Path to the spec file", + ) + parsed = parser.parse_args() + _main( + spec=parsed.spec, + ) + + +if __name__ == "__main__": + main() diff --git a/python/mlc_llm/cli/serve.py b/python/mlc_llm/cli/serve.py index 9f7c1c3580..6663a0c230 100644 --- a/python/mlc_llm/cli/serve.py +++ b/python/mlc_llm/cli/serve.py @@ -44,6 +44,9 @@ def main(argv): "--max-total-seq-length", type=int, help=HELP["max_total_sequence_length_serve"] ) parser.add_argument("--prefill-chunk-size", type=int, help=HELP["prefill_chunk_size_serve"]) + parser.add_argument( + "--max-history-size", type=int, default=1, help=HELP["max_history_size_serve"] + ) parser.add_argument( "--gpu-memory-utilization", type=float, help=HELP["gpu_memory_utilization_serve"] ) @@ -100,6 +103,7 @@ def main(argv): max_batch_size=parsed.max_batch_size, max_total_sequence_length=parsed.max_total_seq_length, prefill_chunk_size=parsed.prefill_chunk_size, + max_history_size=parsed.max_history_size, gpu_memory_utilization=parsed.gpu_memory_utilization, speculative_mode=SpeculativeMode[parsed.speculative_mode], spec_draft_length=parsed.spec_draft_length, diff --git a/python/mlc_llm/compiler_pass/attach_spec_decode_aux_funcs.py b/python/mlc_llm/compiler_pass/attach_spec_decode_aux_funcs.py new file mode 100644 index 0000000000..b7cfd76fa3 --- /dev/null +++ b/python/mlc_llm/compiler_pass/attach_spec_decode_aux_funcs.py @@ -0,0 +1,66 @@ +"""The pass that attaches logit processor functions to the IRModule.""" + +import tvm +from tvm import IRModule +from tvm.script import tir as T + + +@tvm.transform.module_pass(opt_level=0, name="AttachSpecDecodeAuxFuncs") +class AttachSpecDecodeAuxFuncs: # pylint: disable=too-few-public-methods + """Attach logit processing TIR functions to IRModule.""" + + def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: + """Entrypoint""" + mod = mod.clone() + mod["scatter_probs"] = _get_scatter_2d_inplace( + dtype="float32", global_symbol="scatter_probs" + ) + mod["gather_probs"] = _get_gather_2d_inplace(dtype="float32", global_symbol="gather_probs") + if "prefill_to_last_hidden_states" in mod: + hidden_states_struct_info = mod["prefill_to_last_hidden_states"].ret_struct_info.fields[ + 0 + ] # pylint: disable=no-member + dtype = hidden_states_struct_info.dtype + mod["scatter_hidden_states"] = _get_scatter_2d_inplace( + dtype, global_symbol="scatter_hidden_states" + ) + mod["gather_hidden_states"] = _get_gather_2d_inplace( + dtype, global_symbol="gather_hidden_states" + ) + return mod + + +def _get_scatter_2d_inplace(dtype: str, global_symbol: str): + @T.prim_func + def _scatter_2d(var_src: T.handle, var_indices: T.handle, var_dst: T.handle): + T.func_attr({"global_symbol": global_symbol, "tir.noalias": True}) + batch_size = T.int32(is_size_var=True) + m = T.int32(is_size_var=True) + n = T.int32(is_size_var=True) + src = T.match_buffer(var_src, (batch_size, n), dtype) + indices = T.match_buffer(var_indices, (batch_size,), "int32") + dst = T.match_buffer(var_dst, (m, n), dtype) + for b, j in T.grid(batch_size, n): + with T.block("scatter_2d"): + vb, vj = T.axis.remap("SS", [b, j]) + dst[indices[vb], vj] = src[vb, vj] + + return _scatter_2d + + +def _get_gather_2d_inplace(dtype: str, global_symbol: str): + @T.prim_func + def _gather_2d(var_src: T.handle, var_indices: T.handle, var_dst: T.handle): + T.func_attr({"global_symbol": global_symbol, "tir.noalias": True}) + batch_size = T.int32(is_size_var=True) + m = T.int32(is_size_var=True) + n = T.int32(is_size_var=True) + src = T.match_buffer(var_src, (m, n), dtype) + indices = T.match_buffer(var_indices, (batch_size,), "int32") + dst = T.match_buffer(var_dst, (batch_size, n), dtype) + for b, j in T.grid(batch_size, n): + with T.block("gather_2d"): + vb, vj = T.axis.remap("SS", [b, j]) + dst[vb, vj] = src[indices[vb], vj] + + return _gather_2d diff --git a/python/mlc_llm/compiler_pass/pipeline.py b/python/mlc_llm/compiler_pass/pipeline.py index b85a6a2cf6..3c80d2c4df 100644 --- a/python/mlc_llm/compiler_pass/pipeline.py +++ b/python/mlc_llm/compiler_pass/pipeline.py @@ -15,6 +15,7 @@ from .attach_embedding_allocator import AttachAllocEmbeddingTensorFunc from .attach_logit_processor import AttachLogitProcessFunc from .attach_sampler import AttachGPUSamplingFunc +from .attach_spec_decode_aux_funcs import AttachSpecDecodeAuxFuncs from .attach_support_info import ( AttachAdditionalPrimFuncs, AttachCUDAGraphSymbolicCaptureHints, @@ -33,6 +34,7 @@ from .fuse_transpose_matmul import FuseTransposeMatmul from .lift_global_buffer_alloc import LiftTIRGlobalBufferAlloc from .low_batch_specialization import LowBatchGemvSpecialize +from .rewrite_softmax import RewriteTwoStageSoftmax from .scatter_tuple_get_item import ScatterTupleGetItem logger = logging.getLogger(__name__) @@ -103,6 +105,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I AttachAdditionalPrimFuncs(additional_tirs), AttachAllocEmbeddingTensorFunc(metadata), AttachGPUSamplingFunc(target, variable_bounds), + AttachSpecDecodeAuxFuncs(), AttachMemoryPlanAttr(), tvm.tir.transform.BindTarget(tvm.target.Target.current(allow_none=False)), _DebugDump("debug-phase0.py", debug_dump, show_meta=False), @@ -117,6 +120,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I # Phase 2. Lowering to TIR, inherited TVM Relax's official "zero" pipeline _LogProgress("Lowering to TVM TIR kernels"), tvm.relax.backend.DispatchSortScan(), + RewriteTwoStageSoftmax(target=target), tvm.relax.transform.LegalizeOps(), tvm.relax.transform.AnnotateTIROpPattern(), tvm.relax.transform.FoldConstant(), diff --git a/python/mlc_llm/compiler_pass/rewrite_softmax.py b/python/mlc_llm/compiler_pass/rewrite_softmax.py new file mode 100644 index 0000000000..82e6cf863b --- /dev/null +++ b/python/mlc_llm/compiler_pass/rewrite_softmax.py @@ -0,0 +1,193 @@ +"""A compiler pass that rewrites one-shot softmax into two-stage softmax.""" + +import math + +import tvm +from tvm import relax +from tvm.ir.module import IRModule +from tvm.relax.expr import Expr +from tvm.relax.expr_functor import PyExprMutator, mutator +from tvm.script import tir as T + +from ..support.max_thread_check import get_max_num_threads_per_block + + +@tvm.transform.module_pass(opt_level=0, name="RewriteTwoStageSoftmax") +class RewriteTwoStageSoftmax: # pylint: disable=too-few-public-methods + """Rewrites one-shot softmax into two-stage softmax.""" + + def __init__(self, target: tvm.target.Target) -> None: + self.target = target + + def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: + """IRModule-level transformation""" + return _Rewriter(mod, self.target).transform() + + +@mutator +class _Rewriter(PyExprMutator): # pylint: disable=abstract-method + def __init__(self, mod: IRModule, target: tvm.target.Target) -> None: + super().__init__(mod) + self.mod = mod + self.target = target + self.chunk_size = 4096 + + def transform(self) -> IRModule: + """Entry point""" + func_name = "softmax_with_temperature" + if func_name not in self.mod: + return self.mod + gv = self.mod.get_global_var(func_name) + updated_func = self.visit_expr(self.mod[gv]) + self.builder_.update_func(gv, updated_func) + return self.builder_.get() + + def visit_call_(self, call: relax.Call) -> Expr: # pylint: disable=arguments-renamed + if call.op != tvm.ir.Op.get("relax.nn.softmax"): + return call + x = call.args[0] + if call.attrs.axis not in [-1, x.struct_info.ndim - 1]: + return call + # Currently the softmax input is 3-dim, and dtype is float32. + assert x.struct_info.ndim == 3 + assert x.struct_info.dtype == "float32" + x_shape = x.struct_info.shape + new_shape = relax.ShapeExpr([x_shape[0] * x_shape[1], x_shape[2]]) + x_reshaped = relax.call_pure_packed( + "vm.builtin.reshape", + x, + new_shape, + sinfo_args=relax.TensorStructInfo(new_shape, x.struct_info.dtype), + ) + f_chunk_lse, f_softmax_with_lse = _get_lse_and_softmax_func(self.target, self.chunk_size) + chunked_lse = relax.call_tir( + self.builder_.add_func(f_chunk_lse, "chunk_lse"), + args=[x_reshaped], + out_sinfo=relax.TensorStructInfo( + (new_shape[0], (new_shape[1] + self.chunk_size - 1) // self.chunk_size), + x.struct_info.dtype, + ), + ) + softmax = relax.call_tir( + self.builder_.add_func(f_softmax_with_lse, "softmax_with_chunked_lse"), + args=[x_reshaped, chunked_lse], + out_sinfo=relax.TensorStructInfo(new_shape, x.struct_info.dtype), + ) + return relax.call_pure_packed( + "vm.builtin.reshape", softmax, x_shape, sinfo_args=x.struct_info + ) + + +def _get_lse_and_softmax_func( # pylint: disable=too-many-locals,too-many-statements + target: tvm.target.Target, chunk_size: int +): + log2e = math.log2(math.exp(1)) + + # pylint: disable=invalid-name + @T.prim_func + def chunk_lse(var_A: T.handle, var_chunked_lse: T.handle): # pylint: disable=too-many-locals + T.func_attr({"tir.noalias": T.bool(True)}) + batch_size = T.int64(is_size_var=True) + vocab_size = T.int64(is_size_var=True) + num_chunks = T.int64(is_size_var=True) + A = T.match_buffer(var_A, (batch_size, vocab_size), dtype="float32") + chunked_lse = T.match_buffer(var_chunked_lse, (batch_size, num_chunks), dtype="float32") + A_pad = T.alloc_buffer((batch_size, num_chunks, T.int64(chunk_size)), dtype="float32") + temp_max = T.alloc_buffer((batch_size, num_chunks), dtype="float32") + temp_sum = T.alloc_buffer((batch_size, num_chunks), dtype="float32") + + for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(chunk_size)): + with T.block("pad"): + v0, v1, v2 = T.axis.remap("SSS", [l0, l1, l2]) + A_pad[v0, v1, v2] = T.if_then_else( + v1 * T.int64(chunk_size) + v2 < vocab_size, + A[v0, v1 * T.int64(chunk_size) + v2], + T.min_value("float32"), + ) + for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(chunk_size)): + with T.block("max"): + v0, v1, v2 = T.axis.remap("SSR", [l0, l1, l2]) + with T.init(): + temp_max[v0, v1] = T.min_value("float32") + temp_max[v0, v1] = T.max(temp_max[v0, v1], A_pad[v0, v1, v2]) + for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(chunk_size)): + with T.block("sum_exp"): + v0, v1, v2 = T.axis.remap("SSR", [l0, l1, l2]) + with T.init(): + temp_sum[v0, v1] = T.float32(0) + temp_sum[v0, v1] += T.if_then_else( + v1 * T.int64(chunk_size) + v2 < vocab_size, + T.exp2((A_pad[v0, v1, v2] - temp_max[v0, v1]) * log2e), + T.float32(0), + ) + for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(1)): + with T.block("log"): + v0, v1, v2 = T.axis.remap("SSS", [l0, l1, l2]) + chunked_lse[v0, v1] = T.log2(temp_sum[v0, v1]) + temp_max[v0, v1] * log2e + + @T.prim_func + def softmax_with_chunked_lse(var_A: T.handle, var_chunked_lse: T.handle, var_softmax: T.handle): + T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) + batch_size = T.int64(is_size_var=True) + vocab_size = T.int64(is_size_var=True) + num_chunks = T.int64(is_size_var=True) + A = T.match_buffer(var_A, (batch_size, vocab_size), dtype="float32") + chunked_lse = T.match_buffer(var_chunked_lse, (batch_size, num_chunks), dtype="float32") + softmax = T.match_buffer(var_softmax, (batch_size, vocab_size), dtype="float32") + temp_max = T.alloc_buffer((batch_size,), dtype="float32") + temp_sum = T.alloc_buffer((batch_size,), dtype="float32") + lse = T.alloc_buffer((batch_size,), dtype="float32") + for l0, l1 in T.grid(batch_size, num_chunks): + with T.block("max"): + v0, v1 = T.axis.remap("SR", [l0, l1]) + with T.init(): + temp_max[v0] = T.min_value("float32") + temp_max[v0] = T.max(temp_max[v0], chunked_lse[v0, v1]) + for l0, l1 in T.grid(batch_size, num_chunks): + with T.block("sum_exp"): + v0, v1 = T.axis.remap("SR", [l0, l1]) + with T.init(): + temp_sum[v0] = T.float32(0) + temp_sum[v0] += T.exp2(chunked_lse[v0, v1] - temp_max[v0]) + for l0 in T.serial(0, batch_size): + with T.block("log"): + v0 = T.axis.remap("S", [l0]) + lse[v0] = T.log2(temp_sum[v0]) + temp_max[v0] + for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(chunk_size)): + with T.block("pad"): + v0, v1, v2 = T.axis.remap("SSS", [l0, l1, l2]) + if v1 * T.int64(chunk_size) + v2 < vocab_size: + softmax[v0, v1 * T.int64(chunk_size) + v2] = T.exp2( + A[v0, v1 * T.int64(chunk_size) + v2] * log2e - lse[v0] + ) + + sch = tvm.tir.Schedule(IRModule({"softmax_with_chunked_lse": softmax_with_chunked_lse})) + max_threads = get_max_num_threads_per_block(target) + TX = 32 + TY = max_threads // TX + unroll_depth = 64 + # pylint: enable=invalid-name + + sch.work_on("softmax_with_chunked_lse") + sch.compute_inline("log") + l0, l1, l2 = sch.get_loops("pad") + bx = sch.fuse(l0, l1) + sch.bind(bx, "blockIdx.x") + unroll, ty, tx = sch.split(l2, [None, TY, TX]) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + sch.annotate(unroll, ann_key="pragma_auto_unroll_max_step", ann_val=unroll_depth) + sch.annotate(unroll, ann_key="pragma_unroll_explicit", ann_val=1) + + for block_name in ["sum_exp", "max"]: + block = sch.get_block(block_name) + sch.set_scope(block, buffer_index=0, storage_scope="shared") + sch.compute_at(block, bx) + r_loop = sch.get_loops(block)[-1] + r_loop, tx = sch.split(r_loop, [None, TX]) + sch.reorder(tx, r_loop) + sch.bind(tx, "threadIdx.x") + sch.annotate(r_loop, ann_key="pragma_auto_unroll_max_step", ann_val=unroll_depth) + sch.annotate(r_loop, ann_key="pragma_unroll_explicit", ann_val=1) + + return chunk_lse, sch.mod["softmax_with_chunked_lse"] diff --git a/python/mlc_llm/conversation_template.py b/python/mlc_llm/conversation_template.py index 917e229632..1c599fa875 100644 --- a/python/mlc_llm/conversation_template.py +++ b/python/mlc_llm/conversation_template.py @@ -365,7 +365,7 @@ def get_conv_template(name: str) -> Optional[Conversation]: # RWKV World ConvTemplateRegistry.register_conv_template( Conversation( - name="rwkv-world", + name="rwkv_world", system_template=f"User: hi\n\nAssistant: {MessagePlaceholders.SYSTEM.value}", system_message=( "Hi. I am your assistant and I will provide expert full response " diff --git a/python/mlc_llm/help.py b/python/mlc_llm/help.py index 14e5cee321..86930fa5ea 100644 --- a/python/mlc_llm/help.py +++ b/python/mlc_llm/help.py @@ -152,6 +152,11 @@ The maximum number of tokens the model passes for prefill each time. It should not exceed the prefill chunk size in model config. If not specified, this defaults to the prefill chunk size in model config. +""".strip(), + "max_history_size_serve": """ +The maximum history length for rolling back the RNN state. +If unspecified, the default value is 1. +KV cache does not need this. """.strip(), "enable_tracing_serve": """ Enable Chrome Tracing for the server. diff --git a/python/mlc_llm/interface/compiler_flags.py b/python/mlc_llm/interface/compiler_flags.py index 2d0d668672..77b55c5a48 100644 --- a/python/mlc_llm/interface/compiler_flags.py +++ b/python/mlc_llm/interface/compiler_flags.py @@ -2,7 +2,6 @@ import dataclasses import enum -import re from io import StringIO from typing import Optional @@ -96,7 +95,7 @@ def _flashinfer(target) -> bool: return False arch_list = detect_cuda_arch_list(target) for arch in arch_list: - if int(re.findall(r"\d+", arch)[0]) < 80: + if arch < 80: logger.warning("flashinfer is not supported on CUDA arch < 80") return False return True diff --git a/python/mlc_llm/interface/gen_config.py b/python/mlc_llm/interface/gen_config.py index 8e617fc3d2..13f0e1215f 100644 --- a/python/mlc_llm/interface/gen_config.py +++ b/python/mlc_llm/interface/gen_config.py @@ -5,7 +5,7 @@ import re import shutil from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Literal, Optional, Union from mlc_llm.conversation_template import ConvTemplateRegistry from mlc_llm.model import Model @@ -51,7 +51,11 @@ class MLCChatConfig: # pylint: disable=too-many-instance-attributes pad_token_id: int = None bos_token_id: int = None eos_token_id: int = None + # Tokenizer configuration tokenizer_files: List[str] = dataclasses.field(default_factory=list) + # The method to post-process the token table. See + # cpp/tokenizers.h::Tokenizer::PostProcessTokenTable for details + token_table_postproc_method: Literal["byte_fallback", "byte_level"] = None # Version control version: str = VERSION @@ -129,6 +133,70 @@ def json2rwkv_tokenizer(vocab: Path, out: Path) -> None: msgpack.pack(idx2token, f) +def detect_token_table_postproc_method(output_path: Path) -> Literal["byte_fallback", "byte_level"]: + """Detect the token table postprocessing method from tokenizer.json that is found under + output_path. If not detected, use ByteFallback as default. + + Check the decoder field of the tokenizer. If it uses ByteFallback decoder, return + "byte_fallback". If it uses ByteLevel decoder, return "byte_level". Otherwise, use + ByteFallback as default. + + See also cpp/tokenizers.h::Tokenizer::PostProcessTokenTable. + """ + output_tokenizer_path = output_path / "tokenizer.json" + if not output_tokenizer_path.exists(): + logger.warning( + "Tokenizer token table postprocessing method is not detected as tokenizer.json " + "is not found, use ByteFallback (the same as LLaMA/LLaMA2) by default" + ) + return "byte_fallback" + + with output_tokenizer_path.open("r", encoding="utf-8") as in_file: + tokenizer_json = json.load(in_file) + + # Find all decoders in tokenizer.json + decoders = [] + + if "decoder" not in tokenizer_json: + logger.warning( + "Decoder field is not found in tokenizer.json, use ByteFallback (the same as " + "LLaMA/LLaMA2) as the token table postprocessing method by default" + ) + return "byte_fallback" + + decoders_json = tokenizer_json["decoder"] + assert "type" in decoders_json, "Decoder type is not specified in tokenizer.json" + if decoders_json["type"] == "Sequence": + assert "decoders" in decoders_json + decoders = decoders_json["decoders"] + else: + decoders = [decoders_json] + + is_byte_level = False + is_byte_fallback = False + + for decoder in decoders: + if decoder["type"] == "ByteLevel": + is_byte_level = True + if decoder["type"] == "ByteFallback": + is_byte_fallback = True + assert not ( + is_byte_level and is_byte_fallback + ), "Tokenizer decoder cannot have both type ByteLevel and type ByteFallback" + + if is_byte_level: + return "byte_level" + if is_byte_fallback: + return "byte_fallback" + + logger.warning( + "Neither ByteLevel nor ByteFallback decoder is detected in tokenizer.json, use " + "ByteFallback (the same as LLaMA/LLaMA2) as the token table postprocessing method " + "by default" + ) + return "byte_fallback" + + def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-branches,too-many-statements config: Path, model: Model, @@ -255,6 +323,10 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b except Exception: # pylint: disable=broad-exception-caught logger.exception("%s with the exception below. Skipping", FAILED) + # 3.4. Find the token table postprocessing method from tokenizer.json if it exists. If not + # detected, use "byte_fallback" as default. + mlc_chat_config.token_table_postproc_method = detect_token_table_postproc_method(output) + # Step 4. Load system default value mlc_chat_config.apply_defaults() # Step 5. Dump the configuration file to output directory diff --git a/python/mlc_llm/interface/serve.py b/python/mlc_llm/interface/serve.py index d0cbd4690b..40fa9fdda8 100644 --- a/python/mlc_llm/interface/serve.py +++ b/python/mlc_llm/interface/serve.py @@ -22,6 +22,7 @@ def serve( max_batch_size: Optional[int], max_total_sequence_length: Optional[int], prefill_chunk_size: Optional[int], + max_history_size: Optional[int], gpu_memory_utilization: Optional[float], speculative_mode: SpeculativeMode, spec_draft_length: int, @@ -44,6 +45,7 @@ def serve( max_batch_size=max_batch_size, max_total_sequence_length=max_total_sequence_length, prefill_chunk_size=prefill_chunk_size, + max_history_size=max_history_size, gpu_memory_utilization=gpu_memory_utilization, speculative_mode=speculative_mode, spec_draft_length=spec_draft_length, diff --git a/python/mlc_llm/json_ffi/__init__.py b/python/mlc_llm/json_ffi/__init__.py new file mode 100644 index 0000000000..8a7059153d --- /dev/null +++ b/python/mlc_llm/json_ffi/__init__.py @@ -0,0 +1,8 @@ +"""JSON FFI is a pure string based interface of MLC LLM Engine. + +We build interfacing with JSON FFI for both testing purposes +and internal use. For most python API usage, please use MLCEngine +and MLCAsyncEngine +""" + +from .engine import JSONFFIEngine diff --git a/python/mlc_llm/json_ffi/engine.py b/python/mlc_llm/json_ffi/engine.py new file mode 100644 index 0000000000..0c604a2ef3 --- /dev/null +++ b/python/mlc_llm/json_ffi/engine.py @@ -0,0 +1,310 @@ +# pylint: disable=chained-comparison,missing-docstring,too-few-public-methods,too-many-instance-attributes +# pylint: disable=too-many-arguments,too-many-locals,unused-argument,unused-variable +import json +import queue +import threading +from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Union + +import tvm + +from mlc_llm.protocol import openai_api_protocol +from mlc_llm.serve import engine_utils +from mlc_llm.serve.engine_base import ( + EngineConfig, + SpeculativeMode, + _infer_kv_cache_config, + _parse_models, + _process_model_args, + detect_device, +) +from mlc_llm.tokenizer import Tokenizer + + +# TODO(mlc-team): further minimize the JSONFFIEngine +# construction to not depend on any config and directly pass in JSON +# model defined generation config should be read from the JSONFFIEngine via Reload +def create_model_defined_generation_config( + temperature: float, top_p: float, frequency_penalty: float, presence_penalty: float +) -> tvm.runtime.Object: + return tvm.get_global_func("mlc.json_ffi.ModelDefinedGenerationConfig")( + temperature, + top_p, + frequency_penalty, + presence_penalty, + ) + + +# TODO(mlc-team): further minimize the JSONFFIEngine +# Engine config should be passed as json str +# and backend should have good default +# only model and model_lib should be mandatory +def create_json_ffi_engine_config( + conv_template: str, model_generation_cfgs: Dict[str, tvm.runtime.Object] +) -> tvm.runtime.Object: + return tvm.get_global_func("mlc.json_ffi.JSONFFIEngineConfig")( + conv_template, model_generation_cfgs + ) + + +class EngineState: + sync_queue: queue.Queue + + def get_request_stream_callback(self) -> Callable[[List[str]], None]: + # ChatCompletionStreamResponse + + def _callback(chat_completion_stream_responses_json_str: List[str]) -> None: + self._sync_request_stream_callback(chat_completion_stream_responses_json_str) + + return _callback + + def _sync_request_stream_callback( + self, chat_completion_stream_responses_json_str: List[str] + ) -> None: + # Put the delta outputs to the queue in the unblocking way. + self.sync_queue.put_nowait(chat_completion_stream_responses_json_str) + + +class JSONFFIEngine: + def __init__( # pylint: disable=too-many-arguments,too-many-locals + self, + model: str, + device: Union[str, tvm.runtime.Device] = "auto", + *, + model_lib_path: Optional[str] = None, + mode: Literal["local", "interactive", "server"] = "local", + additional_models: Optional[List[str]] = None, + max_batch_size: Optional[int] = None, + max_total_sequence_length: Optional[int] = None, + max_history_size: Optional[int] = None, + prefill_chunk_size: Optional[int] = None, + speculative_mode: SpeculativeMode = SpeculativeMode.DISABLE, + spec_draft_length: int = 4, + gpu_memory_utilization: Optional[float] = None, + ) -> None: + # - Initialize model loading info. + models = _parse_models(model, model_lib_path, additional_models) + if isinstance(device, str): + device = detect_device(device) + assert isinstance(device, tvm.runtime.Device) + ( + model_args, + model_config_paths, + self.conv_template, + ) = _process_model_args(models, device) + + # TODO(mlc-team) Remove the model config parsing, estimation below + # in favor of a simple direct passing of parameters into backend. + # JSONFFIEngine do not have to support automatic mode + # + # Instead, its config should default to interactive mode always + # and allow overrides of parameters through json config via reload + # + # This is to simplify the logic of users of JSONFFI + # since we won't have similar logics in android/iOS + # + # - Load the raw model config into dict + self.model_config_dicts = [] + for i, model_info in enumerate(models): + model_info.model_lib_path = model_args[i][1] + with open(model_config_paths[i], "r", encoding="utf-8") as file: + self.model_config_dicts.append(json.load(file)) + + # - Decide the KV cache config based on mode and user input. + ( + max_batch_size, + max_total_sequence_length, + prefill_chunk_size, + max_single_sequence_length, + max_history_size, + kv_state_kind, + ) = _infer_kv_cache_config( + mode, + max_batch_size, + max_total_sequence_length, + prefill_chunk_size, + max_history_size, + gpu_memory_utilization, + models, + device, + self.model_config_dicts, + model_config_paths, + ) + self.max_input_sequence_length = min(max_single_sequence_length, max_total_sequence_length) + + # - Initialize engine state and engine. + self.state = EngineState() + module = tvm.get_global_func("mlc.json_ffi.CreateJSONFFIEngine", allow_missing=False)() + self._ffi = { + key: module[key] + for key in [ + "init_background_engine", + "reload", + "unload", + "reset", + "chat_completion", + "abort", + "get_last_error", + "run_background_loop", + "run_background_stream_back_loop", + "exit_background_loop", + ] + } + self.tokenizer = Tokenizer(model_args[0][0]) + + self.engine_config = EngineConfig( + model=model_args[0][0], + model_lib_path=model_args[0][1], + additional_models=[model_arg[0] for model_arg in model_args[1:]], + additional_model_lib_paths=[model_arg[1] for model_arg in model_args[1:]], + kv_cache_page_size=16, + max_num_sequence=max_batch_size, + max_total_sequence_length=max_total_sequence_length, + max_single_sequence_length=max_single_sequence_length, + prefill_chunk_size=prefill_chunk_size, + max_history_size=max_history_size, + kv_state_kind=kv_state_kind, + speculative_mode=speculative_mode, + spec_draft_length=spec_draft_length, + ) + + self.json_ffi_engine_config = create_json_ffi_engine_config( + conv_template=self.conv_template.model_dump_json(), + model_generation_cfgs={ + model.model: create_model_defined_generation_config( + temperature=model_config["temperature"], + top_p=model_config["top_p"], + frequency_penalty=model_config["frequency_penalty"], + presence_penalty=model_config["presence_penalty"], + ) + for model, model_config in zip(models, self.model_config_dicts) + }, + ) + + self._ffi["init_background_engine"]( + self.json_ffi_engine_config, + self.engine_config, + device, + self.state.get_request_stream_callback(), + None, + ) + + def _background_loop(): + self._ffi["run_background_loop"]() + + def _background_stream_back_loop(): + self._ffi["run_background_stream_back_loop"]() + + # Create the background engine-driving thread and start the loop. + self._background_loop_thread: threading.Thread = threading.Thread(target=_background_loop) + self._background_stream_back_loop_thread: threading.Thread = threading.Thread( + target=_background_stream_back_loop + ) + self._background_loop_thread.start() + self._background_stream_back_loop_thread.start() + self._terminated = False + + def terminate(self): + self._terminated = True + self._ffi["exit_background_loop"]() + self._background_loop_thread.join() + self._background_stream_back_loop_thread.join() + + def chat_completion( # pylint: disable=too-many-arguments,too-many-locals + self, + *, + messages: List[Dict[str, Any]], + model: str, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, + logprobs: bool = False, + top_logprobs: int = 0, + logit_bias: Optional[Dict[int, float]] = None, + max_tokens: Optional[int] = None, + n: int = 1, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: bool = False, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, + user: Optional[str] = None, + ignore_eos: bool = False, + response_format: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + ) -> Iterator[openai_api_protocol.ChatCompletionStreamResponse]: + if request_id is None: + request_id = f"chatcmpl-{engine_utils.random_uuid()}" + + chatcmpl_generator = self._handle_chat_completion( + openai_api_protocol.ChatCompletionRequest( + messages=[ + openai_api_protocol.ChatCompletionMessage.model_validate(message) + for message in messages + ], + model=model, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + logprobs=logprobs, + top_logprobs=top_logprobs, + logit_bias=logit_bias, + max_tokens=max_tokens, + n=n, + seed=seed, + stop=stop, + stream=stream, + temperature=temperature, + top_p=top_p, + tools=( + [openai_api_protocol.ChatTool.model_validate(tool) for tool in tools] + if tools is not None + else None + ), + tool_choice=tool_choice, + user=user, + ignore_eos=ignore_eos, + response_format=( + openai_api_protocol.RequestResponseFormat.model_validate(response_format) + if response_format is not None + else None + ), + ).model_dump_json(), + n=n, + request_id=request_id, + ) + for response in chatcmpl_generator: + yield response + + def _handle_chat_completion( + self, request_json_str: str, n: int, request_id: str + ) -> Iterator[openai_api_protocol.ChatCompletionStreamResponse]: + self.state.sync_queue = queue.Queue() + num_unfinished_requests = n + + success = bool(self._ffi["chat_completion"](request_json_str, request_id)) + + try: + while num_unfinished_requests > 0: + chat_completion_stream_responses_json_str = self.state.sync_queue.get() + for chat_completion_response_json_str in chat_completion_stream_responses_json_str: + chat_completion_response = ( + openai_api_protocol.ChatCompletionStreamResponse.model_validate_json( + chat_completion_response_json_str + ) + ) + for choice in chat_completion_response.choices: + if choice.finish_reason is not None: + num_unfinished_requests -= 1 + yield chat_completion_response + except Exception as exception: # pylint: disable=broad-exception-caught + self._ffi["abort"](request_id) + raise exception + + def _test_reload(self): + self._ffi["reload"](self.engine_config) + + def _test_reset(self): + self._ffi["reset"]() + + def _test_unload(self): + self._ffi["unload"]() diff --git a/python/mlc_llm/model/gpt2/gpt2_model.py b/python/mlc_llm/model/gpt2/gpt2_model.py index 28c34353e2..ede9dc350f 100644 --- a/python/mlc_llm/model/gpt2/gpt2_model.py +++ b/python/mlc_llm/model/gpt2/gpt2_model.py @@ -28,7 +28,7 @@ class GPT2Config(ConfigBase): # pylint: disable=too-many-instance-attributes n_embd: int n_layer: int n_head: int - layer_norm_epsilon: int + layer_norm_epsilon: float n_inner: int = -1 context_window_size: int = 0 prefill_chunk_size: int = 0 diff --git a/python/mlc_llm/model/rwkv5/rwkv5_model.py b/python/mlc_llm/model/rwkv5/rwkv5_model.py index 49386720da..81c9e9aa7f 100644 --- a/python/mlc_llm/model/rwkv5/rwkv5_model.py +++ b/python/mlc_llm/model/rwkv5/rwkv5_model.py @@ -40,6 +40,7 @@ class RWKV5Config(ConfigBase): # pylint: disable=too-many-instance-attributes context_window_size: int = -1 # RWKV does not have context window limitation. prefill_chunk_size: int = 4096 num_heads: int = 0 + max_batch_size: int = 1 kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) def __post_init__(self): @@ -129,23 +130,18 @@ def wkv_func( def token_shift(state: Tensor, x: Tensor): - # x.shape = (batch, seq_len, hidden_size) - # state.shape = (batch, hidden_size) - seq_len = x.shape[1] - def _te_token_shift(state: te.Tensor, x: te.Tensor): return te.compute( x.shape, lambda b, i, j: tir.if_then_else(i == 0, state[b, j], x[b, i - 1, j]), ) - return state if seq_len == 1 else op.tensor_expr_op(_te_token_shift, "token_shift", [state, x]) + return op.tensor_expr_op(_te_token_shift, "token_shift", [state, x]) def last_token(x: Tensor): # x.shape = (batch, seq_len, hidden_size) batch, seq_len, hidden_size = x.shape - assert batch == 1 def _te_last_token(x: te.Tensor): return te.compute((batch, 1, hidden_size), lambda b, _, j: x[b, x.shape[1] - 1, j]) @@ -350,10 +346,14 @@ def to(self, dtype: Optional[str] = None): def embed(self, input_ids: Tensor): return self.model.embeddings(input_ids) - def forward(self, input_embed: Tensor, state: RNNState): + def forward( + self, input_embed: Tensor, state: RNNState, logit_positions: Optional[Tensor] = None + ): """Forward pass.""" hidden_states, state = self.model(input_embed, state) hidden_states = last_token(hidden_states) + if logit_positions is not None: + hidden_states = op.take(hidden_states, logit_positions, axis=1) logits = self.head(hidden_states) if logits.dtype != "float32": logits = logits.astype("float32") @@ -367,11 +367,27 @@ def decode(self, input_embed: Tensor, state: RNNState): """Decoding step.""" return self.forward(input_embed, state) + def batch_prefill(self, input_embeds: Tensor, logit_positions: Tensor, state: RNNState): + """Prefilling the prompt.""" + return self.forward(input_embeds, state, logit_positions=logit_positions) + + def batch_decode(self, input_embeds: Tensor, state: RNNState): + """Decoding step.""" + return self.forward(input_embeds, state) + + def batch_verify(self, input_embeds: Tensor, state: RNNState): + """Verify step.""" + return self.forward(input_embeds, state) + def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): """Softmax.""" - return op.softmax(logits / temperature, axis=-1) + return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) - def create_rnn_state(self, max_batch_size: tir.Var, max_history: tir.Var) -> Object: + def create_rnn_state( + self, + max_batch_size: tir.Var, + max_history: tir.Var, + ) -> Object: """Create RNN state.""" init_values = [ op.zeros((self.hidden_size,), dtype=self.dtype), # ATT_X @@ -386,7 +402,6 @@ def create_rnn_state(self, max_batch_size: tir.Var, max_history: tir.Var) -> Obj ) def get_default_spec(self): - batch_size = 1 mod_spec = { "embed": { "input_ids": nn.spec.Tensor(["seq_len"], "int32"), @@ -396,9 +411,7 @@ def get_default_spec(self): }, }, "prefill": { - "input_embed": nn.spec.Tensor( - [batch_size, "seq_len", self.hidden_size], self.dtype - ), + "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), "state": nn.spec.Object(object_type=RNNState), "$": { "param_mode": "packed", @@ -406,7 +419,32 @@ def get_default_spec(self): }, }, "decode": { - "input_embed": nn.spec.Tensor([batch_size, 1, self.hidden_size], self.dtype), + "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), + "state": nn.spec.Object(object_type=RNNState), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_prefill": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), + "state": nn.spec.Object(object_type=RNNState), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_decode": { + "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), + "state": nn.spec.Object(object_type=RNNState), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_verify": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), "state": nn.spec.Object(object_type=RNNState), "$": { "param_mode": "packed", @@ -414,8 +452,8 @@ def get_default_spec(self): }, }, "softmax_with_temperature": { - "logits": nn.spec.Tensor([batch_size, 1, "vocab_size"], "float32"), - "temperature": nn.spec.Tensor([], "float32"), + "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), + "temperature": nn.spec.Tensor(["batch_size"], "float32"), "$": { "param_mode": "none", "effect_mode": "none", diff --git a/python/mlc_llm/model/rwkv6/rwkv6_model.py b/python/mlc_llm/model/rwkv6/rwkv6_model.py index 0e1887310d..a8faf48a6b 100644 --- a/python/mlc_llm/model/rwkv6/rwkv6_model.py +++ b/python/mlc_llm/model/rwkv6/rwkv6_model.py @@ -40,6 +40,7 @@ class RWKV6Config(ConfigBase): # pylint: disable=too-many-instance-attributes context_window_size: int = -1 # RWKV does not have context window limitation. prefill_chunk_size: int = 4096 num_heads: int = 0 + max_batch_size: int = 1 kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) def __post_init__(self): @@ -126,20 +127,17 @@ def wkv_func( def token_shift(state: Tensor, x: Tensor): - seq_len = x.shape[1] - def _te_token_shift(state: te.Tensor, x: te.Tensor): return te.compute( x.shape, lambda b, i, j: tir.if_then_else(i == 0, state[b, j], x[b, i - 1, j]), ) - return state if seq_len == 1 else op.tensor_expr_op(_te_token_shift, "token_shift", [state, x]) + return op.tensor_expr_op(_te_token_shift, "token_shift", [state, x]) def last_token(x: Tensor): batch, seq_len, hidden_size = x.shape - assert batch == 1 def _te_last_token(x: te.Tensor): return te.compute((batch, 1, hidden_size), lambda b, _, j: x[b, x.shape[1] - 1, j]) @@ -390,10 +388,14 @@ def to(self, dtype: Optional[str] = None): def embed(self, input_ids: Tensor): return self.model.embeddings(input_ids) - def forward(self, input_embed: Tensor, state: RNNState): + def forward( + self, input_embed: Tensor, state: RNNState, logit_positions: Optional[Tensor] = None + ): """Forward pass.""" hidden_states, state = self.model(input_embed, state) hidden_states = last_token(hidden_states) + if logit_positions is not None: + hidden_states = op.take(hidden_states, logit_positions, axis=1) logits = self.head(hidden_states) if logits.dtype != "float32": logits = logits.astype("float32") @@ -407,11 +409,27 @@ def decode(self, input_embed: Tensor, state: RNNState): """Decoding step.""" return self.forward(input_embed, state) + def batch_prefill(self, input_embeds: Tensor, logit_positions: Tensor, state: RNNState): + """Prefilling the prompt.""" + return self.forward(input_embeds, state, logit_positions=logit_positions) + + def batch_decode(self, input_embeds: Tensor, state: RNNState): + """Decoding step.""" + return self.forward(input_embeds, state) + + def batch_verify(self, input_embeds: Tensor, state: RNNState): + """Verify step.""" + return self.forward(input_embeds, state) + def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): """Softmax.""" - return op.softmax(logits / temperature, axis=-1) + return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) - def create_rnn_state(self, max_batch_size: tir.Var, max_history: tir.Var) -> Object: + def create_rnn_state( + self, + max_batch_size: tir.Var, + max_history: tir.Var, + ) -> Object: """Create RNN state.""" init_values = [ op.zeros((self.hidden_size,), dtype=self.dtype), # ATT_X @@ -426,7 +444,6 @@ def create_rnn_state(self, max_batch_size: tir.Var, max_history: tir.Var) -> Obj ) def get_default_spec(self): - batch_size = 1 mod_spec = { "embed": { "input_ids": nn.spec.Tensor(["seq_len"], "int32"), @@ -436,9 +453,7 @@ def get_default_spec(self): }, }, "prefill": { - "input_embed": nn.spec.Tensor( - [batch_size, "seq_len", self.hidden_size], self.dtype - ), + "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), "state": nn.spec.Object(object_type=RNNState), "$": { "param_mode": "packed", @@ -446,7 +461,32 @@ def get_default_spec(self): }, }, "decode": { - "input_embed": nn.spec.Tensor([batch_size, 1, self.hidden_size], self.dtype), + "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), + "state": nn.spec.Object(object_type=RNNState), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_prefill": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), + "state": nn.spec.Object(object_type=RNNState), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_decode": { + "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), + "state": nn.spec.Object(object_type=RNNState), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_verify": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), "state": nn.spec.Object(object_type=RNNState), "$": { "param_mode": "packed", @@ -454,8 +494,8 @@ def get_default_spec(self): }, }, "softmax_with_temperature": { - "logits": nn.spec.Tensor([batch_size, 1, "vocab_size"], "float32"), - "temperature": nn.spec.Tensor([], "float32"), + "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), + "temperature": nn.spec.Tensor(["batch_size"], "float32"), "$": { "param_mode": "none", "effect_mode": "none", diff --git a/python/mlc_llm/nn/kv_cache.py b/python/mlc_llm/nn/kv_cache.py index 4a058c6e03..e4cbf1c047 100644 --- a/python/mlc_llm/nn/kv_cache.py +++ b/python/mlc_llm/nn/kv_cache.py @@ -887,7 +887,7 @@ def _attention_decode( THREAD_LIMIT = 512 TILE_SIZE_PER_BDX = 2 if target.kind.name == "opencl" and "android" in str(target.host): - THREAD_LIMIT = 64 + THREAD_LIMIT = 256 TILE_SIZE_PER_BDX = 1 max_num_threads_per_block = get_max_num_threads_per_block(target) thread_limit = min(max_num_threads_per_block, THREAD_LIMIT) diff --git a/python/mlc_llm/op/__init__.py b/python/mlc_llm/op/__init__.py index b5db353a3b..850312a8a7 100644 --- a/python/mlc_llm/op/__init__.py +++ b/python/mlc_llm/op/__init__.py @@ -6,3 +6,4 @@ from .extern import configure, enable, get_store from .ft_gemm import faster_transformer_dequantize_gemm from .position_embedding import llama_rope +from .top_p_pivot import top_p_pivot, top_p_renorm diff --git a/python/mlc_llm/op/batch_spec_verify.py b/python/mlc_llm/op/batch_spec_verify.py index 9cdbe2be21..d1a57fc71c 100644 --- a/python/mlc_llm/op/batch_spec_verify.py +++ b/python/mlc_llm/op/batch_spec_verify.py @@ -51,7 +51,7 @@ def batch_spec_verify(vocab_size): token_tree_parent_ptr: Current parent ptr state """ - TX = 128 + TX = 1024 def _var(dtype="int32"): return T.alloc_buffer((1,), dtype, scope="local") @@ -142,7 +142,6 @@ def _func( model_prob_local[0] = model_probs[parent_ptr[0], k] draft_prob_local[0] = draft_probs[child_ptr[0], k] model_prob_local[0] = T.max(model_prob_local[0] - draft_prob_local[0], 0.0) - model_probs[parent_ptr[0], k] = model_prob_local[0] psum[0] += model_prob_local[0] with T.block("block_cross_thread"): @@ -155,13 +154,21 @@ def _func( ) T.tvm_thread_allreduce(T.uint32(1), psum[0], True, t0[0], tx, dtype="handle") - # renormalize - for i in T.serial(T.ceildiv(vocab_size, TX)): - k = T.meta_var(i * TX + tx) - if k < vocab_size: - model_probs[parent_ptr[0], k] = model_probs[parent_ptr[0], k] / t0[0] - - child_ptr[0] = token_tree_next_sibling[child_ptr[0]] + if t0[0] < 1e-7: + # accept the proposal, we move to child + parent_ptr[0] = child_ptr[0] + child_ptr[0] = token_tree_first_child[child_ptr[0]] + else: + # renormalize + for i in T.serial(T.ceildiv(vocab_size, TX)): + k = T.meta_var(i * TX + tx) + if k < vocab_size: + model_prob_local[0] = model_probs[parent_ptr[0], k] + draft_prob_local[0] = draft_probs[child_ptr[0], k] + model_prob_local[0] = T.max(model_prob_local[0] - draft_prob_local[0], 0.0) + model_probs[parent_ptr[0], k] = model_prob_local[0] / t0[0] + + child_ptr[0] = token_tree_next_sibling[child_ptr[0]] if tx == 0: token_tree_parent_ptr[b] = parent_ptr[0] diff --git a/python/mlc_llm/op/top_p_pivot.py b/python/mlc_llm/op/top_p_pivot.py new file mode 100644 index 0000000000..9c97959bff --- /dev/null +++ b/python/mlc_llm/op/top_p_pivot.py @@ -0,0 +1,315 @@ +"""Operators for choosing the pivot to cut-off top-p percentile """ + +import tvm +from tvm.script import tir as T + +# mypy: disable-error-code="attr-defined,valid-type,name-defined" +# pylint: disable=too-many-locals,invalid-name,too-many-arguments,unnecessary-lambda +# pylint: disable=too-many-statements,line-too-long,too-many-nested-blocks,too-many-branches + + +def top_p_pivot(pN): + """Top-p pivot function. This function finds the pivot to cut-off top-p percentile. + + A valide pivot should satisfy the following conditions: + - lsum >= top_p + - top_p > lsum - cmin * lmin + where lsum is the sum of elements that are larger or equal to the pivot, + lmin is the minimum elements that is larger or equal to the pivot, + cmin is the count of elements that are equal to lmin, + + Parameters + ---------- + prob: + The probability vector + + top_p_global: + The top-p threshold + + init_pivots: + The initial pivot candidates + + final_pivot: + The final pivot to cut-off top-p percentile + """ + TX = 1024 + K = 32 + eps_LR = 1e-7 + + def _var(dtype="int32"): + return T.alloc_buffer((1,), dtype, scope="local") + + def valid(lsum, lmin, cmin, top_p): + return tvm.tir.all(lsum >= top_p, top_p > lsum - cmin * lmin) + + # fmt: off + @T.prim_func(private=True) + def _func( + var_prob: T.handle, + top_p_global: T.buffer([1], dtype="float32"), + var_init_pivots: T.handle, + var_final_pivot: T.handle, + var_final_lsum: T.handle, + ): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": True}) + B = T.int32() + N = T.int32() + prob = T.match_buffer(var_prob, (B, N,), "float32") + init_pivots = T.match_buffer(var_init_pivots, (pN,), "float32") + final_pivot = T.match_buffer(var_final_pivot, (B,), "float32") + final_lsum = T.match_buffer(var_final_lsum, (B,), "float32") + + with T.block("kernel"): + pivot = T.alloc_buffer((pN,), "float32", scope="local") + top_p = _var("float32") + + L = T.alloc_buffer((1,), "float32", scope="shared") + R = T.alloc_buffer((1,), "float32", scope="shared") + L_local = _var("float32") + R_local = _var("float32") + + q = _var("float32") + lsum = T.alloc_buffer((pN,), "float32", scope="local") + lmin_broadcast = T.alloc_buffer((1), "float32", scope="shared") + lmin_broadcast_local = _var("float32") + lmin = T.alloc_buffer((pN,), "float32", scope="local") + cmin = T.alloc_buffer((pN,), "int32", scope="local") + total_sum = _var("float32") + + it = _var("int32") + es_local = _var("bool") + es = T.alloc_buffer((1,), "bool", scope="shared") + find_pivot_local = _var("bool") + find_pivot = T.alloc_buffer((1,), "bool", scope="shared") + + total_sum_reduce = _var("float32") + lsum_reduce = _var("float32") + lmin_reduce = _var("float32") + cmin_reduce = _var("int32") + + for _bx in T.thread_binding(0, B, thread="blockIdx.x"): + for _tx in T.thread_binding(0, TX, thread="threadIdx.x"): + with T.block("CTA"): + b, tx = T.axis.remap("SS", [_bx, _tx]) + + top_p[0] = top_p_global[0] + + if tx == 0: + # leader thread initializes L, R + L[0] = 1.0 - top_p[0] + R[0] = eps_LR + find_pivot[0] = False + T.tvm_storage_sync("shared") + + L_local[0] = L[0] + R_local[0] = R[0] + for i in T.unroll(0, pN): + # pivots are in descending order + pivot[i] = init_pivots[i] + find_pivot_local[0] = False + + while T.tvm_thread_invariant( + L_local[0] - R_local[0] > eps_LR + and T.Not(find_pivot_local[0]) + ): + # sync before each iteration + T.tvm_storage_sync("shared") + + ### get lsum, lmin, total_sum + for pidx in T.unroll(0, pN): + lsum[pidx] = 0.0 + lmin[pidx] = 1.0 + cmin[pidx] = 0 + total_sum[0] = 0.0 + it[0] = 0 + es_local[0] = False + while it[0] < T.ceildiv(N, TX) and T.Not(es_local[0]): + idx = T.meta_var(it[0] * TX + tx) + q[0] = T.if_then_else(idx < N, prob[b, idx], 0.0) + total_sum[0] += q[0] + for pidx in T.unroll(0, pN): + if q[0] >= pivot[pidx]: + lsum[pidx] += q[0] + if lmin[pidx] > q[0]: + lmin[pidx] = q[0] + cmin[pidx] = 1 + elif lmin[pidx] == q[0]: + cmin[pidx] += 1 + it[0] += 1 + + # early stop every K iterations + if it[0] % K == 0: + # reduce total_sum over tx + # T.tvm_storage_sync("shared") + with T.block("block_cross_thread"): + T.reads(total_sum[0]) + T.writes(total_sum_reduce[0]) + T.attr( + T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ) + T.tvm_thread_allreduce(T.uint32(1), total_sum[0], True, total_sum_reduce[0], tx, dtype="handle") + # T.tvm_storage_sync("shared") + + if tx == 0: + # leader thread checks if we can stop early + es[0] = 1 - total_sum_reduce[0] < pivot[pN - 1] + T.tvm_storage_sync("shared") + es_local[0] = es[0] + + T.tvm_storage_sync("shared") + + # reduce lsum, lmin, cmin, over tx + for pidx in T.serial(0, pN): + # reduce lsum over tx for pivot[j] + with T.block("block_cross_thread"): + T.reads(lsum[pidx]) + T.writes(lsum_reduce[0]) + T.attr( + T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ) + T.tvm_thread_allreduce(T.uint32(1), lsum[pidx], True, lsum_reduce[0], tx, dtype="handle") + + # reduce lmin over tx for pivot[j] + with T.block("block_cross_thread"): + T.reads(lmin[pidx]) + T.writes(lmin_reduce[0]) + T.attr( + T.comm_reducer(lambda x0, y0: T.min(x0, y0), [T.float32(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ) + T.tvm_thread_allreduce(T.uint32(1), lmin[pidx], True, lmin_reduce[0], tx, dtype="handle") + + if tx == 0: + # broadcast lmin to all threads + lmin_broadcast[0] = lmin_reduce[0] + T.tvm_storage_sync("shared") + lmin_broadcast_local[0] = lmin_broadcast[0] + if lmin[pidx] > lmin_broadcast_local[0]: + cmin[pidx] = 0 + if tx == 0: + # only the leader thread updates lsum, lmin + lsum[pidx] = lsum_reduce[0] + lmin[pidx] = lmin_reduce[0] + + # reduce cmin over tx for pivot[j] + with T.block("block_cross_thread"): + T.reads(cmin[pidx]) + T.writes(cmin_reduce[0]) + T.attr( + T.comm_reducer(lambda x0, y0: x0 + y0, [T.int32(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ) + T.tvm_thread_allreduce(T.uint32(1), cmin[pidx], True, cmin_reduce[0], tx, dtype="handle") + + if tx == 0: + # only the leader thread updates cmin + cmin[pidx] = cmin_reduce[0] + + T.tvm_storage_sync("shared") + + if tx == 0: + # leader thread checks if we have found the pivot, or updates L, R + it[0] = 0 + while it[0] < pN and T.Not(find_pivot_local[0]): + pidx = T.meta_var(it[0]) + if valid(lsum[pidx], lmin[pidx], cmin[pidx], top_p[0]): + find_pivot[0] = True + find_pivot_local[0] = True + # write back the pivot and lsum + final_pivot[b] = pivot[pidx] + final_lsum[b] = lsum[pidx] + elif lsum[pidx] - lmin[pidx] * cmin[pidx] >= top_p[0]: + R[0] = pivot[pidx] + elif lsum[pidx] < top_p[0]: + L[0] = pivot[pidx] + it[0] += 1 + + T.tvm_storage_sync("shared") + + L_local[0] = L[0] + R_local[0] = R[0] + find_pivot_local[0] = find_pivot[0] + # new pivots for next iteration + # uniform spacing between L and R + for pidx in T.unroll(0, pN): + pivot[pidx] = L[0] - (pidx + 1) * (L_local[0] - R_local[0]) / (pN + 1) + + if tx == 0: + # leader thread writes back the pivot + if T.Not(find_pivot_local[0]): + final_pivot[b] = -1e5 + # fmt: on + + return _func + + +def top_p_renorm(): + """Top-p renormalization function. This function renormalizes the probability vector. + + Given the pivot, the probability vector is renormalized as follows: + - if prob >= pivot, renorm_prob = prob / lsum + - otherwise, renorm_prob = 0 + + Parameters + ---------- + prob: + The probability vector + + final_pivot: + The final pivot to cut-off top-p percentile + + final_lsum: + The sum of elements that are larger or equal to the pivot + + renorm_prob: + The renormalized probability vector + """ + TX = 1024 + CTA_COUNT = 512 + + def _var(dtype="int32"): + return T.alloc_buffer((1,), dtype, scope="local") + + # fmt: off + @T.prim_func(private=True) + def _func( + var_prob: T.handle, + var_final_pivot: T.handle, + var_final_lsum: T.handle, + var_renorm_prob: T.handle, + ): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": True}) + B = T.int32() + N = T.int32() + prob = T.match_buffer(var_prob, (B, N,), "float32") + final_pivot = T.match_buffer(var_final_pivot, (B,), "float32") + final_lsum = T.match_buffer(var_final_lsum, (B,), "float32") + renorm_prob = T.match_buffer(var_renorm_prob, (B, N,), "float32") + + with T.block("kernel"): + pivot = _var("float32") + lsum = _var("float32") + BX = T.meta_var(T.ceildiv(CTA_COUNT, B)) + + for _by in T.thread_binding(0, B, thread="blockIdx.y"): + for _bx in T.thread_binding(0, BX, thread="blockIdx.x"): + for _tx in T.thread_binding(0, TX, thread="threadIdx.x"): + with T.block("CTA"): + by, bx, tx = T.axis.remap("SSS", [_by, _bx, _tx]) + + pivot[0] = final_pivot[by] + lsum[0] = final_lsum[by] + + for i in T.serial(T.ceildiv(N, BX * TX)): + idx = T.meta_var(i * BX * TX + bx * TX + tx) + if idx < N: + renorm_prob[by, idx] = T.if_then_else(prob[by, idx] >= pivot[0], prob[by, idx] / lsum[0], 0.0) + # fmt: on + + return _func diff --git a/python/mlc_llm/protocol/openai_api_protocol.py b/python/mlc_llm/protocol/openai_api_protocol.py index d6ce4a4fcb..4a5168f971 100644 --- a/python/mlc_llm/protocol/openai_api_protocol.py +++ b/python/mlc_llm/protocol/openai_api_protocol.py @@ -223,7 +223,7 @@ class ChatCompletionRequest(BaseModel): @classmethod def check_penalty_range(cls, penalty_value: float) -> float: """Check if the penalty value is in range [-2, 2].""" - if penalty_value < -2 or penalty_value > 2: + if penalty_value and (penalty_value < -2 or penalty_value > 2): raise ValueError("Penalty value should be in range [-2, 2].") return penalty_value diff --git a/python/mlc_llm/serve/config.py b/python/mlc_llm/serve/config.py index 60e4eca8c5..6b808ac37b 100644 --- a/python/mlc_llm/serve/config.py +++ b/python/mlc_llm/serve/config.py @@ -128,6 +128,13 @@ def from_json(json_str: str) -> "GenerationConfig": return GenerationConfig(**json.loads(json_str)) +class KVStateKind(enum.IntEnum): # pylint: disable=too-few-public-methods + """Possible kinds of KV state.""" + + ATTENTION = 0 + RNNSTATE = 1 + + class SpeculativeMode(enum.IntEnum): """The speculative mode.""" @@ -157,9 +164,6 @@ class EngineConfig(tvm.runtime.Object): additional_model_lib_paths : List[str] The path to the additional models' libraries. - device : tvm.runtime.Device - The device where the models run. - kv_cache_page_size : int The number of consecutive tokens handled in each page in paged KV cache. @@ -177,6 +181,12 @@ class EngineConfig(tvm.runtime.Object): prefill_chunk_size : int The maximum total sequence length in a prefill. + max_history_size: int + The maximum history size for RNN state to rool back. + + kv_state_kind: KVStateKind + The kind of cache. + speculative_mode : SpeculativeMode The speculative mode. @@ -190,12 +200,13 @@ def __init__( # pylint: disable=too-many-arguments model_lib_path: str, additional_models: List[str], additional_model_lib_paths: List[str], - device: tvm.runtime.Device, kv_cache_page_size: int, max_num_sequence: int, max_total_sequence_length: int, max_single_sequence_length: int, prefill_chunk_size: int, + max_history_size: int, + kv_state_kind: KVStateKind, speculative_mode: SpeculativeMode, spec_draft_length: int, ) -> None: @@ -205,12 +216,13 @@ def __init__( # pylint: disable=too-many-arguments model_lib_path, additional_models, additional_model_lib_paths, - device, kv_cache_page_size, max_num_sequence, max_total_sequence_length, max_single_sequence_length, prefill_chunk_size, + max_history_size, + kv_state_kind, speculative_mode, spec_draft_length, ) diff --git a/python/mlc_llm/serve/engine.py b/python/mlc_llm/serve/engine.py index d9721b4864..413c856db1 100644 --- a/python/mlc_llm/serve/engine.py +++ b/python/mlc_llm/serve/engine.py @@ -816,6 +816,9 @@ class AsyncMLCEngine(engine_base.MLCEngineBase): It should not exceed the prefill chunk size in model config. If not specified, this defaults to the prefill chunk size in model config. + max_history_size : Optional[int] + The maximum history for RNN state. + gpu_memory_utilization : Optional[float] A number in (0, 1) denoting the fraction of GPU memory used by the server in total. It is used to infer to maximum possible KV cache capacity. @@ -846,6 +849,7 @@ def __init__( # pylint: disable=too-many-arguments max_batch_size: Optional[int] = None, max_total_sequence_length: Optional[int] = None, prefill_chunk_size: Optional[int] = None, + max_history_size: Optional[int] = None, gpu_memory_utilization: Optional[float] = None, speculative_mode: SpeculativeMode = SpeculativeMode.DISABLE, spec_draft_length: int = 4, @@ -861,6 +865,7 @@ def __init__( # pylint: disable=too-many-arguments max_batch_size=max_batch_size, max_total_sequence_length=max_total_sequence_length, prefill_chunk_size=prefill_chunk_size, + max_history_size=max_history_size, gpu_memory_utilization=gpu_memory_utilization, speculative_mode=speculative_mode, spec_draft_length=spec_draft_length, @@ -1392,6 +1397,7 @@ def __init__( # pylint: disable=too-many-arguments max_batch_size: Optional[int] = None, max_total_sequence_length: Optional[int] = None, prefill_chunk_size: Optional[int] = None, + max_history_size: Optional[int] = None, gpu_memory_utilization: Optional[float] = None, speculative_mode: SpeculativeMode = SpeculativeMode.DISABLE, spec_draft_length: int = 4, @@ -1407,6 +1413,7 @@ def __init__( # pylint: disable=too-many-arguments max_batch_size=max_batch_size, max_total_sequence_length=max_total_sequence_length, prefill_chunk_size=prefill_chunk_size, + max_history_size=max_history_size, gpu_memory_utilization=gpu_memory_utilization, speculative_mode=speculative_mode, spec_draft_length=spec_draft_length, diff --git a/python/mlc_llm/serve/engine_base.py b/python/mlc_llm/serve/engine_base.py index 7b2ede60b2..65b41a66ac 100644 --- a/python/mlc_llm/serve/engine_base.py +++ b/python/mlc_llm/serve/engine_base.py @@ -20,7 +20,12 @@ from mlc_llm.protocol import openai_api_protocol, protocol_utils from mlc_llm.protocol.conversation_protocol import Conversation from mlc_llm.serve import data, engine_utils -from mlc_llm.serve.config import EngineConfig, GenerationConfig, SpeculativeMode +from mlc_llm.serve.config import ( + EngineConfig, + GenerationConfig, + KVStateKind, + SpeculativeMode, +) from mlc_llm.serve.event_trace_recorder import EventTraceRecorder from mlc_llm.streamer import TextStreamer from mlc_llm.support import logging @@ -121,7 +126,7 @@ def _convert_model_info(model: ModelInfo) -> Tuple[str, str]: return model_args, config_file_paths, conversation -def _estimate_mem_usage_and_max_total_sequence_length( # pylint: disable=too-many-locals,too-many-arguments +def _estimate_mem_usage_and_max_total_sequence_length_for_kv_cache( # pylint: disable=too-many-locals,too-many-arguments models: List[ModelInfo], device: tvm.runtime.Device, model_config_paths: List[str], @@ -240,6 +245,90 @@ def _estimate_mem_usage_and_max_total_sequence_length( # pylint: disable=too-ma ) +def _estimate_mem_usage_and_max_history_size_for_rnn_state( # pylint: disable=too-many-arguments, too-many-locals, unused-argument + models: List[ModelInfo], + device: tvm.runtime.Device, + model_config_paths: List[str], + model_config_dicts: List[Dict[str, Any]], + max_num_sequence: int, + gpu_memory_utilization: Optional[float], +) -> Tuple[float, float, float, int]: + # Get single-card GPU size. + gpu_size_bytes = device.total_global_memory + if gpu_size_bytes is None: + raise ValueError("Cannot read total GPU global memory from device.") + if gpu_memory_utilization is None: + gpu_memory_utilization = 0.90 + + rnn_state_base_bytes = 0.0 # the memory usage for rnn state when history = 1 + param_bytes = 0.0 + temp_func_bytes = 0.0 + model_workspace_bytes = 0.0 + logit_processor_workspace_bytes = 0.0 + for model, model_config_path, model_config_dict in zip( + models, model_config_paths, model_config_dicts + ): + # Read metadata for the parameter size and the temporary memory size. + cmd = [ + sys.executable, + "-m", + "mlc_llm.cli.model_metadata", + model.model_lib_path, + "--print-memory-usage-in-json", + "--mlc-chat-config", + model_config_path, + ] + usage_str = subprocess.check_output(cmd, universal_newlines=True) + usage_json = json.loads(usage_str) + param_bytes += usage_json["params_bytes"] + temp_func_bytes = max(temp_func_bytes, usage_json["temp_func_bytes"]) + + model_config = model_config_dict["model_config"] + vocab_size = model_config_dict["vocab_size"] + head_size = model_config["head_size"] + num_heads = model_config["num_heads"] + num_layers = model_config["num_hidden_layers"] + hidden_size = model_config["hidden_size"] + prefill_chunk_size = model_config["prefill_chunk_size"] + logit_processor_workspace_bytes += ( + max_num_sequence * 20 + max_num_sequence * vocab_size * 16.125 + ) + + model_workspace_bytes += ( + prefill_chunk_size * 4 + + max_num_sequence * 4 + + (prefill_chunk_size * 2 + max_num_sequence) * hidden_size * 2 + ) + + rnn_state_base_bytes += ( + max_num_sequence * hidden_size * num_layers * 2 * 2 + + max_num_sequence * num_heads * head_size * head_size * num_layers * 2 + ) + + max_history_size = int( + ( + gpu_size_bytes * gpu_memory_utilization + - logit_processor_workspace_bytes + - model_workspace_bytes + - param_bytes + - temp_func_bytes + ) + / rnn_state_base_bytes + ) + if max_history_size < 1: + raise ValueError( + f"Memory required by models may be larger than available GPU memory " + f"size {gpu_size_bytes * gpu_memory_utilization} bytes." + ) + + return ( + param_bytes, + model_workspace_bytes + logit_processor_workspace_bytes + temp_func_bytes, + rnn_state_base_bytes, + max_history_size, + ) + + def _get_model_config_limit(model_config_dicts: List[Dict[str, Any]]) -> Tuple[int, int, int]: """Read the model config dictionaries, and return the maximum single sequence length the models can support, the maximum prefill chunk @@ -294,7 +383,7 @@ def _get_model_config_limit(model_config_dicts: List[Dict[str, Any]]) -> Tuple[i return model_max_single_sequence_length, model_max_prefill_chunk_size, model_max_batch_size -def _infer_kv_cache_config( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements +def _infer_kv_cache_config_for_kv_cache( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements mode: Literal["local", "interactive", "server"], max_batch_size: Optional[int], max_total_sequence_length: Optional[int], @@ -304,12 +393,13 @@ def _infer_kv_cache_config( # pylint: disable=too-many-arguments,too-many-local device: tvm.runtime.Device, model_config_dicts: List[Dict[str, Any]], model_config_paths: List[str], -) -> Tuple[int, int, int, int]: +) -> Tuple[int, int, int, KVStateKind, int]: """Initialize the KV cache config with user input and GPU memory usage estimation. The returned four integers are: - max_batch_size - max_total_sequence_length - prefill_chunk_size + - kv_state_kind - model_max_single_sequence_length """ ( @@ -323,7 +413,7 @@ def infer_args_under_mode( max_batch_size: Optional[int], max_total_sequence_length: Optional[int], prefill_chunk_size: Optional[int], - ) -> Tuple[Tuple[int, int, int], List[float]]: + ) -> Tuple[Tuple[int, int, int, KVStateKind], List[float]]: logging_msg = "" # - max_batch_size if max_batch_size is None: @@ -343,7 +433,7 @@ def infer_args_under_mode( kv_aux_workspace_bytes, temp_workspace_bytes, model_max_total_sequence_length, - ) = _estimate_mem_usage_and_max_total_sequence_length( + ) = _estimate_mem_usage_and_max_total_sequence_length_for_kv_cache( models, device, model_config_paths, @@ -400,7 +490,12 @@ def infer_args_under_mode( # - Construct the KV cache config # - Estimate total GPU memory usage on single GPU. - return (max_batch_size, max_total_sequence_length, prefill_chunk_size), [ + return ( + max_batch_size, + max_total_sequence_length, + prefill_chunk_size, + KVStateKind.ATTENTION, + ), [ total_mem_usage_except_kv_cache + max_total_sequence_length * kv_bytes_per_token, model_params_bytes, kv_bytes_per_token * max_total_sequence_length + kv_aux_workspace_bytes, @@ -462,6 +557,189 @@ def infer_args_under_mode( return *kv_cache_config, model_max_single_sequence_length +def _infer_kv_cache_config_for_rnn_state( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements + mode: Literal["local", "interactive", "server"], + max_batch_size: Optional[int], + max_total_sequence_length: Optional[int], + prefill_chunk_size: Optional[int], + max_history_size: Optional[int], + gpu_memory_utilization: Optional[float], + models: List[ModelInfo], + device: tvm.runtime.Device, + model_config_dicts: List[Dict[str, Any]], + model_config_paths: List[str], +) -> Tuple[int, int, int, KVStateKind, int]: + """Initialize the RNN state config with user input and GPU memory usage estimation. + The returned four integers are: + - max_batch_size + - max_total_sequence_length + - prefill_chunk_size + - kv_state_kind + - max_history_size + """ + logging_msg = "" + prefill_chunk_size = 0 + + if prefill_chunk_size is None: + prefill_chunk_size = min( + config["prefill_chunk_size"] if "prefill_chunk_size" in config else 4096 + for config in model_config_dicts + ) + logging_msg += f"prefill chunk size is set to {prefill_chunk_size}. " + else: + logging_msg += f"prefill chunk size {prefill_chunk_size} is specified by user. " + if max_batch_size is None: + max_batch_size = 1 if mode == "interactive" else 4 + logging_msg += f"max batch size is set to {max_batch_size}, " + else: + logging_msg += f"max batch size {max_batch_size} is specified by user, " + + if mode == "local": + logging_msg += ( + "We choose small max batch size and RNN state capacity to use less GPU memory." + ) + elif mode == "interactive": + logging_msg += "We fix max batch size to 1 for interactive single sequence use." + else: + logging_msg += ( + "We use as much GPU memory as possible (within the" " limit of gpu_memory_utilization)." + ) + logger.info('Under mode "%s", %s', mode, logging_msg) + + ( + model_param_bytes, + model_temp_bytes, + model_rnn_state_base_bytes, + model_max_history_size, + ) = _estimate_mem_usage_and_max_history_size_for_rnn_state( + models, + device, + model_config_paths, + model_config_dicts, + max_batch_size, + gpu_memory_utilization, + ) + if max_history_size is None: + max_history_size = model_max_history_size + else: + max_history_size = min(max_history_size, model_max_history_size) + max_total_sequence_length = 32768 + prefill_chunk_size = 0 + kind = KVStateKind.RNNSTATE + + logger.info( + "%s: %.2f MB (Parameters: %.2f MB. RNNState: %.2f MB. Temporary buffer: %.2f MB). " + "The actual usage might be slightly larger than the estimated number.", + green("Estimated total single GPU memory usage"), + (model_param_bytes + model_temp_bytes + model_rnn_state_base_bytes) / 1024 / 1024, + model_param_bytes / 1024 / 1024, + max_history_size * model_rnn_state_base_bytes / 1024 / 1024, + model_temp_bytes / 1024 / 1024, + ) + + return ( + max_batch_size, + max_total_sequence_length, + prefill_chunk_size, + kind, + max_history_size, + ) + + +def _infer_kv_cache_config( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements + mode: Literal["local", "interactive", "server"], + max_batch_size: Optional[int], + max_total_sequence_length: Optional[int], + prefill_chunk_size: Optional[int], + max_history_size: Optional[int], + gpu_memory_utilization: Optional[float], + models: List[ModelInfo], + device: tvm.runtime.Device, + model_config_dicts: List[Dict[str, Any]], + model_config_paths: List[str], +) -> Tuple[int, int, int, int, int, KVStateKind]: + """Initialize the cache config with user input and GPU memory usage estimation. + The returned four integers are: + - max_batch_size + - max_total_sequence_length + - prefill_chunk_size + - max_single_sequence_length + - max_history_size + - kv_state_kind + """ + if all("rwkv" not in model.model for model in models): + ( + max_batch_size, + max_total_sequence_length, + prefill_chunk_size, + kv_state_kind, + max_single_sequence_length, + ) = _infer_kv_cache_config_for_kv_cache( + mode, + max_batch_size, + max_total_sequence_length, + prefill_chunk_size, + gpu_memory_utilization, + models, + device, + model_config_dicts, + model_config_paths, + ) + max_history_size = 0 # KV cache doesn't need this + elif all("rwkv" in model.model for model in models): + ( + max_batch_size, + max_total_sequence_length, + prefill_chunk_size, + kv_state_kind, + max_history_size, + ) = _infer_kv_cache_config_for_rnn_state( + mode, + max_batch_size, + max_total_sequence_length, + prefill_chunk_size, + max_history_size, + gpu_memory_utilization, + models, + device, + model_config_dicts, + model_config_paths, + ) + max_single_sequence_length = max_total_sequence_length # RNN state doesn't need this + else: + raise ValueError("The models should be either all KV cache models or all RNN state models.") + return ( + max_batch_size, + max_total_sequence_length, + prefill_chunk_size, + max_single_sequence_length, + max_history_size, + kv_state_kind, + ) + + +def _infer_generation_config( + model_config_dicts: List[Dict[str, Any]] +) -> List[Tuple[float, float, float, float]]: + """Infer the generation config from the model config dictionaries. + The returned four floats are: + - temperature + - top_p + - frequency_penalty + - presence_penalty + """ + generation_configs = [] + + for model_config in model_config_dicts: + temperature = model_config.get("temperature", 1.0) + top_p = model_config.get("top_p", 1.0) + frequency_penalty = model_config.get("frequency_penalty", 0.0) + presence_penalty = model_config.get("presence_penalty", 0.0) + generation_configs.append((temperature, top_p, frequency_penalty, presence_penalty)) + + return generation_configs + + @dataclass class CallbackStreamOutput: """The output of MLCEngine._generate and AsyncMLCEngine._generate @@ -728,6 +1006,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals max_batch_size: Optional[int], max_total_sequence_length: Optional[int], prefill_chunk_size: Optional[int], + max_history_size: Optional[int], gpu_memory_utilization: Optional[float], speculative_mode: SpeculativeMode, spec_draft_length: int, @@ -757,11 +1036,14 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals max_total_sequence_length, prefill_chunk_size, max_single_sequence_length, + max_history_size, + kv_state_kind, ) = _infer_kv_cache_config( mode, max_batch_size, max_total_sequence_length, prefill_chunk_size, + max_history_size, gpu_memory_utilization, models, device, @@ -788,6 +1070,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals } self.tokenizer = Tokenizer(model_args[0][0]) self._ffi["init_background_engine"]( + device, self.state.get_request_stream_callback(kind), self.state.trace_recorder, ) @@ -797,12 +1080,13 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals model_lib_path=model_args[0][1], additional_models=[model_arg[0] for model_arg in model_args[1:]], additional_model_lib_paths=[model_arg[1] for model_arg in model_args[1:]], - device=device, kv_cache_page_size=16, max_num_sequence=max_batch_size, max_total_sequence_length=max_total_sequence_length, max_single_sequence_length=max_single_sequence_length, prefill_chunk_size=prefill_chunk_size, + max_history_size=max_history_size, + kv_state_kind=kv_state_kind, speculative_mode=speculative_mode, spec_draft_length=spec_draft_length, ) diff --git a/python/mlc_llm/serve/sync_engine.py b/python/mlc_llm/serve/sync_engine.py index 257338da3a..1be841cb08 100644 --- a/python/mlc_llm/serve/sync_engine.py +++ b/python/mlc_llm/serve/sync_engine.py @@ -98,6 +98,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals max_batch_size: Optional[int] = None, max_total_sequence_length: Optional[int] = None, prefill_chunk_size: Optional[int] = None, + max_history_size: Optional[int] = None, gpu_memory_utilization: Optional[float] = None, enable_tracing: bool = False, speculative_mode: SpeculativeMode = SpeculativeMode.DISABLE, @@ -128,11 +129,14 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals max_total_sequence_length, prefill_chunk_size, max_single_sequence_length, + max_history_size, + kv_state_kind, ) = _infer_kv_cache_config( mode, max_batch_size, max_total_sequence_length, prefill_chunk_size, + max_history_size, gpu_memory_utilization, models, device, @@ -162,15 +166,17 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals model_lib_path=model_args[0][1], additional_models=[model_arg[0] for model_arg in model_args[1:]], additional_model_lib_paths=[model_arg[1] for model_arg in model_args[1:]], - device=device, kv_cache_page_size=16, max_num_sequence=max_batch_size, max_total_sequence_length=max_total_sequence_length, max_single_sequence_length=max_single_sequence_length, prefill_chunk_size=prefill_chunk_size, + max_history_size=max_history_size, + kv_state_kind=kv_state_kind, speculative_mode=speculative_mode, spec_draft_length=spec_draft_length, ), + device, request_stream_callback, self.trace_recorder, ) diff --git a/python/mlc_llm/support/auto_config.py b/python/mlc_llm/support/auto_config.py index f0247a6ef9..be0ee8af98 100644 --- a/python/mlc_llm/support/auto_config.py +++ b/python/mlc_llm/support/auto_config.py @@ -62,7 +62,7 @@ def detect_mlc_chat_config(mlc_chat_config: str) -> Path: # search mlc-chat-config.json under path mlc_chat_config_json_path = mlc_chat_config_path / "mlc-chat-config.json" if not mlc_chat_config_json_path.exists(): - raise ValueError(f"Fail to find mlc_chat_config.json under {mlc_chat_config_path}.") + raise ValueError(f"Fail to find mlc-chat-config.json under {mlc_chat_config_path}.") else: mlc_chat_config_json_path = mlc_chat_config_path diff --git a/python/mlc_llm/support/auto_target.py b/python/mlc_llm/support/auto_target.py index 3cf49c43ba..5239756d9d 100644 --- a/python/mlc_llm/support/auto_target.py +++ b/python/mlc_llm/support/auto_target.py @@ -293,14 +293,20 @@ def build(mod: IRModule, args: "CompileArgs", pipeline=None): return build -def detect_cuda_arch_list(target: Target) -> List[str]: +def detect_cuda_arch_list(target: Target) -> List[int]: """Detect the CUDA architecture list from the target.""" + + def convert_to_num(arch_str): + arch_num_str = "".join(filter(str.isdigit, arch_str)) + assert arch_num_str, f"'{arch_str}' does not contain any digits" + return int(arch_num_str) + assert target.kind.name == "cuda", f"Expect target to be CUDA, but got {target}" if MLC_MULTI_ARCH is not None: - multi_arch = [x.strip() for x in MLC_MULTI_ARCH.split(",")] + multi_arch = [convert_to_num(x) for x in MLC_MULTI_ARCH.split(",")] else: assert target.arch.startswith("sm_") - multi_arch = [target.arch[3:]] + multi_arch = [convert_to_num(target.arch[3:])] multi_arch = list(set(multi_arch)) return multi_arch diff --git a/python/mlc_llm/support/max_thread_check.py b/python/mlc_llm/support/max_thread_check.py index 6c078c3bbf..6711fb5c55 100644 --- a/python/mlc_llm/support/max_thread_check.py +++ b/python/mlc_llm/support/max_thread_check.py @@ -3,7 +3,7 @@ from tvm.target import Target -def get_max_num_threads_per_block(target: Target): +def get_max_num_threads_per_block(target: Target) -> int: """ max(max_num_threads, max_threads_per_block); if latter does not exist, return max_num_threads. We add this method since some targets have both fields and `max_threads_per_block` is larger. diff --git a/tests/python/json_ffi/test_json_ffi_engine.py b/tests/python/json_ffi/test_json_ffi_engine.py index 9b594e9784..c52571b522 100644 --- a/tests/python/json_ffi/test_json_ffi_engine.py +++ b/tests/python/json_ffi/test_json_ffi_engine.py @@ -1,23 +1,6 @@ -# pylint: disable=chained-comparison,line-too-long,missing-docstring, -# pylint: disable=too-many-arguments,too-many-locals,unused-argument,unused-variable -import json -import queue -import threading from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Union -import tvm - -from mlc_llm.protocol import openai_api_protocol -from mlc_llm.serve import engine_utils -from mlc_llm.serve.engine_base import ( - EngineConfig, - SpeculativeMode, - _infer_kv_cache_config, - _parse_models, - _process_model_args, - detect_device, -) -from mlc_llm.tokenizer import Tokenizer +from mlc_llm.json_ffi import JSONFFIEngine chat_completion_prompts = [ "What is the meaning of life?", @@ -60,240 +43,6 @@ ] -class EngineState: - sync_queue: queue.Queue - - def get_request_stream_callback(self) -> Callable[[List[str]], None]: - # ChatCompletionStreamResponse - - def _callback(chat_completion_stream_responses_json_str: List[str]) -> None: - self._sync_request_stream_callback(chat_completion_stream_responses_json_str) - - return _callback - - def _sync_request_stream_callback( - self, chat_completion_stream_responses_json_str: List[str] - ) -> None: - # Put the delta outputs to the queue in the unblocking way. - self.sync_queue.put_nowait(chat_completion_stream_responses_json_str) - - -class JSONFFIEngine: - def __init__( # pylint: disable=too-many-arguments,too-many-locals - self, - model: str, - device: Union[str, tvm.runtime.Device] = "auto", - *, - model_lib_path: Optional[str] = None, - mode: Literal["local", "interactive", "server"] = "local", - additional_models: Optional[List[str]] = None, - max_batch_size: Optional[int] = None, - max_total_sequence_length: Optional[int] = None, - prefill_chunk_size: Optional[int] = None, - speculative_mode: SpeculativeMode = SpeculativeMode.DISABLE, - spec_draft_length: int = 4, - gpu_memory_utilization: Optional[float] = None, - ) -> None: - # - Initialize model loading info. - models = _parse_models(model, model_lib_path, additional_models) - if isinstance(device, str): - device = detect_device(device) - assert isinstance(device, tvm.runtime.Device) - ( - model_args, - model_config_paths, - self.conv_template, - ) = _process_model_args(models, device) - - # - Load the raw model config into dict - self.model_config_dicts = [] - for i, model_info in enumerate(models): - model_info.model_lib_path = model_args[i][1] - with open(model_config_paths[i], "r", encoding="utf-8") as file: - self.model_config_dicts.append(json.load(file)) - - # - Decide the KV cache config based on mode and user input. - ( - max_batch_size, - max_total_sequence_length, - prefill_chunk_size, - max_single_sequence_length, - ) = _infer_kv_cache_config( - mode, - max_batch_size, - max_total_sequence_length, - prefill_chunk_size, - gpu_memory_utilization, - models, - device, - self.model_config_dicts, - model_config_paths, - ) - self.max_input_sequence_length = min(max_single_sequence_length, max_total_sequence_length) - - # - Initialize engine state and engine. - self.state = EngineState() - module = tvm.get_global_func("mlc.json_ffi.CreateJSONFFIEngine", allow_missing=False)() - self._ffi = { - key: module[key] - for key in [ - "init_background_engine", - "reload", - "unload", - "reset", - "chat_completion", - "abort", - "get_last_error", - "run_background_loop", - "run_background_stream_back_loop", - "exit_background_loop", - ] - } - self.tokenizer = Tokenizer(model_args[0][0]) - - self.engine_config = EngineConfig( - model=model_args[0][0], - model_lib_path=model_args[0][1], - additional_models=[model_arg[0] for model_arg in model_args[1:]], - additional_model_lib_paths=[model_arg[1] for model_arg in model_args[1:]], - device=device, - kv_cache_page_size=16, - max_num_sequence=max_batch_size, - max_total_sequence_length=max_total_sequence_length, - max_single_sequence_length=max_single_sequence_length, - prefill_chunk_size=prefill_chunk_size, - speculative_mode=speculative_mode, - spec_draft_length=spec_draft_length, - ) - self._ffi["init_background_engine"]( - self.conv_template.model_dump_json(), - self.engine_config, - self.state.get_request_stream_callback(), - None, - ) - - def _background_loop(): - self._ffi["run_background_loop"]() - - def _background_stream_back_loop(): - self._ffi["run_background_stream_back_loop"]() - - # Create the background engine-driving thread and start the loop. - self._background_loop_thread: threading.Thread = threading.Thread(target=_background_loop) - self._background_stream_back_loop_thread: threading.Thread = threading.Thread( - target=_background_stream_back_loop - ) - self._background_loop_thread.start() - self._background_stream_back_loop_thread.start() - self._terminated = False - - def terminate(self): - self._terminated = True - self._ffi["exit_background_loop"]() - self._background_loop_thread.join() - self._background_stream_back_loop_thread.join() - - def chat_completion( # pylint: disable=too-many-arguments,too-many-locals - self, - *, - messages: List[Dict[str, Any]], - model: str, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, - logprobs: bool = False, - top_logprobs: int = 0, - logit_bias: Optional[Dict[int, float]] = None, - max_tokens: Optional[int] = None, - n: int = 1, - seed: Optional[int] = None, - stop: Optional[Union[str, List[str]]] = None, - stream: bool = False, - temperature: float = 1.0, - top_p: float = 1.0, - tools: Optional[List[Dict[str, Any]]] = None, - tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, - user: Optional[str] = None, - ignore_eos: bool = False, - response_format: Optional[Dict[str, Any]] = None, - request_id: Optional[str] = None, - ) -> Iterator[openai_api_protocol.ChatCompletionStreamResponse]: - if request_id is None: - request_id = f"chatcmpl-{engine_utils.random_uuid()}" - - chatcmpl_generator = self._handle_chat_completion( - openai_api_protocol.ChatCompletionRequest( - messages=[ - openai_api_protocol.ChatCompletionMessage.model_validate(message) - for message in messages - ], - model=model, - frequency_penalty=frequency_penalty, - presence_penalty=presence_penalty, - logprobs=logprobs, - top_logprobs=top_logprobs, - logit_bias=logit_bias, - max_tokens=max_tokens, - n=n, - seed=seed, - stop=stop, - stream=stream, - temperature=temperature, - top_p=top_p, - tools=( - [openai_api_protocol.ChatTool.model_validate(tool) for tool in tools] - if tools is not None - else None - ), - tool_choice=tool_choice, - user=user, - ignore_eos=ignore_eos, - response_format=( - openai_api_protocol.RequestResponseFormat.model_validate(response_format) - if response_format is not None - else None - ), - ).model_dump_json(), - n=n, - request_id=request_id, - ) - for response in chatcmpl_generator: - yield response - - def _handle_chat_completion( - self, request_json_str: str, n: int, request_id: str - ) -> Iterator[openai_api_protocol.ChatCompletionStreamResponse]: - self.state.sync_queue = queue.Queue() - num_unfinished_requests = n - - success = bool(self._ffi["chat_completion"](request_json_str, request_id)) - - try: - while num_unfinished_requests > 0: - chat_completion_stream_responses_json_str = self.state.sync_queue.get() - for chat_completion_response_json_str in chat_completion_stream_responses_json_str: - chat_completion_response = ( - openai_api_protocol.ChatCompletionStreamResponse.model_validate_json( - chat_completion_response_json_str - ) - ) - for choice in chat_completion_response.choices: - if choice.finish_reason is not None: - num_unfinished_requests -= 1 - yield chat_completion_response - except Exception as exception: # pylint: disable=broad-exception-caught - self._ffi["abort"](request_id) - raise exception - - def _test_reload(self): - self._ffi["reload"](self.engine_config) - - def _test_reset(self): - self._ffi["reset"]() - - def _test_unload(self): - self._ffi["unload"]() - - def run_chat_completion( engine: JSONFFIEngine, model: str, @@ -335,10 +84,8 @@ def run_chat_completion( def test_chat_completion(): # Create engine. model = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-cuda.so" engine = JSONFFIEngine( model, - model_lib_path=model_lib_path, max_total_sequence_length=1024, ) @@ -355,10 +102,8 @@ def test_chat_completion(): def test_reload_reset_unload(): # Create engine. model = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-cuda.so" engine = JSONFFIEngine( model, - model_lib_path=model_lib_path, max_total_sequence_length=1024, ) diff --git a/tests/python/op/test_batch_spec_verify.py b/tests/python/op/test_batch_spec_verify.py index 359fafdbd0..f35a39d71e 100644 --- a/tests/python/op/test_batch_spec_verify.py +++ b/tests/python/op/test_batch_spec_verify.py @@ -7,7 +7,7 @@ @pytest.mark.parametrize("nbatch", [32, 64]) -@pytest.mark.parametrize("vocab", [3, 32, 64, 32000, 33, 65, 32001]) +@pytest.mark.parametrize("vocab", [3, 32, 64, 32000, 33, 65, 32001, 128000]) @pytest.mark.parametrize("plist", [[0.5, 0.5], [1, 0], [0, 1]]) def test_batch_spec_verify(nbatch, vocab, plist): def numpy_reference( @@ -141,6 +141,20 @@ def gen_full_binary_tree(height, base): token_tree_parent_ptr, token_tree_parent_ptr_tvm.asnumpy(), rtol=0, atol=0 ) + time_evaluator = mod.time_evaluator(mod.entry_name, dev, number=10, repeat=3) + print(f"batch_size: {nbatch}, vocab_size: {vocab}, tree_structure: {plist}") + print( + time_evaluator( + draft_probs_tvm, + draft_tokens_tvm, + model_probs_tvm, + token_tree_first_child_tvm, + token_tree_next_sibling_tvm, + uniform_samples_tvm, + token_tree_parent_ptr_tvm, + ) + ) + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/op/test_top_p_pivot.py b/tests/python/op/test_top_p_pivot.py new file mode 100644 index 0000000000..7cfeb60e9c --- /dev/null +++ b/tests/python/op/test_top_p_pivot.py @@ -0,0 +1,83 @@ +import numpy as np +import pytest +import tvm +import tvm.testing + +from mlc_llm.op.top_p_pivot import top_p_pivot, top_p_renorm + +# mypy: disable-error-code="var-annotated" + + +@pytest.mark.parametrize("batch_size", [32, 64]) +@pytest.mark.parametrize("vocab", [3, 32, 64, 128]) +def test_top_p_renorm(batch_size, vocab): + top_p = 0.95 + init_pivots_np = np.array([1 - top_p, 0.02, 0.01]).astype(np.float32) + top_p_np = np.array([top_p]).astype(np.float32) + + p_np = np.random.exponential(3, size=(batch_size, vocab)).astype(np.float32) + p_np /= np.sum(p_np, axis=-1, keepdims=True) + final_pivot_np = np.zeros(batch_size).astype(np.float32) + final_lsum_np = np.zeros(batch_size).astype(np.float32) + + dev = tvm.cuda(0) + var_prob = tvm.nd.array(p_np, dev) + var_init_pivots = tvm.nd.array(init_pivots_np, dev) + top_p_global = tvm.nd.array(top_p_np, dev) + var_final_pivot = tvm.nd.array(final_pivot_np, dev) + var_final_lsum = tvm.nd.array(final_lsum_np, dev) + + kernel = top_p_pivot(init_pivots_np.shape[0]) + mod = tvm.build(kernel, target="cuda") + mod(var_prob, top_p_global, var_init_pivots, var_final_pivot, var_final_lsum) + + final_pivot = var_final_pivot.asnumpy() + final_lsum = var_final_lsum.asnumpy() + + renorm_np = p_np.copy() + var_renorm = tvm.nd.array(renorm_np, dev) + + kernel_renorm = top_p_renorm() + mod_renorm = tvm.build(kernel_renorm, target="cuda") + mod_renorm(var_prob, var_final_pivot, var_final_lsum, var_renorm) + + renorm = var_renorm.asnumpy() + + def verify_pivot(probs: np.ndarray, pivot: float, lsum: float, renorm: np.ndarray): + sorted_probs = np.sort(probs, axis=-1)[::-1] + num_larger_than_pivot = np.sum(sorted_probs >= pivot) + filtered_sorted_probs = sorted_probs[:num_larger_than_pivot] + min_larger_than_pivot = min(filtered_sorted_probs) + + sum_larger_than_pivot = np.sum(np.where(sorted_probs >= pivot, sorted_probs, 0)) + sum_larger_than_pivot_exclude_min = np.sum( + np.where(filtered_sorted_probs != min_larger_than_pivot, filtered_sorted_probs, 0) + ) + + probs[probs < pivot] = 0 + renorm_prob = probs / np.sum(probs, axis=-1, keepdims=True) + try: + assert sum_larger_than_pivot >= top_p + assert sum_larger_than_pivot_exclude_min < top_p + assert abs(lsum - sum_larger_than_pivot) < 1e-6 + assert np.allclose(renorm, renorm_prob, atol=1e-6, rtol=1e-6) + except AssertionError: + print("Failed") + print("probs:", repr(probs)) + print("pivot:", pivot) + print("sorted_probs:", sorted_probs) + print("num_larger_than_pivot:", num_larger_than_pivot) + print("filtered_sorted_probs:", filtered_sorted_probs) + print("min_larger_than_pivot:", min_larger_than_pivot) + print("sum_larger_than_pivot:", sum_larger_than_pivot) + print("sum_larger_than_pivot_exclude_min:", sum_larger_than_pivot_exclude_min) + print("renom_prob:", renorm_prob) + print("renorm:", renorm) + raise + + for i in range(batch_size): + verify_pivot(p_np[i], final_pivot[i], final_lsum[i], renorm[i]) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/op/test_two_stage_softmax.py b/tests/python/op/test_two_stage_softmax.py new file mode 100644 index 0000000000..1d3d55d8e3 --- /dev/null +++ b/tests/python/op/test_two_stage_softmax.py @@ -0,0 +1,47 @@ +import numpy as np +import scipy.special +import tvm +from tvm import dlight + +from mlc_llm.compiler_pass.rewrite_softmax import _get_lse_and_softmax_func + + +def test_two_stage_softmax(): + chunk_size = 4096 + target = tvm.target.Target("cuda") + f_chunk_lse, f_softmax_with_lse = _get_lse_and_softmax_func(target, chunk_size) + mod = tvm.IRModule({"chunk_lse": f_chunk_lse, "softmax_with_chunked_lse": f_softmax_with_lse}) + with target: + mod = dlight.ApplyDefaultSchedule(dlight.gpu.GeneralReduction())(mod) + + runtime_mod = tvm.build(mod, target=target) + device = tvm.cuda() + + num_runs = 5 + vocab_size = 128256 + for batch_size in [1, 2, 4, 8, 16, 32, 64, 128]: + for _ in range(num_runs): + x_np = np.random.uniform(low=-10, high=10, size=(batch_size, vocab_size)).astype( + "float32" + ) + y_np = scipy.special.softmax(x_np, axis=-1) + + x_nd = tvm.nd.array(x_np, device=device) + r_nd = tvm.nd.empty( + (batch_size, (vocab_size + chunk_size - 1) // chunk_size), + x_np.dtype, + device=device, + ) + y_nd = tvm.nd.empty(x_np.shape, x_np.dtype, device=device) + + runtime_mod["chunk_lse"](x_nd, r_nd) + runtime_mod["softmax_with_chunked_lse"](x_nd, r_nd, y_nd) + + y_nd_arr = y_nd.numpy() + np.testing.assert_allclose(y_nd_arr, y_np, atol=1e-6, rtol=1e-6) + + print(f"pass batch size {batch_size}") + + +if __name__ == "__main__": + test_two_stage_softmax() diff --git a/tests/python/serve/test_serve_engine.py b/tests/python/serve/test_serve_engine.py index f965e8cc82..37d1833b14 100644 --- a/tests/python/serve/test_serve_engine.py +++ b/tests/python/serve/test_serve_engine.py @@ -2,6 +2,8 @@ # pylint: disable=too-many-arguments,too-many-locals,unused-argument,unused-variable from typing import List +import pytest + from mlc_llm.serve import GenerationConfig, MLCEngine prompts = [ @@ -17,17 +19,39 @@ "Do you know AlphaGo? What capabilities does it have, and what achievements has it got? Please elaborate in detail.", ] +test_models = [ + ( + "dist/Llama-2-7b-chat-hf-q0f16-MLC", + "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", + ), + ( + "dist/rwkv-6-world-1b6-q0f16-MLC", + "dist/rwkv-6-world-1b6-q0f16-MLC/rwkv-6-world-1b6-q0f16-MLC-cuda.so", + ), +] -def test_engine_generate(): - # Create engine - model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - engine = MLCEngine( - model=model, - model_lib_path=model_lib_path, - mode="server", - max_total_sequence_length=4096, - ) + +def create_engine(model: str, model_lib_path: str): + if "rwkv" in model: + return MLCEngine( + model=model, + model_lib_path=model_lib_path, + mode="server", + max_batch_size=8, + max_history_size=1, + ) + else: + return MLCEngine( + model=model, + model_lib_path=model_lib_path, + mode="server", + max_total_sequence_length=4096, + ) + + +@pytest.mark.parametrize("model,model_lib_path", test_models) +def test_engine_generate(model: str, model_lib_path: str): + engine = create_engine(model, model_lib_path) num_requests = 10 max_tokens = 256 @@ -57,16 +81,10 @@ def test_engine_generate(): del engine -def test_chat_completion(): +@pytest.mark.parametrize("model,model_lib_path", test_models) +def test_chat_completion(model: str, model_lib_path: str): # Create engine - model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - engine = MLCEngine( - model=model, - model_lib_path=model_lib_path, - mode="server", - max_total_sequence_length=4096, - ) + engine = create_engine(model, model_lib_path) num_requests = 2 max_tokens = 64 @@ -101,16 +119,9 @@ def test_chat_completion(): del engine -def test_chat_completion_non_stream(): - # Create engine - model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - engine = MLCEngine( - model=model, - model_lib_path=model_lib_path, - mode="server", - max_total_sequence_length=4096, - ) +@pytest.mark.parametrize("model,model_lib_path", test_models) +def test_chat_completion_non_stream(model: str, model_lib_path: str): + engine = create_engine(model, model_lib_path) num_requests = 2 max_tokens = 64 @@ -144,16 +155,9 @@ def test_chat_completion_non_stream(): del engine -def test_completion(): - # Create engine - model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - engine = MLCEngine( - model=model, - model_lib_path=model_lib_path, - mode="server", - max_total_sequence_length=4096, - ) +@pytest.mark.parametrize("model,model_lib_path", test_models) +def test_completion(model: str, model_lib_path: str): + engine = create_engine(model, model_lib_path) num_requests = 2 max_tokens = 128 @@ -188,16 +192,9 @@ def test_completion(): del engine -def test_completion_non_stream(): - # Create engine - model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - engine = MLCEngine( - model=model, - model_lib_path=model_lib_path, - mode="server", - max_total_sequence_length=4096, - ) +@pytest.mark.parametrize("model,model_lib_path", test_models) +def test_completion_non_stream(model: str, model_lib_path: str): + engine = create_engine(model, model_lib_path) num_requests = 2 max_tokens = 128 @@ -232,8 +229,9 @@ def test_completion_non_stream(): if __name__ == "__main__": - test_engine_generate() - test_chat_completion() - test_chat_completion_non_stream() - test_completion() - test_completion_non_stream() + for model, model_lib_path in test_models: + test_engine_generate(model, model_lib_path) + test_chat_completion(model, model_lib_path) + test_chat_completion_non_stream(model, model_lib_path) + test_completion(model, model_lib_path) + test_completion_non_stream(model, model_lib_path)