Skip to content

Commit

Permalink
tests: install skopeo from source
Browse files Browse the repository at this point in the history
Signed-off-by: Sébastien Han <seb@redhat.com>
  • Loading branch information
leseb authored and jaideepr97 committed Jul 24, 2024
1 parent 500669b commit c9b7137
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 73 deletions.
25 changes: 23 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,31 @@ jobs:
if: matrix.platform != 'macos-latest'
uses: ./.github/actions/free-disk-space

- name: Install the expect and skopeo package
- name: Install the expect package
if: startsWith(matrix.platform, 'ubuntu')
run: |
sudo apt-get install -y expect skopeo
sudo apt-get install -y expect
- name: Install go for skopeo
if: startsWith(matrix.platform, 'ubuntu')
uses: actions/setup-go@v5
with:
cache: false
go-version: 1.22.x

# Building from source because the latest version of skopeo
# available on Ubuntu is v1.4 which is very old and
# was running into issues downloading artifacts properly
- name: install skopeo from source
if: startsWith(matrix.platform, 'ubuntu')
run: |
sudo apt-get install libgpgme-dev libassuan-dev libbtrfs-dev libdevmapper-dev pkg-config -y
git clone --depth 1 https://github.com/containers/skopeo -b v1.15.0 "$GITHUB_WORKSPACE"/src/github.com/containers/skopeo
cd "$GITHUB_WORKSPACE"/src/github.com/containers/skopeo && \
make bin/skopeo && \
sudo install -D -m 755 bin/skopeo /usr/bin/skopeo && \
rm -rf "$GITHUB_WORKSPACE"/src/github.com/containers/skopeo
skopeo --version
- name: Install tools on MacOS
if: startsWith(matrix.platform, 'macos')
Expand Down
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

### Features

* `ilab model download` now supports downloading models from OCI registries. Repositories
that are prefixed by "docker://" and specified against `--repository` are treated as OCI
registries.
* `ilab` now uses dedicated directories for storing config and data files. On Linux, these
will generally be the XDG directories: `~/.config/instructlab` for config,
`~/.local/share/instructlab` for data, and `~/.cache` for temporary files, including downloaded
Expand Down
135 changes: 65 additions & 70 deletions src/instructlab/model/download.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# SPDX-License-Identifier: Apache-2.0

# Standard
from pathlib import Path
import abc
import json
import logging
import os
import re
import subprocess

# Third Party
Expand All @@ -16,7 +16,7 @@
# First Party
from instructlab import clickext
from instructlab.configuration import DEFAULTS
from instructlab.utils import is_huggingface_repo, is_oci_repo
from instructlab.utils import _extract_SHA, _load_json, is_huggingface_repo, is_oci_repo


class Downloader(abc.ABC):
Expand All @@ -38,7 +38,7 @@ def download(self) -> None:


class HFDownloader(Downloader):
"""Class to handle downloading safetensors and GGUF models from Huggingface"""
"""Class to handle downloading safetensors and GGUF models from Hugging Face"""

