Skip to content

Commit

Permalink
add validation for trt-llm HF repo (#1147)
Browse files Browse the repository at this point in the history
* add validation for trt-llm HF repo

* use constant, and update validationerror message
  • Loading branch information
dsingal0 authored Sep 17, 2024
1 parent 18d3395 commit a121e69
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 0 deletions.
12 changes: 12 additions & 0 deletions truss/config/trt_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from enum import Enum
from typing import Optional

from huggingface_hub.errors import HFValidationError
from huggingface_hub.utils import validate_repo_id
from pydantic import BaseModel, validator
from rich.console import Console

Expand Down Expand Up @@ -97,6 +99,8 @@ def __init__(self, **data):
super().__init__(**data)
self._validate_minimum_required_configuration()
self._validate_kv_cache_flags()
if self.build.checkpoint_repository.source == CheckpointSource.HF:
self._validate_hf_repo_id()

# In pydantic v2 this would be `@model_validator(mode="after")` and
# the __init__ override can be removed.
Expand Down Expand Up @@ -131,6 +135,14 @@ def _validate_kv_cache_flags(self):
raise ValueError("Using fp8 context fmha requires paged context fmha")
return self

def _validate_hf_repo_id(self):
try:
validate_repo_id(self.build.checkpoint_repository.repo)
except HFValidationError as e:
raise ValueError(
f"HuggingFace repository validation failed: {str(e)}"
) from e

@property
def requires_build(self):
if self.build is not None:
Expand Down
18 changes: 18 additions & 0 deletions truss/tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,24 @@ def test_plugin_paged_context_fmha_check(trtllm_config):
TrussConfig.from_dict(trtllm_config)


@pytest.mark.parametrize(
"repo",
[
"./llama-3.1-8b",
"../my-model-is-in-parent-directory",
"~/.huggingface/my--model--cache/model",
"foo.git",
"datasets/foo/bar",
".repo_id" "other..repo..id",
],
)
def test_invalid_hf_repo(trtllm_config, repo):
trtllm_config["trt_llm"]["build"]["checkpoint_repository"]["source"] = "HF"
trtllm_config["trt_llm"]["build"]["checkpoint_repository"]["repo"] = repo
with pytest.raises(ValueError):
TrussConfig.from_dict(trtllm_config)


def test_plugin_paged_fp8_context_fmha_check(trtllm_config):
trtllm_config["trt_llm"]["build"]["plugin_configuration"] = {
"paged_kv_cache": False,
Expand Down

0 comments on commit a121e69

Please sign in to comment.