Skip to content

Commit 9a029be

Browse files
authored
Client side validation for non fp8 kv cache and fp8 context fmha (#1302)
* client side validation for non fp8 kv cache and fp8 context fmha * enable chunked context as default
1 parent e9f9231 commit 9a029be

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

truss/base/trt_llm_config.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ class TrussSpecDecMode(str, Enum):
8181

8282
class TrussTRTLLMRuntimeConfiguration(BaseModel):
8383
kv_cache_free_gpu_mem_fraction: float = 0.9
84-
enable_chunked_context: bool = False
84+
enable_chunked_context: bool = True
8585
batch_scheduler_policy: TrussTRTLLMBatchSchedulerPolicy = (
8686
TrussTRTLLMBatchSchedulerPolicy.GUARANTEED_NO_EVICT
8787
)
@@ -135,6 +135,11 @@ def _validate_kv_cache_flags(self):
135135
and not self.plugin_configuration.use_paged_context_fmha
136136
):
137137
raise ValueError("Using fp8 context fmha requires paged context fmha")
138+
if (
139+
self.plugin_configuration.use_fp8_context_fmha
140+
and not self.quantization_type == TrussTRTLLMQuantizationType.FP8_KV
141+
):
142+
raise ValueError("Using fp8 context fmha requires fp8 kv cache dtype")
138143
return self
139144

140145
def _validate_speculator_config(self):

truss/tests/test_config.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,22 @@ def test_plugin_paged_fp8_context_fmha_check(trtllm_config):
465465
TrussConfig.from_dict(trtllm_config)
466466

467467

468+
def test_fp8_context_fmha_check_kv_dtype(trtllm_config):
469+
trtllm_config["trt_llm"]["build"]["plugin_configuration"] = {
470+
"paged_kv_cache": True,
471+
"use_paged_context_fmha": True,
472+
"use_fp8_context_fmha": True,
473+
}
474+
trtllm_config["trt_llm"]["build"]["quantization_type"] = (
475+
TrussTRTLLMQuantizationType.FP8_KV.value
476+
)
477+
TrussConfig.from_dict(trtllm_config)
478+
479+
del trtllm_config["trt_llm"]["build"]["quantization_type"]
480+
with pytest.raises(ValueError):
481+
TrussConfig.from_dict(trtllm_config)
482+
483+
468484
@pytest.mark.parametrize("verbose, expect_equal", [(False, True), (True, False)])
469485
def test_to_dict_trtllm(
470486
verbose,

0 commit comments

Comments
 (0)