def __init__(
self,
Expand All @@ -52,17 +52,16 @@ def __init__(
super().__init__(
repository=repository, release=release, download_dest=download_dest
)
self.repository = repository
self.release = release
self.download_dest = download_dest
self.filename = filename
self.hf_token = hf_token
self.ctx = ctx

def download(self):
"""Download the model(s) to train"""
"""
Download specified model from Hugging Face
"""
click.echo(
f"Downloading model from huggingface: {self.repository}@{self.release} to {self.download_dest}..."
f"Downloading model from Hugging Face : {self.repository}@{self.release} to {self.download_dest}..."
)

if self.hf_token == "" and "instructlab" not in self.repository:
Expand All @@ -76,7 +75,7 @@ def download(self):
if self.ctx.obj is not None:
hf_logging.set_verbosity(self.ctx.obj.config.general.log_level.upper())
files = list_repo_files(repo_id=self.repository, token=self.hf_token)
if any(".safetensors" in string for string in files):
if any(".safetensors" in fname for fname in files):
self.download_safetensors()
else:
self.download_gguf()
Expand All @@ -100,18 +99,18 @@ def download_gguf(self) -> None:

except Exception as exc:
click.secho(
f"Downloading GGUF model failed with the following HuggingFace Hub error: {exc}",
f"Downloading GGUF model failed with the following Hugging Face Hub error: {exc}",
fg="red",
)
raise click.exceptions.Exit(1)

def download_safetensors(self) -> None:
try:
if not os.path.exists(os.path.join(self.download_dest, self.repository)):
os.makedirs(
name=os.path.join(self.download_dest, self.repository),
exist_ok=True,
)
os.makedirs(
name=os.path.join(self.download_dest, self.repository),
exist_ok=True,
)

snapshot_download(
token=self.hf_token,
repo_id=self.repository,
Expand All @@ -120,7 +119,7 @@ def download_safetensors(self) -> None:
)
except Exception as exc:
click.secho(
f"Downloading safetensors model failed with the following HuggingFace Hub error: {exc}",
f"Downloading safetensors model failed with the following Hugging Face Hub error: {exc}",
fg="red",
)
raise click.exceptions.Exit(1)
Expand All @@ -136,60 +135,53 @@ def __init__(self, repository: str, release: str, download_dest: str, ctx) -> No
super().__init__(
repository=repository, release=release, download_dest=download_dest
)
self.repository = repository
self.release = release
self.download_dest = download_dest
self.ctx = ctx

def _build_oci_model_file_map(self, oci_model_path: str) -> dict:
"""
Helper function to build a mapping between blob files and what they represent
Format for the index.json file can be found here: https://github.com/opencontainers/image-spec/blob/main/image-layout.md#indexjson-file
"""
index_hash = ""
index_ref_path = f"{oci_model_path}/index.json"
try:
with open(f"{oci_model_path}/index.json", mode="r", encoding="UTF-8") as f:
index_ref = json.load(f)
match = re.search("sha256:(.*)", index_ref["manifests"][0]["digest"])
index_ref = _load_json(Path(index_ref_path))
match = None
for manifest in index_ref["manifests"]:
if (
manifest["mediaType"]
== "application/vnd.oci.image.manifest.v1+json"
):
match = _extract_SHA(manifest["digest"])

if match:
index_hash = match.group(1)
else:
click.echo(f"could not find hash for index file at: {oci_model_path}")
raise click.exceptions.Exit(1)
except FileNotFoundError as exc:
raise ValueError(f"file not found: {oci_model_path}/index.json") from exc
except json.JSONDecodeError as exc:
raise ValueError(
f"could not read JSON file: {oci_model_path}/index.json"
) from exc
raise ValueError(
f"could not find hash for index file at: {oci_model_path}"
)
except Exception as exc:
raise ValueError("unexpected error occurred: {e}") from exc
raise exc

blob_dir = f"{oci_model_path}/blobs/sha256"
try:
with open(
f"{oci_model_path}/blobs/sha256/{index_hash}",
mode="r",
encoding="UTF-8",
) as f:
index = json.load(f)
except FileNotFoundError as exc:
raise ValueError(f"file not found: {oci_model_path}/index.json") from exc
except json.JSONDecodeError as exc:
raise ValueError(
f"could not read JSON file: {oci_model_path}/index.json"
) from exc
index = _load_json(Path(f"{blob_dir}/{index_hash}"))
except Exception as exc:
raise ValueError("unexpected error occurred: {e}") from exc
raise exc

title_ref = "org.opencontainers.image.title"
oci_model_file_map = {}
try:
for layer in index["layers"]:
match = _extract_SHA(layer["digest"])

for layer in index["layers"]:
match = re.search("sha256:(.*)", layer["digest"])

if match:
blob_name = match.group(1)
oci_model_file_map[blob_name] = layer["annotations"][title_ref]
if match:
blob_name = match.group(1)
oci_model_file_map[blob_name] = layer["annotations"][title_ref]
except Exception as exc:
raise ValueError(
f"failed to build OCI model file mapping from: {blob_dir}/{index_hash}"
) from exc

return oci_model_file_map

Expand All @@ -198,8 +190,8 @@ def download(self):
f"Downloading model from OCI registry: {self.repository}@{self.release} to {self.download_dest}..."
)

os.makedirs(self.download_dest, exist_ok=True)
model_name = self.repository.split("/")[-1]
os.makedirs(os.path.join(self.download_dest, model_name), exist_ok=True)
oci_dir = f"{DEFAULTS.OCI_DIR}/{model_name}"
os.makedirs(oci_dir, exist_ok=True)

Expand All @@ -209,7 +201,10 @@ def download(self):
f"{self.repository}:{self.release}",
f"oci:{oci_dir}",
]
if self.ctx.obj.config.general.log_level == "DEBUG":
if (
self.ctx.obj is not None
and self.ctx.obj.config.general.log_level == logging.DEBUG
):
command.append("--debug")

