diff --git a/awq_ext/vllm/moe_alig_block.cu b/awq_ext/vllm/moe_alig_block.cu index 63578e5..811cf63 100644 --- a/awq_ext/vllm/moe_alig_block.cu +++ b/awq_ext/vllm/moe_alig_block.cu @@ -1,5 +1,6 @@ #include #include +#include #include #include @@ -75,6 +76,10 @@ void moe_alig_block_size( torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad) { + const at::cuda::OptionalCUDAGuard device_guard_topk_ids(device_of(topk_ids)); + const at::cuda::OptionalCUDAGuard device_guard_sorted(device_of(sorted_token_ids)); + const at::cuda::OptionalCUDAGuard device_guard_experts(device_of(experts_ids)); + const at::cuda::OptionalCUDAGuard device_guard_num_tokens(device_of(num_tokens_post_pad)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); assert(num_experts <= NUM_MAX_EXPERTS); VLLM_DISPATCH_INTEGRAL_TYPES(