Skip to content

Commit

Permalink
Merge branch 'mlc-ai:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
MBaltz authored Apr 30, 2024
2 parents ba474f9 + ca7cdcc commit b1e3e4f
Show file tree
Hide file tree
Showing 80 changed files with 3,037 additions and 829 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/tvm
Submodule tvm updated 97 files
+107 −0 3rdparty/mscclpp/include/common.h
+323 −0 3rdparty/mscclpp/include/msccl.cuh
+494 −0 3rdparty/mscclpp/include/msccl.h
+10 −3 CMakeLists.txt
+5 −0 cmake/config.cmake
+2 −0 cmake/modules/LibInfo.cmake
+50 −0 cmake/modules/contrib/MSCCLPP.cmake
+1 −0 cmake/modules/contrib/Mrvl.cmake
+2 −5 cmake/utils/Utils.cmake
+28 −0 docker/Dockerfile.demo_mrvl
+53 −36 docs/how_to/deploy/mrvl.rst
+117 −73 include/tvm/relax/expr.h
+4 −1 include/tvm/relax/transform.h
+6 −2 include/tvm/runtime/memory/memory_manager.h
+2 −1 include/tvm/script/ir_builder/tir/ir.h
+172 −0 python/tvm/contrib/mrvl.py
+2 −1 python/tvm/contrib/rocm.py
+27 −8 python/tvm/dlight/gpu/general_reduction.py
+33 −24 python/tvm/relax/backend/dispatch_sort_scan.py
+5 −0 python/tvm/relax/transform/transform.py
+82 −48 python/tvm/relax/utils.py
+35 −45 python/tvm/relay/op/contrib/mrvl.py
+12 −0 python/tvm/relay/op/strategy/arm_cpu.py
+5 −2 python/tvm/script/ir_builder/tir/ir.py
+98 −1 python/tvm/topi/arm_cpu/arm_utils.py
+39 −4 python/tvm/topi/arm_cpu/conv2d.py
+59 −75 python/tvm/topi/arm_cpu/conv2d_gemm.py
+6 −1 python/tvm/topi/nn/conv2d.py
+3 −7 src/arith/analyzer.cc
+3 −0 src/arith/const_int_bound.cc
+6 −0 src/arith/int_set.cc
+20 −0 src/arith/rewrite_simplify.cc
+14 −0 src/arith/scalable_expression.cc
+13 −0 src/arith/scalable_expression.h
+3 −6 src/contrib/msc/core/ir/graph_builder.cc
+6 −14 src/contrib/msc/core/transform/set_expr_layout.cc
+12 −20 src/relax/analysis/well_formed.cc
+1 −1 src/relax/backend/contrib/utils.cc
+23 −6 src/relax/ir/dataflow_matcher.cc
+8 −0 src/relax/ir/expr.cc
+82 −21 src/relax/op/op_common.h
+86 −26 src/relax/op/tensor/binary.cc
+2 −5 src/relax/training/utils.cc
+58 −25 src/relax/transform/fuse_ops.cc
+1 −3 src/relax/transform/fuse_tir.cc
+0 −2 src/relax/transform/gradient.cc
+2 −1 src/relax/transform/utils.h
+137 −94 src/relay/backend/contrib/mrvl/codegen.cc
+0 −1 src/relay/backend/contrib/mrvl/compiler_attr.cc
+3 −1 src/relay/backend/utils.cc
+78 −0 src/runtime/contrib/mrvl/mrvl_base64.h
+32 −6 src/runtime/contrib/mrvl/mrvl_runtime.cc
+175 −0 src/runtime/contrib/mrvl/mrvl_sw_runtime_lib.cc
+45 −0 src/runtime/contrib/mrvl/mrvl_sw_runtime_lib.h
+184 −0 src/runtime/contrib/mscclpp/allreduce.cu
+4 −3 src/runtime/cuda/cuda_device_api.cc
+1 −1 src/runtime/disco/bcast_session.cc
+3 −6 src/runtime/relax_vm/builtin.cc
+5 −5 src/script/ir_builder/tir/ir.cc
+2 −2 src/script/printer/relax/binding.cc
+1 −2 src/script/printer/relax/function.cc
+4 −3 src/script/printer/relax/tir.cc
+12 −2 src/support/libinfo.cc
+8 −7 src/te/operation/create_primfunc.cc
+11 −2 src/tir/transforms/ir_utils.cc
+7 −0 src/tir/transforms/storage_rewrite.cc
+11 −2 src/tir/transforms/vectorize_loop.cc
+25 −0 tests/python/arith/test_arith_rewrite_simplify.py
+35 −0 tests/python/codegen/test_target_codegen_aarch64.py
+49 −1 tests/python/contrib/test_mrvl/infrastructure.py
+31 −18 tests/python/contrib/test_mrvl/test_mrvl.py
+149 −0 tests/python/dlight/test_gpu_general_reduction.py
+38 −0 tests/python/relax/test_backend_dispatch_sort_scan.py
+2 −8 tests/python/relax/test_codegen_cublas.py
+1 −8 tests/python/relax/test_codegen_cudnn.py
+1 −8 tests/python/relax/test_codegen_cutlass.py
+11 −2 tests/python/relax/test_codegen_tensorrt.py
+1 −1 tests/python/relax/test_contrib_vllm.py
+1 −1 tests/python/relax/test_expr_functor.py
+86 −20 tests/python/relax/test_op_binary.py
+4 −4 tests/python/relax/test_op_nn_convolution.py
+2 −2 tests/python/relax/test_op_search.py
+6 −4 tests/python/relax/test_transform_codegen_pass.py
+533 −1 tests/python/relax/test_transform_legalize_ops_binary.py
+29 −2 tests/python/relax/test_vm_builtin.py
+52 −0 tests/python/runtime/test_runtime_device_api.py
+43 −43 tests/python/tir-schedule/test_tir_schedule_split_fuse.py
+2 −2 tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py
+23 −2 tests/python/tir-transform/test_tir_transform_split_host_device.py
+97 −55 tests/python/tir-transform/test_tir_transform_vectorize.py
+35 −12 tests/python/topi/test_topi_conv2d_nhwc.py
+15 −0 tests/python/tvmscript/test_tvmscript_parser_tir.py
+0 −24 tests/scripts/release/PRERELEASE_NOTES.md
+4 −0 tests/scripts/release/make_notes.py
+17 −0 web/emcc/wasm_runtime.cc
+2 −2 web/package-lock.json
+26 −6 web/src/runtime.ts
1 change: 1 addition & 0 deletions android/library/prepare_libs.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ cmake .. \
-DMLC_LLM_INSTALL_STATIC_LIB=ON \
-DCMAKE_SKIP_INSTALL_ALL_DEPENDENCY=ON \
-DUSE_OPENCL=ON \
-DUSE_OPENCL_ENABLE_HOST_PTR=ON \
-DUSE_CUSTOM_LOGGING=ON \

