Skip to content

Commit

Permalink
Update settings.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Smartappli authored Aug 15, 2024
1 parent 5af3b53 commit 4494458
Showing 1 changed file with 21 additions and 21 deletions.
42 changes: 21 additions & 21 deletions llama_cpp/server/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

import multiprocessing
from typing import Dict, List, Literal, Optional, Union, cast
from typing_extensions import Self

from pydantic import Field, model_validator
from pydantic_settings import BaseSettings
from typing_extensions import Self

import llama_cpp

Expand All @@ -19,7 +19,7 @@ class ModelSettings(BaseSettings):
model: str = Field(
description="The path to the model to use for generating completions.",
)
model_alias: str | None = Field(
model_alias: Optional[str] = Field(
default=None,
description="The alias of the model to use for generating completions.",
)
Expand All @@ -38,7 +38,7 @@ class ModelSettings(BaseSettings):
ge=0,
description="Main GPU to use.",
)
tensor_split: list[float] | None = Field(
tensor_split: Optional[List[float]] = Field(
default=None,
description="Split layers across multiple GPUs in proportion.",
)
Expand All @@ -53,11 +53,11 @@ class ModelSettings(BaseSettings):
default=llama_cpp.llama_supports_mlock(),
description="Use mlock.",
)
kv_overrides: list[str] | None = Field(
kv_overrides: Optional[List[str]] = Field(
default=None,
description="List of model kv overrides in the format key=type:value where type is one of (bool, int, float). Valid true values are (true, TRUE, 1), otherwise false.",
)
rpc_servers: str | None = Field(
rpc_servers: Optional[str] = Field(
default=None,
description="comma seperated list of rpc servers for offloading",
)
Expand Down Expand Up @@ -109,25 +109,25 @@ class ModelSettings(BaseSettings):
description="Last n tokens to keep for repeat penalty calculation.",
)
# LoRA Params
lora_base: str | None = Field(
lora_base: Optional[str] = Field(
default=None,
description="Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.",
)
lora_path: str | None = Field(
lora_path: Optional[str] = Field(
default=None,
description="Path to a LoRA file to apply to the model.",
)
# Backend Params
numa: bool | int = Field(
numa: Union[bool, int] = Field(
default=False,
description="Enable NUMA support.",
)
# Chat Format Params
chat_format: str | None = Field(
chat_format: Optional[str] = Field(
default=None,
description="Chat format to use.",
)
clip_model_path: str | None = Field(
clip_model_path: Optional[str] = Field(
default=None,
description="Path to a CLIP model to use for multi-modal chat completion.",
)
Expand All @@ -145,21 +145,21 @@ class ModelSettings(BaseSettings):
description="The size of the cache in bytes. Only used if cache is True.",
)
# Tokenizer Options
hf_tokenizer_config_path: str | None = Field(
hf_tokenizer_config_path: Optional[str] = Field(
default=None,
description="The path to a HuggingFace tokenizer_config.json file.",
)
hf_pretrained_model_name_or_path: str | None = Field(
hf_pretrained_model_name_or_path: Optional[str] = Field(
default=None,
description="The model name or path to a pretrained HuggingFace tokenizer model. Same as you would pass to AutoTokenizer.from_pretrained().",
)
# Loading from HuggingFace Model Hub
hf_model_repo_id: str | None = Field(
hf_model_repo_id: Optional[str] = Field(
default=None,
description="The model repo id to use for the HuggingFace tokenizer model.",
)
# Speculative Decoding
draft_model: str | None = Field(
draft_model: Optional[str] = Field(
default=None,
description="Method to use for speculative decoding. One of (prompt-lookup-decoding).",
)
Expand All @@ -168,11 +168,11 @@ class ModelSettings(BaseSettings):
description="Number of tokens to predict using the draft model.",
)
# KV Cache Quantization
type_k: int | None = Field(
type_k: Optional[int] = Field(
default=None,
description="Type of the key cache quantization.",
)
type_v: int | None = Field(
type_v: Optional[int] = Field(
default=None,
description="Type of the value cache quantization.",
)
Expand All @@ -187,7 +187,7 @@ class ModelSettings(BaseSettings):
def set_dynamic_defaults(self) -> Self:
# If n_threads or n_threads_batch is -1, set it to multiprocessing.cpu_count()
cpu_count = multiprocessing.cpu_count()
values = cast(dict[str, int], self)
values = cast(Dict[str, int], self)
if values.get("n_threads", 0) == -1:
values["n_threads"] = cpu_count
if values.get("n_threads_batch", 0) == -1:
Expand All @@ -201,14 +201,14 @@ class ServerSettings(BaseSettings):
# Uvicorn Settings
host: str = Field(default="localhost", description="Listen address")
port: int = Field(default=8000, description="Listen port")
ssl_keyfile: str | None = Field(
ssl_keyfile: Optional[str] = Field(
default=None, description="SSL key file for HTTPS",
)
ssl_certfile: str | None = Field(
ssl_certfile: Optional[str] = Field(
default=None, description="SSL certificate file for HTTPS",
)
# FastAPI Settings
api_key: str | None = Field(
api_key: Optional[str] = Field(
default=None,
description="API key for authentication. If set all requests need to be authenticated.",
)
Expand All @@ -233,4 +233,4 @@ class Settings(ServerSettings, ModelSettings):
class ConfigFileSettings(ServerSettings):
"""Configuration file format settings."""

models: list[ModelSettings] = Field(default=[], description="Model configs")
models: List[ModelSettings] = Field(default=[], description="Model configs")

0 comments on commit 4494458

Please sign in to comment.