diff --git a/python/turbine_models/custom_models/stateless_llama.py b/python/turbine_models/custom_models/stateless_llama.py index 7356d3ca9..0d3d749d2 100644 --- a/python/turbine_models/custom_models/stateless_llama.py +++ b/python/turbine_models/custom_models/stateless_llama.py @@ -242,7 +242,6 @@ def forward(token0: torch.Tensor, *state0_flat): elif device == "vulkan": flags.extend( [ - "--iree-hal-target-backends=vulkan-spirv", "--iree-vulkan-target-triple=" + target_triple, "--iree-stream-resource-max-allocation-size=" + max_alloc, ] @@ -250,7 +249,6 @@ def forward(token0: torch.Tensor, *state0_flat): elif device == "rocm": flags.extend( [ - "--iree-hal-target-backends=rocm", "--iree-rocm-target-chip=" + target_triple, "--iree-rocm-link-bc=true", "--iree-rocm-bc-dir=/opt/rocm/amdgcn/bitcode", @@ -262,14 +260,13 @@ def forward(token0: torch.Tensor, *state0_flat): elif device == "cuda": flags.extend( [ - "--iree-hal-target-backends=cuda", "--iree-hal-cuda-llvm-target-arch=" + target_triple, "--iree-vm-bytecode-module-strip-source-map=true", "--iree-vm-target-truncate-unsupported-floats", ] ) else: - print("incorrect device: ", device) + print("Unknown device kind: ", device) import iree.compiler as ireec flatbuffer_blob = ireec.compile_str(