cmake --build . --target tvm4j_runtime_packed --config release
Expand Down
46 changes: 45 additions & 1 deletion cpp/json_ffi/conv_template.cc → cpp/json_ffi/config.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#include "conv_template.h"
#include "config.h"

#include <tvm/runtime/registry.h>

#include "../metadata/json_parser.h"

Expand All @@ -8,6 +10,29 @@ namespace json_ffi {

using namespace mlc::llm;

/****************** Model-defined generation config ******************/

TVM_REGISTER_OBJECT_TYPE(ModelDefinedGenerationConfigNode);

ModelDefinedGenerationConfig::ModelDefinedGenerationConfig(double temperature, double top_p,
double frequency_penalty,
double presence_penalty) {
ObjectPtr<ModelDefinedGenerationConfigNode> n = make_object<ModelDefinedGenerationConfigNode>();
n->temperature = temperature;
n->top_p = top_p;
n->frequency_penalty = frequency_penalty;
n->presence_penalty = presence_penalty;
data_ = std::move(n);
}

TVM_REGISTER_GLOBAL("mlc.json_ffi.ModelDefinedGenerationConfig")
.set_body_typed([](double temperature, double top_p, double frequency_penalty,
double presence_penalty) {
return ModelDefinedGenerationConfig(temperature, top_p, frequency_penalty, presence_penalty);
});

/****************** Conversation template ******************/

std::map<MessagePlaceholders, std::string> PLACEHOLDERS = {
{MessagePlaceholders::SYSTEM, "{system_message}"},
{MessagePlaceholders::USER, "{user_message}"},
Expand Down Expand Up @@ -308,6 +333,25 @@ std::optional<Conversation> Conversation::FromJSON(const std::string& json_str,
}
return Conversation::FromJSON(json_obj.value(), err);
}

/****************** JSON FFI engine config ******************/

TVM_REGISTER_OBJECT_TYPE(JSONFFIEngineConfigNode);

JSONFFIEngineConfig::JSONFFIEngineConfig(
String conv_template, Map<String, ModelDefinedGenerationConfig> model_generation_cfgs) {
ObjectPtr<JSONFFIEngineConfigNode> n = make_object<JSONFFIEngineConfigNode>();
n->conv_template = conv_template;
n->model_generation_cfgs = model_generation_cfgs;
data_ = std::move(n);
}

TVM_REGISTER_GLOBAL("mlc.json_ffi.JSONFFIEngineConfig")
.set_body_typed([](String conv_template,
Map<String, ModelDefinedGenerationConfig> model_generation_cfgs) {
return JSONFFIEngineConfig(std::move(conv_template), std::move(model_generation_cfgs));
});

} // namespace json_ffi
} // namespace llm
} // namespace mlc
55 changes: 53 additions & 2 deletions cpp/json_ffi/conv_template.h → cpp/json_ffi/config.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
#ifndef MLC_LLM_JSON_FFI_CONV_TEMPLATE_H
#define MLC_LLM_JSON_FFI_CONV_TEMPLATE_H
#ifndef MLC_LLM_JSON_FFI_CONFIG_H
#define MLC_LLM_JSON_FFI_CONFIG_H

