From 6e8eec13e4bb1655606c925a63a13ff07e32f582 Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Wed, 6 Dec 2023 16:57:12 -0500 Subject: [PATCH] Remove duplicate device flags from stateless_llama (#223) The iree-compiler python API uses the `device` field to specify the target devices. Remove the duplicate flags. --- python/turbine_models/custom_models/stateless_llama.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/python/turbine_models/custom_models/stateless_llama.py b/python/turbine_models/custom_models/stateless_llama.py index 7356d3ca9..0d3d749d2 100644 --- a/python/turbine_models/custom_models/stateless_llama.py +++ b/python/turbine_models/custom_models/stateless_llama.py @@ -242,7 +242,6 @@ def forward(token0: torch.Tensor, *state0_flat): elif device == "vulkan": flags.extend( [ - "--iree-hal-target-backends=vulkan-spirv", "--iree-vulkan-target-triple=" + target_triple, "--iree-stream-resource-max-allocation-size=" + max_alloc, ] @@ -250,7 +249,6 @@ def forward(token0: torch.Tensor, *state0_flat): elif device == "rocm": flags.extend( [ - "--iree-hal-target-backends=rocm", "--iree-rocm-target-chip=" + target_triple, "--iree-rocm-link-bc=true", "--iree-rocm-bc-dir=/opt/rocm/amdgcn/bitcode", @@ -262,14 +260,13 @@ def forward(token0: torch.Tensor, *state0_flat): elif device == "cuda": flags.extend( [ - "--iree-hal-target-backends=cuda", "--iree-hal-cuda-llvm-target-arch=" + target_triple, "--iree-vm-bytecode-module-strip-source-map=true", "--iree-vm-target-truncate-unsupported-floats", ] ) else: - print("incorrect device: ", device) + print("Unknown device kind: ", device) import iree.compiler as ireec flatbuffer_blob = ireec.compile_str(