Skip to content

Commit

Permalink
Add benchmarking minimally, comment out a few more models
Browse files Browse the repository at this point in the history
  • Loading branch information
eagarvey-amd committed Sep 20, 2024
1 parent e6072e1 commit 7f0d1e8
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 8 deletions.
6 changes: 6 additions & 0 deletions models/turbine_models/custom_models/torchbench/README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# SHARK torchbench exports and benchmarks

## Overview

This directory serves as a place for scripts and utilities to run a suite of benchmarked inference tasks, showing functionality and performance parity between SHARK/IREE and native torch.compile workflows. It is currently under development and benchmark numbers should not be treated as the best possible result with the current state of IREE compiler optimizations.

Eventually, we want this process to be a plug-in to the upstream torchbench process, and this will be accomplished by exposing the IREE methodology shown here as a compile/runtime backend for the torch benchmark classes. For now, it is set up for developers as a way to get preliminary results and achieve blanket functionality for the models listed in export.py.

### Setup

- pip install torch+rocm packages:
Expand Down
10 changes: 10 additions & 0 deletions models/turbine_models/custom_models/torchbench/cmd_opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,16 @@ def is_valid_file(arg):
choices=["safetensors", "irpa", "gguf", None],
help="Externalizes model weights from the torch dialect IR and its successors",
)
p.add_argument(
"--run_benchmark",
type=bool,
default=True,
)
p.add_argument(
"--output_csv",
type=str,
default="./benchmark_results.csv",
)

##############################################################################
# Modeling and Export Options
Expand Down
44 changes: 36 additions & 8 deletions models/turbine_models/custom_models/torchbench/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import gc

from iree.compiler.ir import Context
from iree import runtime as ireert
import numpy as np
from shark_turbine.aot import *
from shark_turbine.dynamo.passes import (
Expand All @@ -21,10 +22,12 @@
from safetensors import safe_open
import argparse
from turbine_models.turbine_tank import turbine_tank
from turbine_models.model_runner import vmfbRunner

from pytorch.benchmarks.dynamo.common import parse_args
from pytorch.benchmarks.dynamo.torchbench import TorchBenchmarkRunner, setup_torchbench_cwd

import csv
torchbench_models_dict = {
# "BERT_pytorch": {
# "dim": 128,
Expand Down Expand Up @@ -84,7 +87,7 @@
"resnet50": {
"dim": 128,
},
"resnet50_32x4d": {
"resnext50_32x4d": {
"dim": 128,
},
"shufflenet_v2_x1_0": {
Expand All @@ -93,9 +96,9 @@
"squeezenet1_1": {
"dim": 512,
},
"timm_nfnet": {
"dim": 256,
},
# "timm_nfnet": {
# "dim": 256,
# },
"timm_efficientnet": {
"dim": 128,
},
Expand Down Expand Up @@ -163,9 +166,13 @@ def export_torchbench_model(
model_id,
f"_{static_dim}_{precision}",
)
safe_name = os.path.join("generated", safe_name)
if decomp_attn:
safe_name += "_decomp_attn"

if not os.path.exists("generated"):
os.mkdir("generated")

if input_mlir:
vmfb_path = utils.compile_to_vmfb(
input_mlir,
Expand All @@ -179,6 +186,7 @@ def export_torchbench_model(
)
return vmfb_path


_, model_name, model, forward_args, _ = get_model_and_inputs(model_id, batch_size, tb_dir, tb_args)

if dtype == torch.float16:
Expand All @@ -188,7 +196,8 @@ def export_torchbench_model(
if not isinstance(forward_args, dict):
forward_args = [i.type(dtype) for i in forward_args]
for idx, i in enumerate(forward_args):
np.save(f"{model_id}_input{idx}", i.clone().detach().cpu())
np.save(
os.path.join("generated", f"{model_id}_input{idx}"), i.clone().detach().cpu())
else:
for idx, i in enumerate(forward_args.values()):
np.save(f"{model_id}_input{idx}", i.clone().detach().cpu())
Expand All @@ -199,7 +208,8 @@ def export_torchbench_model(
if not os.path.exists(external_weights_dir):
os.mkdir(external_weights_dir)
external_weight_path = os.path.join(external_weights_dir, f"{model_id}_{precision}.irpa")

else:
external_weight_path = None

decomp_list = [torch.ops.aten.reflection_pad2d]
if decomp_attn == True:
Expand Down Expand Up @@ -265,11 +275,26 @@ class CompiledTorchbenchModel(CompiledModule):
return_path=not exit_on_vmfb,
attn_spec=attn_spec,
)
return vmfb_path
return vmfb_path, external_weight_path, forward_args

def run_benchmark(device, vmfb_path, weights_path, example_args, model_id, csv_path):
if "rocm" in device:
device = "hip" + device.split("rocm")[-1]
mod_runner = vmfbRunner(device, vmfb_path, weights_path)
inputs = [ireert.asdevicearray(mod_runner.config.device, i) for i in example_args]
start = time.time()
results = runner.ctx.modules.compiled_torchbench_model["main"](*inputs)
latency = time.time() - start
with open(csv_path, "a") as csvfile:
fieldnames = ["model", "latency"]
data = [{"model": model_id, "latency": latency}]
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writerows(data)


def run_main(model_id, args, tb_dir, tb_args):
print(f"exporting {model_id}")
mod_str = export_torchbench_model(
mod_str, weights_path, example_args = export_torchbench_model(
model_id,
tb_dir,
tb_args,
Expand All @@ -293,6 +318,9 @@ def run_main(model_id, args, tb_dir, tb_args):
with open(f"{safe_name}.mlir", "w+") as f:
f.write(mod_str)
print("Saved to", safe_name + ".mlir")
elif args.run_benchmark:
run_benchmark(args.device, mod_str, weights_path, example_args, model_id, args.output_csv)

gc.collect()

if __name__ == "__main__":
Expand Down

0 comments on commit 7f0d1e8

Please sign in to comment.