diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index d79dca2cbb..b87fcdd3d4 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -68,10 +68,31 @@ jobs: if: matrix.platform != 'macos-latest' uses: ./.github/actions/free-disk-space - - name: Install the expect and skopeo package + - name: Install the expect package if: startsWith(matrix.platform, 'ubuntu') run: | - sudo apt-get install -y expect skopeo + sudo apt-get install -y expect + + - name: Install go for skopeo + if: startsWith(matrix.platform, 'ubuntu') + uses: actions/setup-go@v5 + with: + cache: false + go-version: 1.22.x + + # Building from source because the latest version of skopeo + # available on Ubuntu is v1.4 which is very old and + # was running into issues downloading artifacts properly + - name: install skopeo from source + if: startsWith(matrix.platform, 'ubuntu') + run: | + sudo apt-get install libgpgme-dev libassuan-dev libbtrfs-dev libdevmapper-dev pkg-config -y + git clone --depth 1 https://github.com/containers/skopeo -b v1.15.0 "$GITHUB_WORKSPACE"/src/github.com/containers/skopeo + cd "$GITHUB_WORKSPACE"/src/github.com/containers/skopeo && \ + make bin/skopeo && \ + sudo install -D -m 755 bin/skopeo /usr/bin/skopeo && \ + rm -rf "$GITHUB_WORKSPACE"/src/github.com/containers/skopeo + skopeo --version - name: Install tools on MacOS if: startsWith(matrix.platform, 'macos') diff --git a/CHANGELOG.md b/CHANGELOG.md index 2bd0835b18..9a34c21f13 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,9 @@ ### Features +* `ilab model download` now supports downloading models from OCI registries. Repositories + that are prefixed by "docker://" and specified against `--repository` are treated as OCI + registries. * `ilab` now uses dedicated directories for storing config and data files. On Linux, these will generally be the XDG directories: `~/.config/instructlab` for config, `~/.local/share/instructlab` for data, and `~/.cache` for temporary files, including downloaded diff --git a/src/instructlab/model/download.py b/src/instructlab/model/download.py index 76a0241525..1ac9cf55ff 100644 --- a/src/instructlab/model/download.py +++ b/src/instructlab/model/download.py @@ -1,10 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # Standard +from pathlib import Path import abc -import json +import logging import os -import re import subprocess # Third Party @@ -16,7 +16,7 @@ # First Party from instructlab import clickext from instructlab.configuration import DEFAULTS -from instructlab.utils import is_huggingface_repo, is_oci_repo +from instructlab.utils import _extract_SHA, _load_json, is_huggingface_repo, is_oci_repo class Downloader(abc.ABC): @@ -38,7 +38,7 @@ def download(self) -> None: class HFDownloader(Downloader): - """Class to handle downloading safetensors and GGUF models from Huggingface""" + """Class to handle downloading safetensors and GGUF models from Hugging Face""" def __init__( self, @@ -52,17 +52,16 @@ def __init__( super().__init__( repository=repository, release=release, download_dest=download_dest ) - self.repository = repository - self.release = release - self.download_dest = download_dest self.filename = filename self.hf_token = hf_token self.ctx = ctx def download(self): - """Download the model(s) to train""" + """ + Download specified model from Hugging Face + """ click.echo( - f"Downloading model from huggingface: {self.repository}@{self.release} to {self.download_dest}..." + f"Downloading model from Hugging Face : {self.repository}@{self.release} to {self.download_dest}..." ) if self.hf_token == "" and "instructlab" not in self.repository: @@ -76,7 +75,7 @@ def download(self): if self.ctx.obj is not None: hf_logging.set_verbosity(self.ctx.obj.config.general.log_level.upper()) files = list_repo_files(repo_id=self.repository, token=self.hf_token) - if any(".safetensors" in string for string in files): + if any(".safetensors" in fname for fname in files): self.download_safetensors() else: self.download_gguf() @@ -100,18 +99,18 @@ def download_gguf(self) -> None: except Exception as exc: click.secho( - f"Downloading GGUF model failed with the following HuggingFace Hub error: {exc}", + f"Downloading GGUF model failed with the following Hugging Face Hub error: {exc}", fg="red", ) raise click.exceptions.Exit(1) def download_safetensors(self) -> None: try: - if not os.path.exists(os.path.join(self.download_dest, self.repository)): - os.makedirs( - name=os.path.join(self.download_dest, self.repository), - exist_ok=True, - ) + os.makedirs( + name=os.path.join(self.download_dest, self.repository), + exist_ok=True, + ) + snapshot_download( token=self.hf_token, repo_id=self.repository, @@ -120,7 +119,7 @@ def download_safetensors(self) -> None: ) except Exception as exc: click.secho( - f"Downloading safetensors model failed with the following HuggingFace Hub error: {exc}", + f"Downloading safetensors model failed with the following Hugging Face Hub error: {exc}", fg="red", ) raise click.exceptions.Exit(1) @@ -136,60 +135,53 @@ def __init__(self, repository: str, release: str, download_dest: str, ctx) -> No super().__init__( repository=repository, release=release, download_dest=download_dest ) - self.repository = repository - self.release = release - self.download_dest = download_dest self.ctx = ctx def _build_oci_model_file_map(self, oci_model_path: str) -> dict: """ Helper function to build a mapping between blob files and what they represent + Format for the index.json file can be found here: https://github.com/opencontainers/image-spec/blob/main/image-layout.md#indexjson-file """ index_hash = "" + index_ref_path = f"{oci_model_path}/index.json" try: - with open(f"{oci_model_path}/index.json", mode="r", encoding="UTF-8") as f: - index_ref = json.load(f) - match = re.search("sha256:(.*)", index_ref["manifests"][0]["digest"]) + index_ref = _load_json(Path(index_ref_path)) + match = None + for manifest in index_ref["manifests"]: + if ( + manifest["mediaType"] + == "application/vnd.oci.image.manifest.v1+json" + ): + match = _extract_SHA(manifest["digest"]) if match: index_hash = match.group(1) else: - click.echo(f"could not find hash for index file at: {oci_model_path}") - raise click.exceptions.Exit(1) - except FileNotFoundError as exc: - raise ValueError(f"file not found: {oci_model_path}/index.json") from exc - except json.JSONDecodeError as exc: - raise ValueError( - f"could not read JSON file: {oci_model_path}/index.json" - ) from exc + raise ValueError( + f"could not find hash for index file at: {oci_model_path}" + ) except Exception as exc: - raise ValueError("unexpected error occurred: {e}") from exc + raise exc + blob_dir = f"{oci_model_path}/blobs/sha256" try: - with open( - f"{oci_model_path}/blobs/sha256/{index_hash}", - mode="r", - encoding="UTF-8", - ) as f: - index = json.load(f) - except FileNotFoundError as exc: - raise ValueError(f"file not found: {oci_model_path}/index.json") from exc - except json.JSONDecodeError as exc: - raise ValueError( - f"could not read JSON file: {oci_model_path}/index.json" - ) from exc + index = _load_json(Path(f"{blob_dir}/{index_hash}")) except Exception as exc: - raise ValueError("unexpected error occurred: {e}") from exc + raise exc title_ref = "org.opencontainers.image.title" oci_model_file_map = {} + try: + for layer in index["layers"]: + match = _extract_SHA(layer["digest"]) - for layer in index["layers"]: - match = re.search("sha256:(.*)", layer["digest"]) - - if match: - blob_name = match.group(1) - oci_model_file_map[blob_name] = layer["annotations"][title_ref] + if match: + blob_name = match.group(1) + oci_model_file_map[blob_name] = layer["annotations"][title_ref] + except Exception as exc: + raise ValueError( + f"failed to build OCI model file mapping from: {blob_dir}/{index_hash}" + ) from exc return oci_model_file_map @@ -198,8 +190,8 @@ def download(self): f"Downloading model from OCI registry: {self.repository}@{self.release} to {self.download_dest}..." ) - os.makedirs(self.download_dest, exist_ok=True) model_name = self.repository.split("/")[-1] + os.makedirs(os.path.join(self.download_dest, model_name), exist_ok=True) oci_dir = f"{DEFAULTS.OCI_DIR}/{model_name}" os.makedirs(oci_dir, exist_ok=True) @@ -209,7 +201,10 @@ def download(self): f"{self.repository}:{self.release}", f"oci:{oci_dir}", ] - if self.ctx.obj.config.general.log_level == "DEBUG": + if ( + self.ctx.obj is not None + and self.ctx.obj.config.general.log_level == logging.DEBUG + ): command.append("--debug") try: @@ -219,29 +214,29 @@ def download(self): "skopeo not installed, but required to perform downloads from OCI registries. Exiting", ) from exc except Exception as e: - raise ValueError( - f"CalledProcessError: command exited with non-zero code: {e}" - ) from e + click.secho( + f"unexpected error: {e}", + fg="red", + ) + raise click.exceptions.Exit(1) file_map = self._build_oci_model_file_map(oci_dir) - for _, _, files in os.walk(f"{oci_dir}/blobs/sha256/"): + blob_dir = f"{oci_dir}/blobs/sha256/" + for _, _, files in os.walk(blob_dir): for name in files: if name not in file_map: continue dest = file_map[name] - if not os.path.exists(os.path.join(self.download_dest, model_name)): - os.makedirs( - os.path.join(self.download_dest, model_name), exist_ok=True - ) + dest_model_path = os.path.join(self.download_dest, model_name, dest) # unlink any existing version of the file - if os.path.exists(os.path.join(self.download_dest, model_name, dest)): - os.unlink(os.path.join(self.download_dest, model_name, dest)) + if os.path.exists(dest_model_path): + os.unlink(dest_model_path) - # create hard link to files in cache, to avoid redownloading if model has been downloaded before - os.link( - os.path.join(f"{oci_dir}/blobs/sha256/", name), - os.path.join(self.download_dest, model_name, dest), + # create symlink to files in cache, to avoid redownloading if model has been downloaded before + os.symlink( + os.path.join(blob_dir, name), + dest_model_path, ) @@ -250,13 +245,13 @@ def download(self): "--repository", default=DEFAULTS.MERLINITE_GGUF_REPO, # TODO: add to config.yaml show_default=True, - help="HuggingFace or OCI repository of the model to download.", + help="Hugging Face or OCI repository of the model to download.", ) @click.option( "--release", default="main", # TODO: add to config.yaml show_default=True, - help="The revision of the model to download - e.g. a branch, tag, or commit hash for Huggingface repositories and tag or commit has for OCI repositories.", + help="The revision of the model to download - e.g. a branch, tag, or commit hash for Hugging Face repositories and tag or commit hash for OCI repositories.", ) @click.option( "--filename", @@ -296,7 +291,7 @@ def download(ctx, repository, release, filename, model_dir, hf_token): ) else: click.secho( - f"repository {repository} matches neither Huggingface, nor OCI registry format. Please supply a valid repository", + f"repository {repository} matches neither Hugging Face nor OCI registry format. Please supply a valid repository", fg="red", ) raise click.exceptions.Exit(1) diff --git a/src/instructlab/utils.py b/src/instructlab/utils.py index 5498c197f5..46abea135a 100644 --- a/src/instructlab/utils.py +++ b/src/instructlab/utils.py @@ -612,7 +612,7 @@ def is_oci_repo(repo_url: str) -> bool: Checks if a provided repository follows the OCI registry URL syntax """ - # TO DO: flesh this out and make it a more robust check + # TODO: flesh this out and make it a more robust check oci_url_prefix = "docker://" return repo_url.startswith(oci_url_prefix) @@ -622,3 +622,22 @@ def is_huggingface_repo(repo_name: str) -> bool: # repo name should be of the format / pattern = r"^[\w.-]+\/[\w.-]+$" return re.match(pattern, repo_name) is not None + + +def _load_json(file_path: Path): + try: + with open( + file_path, + encoding="UTF-8", + ) as f: + return json.load(f) + except FileNotFoundError as e: + raise ValueError(f"file not found: {file_path}") from e + except json.JSONDecodeError as e: + raise ValueError(f"could not read JSON file: {file_path}") from e + except Exception as e: + raise ValueError("unexpected error occurred:") from e + + +def _extract_SHA(SHAstr: str): + return re.search("sha256:(.*)", SHAstr)