Skip to content

Commit

Permalink
add a namespace package for turbine_models
Browse files Browse the repository at this point in the history
  • Loading branch information
dan-garvey committed Nov 29, 2023
1 parent cd063df commit 2ee4a0f
Show file tree
Hide file tree
Showing 10 changed files with 134 additions and 11 deletions.
47 changes: 47 additions & 0 deletions python/turbine_models/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# LLAMA 2 Inference

This example require some extra dependencies. Here's an easy way to get it running on a fresh server.

Don't forget to put in your huggingface token from https://huggingface.co/settings/tokens

```bash
#!/bin/bash


# if you don't insert it, you will be prompted to log in later;
# you may need to rerun this script after logging in
YOUR_HF_TOKEN="insert token for headless"

# clone and install dependencies
sudo apt install -y git
git clone https://github.com/nod-ai/SHARK-Turbine.git
cd SHARK-Turbine
pip install -r requirements.txt
pip install -r turbine-models-requirements.txt

# do an editable install from the cloned SHARK-Turbine
pip install --editable .

# Log in with Hugging Face CLI if token setup is required
if [[ $YOUR_HF_TOKEN == hf_* ]]; then
huggingface login --token $YOUR_HF_TOKEN
echo "Logged in with YOUR_HF_TOKEN."
elif [ -f ~/.cache/huggingface/token ]; then
# Read token from the file
TOKEN_CONTENT=$(cat ~/.cache/huggingface/token)

# Check if the token starts with "hf_"
if [[ $TOKEN_CONTENT == hf_* ]]; then
echo "Already logged in with a Hugging Face token."
else
echo "Token in file does not start with 'hf_'. Please log into huggingface to download models."
huggingface-cli login
fi
else
echo "Please log into huggingface to download models."
huggingface-cli login
fi

# Step 7: Run the Python script
python .\python\turbine_models\custom_models\stateless_llama.py --compile_to=torch --external_weights=safetensors --external_weight_file=llama_f32.safetensors
```
File renamed without changes.
8 changes: 4 additions & 4 deletions python/turbine_models/custom_models/stateless_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ def slice_up_to_step(global_pkv, seq_step, heads, hidden_dim):

def export_transformer_model(
hf_model_name,
hf_auth_token,
compile_to,
hf_auth_token=None,
compile_to="torch",
external_weights=None,
external_weight_file=None,
quantization=None,
Expand All @@ -85,7 +85,7 @@ def export_transformer_model(
mod = AutoModelForCausalLM.from_pretrained(
hf_model_name,
torch_dtype=torch.float,
use_auth_token=hf_auth_token,
token=hf_auth_token,
)
dtype = torch.float32
if precision == "f16":
Expand All @@ -94,7 +94,7 @@ def export_transformer_model(
tokenizer = AutoTokenizer.from_pretrained(
hf_model_name,
use_fast=False,
use_auth_token=hf_auth_token,
token=hf_auth_token,
)
# TODO: generate these values instead of magic numbers
HEADS = 32
Expand Down
Empty file.
3 changes: 3 additions & 0 deletions python/turbine_models/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[build-system]
requires = ["setuptools", "wheel"]
build-backend = "setuptools.build_meta"
72 changes: 72 additions & 0 deletions python/turbine_models/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import json
import os
from pathlib import Path

from setuptools import find_namespace_packages, setup


#### TURBINE MODELS SETUP ####


TURBINE_MODELS_DIR = os.path.realpath(os.path.dirname(__file__))
TURBINE_ROOT_DIR = Path(TURBINE_MODELS_DIR).parent.parent
print(TURBINE_ROOT_DIR)
VERSION_INFO_FILE = os.path.join(TURBINE_MODELS_DIR, "version_info.json")


with open(
os.path.join(
TURBINE_MODELS_DIR,
"README.md",
),
"rt",
) as f:
README = f.read()


def load_version_info():
with open(VERSION_INFO_FILE, "rt") as f:
return json.load(f)


version_info = load_version_info()
PACKAGE_VERSION = version_info["package-version"]

setup(
name=f"turbine-models",
version=f"{PACKAGE_VERSION}",
author="SHARK Authors",
author_email="dan@nod.ai",
description="SHARK Turbine Machine Learning Model Zoo",
long_description=README,
long_description_content_type="text/markdown",
url="https://github.com/nod-ai/SHARK-Turbine",
license="Apache-2.0",
classifiers=[
"Development Status :: 3 - Alpha",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3",
],
package_dir={
"": f"{str(TURBINE_ROOT_DIR)}/python",
},
packages=find_namespace_packages(
include=[
"turbine_models",
"turbine_models.*",
],
where=f"{str(TURBINE_ROOT_DIR)}/python",
),
entry_points={
"torch_dynamo_backends": [
"turbine_cpu = shark_turbine.dynamo.backends.cpu:backend",
],
},
install_requires=[
"Shark-Turbine",
"brevitas @ git+https://github.com/Xilinx/brevitas.git@6695e8df7f6a2c7715b9ed69c4b78157376bb60b",
"protobuf",
"sentencepiece",
"transformers",
],
)
Empty file.
10 changes: 5 additions & 5 deletions python/turbine_models/tests/llama_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,16 @@
class LLamaTest(unittest.TestCase):
def testExportTransformerModel(self):
llama.export_transformer_model(
"meta-llama/Llama-2-7b-chat-hf",
# TODO: replace with github secret
"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",
# This is a public model, so no auth required
"llSourcell/medllama2_7b",
None,
"torch",
"safetensors",
"llama_f32.safetensors",
"medllama2_f32.safetensors",
None,
"f32",
)
os.remove("llama_f32.safetensors")
os.remove("medllama2_f32.safetensors")


if __name__ == "__main__":
Expand Down
3 changes: 3 additions & 0 deletions python/turbine_models/version_info.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"package-version": "0.0.1.dev1"
}
2 changes: 0 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,10 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import argparse
import json
import os
import distutils.command.build
from pathlib import Path
import sys

from setuptools import find_namespace_packages, setup

Expand Down

0 comments on commit 2ee4a0f

Please sign in to comment.