Skip to content

Commit

Permalink
Add baseline comparisons
Browse files Browse the repository at this point in the history
  • Loading branch information
eagarvey-amd committed Oct 2, 2024
1 parent e3ced56 commit 896c221
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 32 deletions.
76 changes: 55 additions & 21 deletions models/turbine_models/custom_models/torchbench/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,19 @@
torchbench_models_dict = {
# "BERT_pytorch": {
# "dim": 128,
# },
# }, # Dynamo Export Issue
# "Background_Matting": {
# "dim": 16,
# },
# "LearningToPaint": {
# "dim": 1024,
# },
# }, # Transpose Bubbling Pattern Failed
"LearningToPaint": {
"dim": 1024,
},
"alexnet": {
"dim": 1024,
},
# "densenet121": {
# "dim": 64,
# },
"densenet121": {
"dim": 64,
},
# "hf_Albert": {"dim": 32, "buffer_prefix": "albert"},
# "hf_Bart": {
# "dim": 16,
Expand Down Expand Up @@ -131,17 +131,28 @@ def get_runner(tb_dir, tb_args):
return runner


def get_model_and_inputs(model_id, batch_size, tb_dir, tb_args):
def get_model_and_inputs(model_id, batch_size, tb_dir, tb_args, get_baseline=False):
runner = get_runner(tb_dir, tb_args)
return runner.load_model(
_, model_name, model, forward_args, _ = runner.load_model(
"cuda:0",
model_id,
batch_size=batch_size,
)


match get_baseline:
case True:
start_t = time.time()
res = runner.forward_pass(model, forward_args, collect_outputs=True)
baseline = time.time() - start_t
return model_name, model, forward_args, res, baseline
case False:
return model_name, model, forward_args


'''
Imports models from torchbench model tooling, exports them with turbine AOT, and does simple benchmarking.
'''
@torch.no_grad()
def export_torchbench_model(
def benchmark_torchbench_model(
model_id,
tb_dir,
tb_args,
Expand All @@ -159,6 +170,7 @@ def export_torchbench_model(
input_mlir=None,
weights_only=False,
upload_ir=False,
compare_vs_eager=False,
):
static_dim = torchbench_models_dict[model_id]["dim"]
dtype = torch.float16 if precision == "fp16" else torch.float32
Expand Down Expand Up @@ -187,9 +199,16 @@ def export_torchbench_model(
)
return vmfb_path

_, model_name, model, forward_args, _ = get_model_and_inputs(
model_id, batch_size, tb_dir, tb_args
)
if compare_vs_eager:
model_name, model, forward_args, golden, baseline = get_model_and_inputs(
model_id, batch_size, tb_dir, tb_args, get_baseline=True
)
else:
model_name, model, forward_args = get_model_and_inputs(
model_id, batch_size, tb_dir, tb_args
)
golden = None
baseline = None

if dtype == torch.float16:
model = model.half()
Expand Down Expand Up @@ -275,7 +294,8 @@ class CompiledTorchbenchModel(CompiledModule):
inst = CompiledTorchbenchModel(context=Context(), import_to="IMPORT")

module = CompiledModule.get_mlir_module(inst)

model.to("cpu")
del model
if compile_to != "vmfb":
return str(module)
else:
Expand All @@ -288,17 +308,21 @@ class CompiledTorchbenchModel(CompiledModule):
return_path=not exit_on_vmfb,
attn_spec=attn_spec,
)
return vmfb_path, external_weight_path, forward_args
return vmfb_path, external_weight_path, forward_args, golden, baseline


def _run_iter(runner, inputs):
start = time.time()
res = runner.ctx.modules.compiled_torchbench_model["main"](*inputs)
return res, time.time() - start

def do_compare(shark_results, shark_latency, golden_results, golden_latency):
numerics_pass_fail = np.allclose(shark_results.to_host(), golden_results.clone().cpu().numpy(), rtol=1e-4, atol=1e-4)
speedup = golden_latency / shark_latency
return speedup, numerics_pass_fail

def run_benchmark(
device, vmfb_path, weights_path, example_args, model_id, csv_path, iters
device, vmfb_path, weights_path, example_args, model_id, csv_path, iters, golden=None, baseline=None,
):
if "rocm" in device:
device = "hip" + device.split("rocm")[-1]
Expand All @@ -311,16 +335,23 @@ def run_benchmark(
avg_latency = sum(iter_latencies) / len(iter_latencies)
it_per_sec = 1 / avg_latency

if golden is not None and baseline is not None:
speedup, numerics_pass_fail = do_compare(results, avg_latency, golden, baseline)
else:
speedup, numerics_pass_fail = ("N/A", "N/A")

needs_header = True
if os.path.exists(csv_path):
needs_header = False
with open(csv_path, "a") as csvfile:
fieldnames = ["model", "avg_latency", "avg_iter_per_sec"]
fieldnames = ["model", "avg_latency", "avg_iter_per_sec", "speedup_over_eager", "numerics"]
data = [
{
"model": model_id,
"avg_latency": avg_latency,
"avg_iter_per_sec": it_per_sec,
"speedup_over_eager": speedup,
"numerics": numerics_pass_fail,
}
]
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
Expand All @@ -346,7 +377,7 @@ def torch_to_iree(iree_runner, example_args):

def run_main(model_id, args, tb_dir, tb_args):
print(f"exporting {model_id}")
mod_str, weights_path, example_args = export_torchbench_model(
mod_str, weights_path, example_args, golden, baseline = benchmark_torchbench_model(
model_id,
tb_dir,
tb_args,
Expand All @@ -361,6 +392,7 @@ def run_main(model_id, args, tb_dir, tb_args):
decomp_attn=args.decomp_attn,
attn_spec=args.attn_spec,
input_mlir=args.input_mlir,
compare_vs_eager=args.compare_vs_torch,
)
if args.compile_to in ["torch", "mlir"]:
safe_name = utils.create_safe_name(
Expand All @@ -379,6 +411,8 @@ def run_main(model_id, args, tb_dir, tb_args):
model_id,
args.output_csv,
args.num_iters,
golden,
baseline,
)

gc.collect()
Expand Down
20 changes: 9 additions & 11 deletions models/turbine_models/custom_models/torchbench/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,20 @@
MI_flags = {
"all": [
"--iree-global-opt-propagate-transposes=true",
"--iree-opt-const-eval=false",
"--iree-llvmgpu-enable-prefetch=true",
"--iree-execution-model=async-external",
"--iree-dispatch-creation-enable-aggressive-fusion",
"--iree-dispatch-creation-enable-fuse-horizontal-contractions=true",
"--iree-dispatch-creation-enable-aggressive-fusion=true",
"--iree-opt-aggressively-propagate-transposes=true",
"--iree-opt-outer-dim-concat=true",
"--iree-codegen-llvmgpu-use-vector-distribution=true",
"--iree-vm-target-truncate-unsupported-floats",
"--iree-llvmgpu-enable-prefetch=true",
"--iree-opt-data-tiling=false",
"--iree-codegen-gpu-native-math-precision=true",
"--iree-codegen-llvmgpu-use-vector-distribution",
"--iree-hip-waves-per-eu=2",
"--iree-execution-model=async-external",
],
"preprocess_default": [
"--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-pad-to-intrinsics)",
],
"preprocess_transpose": [
"--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, iree-preprocessing-pad-to-intrinsics)",
"--iree-codegen-llvmgpu-enable-transform-dialect-jit=false",
],
]
}
GFX11_flags = {
"all": [
Expand Down

0 comments on commit 896c221

Please sign in to comment.