Skip to content

Commit bd5263b

Browse files
committed
Use a CUDAGuard when running Torch models
1 parent 9c3be22 commit bd5263b

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

source/neuropod/backends/torchscript/torch_backend.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "neuropod/backends/torchscript/type_utils.hh"
88
#include "neuropod/internal/tensor_types.hh"
99

10+
#include <c10/cuda/CUDAGuard.h>
1011
#include <caffe2/core/macros.h>
1112

1213
#include <iostream>
@@ -291,6 +292,14 @@ std::unique_ptr<NeuropodValueMap> TorchNeuropodBackend::infer_internal(const Neu
291292
{
292293
torch::NoGradGuard guard;
293294

295+
// Make sure we're running on the correct device
296+
std::unique_ptr<at::cuda::CUDAGuard> device_guard;
297+
const auto model_device = get_torch_device(DeviceType::GPU);
298+
if (model_device.is_cuda())
299+
{
300+
device_guard = stdx::make_unique<at::cuda::CUDAGuard>(model_device);
301+
}
302+
294303
// Get inference schema
295304
const auto &method = model_->get_method("forward");
296305
const auto &schema = SCHEMA(method);

0 commit comments

Comments
 (0)