File tree Expand file tree Collapse file tree 1 file changed +9
-0
lines changed
source/neuropod/backends/torchscript Expand file tree Collapse file tree 1 file changed +9
-0
lines changed Original file line number Diff line number Diff line change 77#include " neuropod/backends/torchscript/type_utils.hh"
88#include " neuropod/internal/tensor_types.hh"
99
10+ #include < c10/cuda/CUDAGuard.h>
1011#include < caffe2/core/macros.h>
1112
1213#include < iostream>
@@ -291,6 +292,14 @@ std::unique_ptr<NeuropodValueMap> TorchNeuropodBackend::infer_internal(const Neu
291292{
292293 torch::NoGradGuard guard;
293294
295+ // Make sure we're running on the correct device
296+ std::unique_ptr<at::cuda::CUDAGuard> device_guard;
297+ const auto model_device = get_torch_device (DeviceType::GPU);
298+ if (model_device.is_cuda ())
299+ {
300+ device_guard = stdx::make_unique<at::cuda::CUDAGuard>(model_device);
301+ }
302+
294303 // Get inference schema
295304 const auto &method = model_->get_method (" forward" );
296305 const auto &schema = SCHEMA (method);
You can’t perform that action at this time.
0 commit comments