Skip to content

Commit 2189675

Browse files
committed
Adds instructions for exporting and benchmarking an example model
1 parent 174e73d commit 2189675

File tree

3 files changed

+43
-42
lines changed

3 files changed

+43
-42
lines changed

models/turbine_models/custom_models/torchbench/README.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
- pip install torch+rocm packages:
66
```shell
7-
pip install --pre torch==2.5.0.dev20240801+rocm6.1 torchvision==0.20.0.dev20240801+rocm6.1 torchaudio==2.4.0.dev20240801%2Brocm6.1 --index-url https://download.pytorch.org/whl/nightly/rocm6.1
7+
pip install torch==2.5.0.dev20240801+rocm6.1 torchvision==0.20.0.dev20240801+rocm6.1 torchaudio==2.4.0.dev20240801+rocm6.1 --index-url https://download.pytorch.org/whl/nightly/rocm6.1
88

99
```
1010
- Workaround amdsmi error in pre-release pytorch+rocm:
@@ -33,4 +33,12 @@ cd ..
3333

3434
```shell
3535
python ./export.py --target=gfx942 --device=rocm --compile_to=vmfb --performance --inference --precision=fp16 --float16 --external_weights=safetensors --external_weights_dir=./torchbench_weights/
36+
```
37+
38+
### Example (hf_Albert)
39+
40+
```shell
41+
python ./export.py --target=gfx942 --device=rocm --compile_to=vmfb --performance --inference --precision=fp16 --float16 --external_weights=safetensors --external_weights_dir=./torchbench_weights/ --model_id=hf_Albert
42+
43+
iree-benchmark-module --module=hf_Albert_32_fp16_gfx942.vmfb --input=@input0.npy --parameters=model=./torchbench_weights/hf_Albert_fp16.irpa --device=hip://0 --device_allocator=caching --function=main --benchmark_repetitions=10
3644
```

models/turbine_models/custom_models/torchbench/cmd_opts.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def is_valid_file(arg):
6868
p.add_argument(
6969
"--external_weights",
7070
type=str,
71-
default=None,
71+
default="irpa",
7272
choices=["safetensors", "irpa", "gguf", None],
7373
help="Externalizes model weights from the torch dialect IR and its successors",
7474
)

models/turbine_models/custom_models/torchbench/export.py

Lines changed: 33 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -48,22 +48,21 @@
4848
"dim": 32,
4949
"buffer_prefix": "albert"
5050
},
51-
"hf_Bart": {
52-
"dim": 16,
53-
"buffer_prefix": "bart"
54-
},
55-
"hf_Bert": {
56-
"dim": 16,
57-
"buffer_prefix": "bert"
58-
},
59-
"hf_GPT2": {
60-
"dim": 16,
61-
"buffer_prefix": "gpt2"
62-
},
63-
"hf_T5": {
64-
"dim": 4,
65-
"buffer_prefix": "t5"
66-
},
51+
# "hf_Bart": {
52+
# "dim": 16,
53+
# },
54+
# "hf_Bert": {
55+
# "dim": 16,
56+
# "buffer_prefix": "bert"
57+
# },
58+
# "hf_GPT2": {
59+
# "dim": 16,
60+
# "buffer_prefix": "gpt2"
61+
# },
62+
# "hf_T5": {
63+
# "dim": 4,
64+
# "buffer_prefix": "t5"
65+
# },
6766
"mnasnet1_0": {
6867
"dim": 256,
6968
},
@@ -182,30 +181,21 @@ def export_torchbench_model(
182181

183182
_, model_name, model, forward_args, _ = get_model_and_inputs(model_id, batch_size, tb_dir, tb_args)
184183

184+
for idx, i in enumerate(forward_args.values()):
185+
np.save(f"input{idx}", i.clone().detach().cpu())
185186
if dtype == torch.float16:
186187
model = model.half()
187188
model.to("cuda:0")
188189

189190
if not isinstance(forward_args, dict):
190191
forward_args = [i.type(dtype) for i in forward_args]
191-
elif "hf" in model_id:
192-
forward_args["head_mask"] = torch.zeros(model.config.num_hidden_layers, device="cuda:0")
193192

194193
mapper = {}
195194
if (external_weights_dir is not None):
196195
if not os.path.exists(external_weights_dir):
197196
os.mkdir(external_weights_dir)
198-
external_weight_path = os.path.join(external_weights_dir, f"{model_id}_{precision}.{external_weights}")
199-
if os.path.exists(external_weight_path):
200-
print("External weights for this module already exist at {external_weight_path}. Will not overwrite.")
201-
utils.save_external_weights(
202-
mapper,
203-
model,
204-
external_weights,
205-
external_weight_path,
206-
)
207-
if weights_only:
208-
return external_weight_path
197+
external_weight_path = os.path.join(external_weights_dir, f"{model_id}_{precision}.irpa")
198+
209199

210200
decomp_list = [torch.ops.aten.reflection_pad2d]
211201
if decomp_attn == True:
@@ -225,18 +215,20 @@ def __init__(self, model):
225215
self.mod = model
226216

227217
def forward(self, inp):
228-
return self.mod(**inp, return_dict=False)
229-
# In transformers, the position ids buffer is registered as non-persistent,
230-
# which makes it fail to globalize in the FX import.
231-
# Add them manually to the state dict here.
232-
233-
prefix = torchbench_models_dict[model_id]["buffer_prefix"]
234-
getattr(model, prefix).embeddings.register_buffer(
235-
"position_ids",
236-
getattr(model, prefix).embeddings.position_ids,
237-
persistent=True,
238-
)
218+
return self.mod(**inp)
219+
220+
if "Bart" not in model_id:
221+
# In some transformers models, the position ids buffer is registered as non-persistent,
222+
# which makes it fail to globalize in the FX import.
223+
# Add them manually to the state dict here.
239224

225+
prefix = torchbench_models_dict[model_id]["buffer_prefix"]
226+
getattr(model, prefix).embeddings.register_buffer(
227+
"position_ids",
228+
getattr(model, prefix).embeddings.position_ids,
229+
persistent=True,
230+
)
231+
breakpoint()
240232
fxb = FxProgramsBuilder(HF_M(model))
241233
@fxb.export_program(args=(forward_args,))
242234
def _forward(module: HF_M(model), inputs):
@@ -252,6 +244,7 @@ class CompiledTorchbenchModel(CompiledModule):
252244

253245
if external_weights:
254246
externalize_module_parameters(model)
247+
save_module_parameters(external_weight_path, model)
255248

256249
inst = CompiledTorchbenchModel(context=Context(), import_to="IMPORT")
257250

0 commit comments

Comments
 (0)