Skip to content

Commit 61c956c

Browse files
committed
Add GDS
2 parents 6512503 + 6e15ad6 commit 61c956c

File tree

15 files changed

+623
-11
lines changed

15 files changed

+623
-11
lines changed

ucm/store/device/cuda/CMakeLists.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,10 @@ target_compile_options(storedevice PRIVATE
88
--diag-suppress=128 --diag-suppress=2417 --diag-suppress=2597
99
-Wall -fPIC
1010
)
11+
add_library(Cuda::cudart UNKNOWN IMPORTED)
12+
set_target_properties(Cuda::cudart PROPERTIES
13+
INTERFACE_INCLUDE_DIRECTORIES "${CUDA_ROOT}/include"
14+
IMPORTED_LOCATION "${CUDA_ROOT}/lib64/libcudart.so"
15+
IMPORTED_LOCATION "${CUDA_ROOT}/lib64/libcufile.so"
16+
)
17+
target_link_libraries(storedevice PUBLIC Cuda::cudart)

ucm/store/device/cuda/cuda_device.cu

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,15 @@
2424
#include <cuda_runtime.h>
2525
#include "ibuffered_device.h"
2626
#include "logger/logger.h"
27+
#include <cufile.h>
28+
#include <mutex>
29+
#include <fcntl.h>
30+
#include <unistd.h>
31+
#include <cerrno>
32+
#include <cstring>
33+
#include <unordered_map>
34+
#include <cstdlib>
35+
#include "infra/template/handle_recorder.h"
2736

