Skip to content

Commit cc7657d

Browse files
authored
Merge branch 'main' into ean-sd-fp16
2 parents 8ba19cc + 6e3adb3 commit cc7657d

File tree

7 files changed

+151
-55
lines changed

7 files changed

+151
-55
lines changed

.github/workflows/test_sdxl.yml

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
name: SDXL Models Nightly
2+
3+
on:
4+
schedule:
5+
- cron: '30 6 * * *'
6+
7+
jobs:
8+
test-sdxl-models:
9+
strategy:
10+
matrix:
11+
version: [3.11]
12+
os: [nodai-amdgpu-w7900-x86-64]
13+
14+
runs-on: ${{matrix.os}}
15+
steps:
16+
- name: "Setting up Python"
17+
uses: actions/setup-python@75f3110429a8c05be0e1bf360334e4cced2b63fa # v2.3.3
18+
with:
19+
python-version: ${{matrix.version}}
20+
21+
- name: "Checkout Code"
22+
uses: actions/checkout@v2
23+
with:
24+
ref: ean-sd-fp16
25+
26+
- name: Sync source deps
27+
# build IREE from source with -DIREE_BUILD_TRACY=ON if getting tracy profile
28+
run: |
29+
python -m pip install --upgrade pip
30+
# Note: We install in three steps in order to satisfy requirements
31+
# from non default locations first. Installing the PyTorch CPU
32+
# wheels saves multiple minutes and a lot of bandwidth on runner setup.
33+
pip install --index-url https://download.pytorch.org/whl/cpu \
34+
-r core/pytorch-cpu-requirements.txt \
35+
-r core/torchvision-requirements.txt
36+
pip install --upgrade -r core/requirements.txt
37+
pip install -e core[testing,torch-cpu-nightly]
38+
pip install --upgrade -r models/requirements.txt
39+
pip install -e models
40+
41+
- name: Show current free memory
42+
run: |
43+
free -mh
44+
45+
- name: Run sdxl tests
46+
run: |
47+
pip install --upgrade --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu
48+
pytest models/turbine_models/tests/sdxl_test.py --device cpu --rt_device local-task --iree_target_triple x86_64-linux-gnu
49+
pytest models/turbine_models/tests/sdxl_test.py --device vulkan --rt_device vulkan --iree_target_triple rdna3-unknown-linux
50+
pytest models/turbine_models/tests/sdxl_test.py --device rocm --rt_device rocm --iree_target_triple gfx90a

README.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ is intended to be a general purpose model compilation and execution tool.
1010
Turbine provides three primary tools:
1111

