Skip to content

Commit

Permalink
Remove duplicate device flags from stateless_llama (#223)
Browse files Browse the repository at this point in the history
The iree-compiler python API uses the `device` field to specify the
target devices. Remove the duplicate flags.
  • Loading branch information
qedawkins authored Dec 6, 2023
1 parent 35e2238 commit 6e8eec1
Showing 1 changed file with 1 addition and 4 deletions.
5 changes: 1 addition & 4 deletions python/turbine_models/custom_models/stateless_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,15 +242,13 @@ def forward(token0: torch.Tensor, *state0_flat):
elif device == "vulkan":
flags.extend(
[
"--iree-hal-target-backends=vulkan-spirv",
"--iree-vulkan-target-triple=" + target_triple,
"--iree-stream-resource-max-allocation-size=" + max_alloc,
]
)
elif device == "rocm":
flags.extend(
[
"--iree-hal-target-backends=rocm",
"--iree-rocm-target-chip=" + target_triple,
"--iree-rocm-link-bc=true",
"--iree-rocm-bc-dir=/opt/rocm/amdgcn/bitcode",
Expand All @@ -262,14 +260,13 @@ def forward(token0: torch.Tensor, *state0_flat):
elif device == "cuda":
flags.extend(
[
"--iree-hal-target-backends=cuda",
"--iree-hal-cuda-llvm-target-arch=" + target_triple,
"--iree-vm-bytecode-module-strip-source-map=true",
"--iree-vm-target-truncate-unsupported-floats",
]
)
else:
print("incorrect device: ", device)
print("Unknown device kind: ", device)
import iree.compiler as ireec

flatbuffer_blob = ireec.compile_str(
Expand Down

0 comments on commit 6e8eec1

Please sign in to comment.