From 500669bceec7103dfb94dd6a4e1f1e0e146c1195 Mon Sep 17 00:00:00 2001 From: Jaideep Rao Date: Wed, 17 Jul 2024 01:55:09 -0400 Subject: [PATCH] feat: allow ilab model download to pull from OCI registries Check if repository supplied during ilab model download meets the URL structure of an OCI registry. If so, leverage skopeo to copy image layers into cache, and apply mapping logic to move the model files into the permanent models directory. This PR also refactors existing ilab model download code to introduce base and implementation classes to support multiple downloading backends (HF vs OCI). Assumes skopeo is installed on the system and uses OCI v1.1 Signed-off-by: Jaideep Rao --- .github/workflows/test.yml | 7 +- scripts/functional-tests.sh | 27 +++ src/instructlab/configuration.py | 6 + src/instructlab/model/download.py | 286 ++++++++++++++++++++++++++---- src/instructlab/utils.py | 17 ++ 5 files changed, 310 insertions(+), 33 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 549ba7bf3..d79dca2cb 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -29,6 +29,7 @@ on: env: LC_ALL: en_US.UTF-8 + REGISTRY_AUTH_FILE: /run/containers/1001/auth.json defaults: run: @@ -67,15 +68,15 @@ jobs: if: matrix.platform != 'macos-latest' uses: ./.github/actions/free-disk-space - - name: Install the expect package + - name: Install the expect and skopeo package if: startsWith(matrix.platform, 'ubuntu') run: | - sudo apt-get install -y expect + sudo apt-get install -y expect skopeo - name: Install tools on MacOS if: startsWith(matrix.platform, 'macos') run: | - brew install expect coreutils bash + brew install expect coreutils bash skopeo - name: Setup Python ${{ matrix.python }} uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # v5.1.1 diff --git a/scripts/functional-tests.sh b/scripts/functional-tests.sh index d8c174c08..1f6949b4d 100755 --- a/scripts/functional-tests.sh +++ b/scripts/functional-tests.sh @@ -161,6 +161,31 @@ fi # download the latest version of the ilab ilab model download +test_oci_model_download_with_vllm_backend(){ + # Enable globstar for recursive globbing + shopt -s globstar + + # Run the ilab model download command with REGISTRY_AUTH_FILE + REGISTRY_AUTH_FILE=$HOME/auth.json ilab model download --repository docker://quay.io/ai-lab/models/granite-7b-lab --release latest --model-dir models/instructlab + + patterns=( + "models/instructlab/config.json" + "models/instructlab/tokenizer.json" + "models/instructlab/tokenizer_config.json" + "models/instructlab/*.safetensors" + ) + + for pattern in "${patterns[@]}" + do + matching_files=("$pattern") + if [ ${#matching_files[@]} -eq 0 ] + then + echo "No files found matching pattern: $pattern" + exit 1 + fi + done +} + # check that ilab model serve is working test_bind_port(){ local formatted_script @@ -546,6 +571,8 @@ test_server_chat_template() { # MAIN # ######## # call cleanup in-between each test so they can run without conflicting with the server/chat process +test_oci_model_download_with_vllm_backend +cleanup test_bind_port cleanup test_ctx_size diff --git a/src/instructlab/configuration.py b/src/instructlab/configuration.py index b95ca2bdd..67ce73dc6 100644 --- a/src/instructlab/configuration.py +++ b/src/instructlab/configuration.py @@ -43,6 +43,7 @@ class STORAGE_DIR_NAMES: ILAB = "instructlab" DATASETS = "datasets" CHECKPOINTS = "checkpoints" + OCI = "oci" MODELS = "models" TAXONOMY = "taxonomy" INTERNAL = ( @@ -98,6 +99,10 @@ def _reset(self): def CHECKPOINTS_DIR(self) -> str: return path.join(self._data_dir, STORAGE_DIR_NAMES.CHECKPOINTS) + @property + def OCI_DIR(self) -> str: + return path.join(self._cache_home, STORAGE_DIR_NAMES.OCI) + @property def DATASETS_DIR(self) -> str: return path.join(self._data_dir, STORAGE_DIR_NAMES.DATASETS) @@ -545,6 +550,7 @@ def ensure_storage_directories_exist(): DEFAULTS._data_dir, DEFAULTS.CHATLOGS_DIR, DEFAULTS.CHECKPOINTS_DIR, + DEFAULTS.OCI_DIR, DEFAULTS.DATASETS_DIR, DEFAULTS.EVAL_DATA_DIR, DEFAULTS.INTERNAL_DIR, diff --git a/src/instructlab/model/download.py b/src/instructlab/model/download.py index c39e47062..76a024152 100644 --- a/src/instructlab/model/download.py +++ b/src/instructlab/model/download.py @@ -1,7 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # Standard +import abc +import json import os +import re +import subprocess # Third Party from huggingface_hub import hf_hub_download, list_repo_files @@ -12,6 +16,233 @@ # First Party from instructlab import clickext from instructlab.configuration import DEFAULTS +from instructlab.utils import is_huggingface_repo, is_oci_repo + + +class Downloader(abc.ABC): + """Base class for a downloading backend""" + + def __init__( + self, + repository: str, + release: str, + download_dest: str, + ) -> None: + self.repository = repository + self.release = release + self.download_dest = download_dest + + @abc.abstractmethod + def download(self) -> None: + """Downloads model from specified repo/release and stores it into download_dest""" + + +class HFDownloader(Downloader): + """Class to handle downloading safetensors and GGUF models from Huggingface""" + + def __init__( + self, + repository: str, + release: str, + download_dest: str, + filename: str, + hf_token: str, + ctx, + ) -> None: + super().__init__( + repository=repository, release=release, download_dest=download_dest + ) + self.repository = repository + self.release = release + self.download_dest = download_dest + self.filename = filename + self.hf_token = hf_token + self.ctx = ctx + + def download(self): + """Download the model(s) to train""" + click.echo( + f"Downloading model from huggingface: {self.repository}@{self.release} to {self.download_dest}..." + ) + + if self.hf_token == "" and "instructlab" not in self.repository: + raise ValueError( + """HF_TOKEN var needs to be set in your environment to download HF Model. + Alternatively, the token can be passed with --hf-token flag. + The HF Token is used to authenticate your identity to the Hugging Face Hub.""" + ) + + try: + if self.ctx.obj is not None: + hf_logging.set_verbosity(self.ctx.obj.config.general.log_level.upper()) + files = list_repo_files(repo_id=self.repository, token=self.hf_token) + if any(".safetensors" in string for string in files): + self.download_safetensors() + else: + self.download_gguf() + + except Exception as exc: + click.secho( + f"Downloading model failed with the following Hugging Face Hub error: {exc}", + fg="red", + ) + raise click.exceptions.Exit(1) + + def download_gguf(self) -> None: + try: + hf_hub_download( + token=self.hf_token, + repo_id=self.repository, + revision=self.release, + filename=self.filename, + local_dir=self.download_dest, + ) + + except Exception as exc: + click.secho( + f"Downloading GGUF model failed with the following HuggingFace Hub error: {exc}", + fg="red", + ) + raise click.exceptions.Exit(1) + + def download_safetensors(self) -> None: + try: + if not os.path.exists(os.path.join(self.download_dest, self.repository)): + os.makedirs( + name=os.path.join(self.download_dest, self.repository), + exist_ok=True, + ) + snapshot_download( + token=self.hf_token, + repo_id=self.repository, + revision=self.release, + local_dir=os.path.join(self.download_dest, self.repository), + ) + except Exception as exc: + click.secho( + f"Downloading safetensors model failed with the following HuggingFace Hub error: {exc}", + fg="red", + ) + raise click.exceptions.Exit(1) + + +class OCIDownloader(Downloader): + """ + Class to handle downloading safetensors models from OCI Registries + We are leveraging OCI v1.1 for this functionality + """ + + def __init__(self, repository: str, release: str, download_dest: str, ctx) -> None: + super().__init__( + repository=repository, release=release, download_dest=download_dest + ) + self.repository = repository + self.release = release + self.download_dest = download_dest + self.ctx = ctx + + def _build_oci_model_file_map(self, oci_model_path: str) -> dict: + """ + Helper function to build a mapping between blob files and what they represent + """ + index_hash = "" + try: + with open(f"{oci_model_path}/index.json", mode="r", encoding="UTF-8") as f: + index_ref = json.load(f) + match = re.search("sha256:(.*)", index_ref["manifests"][0]["digest"]) + + if match: + index_hash = match.group(1) + else: + click.echo(f"could not find hash for index file at: {oci_model_path}") + raise click.exceptions.Exit(1) + except FileNotFoundError as exc: + raise ValueError(f"file not found: {oci_model_path}/index.json") from exc + except json.JSONDecodeError as exc: + raise ValueError( + f"could not read JSON file: {oci_model_path}/index.json" + ) from exc + except Exception as exc: + raise ValueError("unexpected error occurred: {e}") from exc + + try: + with open( + f"{oci_model_path}/blobs/sha256/{index_hash}", + mode="r", + encoding="UTF-8", + ) as f: + index = json.load(f) + except FileNotFoundError as exc: + raise ValueError(f"file not found: {oci_model_path}/index.json") from exc + except json.JSONDecodeError as exc: + raise ValueError( + f"could not read JSON file: {oci_model_path}/index.json" + ) from exc + except Exception as exc: + raise ValueError("unexpected error occurred: {e}") from exc + + title_ref = "org.opencontainers.image.title" + oci_model_file_map = {} + + for layer in index["layers"]: + match = re.search("sha256:(.*)", layer["digest"]) + + if match: + blob_name = match.group(1) + oci_model_file_map[blob_name] = layer["annotations"][title_ref] + + return oci_model_file_map + + def download(self): + click.echo( + f"Downloading model from OCI registry: {self.repository}@{self.release} to {self.download_dest}..." + ) + + os.makedirs(self.download_dest, exist_ok=True) + model_name = self.repository.split("/")[-1] + oci_dir = f"{DEFAULTS.OCI_DIR}/{model_name}" + os.makedirs(oci_dir, exist_ok=True) + + command = [ + "skopeo", + "copy", + f"{self.repository}:{self.release}", + f"oci:{oci_dir}", + ] + if self.ctx.obj.config.general.log_level == "DEBUG": + command.append("--debug") + + try: + subprocess.run(command, check=True) + except FileNotFoundError as exc: + raise FileNotFoundError( + "skopeo not installed, but required to perform downloads from OCI registries. Exiting", + ) from exc + except Exception as e: + raise ValueError( + f"CalledProcessError: command exited with non-zero code: {e}" + ) from e + + file_map = self._build_oci_model_file_map(oci_dir) + + for _, _, files in os.walk(f"{oci_dir}/blobs/sha256/"): + for name in files: + if name not in file_map: + continue + dest = file_map[name] + if not os.path.exists(os.path.join(self.download_dest, model_name)): + os.makedirs( + os.path.join(self.download_dest, model_name), exist_ok=True + ) + # unlink any existing version of the file + if os.path.exists(os.path.join(self.download_dest, model_name, dest)): + os.unlink(os.path.join(self.download_dest, model_name, dest)) + + # create hard link to files in cache, to avoid redownloading if model has been downloaded before + os.link( + os.path.join(f"{oci_dir}/blobs/sha256/", name), + os.path.join(self.download_dest, model_name, dest), + ) @click.command() @@ -19,13 +250,13 @@ "--repository", default=DEFAULTS.MERLINITE_GGUF_REPO, # TODO: add to config.yaml show_default=True, - help="Hugging Face repository of the model to download.", + help="HuggingFace or OCI repository of the model to download.", ) @click.option( "--release", default="main", # TODO: add to config.yaml show_default=True, - help="The git revision of the model to download - e.g. a branch, tag, or commit hash.", + help="The revision of the model to download - e.g. a branch, tag, or commit hash for Huggingface repositories and tag or commit has for OCI repositories.", ) @click.option( "--filename", @@ -48,38 +279,33 @@ @click.pass_context @clickext.display_params def download(ctx, repository, release, filename, model_dir, hf_token): - """Download the model(s) to train""" - click.echo(f"Downloading model from {repository}@{release} to {model_dir}...") - if hf_token == "" and "instructlab" not in repository: - raise ValueError( - """HF_TOKEN var needs to be set in your environment to download HF Model. - Alternatively, the token can be passed with --hf-token flag. - The HF Token is used to authenticate your identity to the Hugging Face Hub.""" + downloader = None + + if is_oci_repo(repository): + downloader = OCIDownloader( + repository=repository, release=release, download_dest=model_dir, ctx=ctx + ) + elif is_huggingface_repo(repository): + downloader = HFDownloader( + repository=repository, + release=release, + download_dest=model_dir, + filename=filename, + hf_token=hf_token, + ctx=ctx, + ) + else: + click.secho( + f"repository {repository} matches neither Huggingface, nor OCI registry format. Please supply a valid repository", + fg="red", ) + raise click.exceptions.Exit(1) + try: - if ctx.obj is not None: - hf_logging.set_verbosity(ctx.obj.config.general.log_level.upper()) - files = list_repo_files(repo_id=repository, token=hf_token) - if any(".safetensors" in string for string in files): - if not os.path.exists(os.path.join(model_dir, repository)): - os.makedirs(name=os.path.join(model_dir, repository), exist_ok=True) - snapshot_download( - token=hf_token, - repo_id=repository, - revision=release, - local_dir=os.path.join(model_dir, repository), - ) - else: - hf_hub_download( - token=hf_token, - repo_id=repository, - revision=release, - filename=filename, - local_dir=model_dir, - ) + downloader.download() except Exception as exc: click.secho( - f"Downloading model failed with the following Hugging Face Hub error: {exc}", + f"Downloading model failed with the following error: {exc}", fg="red", ) raise click.exceptions.Exit(1) diff --git a/src/instructlab/utils.py b/src/instructlab/utils.py index a17f2c45f..5498c197f 100644 --- a/src/instructlab/utils.py +++ b/src/instructlab/utils.py @@ -605,3 +605,20 @@ def ensure_legacy_dataset( return dataset # type: ignore return convert_messages_to_legacy_dataset(dataset) # type: ignore + + +def is_oci_repo(repo_url: str) -> bool: + """ + Checks if a provided repository follows the OCI registry URL syntax + """ + + # TO DO: flesh this out and make it a more robust check + oci_url_prefix = "docker://" + return repo_url.startswith(oci_url_prefix) + + +def is_huggingface_repo(repo_name: str) -> bool: + # allow alphanumerics, underscores, hyphens and periods in huggingface repo names + # repo name should be of the format / + pattern = r"^[\w.-]+\/[\w.-]+$" + return re.match(pattern, repo_name) is not None