Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions .github/workflows/release_model.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
name: release_model

env:
model_name: "replicate/llama-7b"

on:
push:
branches:
- main

jobs:
release:
runs-on: [self-hosted, gpu-a100]
strategy:
matrix:
model: ['llama-7b']
steps:
- uses: actions/checkout@v3

- name: Install Cog
run: |
sudo curl -o /usr/local/bin/cog -L https://github.com/replicate/cog/releases/latest/download/cog_`uname -s`_`uname -m`
sudo chmod +x /usr/local/bin/cog

- name: Select model
run:
cog run python select_model.py --model_name ${{ matrix.model }} --model_path ${{ secrets.LLAMA_7B_PATH }}

- name: Log in to Replicate
env:
REPLICATE_API_TOKEN: ${{ secrets.REPLICATE_API_TOKEN }}
run: |
echo $REPLICATE_API_TOKEN | cog login --token-stdin

- name: Push to Replicate
run: |
cog push r8.im/${{ env.model_name }}
33 changes: 33 additions & 0 deletions .github/workflows/run_tests.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
name: run_tests

on:
pull_request:
branches:
- main

jobs:
test:
runs-on: [self-hosted, gpu-a100]
strategy:
matrix:
model: ['llama-7b']
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v3
with:
python-version: 3.8

- name: Install Python dependencies
run: pip install pytest

- name: Install Cog
run: |
sudo curl -o /usr/local/bin/cog -L https://github.com/replicate/cog/releases/latest/download/cog_`uname -s`_`uname -m`
sudo chmod +x /usr/local/bin/cog

- name: Select model
run:
cog run python select_model.py --model_name ${{ matrix.model }} --model_path ${{ secrets.LLAMA_7B_PATH }}

- name: Run tests
run: pytest -vv test/
12 changes: 7 additions & 5 deletions select_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
CONFIGS = {
"llama-7b": {
"cog_yaml_parameters": {"predictor":"predict.py:Predictor"},
"config_py_parameters": {"model_name": "SET_ME", "config_location": "llama_weights/llama-7b"}
"config_py_parameters": {"model_path": "SET_ME", "config_location": "llama_weights/llama-7b"}
},
"llama-13b": {
"cog_yaml_parameters": {"predictor":"predict.py:Predictor"},
"config_py_parameters": {"model_name": "SET_ME", "config_location": "llama_weights/llama-13b"}
"config_py_parameters": {"model_path": "SET_ME", "config_location": "llama_weights/llama-13b"}
},
}

Expand All @@ -36,15 +36,17 @@ def write_one_config(template_fpath: str, fname_out: str, config: dict):
os.chmod(fname_out, new_permissions)


def write_configs(model_name):
def write_configs(model_name, model_path):
master_config = CONFIGS[model_name]
#write_one_config("templates/cog_template.yaml", "cog.yaml", master_config['cog_yaml_parameters'])
write_one_config("templates/config_template.py", "cronfig.py", master_config['config_py_parameters'])
cfg = master_config["config_py_parameters"]
cfg['model_path'] = model_path
write_one_config("templates/config_template.py", "config.py", master_config['config_py_parameters'])


if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument("--model_name", default="llama-7b", help="name of the flan-t5 model you want to configure cog for")
parser.add_argument("--model_path", default="llama-7b", help="path to llama model, in cloud storage or locally")
args = parser.parse_args()

write_configs(args.model_name)
4 changes: 2 additions & 2 deletions templates/config_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from subclass import YieldingLlama

DEFAULT_MODEL_NAME = "{{model_name}}" # path from which we pull weights when there's no COG_WEIGHTS environment variable
DEFAULT_MODEL_NAME = "{{model_path}}" # path from which we pull weights when there's no COG_WEIGHTS environment variable
TOKENIZER_NAME = "llama_weights/tokenizer"
CONFIG_LOCATION = "{{config_location}}"

Expand Down Expand Up @@ -60,7 +60,7 @@ def load_tensorizer(
weights = str(weights)
local_weights = "/src/llama_tensors"
print("Deserializing weights...")
if 'http' in weights:
if 'http' in weights or 'gs' in weights:
pull_gcp_file(weights, local_weights)
else:
local_weights = weights
Expand Down