From 467b7a12c88befd36f6fc978c9173e69be2a5f14 Mon Sep 17 00:00:00 2001 From: joostinyi <63941848+joostinyi@users.noreply.github.com> Date: Mon, 28 Oct 2024 14:16:08 -0700 Subject: [PATCH] Add qwen config and and input config simplification (#1190) * remove max input/output in favor of max_seq_len * add qwen * bump version * revert bump * update config in test * fix tests * bump version * bump briton image * remove unused config * better client side error msg + validation * fix test --- pyproject.toml | 2 +- truss/config/trt_llm.py | 11 +++++------ truss/constants.py | 2 +- .../contexts/image_builder/serving_image_builder.py | 11 ----------- truss/test_data/test_trt_llm_truss/config.yaml | 3 +-- truss/tests/conftest.py | 3 +-- truss/tests/test_config.py | 9 +++++++-- truss/trt_llm/validation.py | 3 ++- truss/truss_config.py | 13 ++++++++++--- 9 files changed, 28 insertions(+), 29 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4410500fa..40c3a6953 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "truss" -version = "0.9.45rc009" +version = "0.9.45rc018" description = "A seamless bridge from model development to model delivery" license = "MIT" readme = "README.md" diff --git a/truss/config/trt_llm.py b/truss/config/trt_llm.py index 877053c28..73c633d2e 100644 --- a/truss/config/trt_llm.py +++ b/truss/config/trt_llm.py @@ -23,6 +23,7 @@ class TrussTRTLLMModel(str, Enum): MISTRAL = "mistral" DEEPSEEK = "deepseek" WHISPER = "whisper" + QWEN = "qwen" class TrussTRTLLMQuantizationType(str, Enum): @@ -30,7 +31,7 @@ class TrussTRTLLMQuantizationType(str, Enum): WEIGHTS_ONLY_INT8 = "weights_int8" WEIGHTS_KV_INT8 = "weights_kv_int8" WEIGHTS_ONLY_INT4 = "weights_int4" - WEIGHTS_KV_INT4 = "weights_kv_int4" + WEIGHTS_INT4_KV_INT8 = "weights_int4_kv_int8" SMOOTH_QUANT = "smooth_quant" FP8 = "fp8" FP8_KV = "fp8_kv" @@ -58,10 +59,9 @@ class CheckpointRepository(BaseModel): class TrussTRTLLMBuildConfiguration(BaseModel): base_model: TrussTRTLLMModel - max_input_len: int - max_output_len: int - max_batch_size: int - max_num_tokens: Optional[int] = None + max_seq_len: int + max_batch_size: Optional[int] = 256 + max_num_tokens: Optional[int] = 8192 max_beam_width: int = 1 max_prompt_embedding_table_size: int = 0 checkpoint_repository: CheckpointRepository @@ -75,7 +75,6 @@ class TrussTRTLLMBuildConfiguration(BaseModel): plugin_configuration: TrussTRTLLMPluginConfiguration = ( TrussTRTLLMPluginConfiguration() ) - use_fused_mlp: bool = False kv_cache_free_gpu_mem_fraction: float = 0.9 num_builder_gpus: Optional[int] = None enable_chunked_context: bool = False diff --git a/truss/constants.py b/truss/constants.py index 093c251c7..5fda6e7f1 100644 --- a/truss/constants.py +++ b/truss/constants.py @@ -106,7 +106,7 @@ REGISTRY_BUILD_SECRET_PREFIX = "DOCKER_REGISTRY_" -TRTLLM_BASE_IMAGE = "baseten/briton-server:5fa9436e_v0.0.11" +TRTLLM_BASE_IMAGE = "baseten/briton-server:v0.13.0" TRTLLM_PYTHON_EXECUTABLE = "/usr/bin/python3" BASE_TRTLLM_REQUIREMENTS = [ "grpcio==1.62.3", diff --git a/truss/contexts/image_builder/serving_image_builder.py b/truss/contexts/image_builder/serving_image_builder.py index 36d4de2ac..0426c57b9 100644 --- a/truss/contexts/image_builder/serving_image_builder.py +++ b/truss/contexts/image_builder/serving_image_builder.py @@ -412,17 +412,6 @@ def copy_into_build_dir(from_path: Path, path_in_build_dir: str): DEFAULT_BUNDLED_PACKAGES_DIR, ) - tensor_parallel_count = ( - config.trt_llm.build.tensor_parallel_count # type: ignore[union-attr] - if config.trt_llm.build is not None - else config.trt_llm.serve.tensor_parallel_count # type: ignore[union-attr] - ) - - if tensor_parallel_count != config.resources.accelerator.count: - raise ValueError( - "Tensor parallelism and GPU count must be the same for TRT-LLM" - ) - config.runtime.predict_concurrency = TRTLLM_PREDICT_CONCURRENCY if not is_audio_model: diff --git a/truss/test_data/test_trt_llm_truss/config.yaml b/truss/test_data/test_trt_llm_truss/config.yaml index 5e4bd3f4e..03f300cb4 100644 --- a/truss/test_data/test_trt_llm_truss/config.yaml +++ b/truss/test_data/test_trt_llm_truss/config.yaml @@ -4,10 +4,9 @@ resources: use_gpu: True trt_llm: build: - max_input_len: 1000 + max_seq_len: 1000 max_batch_size: 1 max_beam_width: 1 - max_output_len: 1000 base_model: llama checkpoint_repository: repo: TinyLlama/TinyLlama-1.1B-Chat-v1.0 diff --git a/truss/tests/conftest.py b/truss/tests/conftest.py index 37fdf7e5c..ce9c6a99e 100644 --- a/truss/tests/conftest.py +++ b/truss/tests/conftest.py @@ -389,8 +389,7 @@ def modify_handle(h: TrussHandle): content["trt_llm"] = { "build": { "base_model": "llama", - "max_input_len": 1024, - "max_output_len": 1024, + "max_seq_len": 2048, "max_batch_size": 512, "checkpoint_repository": { "source": "HF", diff --git a/truss/tests/test_config.py b/truss/tests/test_config.py index ffc4d6ae0..6529475e7 100644 --- a/truss/tests/test_config.py +++ b/truss/tests/test_config.py @@ -49,11 +49,16 @@ def default_config() -> Dict[str, Any]: @pytest.fixture def trtllm_config(default_config) -> Dict[str, Any]: trtllm_config = default_config + trtllm_config["resources"] = { + "accelerator": Accelerator.L4.value, + "cpu": "1", + "memory": "24Gi", + "use_gpu": True, + } trtllm_config["trt_llm"] = { "build": { "base_model": "llama", - "max_input_len": 1024, - "max_output_len": 1024, + "max_seq_len": 2048, "max_batch_size": 512, "checkpoint_repository": { "source": "HF", diff --git a/truss/trt_llm/validation.py b/truss/trt_llm/validation.py index 85837df17..d4bd2c659 100644 --- a/truss/trt_llm/validation.py +++ b/truss/trt_llm/validation.py @@ -31,7 +31,8 @@ def _verify_has_class_init_arg(source: str, class_name: str, arg_name: str): raise ValidationError( ( "Model class `__init__` method is required to have `trt_llm` as an argument. Please add that argument.\n " - "Or if you want to use the automatically generated model class then remove the `model.py` file." + "Or if you want to use the automatically generated model class then remove the `model.py` file.\n " + "Refer to https://docs.baseten.co/performance/engine-builder-customization for details on engine object usage." ) ) diff --git a/truss/truss_config.py b/truss/truss_config.py index e7a53b739..d1da2582d 100644 --- a/truss/truss_config.py +++ b/truss/truss_config.py @@ -649,7 +649,7 @@ def to_dict(self, verbose: bool = True): def clone(self): return TrussConfig.from_dict(self.to_dict()) - def _validate_quant_format_and_accelerator_for_trt_llm_builder(self) -> None: + def _validate_accelerator_for_trt_llm_builder(self) -> None: if self.trt_llm and self.trt_llm.build: if ( self.trt_llm.build.quantization_type @@ -665,9 +665,16 @@ def _validate_quant_format_and_accelerator_for_trt_llm_builder(self) -> None: ] and self.resources.accelerator.accelerator not in [ Accelerator.H100, Accelerator.H100_40GB, + Accelerator.L4, ]: raise ValueError( - "FP8 quantization is only supported on H100 accelerators" + "FP8 quantization is only supported on L4 and H100 accelerators" + ) + tensor_parallel_count = self.trt_llm.build.tensor_parallel_count + + if tensor_parallel_count != self.resources.accelerator.count: + raise ValueError( + "Tensor parallelism and GPU count must be the same for TRT-LLM" ) def validate(self): @@ -692,7 +699,7 @@ def validate(self): raise ValueError( "Please ensure that only one of `requirements` and `requirements_file` is specified" ) - self._validate_quant_format_and_accelerator_for_trt_llm_builder() + self._validate_accelerator_for_trt_llm_builder() def _handle_env_vars(env_vars: Dict[str, Any]) -> Dict[str, str]: