Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Xida Ren committed Dec 6, 2023
1 parent dd1c771 commit 7999ddb
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 0 deletions.
41 changes: 41 additions & 0 deletions .github/workflows/test_llama_end_to_end.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
name: Test Llama End to End

on:
workflow_dispatch:
pull_request:
push:
branches:
- refactor-stateless_llama.py

jobs:
test:
strategy:
matrix:
version: [3.11]
os: [xida-cpu-0]

runs-on: ${{matrix.os}}
steps:
- name: "Setting up Python"
uses: actions/setup-python@75f3110429a8c05be0e1bf360334e4cced2b63fa # v2.3.3
with:
python-version: ${{matrix.version}}

- name: "Checkout Code"
uses: actions/checkout@v2

- name: Sync source deps
run: |
python -m pip install --upgrade pip
# Note: We install in three steps in order to satisfy requirements
# from non default locations first. Installing the PyTorch CPU
# wheels saves multiple minutes and a lot of bandwidth on runner setup.
pip install --index-url https://download.pytorch.org/whl/cpu \
-r pytorch-cpu-requirements.txt \
-r torchvision-requirements.txt
pip install --upgrade -r requirements.txt
pip install -e .[testing]
- name: Run tests
run: |
pytest tests/custom_models/stateless_llama
47 changes: 47 additions & 0 deletions tests/custom_models/stateless_llama/vmfb_comparison_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import pytest
import subprocess
import os
import glob

def delete_files(pattern):
# Delete files matching the given pattern
for file in glob.glob(pattern):
os.remove(file)

@pytest.fixture(scope="session", autouse=True)
def setup_and_teardown():
# Setup: Delete existing files
delete_files('*.safetensors')
delete_files('*.ggml')
delete_files('*.vmfb')
delete_files('*.mlir')
# Yield to the test execution
yield
# Teardown: Delete files after tests
delete_files('*.safetensors')
delete_files('*.ggml')
delete_files('*.vmfb')
delete_files('*.mlir')

@pytest.fixture
def setup_environment():
# Change to the SHARK-Turbine directory
os.chdir(os.path.expanduser('~/SHARK-Turbine'))
# Ensure that any failure in the commands causes the test to stop
subprocess.run('set -e', shell=True, check=True)

def run_command(command):
# Run the command and check for successful execution
subprocess.run(command, shell=True, check=True)

def test_generate_vmfb(setup_environment):
command = 'python python/turbine_models/custom_models/stateless_llama_export_old.py --compile_to=vmfb --hf_model_name="llSourcell/medllama2_7b" --precision=f16 --quantization=int4 --external_weights=safetensors'
run_command(command)

def test_generate_quantized_safetensors(setup_environment):
command = 'python python/turbine_models/gen_external_params/gen_external_params.py --hf_model_name="llSourcell/medllama2_7b" --precision=f16 --quantization=int4'
run_command(command)

def test_run_vmfb_vs_torch_model(setup_environment):
command = 'python python/turbine_models/custom_models/stateless_llama.py --run_vmfb --hf_model_name="llSourcell/medllama2_7b" --vmfb_path=medllama2_7b.vmfb --external_weight_file=medllama2_7b_f16_int4.safetensors'
run_command(command)

0 comments on commit 7999ddb

Please sign in to comment.