try:
Expand All @@ -219,29 +214,29 @@ def download(self):
"skopeo not installed, but required to perform downloads from OCI registries. Exiting",
) from exc
except Exception as e:
raise ValueError(
f"CalledProcessError: command exited with non-zero code: {e}"
) from e
click.secho(
f"unexpected error: {e}",
fg="red",
)
raise click.exceptions.Exit(1)

file_map = self._build_oci_model_file_map(oci_dir)

for _, _, files in os.walk(f"{oci_dir}/blobs/sha256/"):
blob_dir = f"{oci_dir}/blobs/sha256/"
for _, _, files in os.walk(blob_dir):
for name in files:
if name not in file_map:
continue
dest = file_map[name]
if not os.path.exists(os.path.join(self.download_dest, model_name)):
os.makedirs(
os.path.join(self.download_dest, model_name), exist_ok=True
)
dest_model_path = os.path.join(self.download_dest, model_name, dest)
# unlink any existing version of the file
if os.path.exists(os.path.join(self.download_dest, model_name, dest)):
os.unlink(os.path.join(self.download_dest, model_name, dest))
if os.path.exists(dest_model_path):
os.unlink(dest_model_path)

# create hard link to files in cache, to avoid redownloading if model has been downloaded before
os.link(
os.path.join(f"{oci_dir}/blobs/sha256/", name),
os.path.join(self.download_dest, model_name, dest),
# create symlink to files in cache, to avoid redownloading if model has been downloaded before
os.symlink(
os.path.join(blob_dir, name),
dest_model_path,
)


Expand All @@ -250,13 +245,13 @@ def download(self):
"--repository",
default=DEFAULTS.MERLINITE_GGUF_REPO, # TODO: add to config.yaml
show_default=True,
help="HuggingFace or OCI repository of the model to download.",
help="Hugging Face or OCI repository of the model to download.",
)
@click.option(
"--release",
default="main", # TODO: add to config.yaml
show_default=True,
help="The revision of the model to download - e.g. a branch, tag, or commit hash for Huggingface repositories and tag or commit has for OCI repositories.",
help="The revision of the model to download - e.g. a branch, tag, or commit hash for Hugging Face repositories and tag or commit hash for OCI repositories.",
)
@click.option(
"--filename",
Expand Down Expand Up @@ -296,7 +291,7 @@ def download(ctx, repository, release, filename, model_dir, hf_token):
)
else:
click.secho(
f"repository {repository} matches neither Huggingface, nor OCI registry format. Please supply a valid repository",
f"repository {repository} matches neither Hugging Face nor OCI registry format. Please supply a valid repository",
fg="red",
)
raise click.exceptions.Exit(1)
Expand Down
21 changes: 20 additions & 1 deletion src/instructlab/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,7 @@ def is_oci_repo(repo_url: str) -> bool:
Checks if a provided repository follows the OCI registry URL syntax
"""

# TO DO: flesh this out and make it a more robust check
# TODO: flesh this out and make it a more robust check
oci_url_prefix = "docker://"
return repo_url.startswith(oci_url_prefix)

Expand All @@ -622,3 +622,22 @@ def is_huggingface_repo(repo_name: str) -> bool:
# repo name should be of the format <owner>/<model>
pattern = r"^[\w.-]+\/[\w.-]+$"
return re.match(pattern, repo_name) is not None


def _load_json(file_path: Path):
try:
with open(
file_path,
encoding="UTF-8",
) as f:
return json.load(f)
except FileNotFoundError as e:
raise ValueError(f"file not found: {file_path}") from e
except json.JSONDecodeError as e:
raise ValueError(f"could not read JSON file: {file_path}") from e
except Exception as e:
raise ValueError("unexpected error occurred:") from e


def _extract_SHA(SHAstr: str):
return re.search("sha256:(.*)", SHAstr)

0 comments on commit c9b7137

Please sign in to comment.