Skip to content

Commit

Permalink
Not raising error when valid tokenizer not found (#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
kooyunmo authored Aug 3, 2023
1 parent a4583fc commit fe6008b
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 21 deletions.
1 change: 0 additions & 1 deletion docs/docs/sdk/resource/checkpoint.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -632,4 +632,3 @@ Convert Hugging Face model checkpoint to PeriFlow format.
|------|-------------|
| `NotFoundError` | Raised when `model_name_or_path` is not found. Also raised when `tokenizer_output_dir` is not found. |
| `CheckpointConversionError` | Raised when given model architecture from checkpoint is not supported to convert. |
| `TokenizerNotFoundError` | Raised when `tokenizer_output_dir` is not `None` and the model does not have the PeriFlow-compatible tokenizer implementation, which is equivalent to Hugging Face 'fast' tokenizer. Refer to [this link](https://huggingface.co/docs/transformers/main_classes/tokenizer#tokenizer) to get more info. |
3 changes: 1 addition & 2 deletions periflow/cli/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
InvalidPathError,
NotFoundError,
NotSupportedError,
TokenizerNotFoundError,
)
from periflow.formatter import (
JSONFormatter,
Expand Down Expand Up @@ -708,7 +707,7 @@ def convert(
cache_dir=cache_dir,
dry_run=dry_run,
)
except (NotFoundError, CheckpointConversionError, TokenizerNotFoundError) as exc:
except (NotFoundError, CheckpointConversionError) as exc:
secho_error_and_exit(str(exc))

msg = (
Expand Down
5 changes: 3 additions & 2 deletions periflow/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,16 @@
from periflow.utils.format import secho_error_and_exit
from periflow.utils.request import DEFAULT_REQ_TIMEOUT
from periflow.utils.url import get_training_uri
from periflow.utils.validate import validate_cli_version
from periflow.utils.validate import validate_package_version
from periflow.utils.version import get_installed_version

app = typer.Typer(
help="Supercharge Generative AI Serving 🚀",
no_args_is_help=True,
context_settings={"help_option_names": ["-h", "--help"]},
add_completion=False,
callback=validate_cli_version,
callback=validate_package_version,
pretty_exceptions_enable=False,
)

app.add_typer(credential.app, name="credential", help="Manage credentials")
Expand Down
19 changes: 13 additions & 6 deletions periflow/converter/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,15 +121,22 @@ def save_tokenizer(
if not os.path.isdir(save_dir):
raise NotFoundError(f"Directory '{save_dir}' is not found.")

tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path,
cache_dir=cache_dir,
trust_remote_code=True,
)
try:
tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path,
cache_dir=cache_dir,
trust_remote_code=True,
)
except OSError as exc:
raise TokenizerNotFoundError(str(exc)) from exc

if not tokenizer.is_fast:
raise TokenizerNotFoundError
raise TokenizerNotFoundError(
"This model does not support PeriFlow-compatible tokenizer"
)

saved_file_paths = tokenizer.save_pretrained(save_directory=save_dir)

tokenizer_json_path = None
for path in saved_file_paths:
if "tokenizer.json" == os.path.basename(path):
Expand Down
17 changes: 10 additions & 7 deletions periflow/sdk/resource/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

"""PeriFlow Checkpoint SDK."""

# pylint: disable=line-too-long, arguments-differ, too-many-arguments, too-many-locals, redefined-builtin
# pylint: disable=line-too-long, arguments-differ, too-many-arguments, too-many-statements, too-many-locals, redefined-builtin

from __future__ import annotations

Expand All @@ -24,6 +24,7 @@
NotFoundError,
NotSupportedCheckpointError,
PeriFlowInternalError,
TokenizerNotFoundError,
)
from periflow.logging import logger
from periflow.schema.resource.v1.checkpoint import V1Checkpoint
Expand Down Expand Up @@ -816,7 +817,6 @@ def convert(
Raises:
NotFoundError: Raised when `model_name_or_path` is not found. Also raised when `tokenizer_output_dir` is not found.
CheckpointConversionError: Raised when given model architecture from checkpoint is not supported to convert.
TokenizerNotFoundError: Raised when `tokenizer_output_dir` is not `None` and the model does not have the PeriFlow-compatible tokenizer implementation, which is equivalent to Hugging Face 'fast' tokenizer. Refer to [this link](https://huggingface.co/docs/transformers/main_classes/tokenizer#tokenizer) to get more info.
"""
try:
Expand Down Expand Up @@ -923,8 +923,11 @@ def convert(
yaml.dump(attr, file, sort_keys=False)

if tokenizer_output_dir is not None:
save_tokenizer(
model_name_or_path=model_name_or_path,
cache_dir=cache_dir,
save_dir=tokenizer_output_dir,
)
try:
save_tokenizer(
model_name_or_path=model_name_or_path,
cache_dir=cache_dir,
save_dir=tokenizer_output_dir,
)
except TokenizerNotFoundError as exc:
logger.warn(str(exc))
6 changes: 3 additions & 3 deletions periflow/utils/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,15 @@ def validate_cloud_storage_type(val: StorageType) -> None:
)


def validate_cli_version() -> None:
def validate_package_version() -> None:
"""Validate the installed CLI version."""
installed_version = get_installed_version()
if not is_latest_version(installed_version):
latest_version = get_latest_version()
secho_error_and_exit(
f"CLI version({installed_version}) is deprecated. "
f"Package version({installed_version}) is deprecated. "
f"Please install the latest version({latest_version}) with "
f"'pip install {PERIFLOW_PACKAGE_NAME}=={latest_version} -U --no-cache-dir'."
f"'pip install {PERIFLOW_PACKAGE_NAME}=={latest_version} -U'."
)


Expand Down

0 comments on commit fe6008b

Please sign in to comment.