Skip to content

Commit 9e04c49

Browse files
committed
Revert "Introduction of the raft::device_resources_snmg type (#2487)"
This reverts commit fb6bfe6.
1 parent 8299f17 commit 9e04c49

File tree

5 files changed

+223
-236
lines changed

5 files changed

+223
-236
lines changed
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
/*
2+
* Copyright (c) 2024, NVIDIA CORPORATION.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include <raft/core/device_resources.hpp>
18+
19+
#include <rmm/mr/device/per_device_resource.hpp>
20+
#include <rmm/mr/device/pool_memory_resource.hpp>
21+
22+
#include <nccl.h>
23+
24+
/**
25+
* @brief Error checking macro for NCCL runtime API functions.
26+
*
27+
* Invokes a NCCL runtime API function call, if the call does not return ncclSuccess, throws an
28+
* exception detailing the NCCL error that occurred
29+
*/
30+
#define RAFT_NCCL_TRY(call) \
31+
do { \
32+
ncclResult_t const status = (call); \
33+
if (ncclSuccess != status) { \
34+
std::string msg{}; \
35+
SET_ERROR_MSG(msg, \
36+
"NCCL error encountered at: ", \
37+
"call='%s', Reason=%d:%s", \
38+
#call, \
39+
status, \
40+
ncclGetErrorString(status)); \
41+
throw raft::logic_error(msg); \
42+
} \
43+
} while (0);
44+
45+
namespace raft::comms {
46+
void build_comms_nccl_only(raft::resources* handle, ncclComm_t nccl_comm, int num_ranks, int rank);
47+
}
48+
49+
namespace raft::comms {
50+
51+
struct nccl_clique {
52+
using pool_mr = rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource>;
53+
54+
/**
55+
* Instantiates a NCCL clique with all available GPUs
56+
*
57+
* @param[in] percent_of_free_memory percentage of device memory to pre-allocate as memory pool
58+
*
59+
*/
60+
nccl_clique(int percent_of_free_memory = 80)
61+
: root_rank_(0),
62+
percent_of_free_memory_(percent_of_free_memory),
63+
per_device_pools_(0),
64+
device_resources_(0)
65+
{
66+
cudaGetDeviceCount(&num_ranks_);
67+
device_ids_.resize(num_ranks_);
68+
std::iota(device_ids_.begin(), device_ids_.end(), 0);
69+
nccl_comms_.resize(num_ranks_);
70+
nccl_clique_init();
71+
}
72+
73+
/**
74+
* Instantiates a NCCL clique
75+
*
76+
* Usage example:
77+
* @code{.cpp}
78+
* int n_devices;
79+
* cudaGetDeviceCount(&n_devices);
80+
* std::vector<int> device_ids(n_devices);
81+
* std::iota(device_ids.begin(), device_ids.end(), 0);
82+
* cuvs::neighbors::mg::nccl_clique& clique(device_ids); // first device is the root rank
83+
* @endcode
84+
*
85+
* @param[in] device_ids list of device IDs to be used to initiate the clique
86+
* @param[in] percent_of_free_memory percentage of device memory to pre-allocate as memory pool
87+
*
88+
*/
89+
nccl_clique(const std::vector<int>& device_ids, int percent_of_free_memory = 80)
90+
: root_rank_(0),
91+
num_ranks_(device_ids.size()),
92+
percent_of_free_memory_(percent_of_free_memory),
93+
device_ids_(device_ids),
94+
nccl_comms_(device_ids.size()),
95+
per_device_pools_(0),
96+
device_resources_(0)
97+
{
98+
nccl_clique_init();
99+
}
100+
101+
void nccl_clique_init()
102+
{
103+
RAFT_NCCL_TRY(ncclCommInitAll(nccl_comms_.data(), num_ranks_, device_ids_.data()));
104+
105+
for (int rank = 0; rank < num_ranks_; rank++) {
106+
RAFT_CUDA_TRY(cudaSetDevice(device_ids_[rank]));
107+
108+
// create a pool memory resource for each device
109+
auto old_mr = rmm::mr::get_current_device_resource();
110+
per_device_pools_.push_back(std::make_unique<pool_mr>(
111+
old_mr, rmm::percent_of_free_device_memory(percent_of_free_memory_)));
112+
rmm::cuda_device_id id(device_ids_[rank]);
113+
rmm::mr::set_per_device_resource(id, per_device_pools_.back().get());
114+
115+
// create a device resource handle for each device
116+
device_resources_.emplace_back();
117+
118+
// add NCCL communications to the device resource handle
119+
raft::comms::build_comms_nccl_only(
120+
&device_resources_[rank], nccl_comms_[rank], num_ranks_, rank);
121+
}
122+
123+
for (int rank = 0; rank < num_ranks_; rank++) {
124+
RAFT_CUDA_TRY(cudaSetDevice(device_ids_[rank]));
125+
raft::resource::sync_stream(device_resources_[rank]);
126+
}
127+
}
128+
129+
const raft::device_resources& set_current_device_to_root_rank() const
130+
{
131+
int root_device_id = device_ids_[root_rank_];
132+
RAFT_CUDA_TRY(cudaSetDevice(root_device_id));
133+
return device_resources_[root_rank_];
134+
}
135+
136+
~nccl_clique()
137+
{
138+
#pragma omp parallel for // necessary to avoid hangs
139+
for (int rank = 0; rank < num_ranks_; rank++) {
140+
cudaSetDevice(device_ids_[rank]);
141+
ncclCommDestroy(nccl_comms_[rank]);
142+
rmm::cuda_device_id id(device_ids_[rank]);
143+
rmm::mr::set_per_device_resource(id, nullptr);
144+
}
145+
}
146+
147+
int root_rank_;
148+
int num_ranks_;
149+
int percent_of_free_memory_;
150+
std::vector<int> device_ids_;
151+
std::vector<ncclComm_t> nccl_comms_;
152+
std::vector<std::shared_ptr<pool_mr>> per_device_pools_;
153+
std::vector<raft::device_resources> device_resources_;
154+
};
155+
156+
} // namespace raft::comms

cpp/include/raft/core/device_resources_snmg.hpp

Lines changed: 0 additions & 217 deletions
This file was deleted.

0 commit comments

Comments
 (0)