#include <tvm/runtime/container/map.h>
#include <tvm/runtime/container/string.h>
#include <tvm/runtime/object.h>

#include <iostream>
#include <map>
Expand All @@ -18,6 +22,32 @@ namespace mlc {
namespace llm {
namespace json_ffi {

/****************** Model-defined generation config ******************/

class ModelDefinedGenerationConfigNode : public Object {
public:
double temperature;
double top_p;
double frequency_penalty;
double presence_penalty;

static constexpr const char* _type_key = "mlc.json_ffi.ModelDefinedGenerationConfig";
static constexpr const bool _type_has_method_sequal_reduce = false;
static constexpr const bool _type_has_method_shash_reduce = false;
TVM_DECLARE_BASE_OBJECT_INFO(ModelDefinedGenerationConfigNode, Object);
};

class ModelDefinedGenerationConfig : public ObjectRef {
public:
explicit ModelDefinedGenerationConfig(double temperature, double top_p, double frequency_penalty,
double presence_penalty);

TVM_DEFINE_OBJECT_REF_METHODS(ModelDefinedGenerationConfig, ObjectRef,
ModelDefinedGenerationConfigNode);
};

/****************** Conversation template ******************/

enum class MessagePlaceholders { SYSTEM, USER, ASSISTANT, TOOL, FUNCTION };

MessagePlaceholders messagePlaceholderFromString(const std::string& role);
Expand Down Expand Up @@ -114,6 +144,27 @@ struct Conversation {
static std::optional<Conversation> FromJSON(const std::string& json_str, std::string* err);
};

/****************** JSON FFI engine config ******************/

class JSONFFIEngineConfigNode : public Object {
public:
String conv_template;
Map<String, ModelDefinedGenerationConfig> model_generation_cfgs;

static constexpr const char* _type_key = "mlc.json_ffi.JSONFFIEngineConfig";
static constexpr const bool _type_has_method_sequal_reduce = false;
static constexpr const bool _type_has_method_shash_reduce = false;
TVM_DECLARE_BASE_OBJECT_INFO(JSONFFIEngineConfigNode, Object);
};

class JSONFFIEngineConfig : public ObjectRef {
public:
explicit JSONFFIEngineConfig(String conv_template,
Map<String, ModelDefinedGenerationConfig> model_generation_cfgs);

TVM_DEFINE_OBJECT_REF_METHODS(JSONFFIEngineConfig, ObjectRef, JSONFFIEngineConfigNode);
};

} // namespace json_ffi
} // namespace llm
} // namespace mlc
Expand Down
14 changes: 8 additions & 6 deletions cpp/json_ffi/json_ffi_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ bool JSONFFIEngine::AddRequest(std::string request_json_str, std::string request
Array<Data> inputs = inputs_obj.value();

// generation_cfg
Optional<GenerationConfig> generation_cfg =
GenerationConfig::FromJSON(request_json_str, &err_, conv_template);
Optional<GenerationConfig> generation_cfg = GenerationConfig::Create(
request_json_str, &err_, conv_template, this->model_generation_cfgs[request.model]);
if (!generation_cfg.defined()) {
return false;
}
Expand Down Expand Up @@ -122,14 +122,16 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode {
TVM_MODULE_VTABLE_ENTRY("exit_background_loop", &JSONFFIEngineImpl::ExitBackgroundLoop);
TVM_MODULE_VTABLE_END();

void InitBackgroundEngine(std::string conv_template_str, EngineConfig engine_config,
Optional<PackedFunc> request_stream_callback,
void InitBackgroundEngine(JSONFFIEngineConfig json_ffi_engine_config, EngineConfig engine_config,
Device device, Optional<PackedFunc> request_stream_callback,
Optional<EventTraceRecorder> trace_recorder) {
std::optional<Conversation> conv_template = Conversation::FromJSON(conv_template_str, &err_);
std::optional<Conversation> conv_template =
Conversation::FromJSON(json_ffi_engine_config->conv_template, &err_);
if (!conv_template.has_value()) {
LOG(FATAL) << "Invalid conversation template JSON: " << err_;
}
this->conv_template_ = conv_template.value();
this->model_generation_cfgs = json_ffi_engine_config->model_generation_cfgs;

// Todo(mlc-team): decouple InitBackgroundEngine into two functions
// by removing `engine_config` from arguments, after properly handling
Expand All @@ -148,7 +150,7 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode {
};

request_stream_callback = PackedFunc(frequest_stream_callback_wrapper);
this->engine_->InitBackgroundEngine(std::move(request_stream_callback),
this->engine_->InitBackgroundEngine(device, std::move(request_stream_callback),
std::move(trace_recorder));
this->engine_->Reload(std::move(engine_config));
}
Expand Down
3 changes: 2 additions & 1 deletion cpp/json_ffi/json_ffi_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

#include "../serve/threaded_engine.h"
#include "../streamer.h"
#include "conv_template.h"
#include "config.h"
#include "openai_api_protocol.h"

namespace mlc {
Expand Down Expand Up @@ -49,6 +49,7 @@ class JSONFFIEngine {
PackedFunc request_stream_callback_;
TextStreamer streamer_; // TODO: Support "n", and support different streamers for each request
Conversation conv_template_;
Map<String, ModelDefinedGenerationConfig> model_generation_cfgs;
};

} // namespace json_ffi
Expand Down
10 changes: 5 additions & 5 deletions cpp/json_ffi/openai_api_protocol.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#include <unordered_map>
#include <vector>

#include "conv_template.h"
#include "config.h"
#include "picojson.h"

namespace mlc {
Expand Down Expand Up @@ -90,8 +90,8 @@ class ChatCompletionRequest {
public:
std::vector<ChatCompletionMessage> messages;
std::string model;
double frequency_penalty = 0.0;
double presence_penalty = 0.0;
std::optional<double> frequency_penalty = std::nullopt;
std::optional<double> presence_penalty = std::nullopt;
bool logprobs = false;
int top_logprobs = 0;
std::optional<std::unordered_map<int, double>> logit_bias = std::nullopt;
Expand All @@ -100,8 +100,8 @@ class ChatCompletionRequest {
std::optional<int> seed = std::nullopt;
std::optional<std::vector<std::string>> stop = std::nullopt;
bool stream = false;
double temperature = 1.0;
double top_p = 1.0;
std::optional<double> temperature = std::nullopt;
std::optional<double> top_p = std::nullopt;
std::optional<std::vector<ChatTool>> tools = std::nullopt;
std::optional<std::string> tool_choice = std::nullopt;
std::optional<std::string> user = std::nullopt;
Expand Down
16 changes: 16 additions & 0 deletions cpp/metadata/json_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,22 @@ inline ValueType Lookup(const picojson::object& json, const std::string& key) {
return it->second.get<ValueType>();
}

template <typename ValueType>
inline ValueType LookupOrDefault(const picojson::object& json, const std::string& key,
const ValueType& default_value) {
auto it = json.find(key);
if (it == json.end()) {
return default_value;
}

if (it->second.is<picojson::null>()) {
return default_value;
}

CHECK(it->second.is<ValueType>()) << "ValueError: key `" << key << "` has unexpected type";
return it->second.get<ValueType>();
}

template <typename ValueType>
inline ValueType Lookup(const picojson::array& json, int index) {
CHECK(index < json.size()) << "IndexError: json::array index out of range";
Expand Down
96 changes: 76 additions & 20 deletions cpp/serve/config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,15 +161,26 @@ GenerationConfig::GenerationConfig(String config_json_str) {
data_ = std::move(n);
}

Optional<GenerationConfig> GenerationConfig::FromJSON(const std::string& json_str, std::string* err,
const Conversation& conv_template) {
std::optional<picojson::object> json_obj = json::LoadJSONFromString(json_str, err);
if (!err->empty() || !json_obj.has_value()) {
Optional<GenerationConfig> GenerationConfig::Create(
const std::string& json_str, std::string* err, const Conversation& conv_template,
const ModelDefinedGenerationConfig& model_defined_gen_config) {
std::optional<picojson::object> optional_json_obj = json::LoadJSONFromString(json_str, err);
if (!err->empty() || !optional_json_obj.has_value()) {
return NullOpt;
}
picojson::object& json_obj = optional_json_obj.value();
ObjectPtr<GenerationConfigNode> n = make_object<GenerationConfigNode>();

// TODO(mlc-team): Pass the parameters from `json_obj` to `n`.
n->temperature =
json::LookupOrDefault<double>(json_obj, "temperature", model_defined_gen_config->temperature);
n->top_p = json::LookupOrDefault<double>(json_obj, "top_p", model_defined_gen_config->top_p);
n->frequency_penalty = json::LookupOrDefault<double>(json_obj, "frequency_penalty",
model_defined_gen_config->frequency_penalty);
n->presence_penalty = json::LookupOrDefault<double>(json_obj, "presence_penalty",
model_defined_gen_config->presence_penalty);
n->logprobs = json::LookupOrDefault<bool>(json_obj, "logprobs", false);
n->top_logprobs = static_cast<int>(json::LookupOrDefault<double>(json_obj, "top_logprobs", 0));
n->ignore_eos = json::LookupOrDefault<bool>(json_obj, "ignore_eos", false);

// Copy stop str from conversation template to generation config
for (auto& stop_str : conv_template.stop_str) {
Expand All @@ -179,9 +190,6 @@ Optional<GenerationConfig> GenerationConfig::FromJSON(const std::string& json_st
n->stop_token_ids.push_back(stop_token_id);
}

if (!err->empty()) {
return NullOpt;
}
GenerationConfig gen_config;
gen_config.data_ = std::move(n);
return gen_config;
Expand Down Expand Up @@ -236,37 +244,85 @@ String GenerationConfigNode::AsJSONString() const {
TVM_REGISTER_OBJECT_TYPE(EngineConfigNode);

EngineConfig::EngineConfig(String model, String model_lib_path, Array<String> additional_models,
Array<String> additional_model_lib_paths, DLDevice device,
int kv_cache_page_size, int max_num_sequence,
int max_total_sequence_length, int max_single_sequence_length,
int prefill_chunk_size, SpeculativeMode speculative_mode,
int spec_draft_length) {
Array<String> additional_model_lib_paths, int kv_cache_page_size,
int max_num_sequence, int max_total_sequence_length,
int max_single_sequence_length, int prefill_chunk_size,
int max_history_size, KVStateKind kv_state_kind,
SpeculativeMode speculative_mode, int spec_draft_length) {
ObjectPtr<EngineConfigNode> n = make_object<EngineConfigNode>();
n->model = std::move(model);
n->model_lib_path = std::move(model_lib_path);
n->additional_models = std::move(additional_models);
n->additional_model_lib_paths = std::move(additional_model_lib_paths);
n->device = device;
n->kv_cache_page_size = kv_cache_page_size;
n->max_num_sequence = max_num_sequence;
n->max_total_sequence_length = max_total_sequence_length;
n->max_single_sequence_length = max_single_sequence_length;
n->prefill_chunk_size = prefill_chunk_size;
n->max_history_size = max_history_size;
n->kv_state_kind = kv_state_kind;
n->spec_draft_length = spec_draft_length;
n->speculative_mode = speculative_mode;
data_ = std::move(n);
}

EngineConfig EngineConfig::FromJSONString(const std::string& json_str) {
picojson::value config_json;
std::string err = picojson::parse(config_json, json_str);
if (!err.empty()) {
LOG(FATAL) << err;
}

// Get json fields.
picojson::object config = config_json.get<picojson::object>();
String model = json::Lookup<std::string>(config, "model");
String model_lib_path = json::Lookup<std::string>(config, "model_lib_path");
std::vector<String> additional_models;
std::vector<String> additional_model_lib_paths;
int kv_cache_page_size = json::Lookup<int64_t>(config, "kv_cache_page_size");
int max_num_sequence = json::Lookup<int64_t>(config, "max_num_sequence");
int max_total_sequence_length = json::Lookup<int64_t>(config, "max_total_sequence_length");
int max_single_sequence_length = json::Lookup<int64_t>(config, "max_single_sequence_length");
int prefill_chunk_size = json::Lookup<int64_t>(config, "prefill_chunk_size");
int max_history_size = json::Lookup<int64_t>(config, "max_history_size");
KVStateKind kv_state_kind =
static_cast<KVStateKind>(json::Lookup<int64_t>(config, "kv_state_kind"));
SpeculativeMode speculative_mode =
static_cast<SpeculativeMode>(json::Lookup<int64_t>(config, "speculative_mode"));
int spec_draft_length = json::Lookup<int64_t>(config, "spec_draft_length");

picojson::array additional_models_arr =
json::Lookup<picojson::array>(config, "additional_models");
picojson::array additional_model_lib_paths_arr =
json::Lookup<picojson::array>(config, "additional_model_lib_paths");
CHECK_EQ(additional_models_arr.size(), additional_model_lib_paths_arr.size())
<< "The number of additional model lib paths does not match the number of additional models";
int num_additional_models = additional_models_arr.size();
additional_models.reserve(num_additional_models);
additional_model_lib_paths.reserve(num_additional_models);
for (int i = 0; i < num_additional_models; ++i) {
additional_models.push_back(json::Lookup<std::string>(additional_models_arr, i));
additional_model_lib_paths.push_back(
json::Lookup<std::string>(additional_model_lib_paths_arr, i));
}

return EngineConfig(std::move(model), std::move(model_lib_path), additional_models,
additional_model_lib_paths, kv_cache_page_size, max_num_sequence,
max_total_sequence_length, max_single_sequence_length, prefill_chunk_size,
max_history_size, kv_state_kind, speculative_mode, spec_draft_length);
}

TVM_REGISTER_GLOBAL("mlc.serve.EngineConfig")
.set_body_typed([](String model, String model_lib_path, Array<String> additional_models,
Array<String> additional_model_lib_paths, DLDevice device,
int kv_cache_page_size, int max_num_sequence, int max_total_sequence_length,
int max_single_sequence_length, int prefill_chunk_size, int speculative_mode,
int spec_draft_length) {
Array<String> additional_model_lib_paths, int kv_cache_page_size,
int max_num_sequence, int max_total_sequence_length,
int max_single_sequence_length, int prefill_chunk_size, int max_history_size,
int kv_state_kind, int speculative_mode, int spec_draft_length) {
return EngineConfig(std::move(model), std::move(model_lib_path), std::move(additional_models),
std::move(additional_model_lib_paths), device, kv_cache_page_size,
std::move(additional_model_lib_paths), kv_cache_page_size,
max_num_sequence, max_total_sequence_length, max_single_sequence_length,
prefill_chunk_size, SpeculativeMode(speculative_mode), spec_draft_length);
prefill_chunk_size, max_history_size, KVStateKind(kv_state_kind),
SpeculativeMode(speculative_mode), spec_draft_length);
});

} // namespace serve
Expand Down
Loading

0 comments on commit b1e3e4f

Please sign in to comment.