1212
* *AOT Export*: For compiling one or more `nn.Module`s to compiled, deployment
13-
ready artifacts. This operates via both a [simple one-shot export API](https://github.com/nod-ai/SHARK-Turbine/blob/main/python/shark_turbine/aot/exporter.py)
14-
for simple models and an underlying [advanced API](https://github.com/nod-ai/SHARK-Turbine/blob/main/python/shark_turbine/aot/compiled_module.py) for complicated models
13+
ready artifacts. This operates via both a simple one-shot export API (Already upstreamed to [torch-mlir](https://github.com/llvm/torch-mlir/blob/main/python/torch_mlir/extras/fx_importer.py))
14+
for simple models and an underlying [advanced API](https://github.com/nod-ai/SHARK-Turbine/blob/main/core/shark_turbine/aot/compiled_module.py) for complicated models
1515
and accessing the full features of the runtime.
1616
* *Eager Execution*: A `torch.compile` backend is provided and a Turbine Tensor/Device
1717
is available for more native, interactive use within a PyTorch session.
@@ -39,7 +39,7 @@ please reach out to us on the `#turbine` channel of the
3939

4040
```
4141
pip install shark-turbine
42-
# Or editable: pip install -e core
42+
# Or for editable: see instructions under developers
4343
```
4444

4545
The above does install some unecessary cuda/cudnn packages for cpu use. To avoid this you
@@ -62,11 +62,11 @@ compiler, these should be compilable via IREE with `--iree-input-type=torch` for
6262
end to end execution. Dynamic shape support in torch-mlir is a work in progress,
6363
and not everything works at head with release binaries at present.
6464

65-
* [AOT MLP With Static Shapes](https://github.com/nod-ai/SHARK-Turbine/blob/main/examples/aot_mlp/mlp_export_simple.py)
66-
* [AOT MLP with a dynamic batch size](https://github.com/nod-ai/SHARK-Turbine/blob/main/examples/aot_mlp/mlp_export_dynamic.py)
67-
* [AOT llama2](https://github.com/nod-ai/SHARK-Turbine/blob/main/examples/llama2_inference/llama2.ipynb):
65+
* [AOT MLP With Static Shapes](https://github.com/nod-ai/SHARK-Turbine/blob/main/core/examples/aot_mlp/mlp_export_simple.py)
66+
* [AOT MLP with a dynamic batch size](https://github.com/nod-ai/SHARK-Turbine/blob/main/core/examples/aot_mlp/mlp_export_dynamic.py)
67+
* [AOT llama2](https://github.com/nod-ai/SHARK-Turbine/blob/main/core/examples/llama2_inference/llama2.ipynb):
6868
Dynamic sequence length custom compiled module with state management internal to the model.
69-
* [Eager MNIST with `torch.compile`](https://github.com/nod-ai/SHARK-Turbine/blob/main/examples/eager_mlp/mlp_eager_simple.py)
69+
* [Eager MNIST with `torch.compile`](https://github.com/nod-ai/SHARK-Turbine/blob/main/core/examples/eager_mlp/mlp_eager_simple.py)
7070

7171
## Developers
7272

core/iree-requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
iree-compiler>=20240306.822
2-
iree-runtime>=20240306.822
1+
iree-compiler==20240311.828
2+
iree-runtime==20240311.828

models/turbine_models/custom_models/llm_runner.py

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -168,12 +168,14 @@ def run_llm(
168168
streaming_llm=False,
169169
chat_mode=False,
170170
chat_sys_prompt=DEFAULT_CHAT_SYS_PROMPT,
171+
tokenizer=None,
171172
):
172-
tokenizer = AutoTokenizer.from_pretrained(
173-
hf_model_name,
174-
use_fast=False,
175-
token=hf_auth_token,
176-
)
173+
if tokenizer == None:
174+
tokenizer = AutoTokenizer.from_pretrained(
175+
hf_model_name,
176+
use_fast=False,
177+
token=hf_auth_token,
178+
)
177179
llm = SharkLLM(
178180
device=device,
179181
vmfb_path=vmfb_path,
@@ -204,43 +206,35 @@ def run_torch_llm(
204206
prompt,
205207
streaming_llm=False,
206208
chat_sys_prompt=DEFAULT_CHAT_SYS_PROMPT,
209+
model=None,
210+
tokenizer=None,
207211
):
208-
from turbine_models.model_builder import HFTransformerBuilder
209-
from transformers import AutoModelForCausalLM
210-
211-
model_builder = HFTransformerBuilder(
212-
example_input=None,
213-
hf_id=hf_model_name,
214-
auto_model=AutoModelForCausalLM,
215-
hf_auth_token=hf_auth_token,
216-
auto_tokenizer=AutoTokenizer,
217-
)
218212
if streaming_llm is True:
219-
enable_llama_pos_shift_attention(model_builder.model)
213+
enable_llama_pos_shift_attention(model)
220214

221215
def get_token_from_logits(logits):
222216
return torch.argmax(logits[:, -1, :], dim=1)
223217

224218
prompt = append_user_prompt(chat_sys_prompt, prompt)
225-
initial_input = model_builder.tokenizer(prompt, return_tensors="pt")
219+
initial_input = tokenizer(prompt, return_tensors="pt")
226220
example_input_id = initial_input.input_ids
227221

228-
model_results = model_builder.model.forward(example_input_id)
222+
model_results = model.forward(example_input_id)
229223
model_token = get_token_from_logits(model_results.logits)
230224

231225
pkv = model_results.past_key_values
232226

233227
torch_results = []
234228
torch_results.append(int(model_token))
235229
while model_token != 2:
236-
model_results = model_builder.model.forward(
230+
model_results = model.forward(
237231
torch.unsqueeze(model_token, 0), past_key_values=pkv
238232
)
239233
model_token = get_token_from_logits(model_results.logits)
240234
pkv = model_results.past_key_values
241235
torch_results.append(int(model_token[0]))
242236

243-
return model_builder.tokenizer.decode(torch_results)
237+
return tokenizer.decode(torch_results)
244238

245239

246240
if __name__ == "__main__":

models/turbine_models/custom_models/stateless_llama.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -121,18 +121,21 @@ def export_transformer_model(
121121
streaming_llm=False,
122122
vmfb_path=None,
123123
upload_ir=False,
124+
mod=None,
125+
tokenizer=None,
124126
):
125-
tokenizer = AutoTokenizer.from_pretrained(
126-
hf_model_name,
127-
use_fast=False,
128-
token=hf_auth_token,
129-
)
130-
131-
mod = AutoModelForCausalLM.from_pretrained(
132-
hf_model_name,
133-
torch_dtype=torch.float,
134-
token=hf_auth_token,
135-
)
127+
if tokenizer == None:
128+
tokenizer = AutoTokenizer.from_pretrained(
129+
hf_model_name,
130+
use_fast=False,
131+
token=hf_auth_token,
132+
)
133+
if mod == None:
134+
mod = AutoModelForCausalLM.from_pretrained(
135+
hf_model_name,
136+
torch_dtype=torch.float,
137+
token=hf_auth_token,
138+
)
136139
schema_json = generate_schema(mod.config.num_hidden_layers)
137140
state_schema = pytree.treespec_loads(schema_json)
138141
if streaming_llm:
@@ -165,7 +168,8 @@ def export_transformer_model(
165168
for name in mod_params:
166169
mapper["params." + name] = name
167170
if external_weight_file:
168-
safetensors.torch.save_file(mod_params, external_weight_file)
171+
if os.path.exists(external_weight_file) == False:
172+
safetensors.torch.save_file(mod_params, external_weight_file)
169173

170174
elif external_weights == "gguf":
171175
tensor_mapper = remap_gguf.TensorNameMap(remap_gguf.MODEL_ARCH.LLAMA, HEADS)

models/turbine_models/model_builder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def __init__(
3030
model=None,
3131
model_type: str = None,
3232
compile_to_vmfb: bool = None,
33+
tokenizer=None,
3334
) -> None:
3435
self.example_input = example_input
3536
self.hf_id = hf_id
@@ -38,7 +39,7 @@ def __init__(
3839
self.auto_config = auto_config
3940
self.hf_auth_token = hf_auth_token
4041
self.model = model
41-
self.tokenizer = None
42+
self.tokenizer = tokenizer
4243
self.upload_ir = upload_ir
4344
self.model_type = model_type
4445
self.compile_to_vmfb = compile_to_vmfb

models/turbine_models/tests/stateless_llama_test.py

Lines changed: 60 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@
99
import os
1010
import unittest
1111
import difflib
12+
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
13+
import torch
14+
from accelerate import init_empty_weights
15+
from transformers.modeling_utils import load_sharded_checkpoint
16+
import tempfile
1217

1318
os.environ["TORCH_LOGS"] = "dynamic"
1419
from shark_turbine.aot import *
@@ -18,18 +23,6 @@
1823
gen_external_params,
1924
)
2025

21-
quantization = "unquantized"
22-
precision = "f32"
23-
gen_external_params(
24-
hf_model_name="Trelis/Llama-2-7b-chat-hf-function-calling-v2",
25-
quantization=quantization,
26-
hf_auth_token=None,
27-
precision=precision,
28-
)
29-
DEFAULT_PROMPT = """<s>[INST] <<SYS>>
30-
Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <</SYS>> hi what are you? [/INST]
31-
"""
32-
3326

3427
def check_output_string(reference, output):
3528
# Calculate and print diff
@@ -43,7 +36,45 @@ def check_output_string(reference, output):
4336
assert reference == output, "".join(diff)
4437

4538

39+
quantization = "unquantized"
40+
precision = "f32"
41+
42+
DEFAULT_PROMPT = """<s>[INST] <<SYS>>
43+
Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <</SYS>> hi what are you? [/INST]
44+
"""
45+
46+
4647
class StatelessLlamaChecks(unittest.TestCase):
48+
@classmethod
49+
def setUpClass(cls):
50+
gen_external_params(
51+
hf_model_name="Trelis/Llama-2-7b-chat-hf-function-calling-v2",
52+
quantization=quantization,
53+
hf_auth_token=None,
54+
precision=precision,
55+
)
56+
57+
cls.tokenizer = AutoTokenizer.from_pretrained(
58+
"Trelis/Llama-2-7b-chat-hf-function-calling-v2",
59+
use_fast=False,
60+
)
61+
62+
# The model is first created on the Meta device (with empty weights) and the state dict
63+
# is then loaded inside it (shard by shard in the case of a sharded checkpoint).
64+
# This avoids using twice the size of model with creating whole model with random weights,
65+
# then loading pretrained weights.
66+
cls.mod = AutoModelForCausalLM.from_pretrained(
67+
"Trelis/Llama-2-7b-chat-hf-function-calling-v2",
68+
torch_dtype=torch.float,
69+
low_cpu_mem_usage=True,
70+
device_map="auto",
71+
)
72+
73+
@classmethod
74+
def tearDownClass(cls):
75+
cls.tokenizer = None
76+
cls.mod = None
77+
4778
def test_vmfb_comparison(self):
4879
"""
4980
Test that the vmfb model produces the same output as the torch model
@@ -66,6 +97,8 @@ def test_vmfb_comparison(self):
6697
device="llvm-cpu",
6798
target_triple="host",
6899
upload_ir=upload_ir_var == "upload",
100+
mod=self.mod,
101+
tokenizer=self.tokenizer,
69102
)
70103

71104
torch_str_cache_path = (
@@ -77,7 +110,11 @@ def test_vmfb_comparison(self):
77110
torch_str = f.read()
78111
else:
79112
torch_str = llm_runner.run_torch_llm(
80-
"Trelis/Llama-2-7b-chat-hf-function-calling-v2", None, DEFAULT_PROMPT
113+
"Trelis/Llama-2-7b-chat-hf-function-calling-v2",
114+
None,
115+
self.DEFAULT_PROMPT,
116+
model=self.mod,
117+
tokenizer=self.tokenizer,
81118
)
82119

83120
with open(torch_str_cache_path, "w") as f:
@@ -90,6 +127,7 @@ def test_vmfb_comparison(self):
90127
"Trelis/Llama-2-7b-chat-hf-function-calling-v2",
91128
None,
92129
f"Llama_2_7b_chat_hf_function_calling_v2_{precision}_{quantization}.safetensors",
130+
tokenizer=self.tokenizer,
93131
)
94132
check_output_string(torch_str, turbine_str)
95133

@@ -109,6 +147,8 @@ def test_streaming_vmfb_comparison(self):
109147
target_triple="host",
110148
streaming_llm=True,
111149
vmfb_path="streaming_llama.vmfb",
150+
mod=self.mod,
151+
tokenizer=self.tokenizer,
112152
)
113153

114154
torch_str_cache_path = (
@@ -124,6 +164,8 @@ def test_streaming_vmfb_comparison(self):
124164
None,
125165
DEFAULT_PROMPT,
126166
streaming_llm=True,
167+
model=self.mod,
168+
tokenizer=self.tokenizer,
127169
)
128170

129171
with open(torch_str_cache_path, "w") as f:
@@ -137,6 +179,7 @@ def test_streaming_vmfb_comparison(self):
137179
None,
138180
f"Llama_2_7b_chat_hf_function_calling_v2_{precision}_{quantization}.safetensors",
139181
streaming_llm=True,
182+
tokenizer=self.tokenizer,
140183
)
141184
check_output_string(torch_str, turbine_str)
142185

@@ -145,12 +188,16 @@ def test_rerotated_torch_comparison(self):
145188
"Trelis/Llama-2-7b-chat-hf-function-calling-v2",
146189
None,
147190
DEFAULT_PROMPT,
191+
model=self.mod,
192+
tokenizer=self.tokenizer,
148193
)
149194
rotated_torch_str = llm_runner.run_torch_llm(
150195
"Trelis/Llama-2-7b-chat-hf-function-calling-v2",
151196
None,
152197
DEFAULT_PROMPT,
153198
streaming_llm=True,
199+
model=self.mod,
200+
tokenizer=self.tokenizer,
154201
)
155202
check_output_string(torch_str, rotated_torch_str)
156203

0 commit comments

Comments
 (0)