2837
#define CUDA_TRANS_UNIT_SIZE (sizeof(uint64_t) * 2)
2938
#define CUDA_TRANS_BLOCK_NUMBER (32)
@@ -90,6 +99,25 @@ struct fmt::formatter<cudaError_t> : formatter<int32_t> {
9099

91100
namespace UC {
92101

102+
static Status CreateCuFileHandle(int fd, CUfileHandle_t& cuFileHandle)
103+
{
104+
if (fd < 0) {
105+
UC_ERROR("Invalid file descriptor: {}", fd);
106+
return Status::Error();
107+
}
108+
109+
CUfileDescr_t cfDescr{};
110+
cfDescr.handle.fd = fd;
111+
cfDescr.type = CU_FILE_HANDLE_TYPE_OPAQUE_FD;
112+
CUfileError_t err = cuFileHandleRegister(&cuFileHandle, &cfDescr);
113+
if (err.err != CU_FILE_SUCCESS) {
114+
UC_ERROR("Failed to register cuFile handle for fd {}: error {}",
115+
fd, static_cast<int>(err.err));
116+
return Status::Error();
117+
}
118+
119+
return Status::OK();
120+
}
93121
template <typename Api, typename... Args>
94122
Status CudaApi(const char* caller, const char* file, const size_t line, const char* name, Api&& api,
95123
Args&&... args)
@@ -133,12 +161,23 @@ class CudaDevice : public IBufferedDevice {
133161
return nullptr;
134162
}
135163
static void ReleaseDeviceArray(void* deviceArray) { CUDA_API(cudaFree, deviceArray); }
164+
static std::once_flag gdsOnce_;
136165

137166
public:
167+
static void InitGdsOnce();
138168
CudaDevice(const int32_t deviceId, const size_t bufferSize, const size_t bufferNumber)
139169
: IBufferedDevice{deviceId, bufferSize, bufferNumber}, stream_{nullptr}
140170
{
141171
}
172+
~CudaDevice() {
173+
HandlePool<int, CUfileHandle_t>::Instance().ClearAll([](CUfileHandle_t h) {
174+
cuFileHandleDeregister(h);
175+
});
176+
177+
if (stream_ != nullptr) {
178+
cudaStreamDestroy((cudaStream_t)stream_);
179+
}
180+
}
142181
Status Setup() override
143182
{
144183
auto status = Status::OK();
@@ -165,6 +204,52 @@ public:
165204
{
166205
return CUDA_API(cudaMemcpyAsync, dst, src, count, cudaMemcpyDeviceToHost, this->stream_);
167206
}
207+
Status S2DSync(int fd, void* address, const size_t length, const size_t fileOffset, const size_t devOffset) override
208+
{
209+
CUfileHandle_t cuFileHandle = nullptr;
210+
auto status = HandlePool<int, CUfileHandle_t>::Instance().Get(fd, cuFileHandle,
211+
[fd](CUfileHandle_t& handle) -> Status {
212+
return CreateCuFileHandle(fd, handle);
213+
});
214+
if (status.Failure()) {
215+
return status;
216+
}
217+
ssize_t bytesRead = cuFileRead(cuFileHandle, address, length, fileOffset, devOffset);
218+
HandlePool<int, CUfileHandle_t>::Instance().Put(fd, [](CUfileHandle_t h) {
219+
if (h != nullptr) {
220+
cuFileHandleDeregister(h);
221+
}
222+
});
223+
224+
if (bytesRead < 0 || (size_t)bytesRead != length) {
225+
UC_ERROR("cuFileRead failed for fd {}: expected {}, got {}", fd, length, bytesRead);
226+
return Status::Error();
227+
}
228+
return Status::OK();
229+
}
230+
Status D2SSync(int fd, void* address, const size_t length, const size_t fileOffset, const size_t devOffset) override
231+
{
232+
CUfileHandle_t cuFileHandle = nullptr;
233+
auto status = HandlePool<int, CUfileHandle_t>::Instance().Get(fd, cuFileHandle,
234+
[fd](CUfileHandle_t& handle) -> Status {
235+
return CreateCuFileHandle(fd, handle);
236+
});
237+
if (status.Failure()) {
238+
return status;
239+
}
240+
ssize_t bytesWrite = cuFileWrite(cuFileHandle, address, length, fileOffset, devOffset);
241+
HandlePool<int, CUfileHandle_t>::Instance().Put(fd, [](CUfileHandle_t h) {
242+
if (h != nullptr) {
243+
cuFileHandleDeregister(h);
244+
}
245+
});
246+
247+
if (bytesWrite < 0 || (size_t)bytesWrite != length) {
248+
UC_ERROR("cuFileWrite failed for fd {}: expected {}, got {}", fd, length, bytesWrite);
249+
return Status::Error();
250+
}
251+
return Status::OK();
252+
}
168253
Status AppendCallback(std::function<void(bool)> cb) override
169254
{
170255
auto* c = new (std::nothrow) Closure(cb);
@@ -226,6 +311,13 @@ private:
226311
cudaStream_t stream_;
227312
};
228313

314+
void DeviceFactory::Setup(bool useDirect)
315+
{
316+
if (useDirect) {
317+
CudaDevice::InitGdsOnce();
318+
}
319+
}
320+
229321
std::unique_ptr<IDevice> DeviceFactory::Make(const int32_t deviceId, const size_t bufferSize,
230322
const size_t bufferNumber)
231323
{
@@ -237,5 +329,17 @@ std::unique_ptr<IDevice> DeviceFactory::Make(const int32_t deviceId, const size_
237329
return nullptr;
238330
}
239331
}
332+
std::once_flag CudaDevice::gdsOnce_{};
333+
void CudaDevice::InitGdsOnce()
334+
{
335+
std::call_once(gdsOnce_, [] (){
336+
CUfileError_t ret = cuFileDriverOpen();
337+
if (ret.err == CU_FILE_SUCCESS) {
338+
UC_INFO("GDS driver initialized successfully");
339+
} else {
340+
UC_ERROR("GDS driver initialized unsuccessfully");
341+
}
342+
});
343+
}
240344

241345
} // namespace UC

ucm/store/device/idevice.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ class IDevice {
4949
const size_t count) = 0;
5050
virtual Status D2HBatchSync(std::byte* hArr[], const std::byte* dArr[], const size_t number,
5151
const size_t count) = 0;
52+
virtual Status S2DSync(int fd, void* address, const size_t length, const size_t fileOffset, const size_t devOffset) = 0;
53+
virtual Status D2SSync(int fd, void* address, const size_t length, const size_t fileOffset, const size_t devOffset) = 0;
5254

5355
protected:
5456
virtual std::shared_ptr<std::byte> MakeBuffer(const size_t size) = 0;
@@ -59,6 +61,7 @@ class IDevice {
5961

6062
class DeviceFactory {
6163
public:
64+
static void Setup(bool useDirect = false);
6265
static std::unique_ptr<IDevice> Make(const int32_t deviceId, const size_t bufferSize,
6366
const size_t bufferNumber);
6467
};

ucm/store/device/simu/simu_device.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,14 @@ class SimuDevice : public IBufferedDevice {
7272
this->backend_.Push([=] { std::copy(src, src + count, dst); });
7373
return Status::OK();
7474
}
75+
Status S2DSync(int fd, void* address, const size_t length, const size_t fileOffset, const size_t devOffset) override
76+
{
77+
return Status::OK();
78+
}
79+
Status D2SSync(int fd, void* address, const size_t length, const size_t fileOffset, const size_t devOffset) override
80+
{
81+
return Status::OK();
82+
}
7583
Status AppendCallback(std::function<void(bool)> cb) override
7684
{
7785
this->backend_.Push([=] { cb(true); });
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
#ifndef UC_INFRA_HANDLE_POOL_H
2+
#define UC_INFRA_HANDLE_POOL_H
3+
4+
#include <functional>
5+
#include "status/status.h"
6+
#include "hashmap.h"
7+
8+
namespace UC {
9+
10+
template <typename KeyType, typename HandleType>
11+
class HandlePool {
12+
private:
13+
struct PoolEntry {
14+
HandleType handle;
15+
uint64_t refCount;
16+
};
17+
using PoolMap = HashMap<KeyType, PoolEntry, std::hash<KeyType>, 10>;
18+
PoolMap pool_;
19+
HandlePool() = default;
20+
HandlePool(const HandlePool&) = delete;
21+
HandlePool& operator=(const HandlePool&) = delete;
22+
23+
public:
24+
static HandlePool& Instance()
25+
{
26+
static HandlePool instance;
27+
return instance;
28+
}
29+
30+
Status Get(const KeyType& key, HandleType& handle,
31+
std::function<Status(HandleType&)> instantiate)
32+
{
33+
auto result = pool_.GetOrCreate(key, [&instantiate](PoolEntry& entry) -> bool {
34+
HandleType h{};
35+
36+
auto status = instantiate(h);
37+
if (status.Failure()) {
38+
return false;
39+
}
40+
41+
entry.handle = h;
42+
entry.refCount = 1;
43+
return true;
44+
});
45+
46+
if (!result.has_value()) {
47+
return Status::Error();
48+
}
49+
50+
auto& entry = result.value().get();
51+
entry.refCount++;
52+
handle = entry.handle;
53+
return Status::OK();
54+
}
55+
56+
void Put(const KeyType& key,
57+
std::function<void(HandleType)> cleanup)
58+
{
59+
pool_.Upsert(key, [&cleanup](PoolEntry& entry) -> bool {
60+
entry.refCount--;
61+
if (entry.refCount > 0) {
62+
return false;
63+
}
64+
cleanup(entry.handle);
65+
return true;
66+
});
67+
}
68+
69+
void ClearAll(std::function<void(HandleType)> cleanup)
70+
{
71+
pool_.ForEach([&cleanup](const KeyType& key, PoolEntry& entry) {
72+
(void)key;
73+
cleanup(entry.handle);
74+
});
75+
pool_.Clear();
76+
}
77+
};
78+
79+
} // namespace UC
80+
81+
#endif
82+

0 commit comments

Comments
 (0)