Skip to content

Commit

Permalink
Update cli.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Smartappli authored Aug 15, 2024
1 parent 2d9ee84 commit e133736
Showing 1 changed file with 15 additions and 13 deletions.
28 changes: 15 additions & 13 deletions llama_cpp/server/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,34 +6,35 @@
from pydantic import BaseModel


def _get_base_type(annotation: type[Any]) -> type[Any]:
def _get_base_type(annotation: Type[Any]) -> Type[Any]:
if getattr(annotation, "__origin__", None) is Literal:
assert hasattr(annotation, "__args__") and len(annotation.__args__) >= 1 # type: ignore
return type(annotation.__args__[0]) # type: ignore
if getattr(annotation, "__origin__", None) is Union:
elif getattr(annotation, "__origin__", None) is Union:
assert hasattr(annotation, "__args__") and len(annotation.__args__) >= 1 # type: ignore
non_optional_args: list[type[Any]] = [
non_optional_args: List[Type[Any]] = [
arg for arg in annotation.__args__ if arg is not type(None) # type: ignore
]
if non_optional_args:
return _get_base_type(non_optional_args[0])
elif (
getattr(annotation, "__origin__", None) is list
or getattr(annotation, "__origin__", None) is list
or getattr(annotation, "__origin__", None) is List
):
assert hasattr(annotation, "__args__") and len(annotation.__args__) >= 1 # type: ignore
return _get_base_type(annotation.__args__[0]) # type: ignore
return annotation


def _contains_list_type(annotation: type[Any] | None) -> bool:
def _contains_list_type(annotation: Type[Any] | None) -> bool:
origin = getattr(annotation, "__origin__", None)

if origin is list or origin is list:
if origin is list or origin is List:
return True
if origin in (Literal, Union):
elif origin in (Literal, Union):
return any(_contains_list_type(arg) for arg in annotation.__args__) # type: ignore
return False
else:
return False


def _parse_bool_arg(arg: str | bytes | bool) -> bool:
Expand All @@ -47,12 +48,13 @@ def _parse_bool_arg(arg: str | bytes | bool) -> bool:

if arg_str in true_values:
return True
if arg_str in false_values:
elif arg_str in false_values:
return False
raise ValueError(f"Invalid boolean argument: {arg}")
else:
raise ValueError(f"Invalid boolean argument: {arg}")


def add_args_from_model(parser: argparse.ArgumentParser, model: type[BaseModel]):
def add_args_from_model(parser: argparse.ArgumentParser, model: Type[BaseModel]):
"""Add arguments from a pydantic model to an argparse parser."""

for name, field in model.model_fields.items():
Expand Down Expand Up @@ -80,7 +82,7 @@ def add_args_from_model(parser: argparse.ArgumentParser, model: type[BaseModel])
)


T = TypeVar("T", bound=type[BaseModel])
T = TypeVar("T", bound=Type[BaseModel])


def parse_model_from_args(model: T, args: argparse.Namespace) -> T:
Expand All @@ -90,5 +92,5 @@ def parse_model_from_args(model: T, args: argparse.Namespace) -> T:
k: v
for k, v in vars(args).items()
if v is not None and k in model.model_fields
},
}
)

0 comments on commit e133736

Please sign in to comment.