Skip to content

Commit 167a5dc

Browse files
committed
GPU Direct Storage(GDS)
1 parent 6512503 commit 167a5dc

File tree

17 files changed

+643
-6
lines changed

17 files changed

+643
-6
lines changed

ucm/store/device/cuda/CMakeLists.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,10 @@ target_compile_options(storedevice PRIVATE
88
--diag-suppress=128 --diag-suppress=2417 --diag-suppress=2597
99
-Wall -fPIC
1010
)
11+
add_library(Cuda::cudart UNKNOWN IMPORTED)
12+
set_target_properties(Cuda::cudart PROPERTIES
13+
INTERFACE_INCLUDE_DIRECTORIES "${CUDA_ROOT}/include"
14+
IMPORTED_LOCATION "${CUDA_ROOT}/lib64/libcudart.so"
15+
IMPORTED_LOCATION "${CUDA_ROOT}/lib64/libcufile.so"
16+
)
17+
target_link_libraries(storedevice PUBLIC Cuda::cudart)

ucm/store/device/cuda/cuda_device.cu

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,15 @@
2424
#include <cuda_runtime.h>
2525
#include "ibuffered_device.h"
2626
#include "logger/logger.h"
27+
#include <cufile.h>
28+
#include <mutex>
29+
#include <fcntl.h>
30+
#include <unistd.h>
31+
#include <cerrno>
32+
#include <cstring>
33+
#include <unordered_map>
34+
#include <cstdlib>
35+
#include "infra/template/handle_recorder.h"
2736

2837
#define CUDA_TRANS_UNIT_SIZE (sizeof(uint64_t) * 2)
2938
#define CUDA_TRANS_BLOCK_NUMBER (32)
@@ -90,6 +99,25 @@ struct fmt::formatter<cudaError_t> : formatter<int32_t> {
9099

91100
namespace UC {
92101

102+
static Status CreateCuFileHandle(int fd, CUfileHandle_t& cuFileHandle)
103+
{
104+
if (fd < 0) {
105+
UC_ERROR("Invalid file descriptor: {}", fd);
106+
return Status::Error();
107+
}
108+
109+
CUfileDescr_t cfDescr{};
110+
cfDescr.handle.fd = fd;
111+
cfDescr.type = CU_FILE_HANDLE_TYPE_OPAQUE_FD;
112+
CUfileError_t err = cuFileHandleRegister(&cuFileHandle, &cfDescr);
113+
if (err.err != CU_FILE_SUCCESS) {
114+
UC_ERROR("Failed to register cuFile handle for fd {}: error {}",
115+
fd, static_cast<int>(err.err));
116+
return Status::Error();
117+
}
118+
119+
return Status::OK();
120+
}
93121
template <typename Api, typename... Args>
94122
Status CudaApi(const char* caller, const char* file, const size_t line, const char* name, Api&& api,
95123
Args&&... args)
@@ -133,12 +161,23 @@ class CudaDevice : public IBufferedDevice {
133161
return nullptr;
134162
}
135163
static void ReleaseDeviceArray(void* deviceArray) { CUDA_API(cudaFree, deviceArray); }
164+
static std::once_flag gdsOnce_;
136165

137166
public:
167+
static Status InitGdsOnce();
138168
CudaDevice(const int32_t deviceId, const size_t bufferSize, const size_t bufferNumber)
139169
: IBufferedDevice{deviceId, bufferSize, bufferNumber}, stream_{nullptr}
140170
{
141171
}
172+
~CudaDevice() {
173+
HandlePool<int, CUfileHandle_t>::Instance().ClearAll([](CUfileHandle_t h) {
174+
cuFileHandleDeregister(h);
175+
});
176+
177+
if (stream_ != nullptr) {
178+
cudaStreamDestroy((cudaStream_t)stream_);
179+
}
180+
}
142181
Status Setup() override
143182
{
144183
auto status = Status::OK();
@@ -165,6 +204,52 @@ public:
165204
{
166205
return CUDA_API(cudaMemcpyAsync, dst, src, count, cudaMemcpyDeviceToHost, this->stream_);
167206
}
207+
Status S2DSync(int fd, void* address, const size_t length, const size_t fileOffset, const size_t devOffset) override
208+
{
209+
CUfileHandle_t cuFileHandle = nullptr;
210+
auto status = HandlePool<int, CUfileHandle_t>::Instance().Get(fd, cuFileHandle,
211+
[fd](CUfileHandle_t& handle) -> Status {
212+
return CreateCuFileHandle(fd, handle);
213+
});
214+
if (status.Failure()) {
215+
return status;
216+
}
217+
ssize_t bytesRead = cuFileRead(cuFileHandle, address, length, fileOffset, devOffset);
218+
HandlePool<int, CUfileHandle_t>::Instance().Put(fd, [](CUfileHandle_t h) {
219+
if (h != nullptr) {
220+
cuFileHandleDeregister(h);
221+
}
222+
});
223+
224+
if (bytesRead < 0 || (size_t)bytesRead != length) {
225+
UC_ERROR("cuFileRead failed for fd {}: expected {}, got {}", fd, length, bytesRead);
226+
return Status::Error();
227+
}
228+
return Status::OK();
229+
}
230+
Status D2SSync(int fd, void* address, const size_t length, const size_t fileOffset, const size_t devOffset) override
231+
{
232+
CUfileHandle_t cuFileHandle = nullptr;
233+
auto status = HandlePool<int, CUfileHandle_t>::Instance().Get(fd, cuFileHandle,
234+
[fd](CUfileHandle_t& handle) -> Status {
235+
return CreateCuFileHandle(fd, handle);
236+
});
237+
if (status.Failure()) {
238+
return status;
239+
}
240+
ssize_t bytesWrite = cuFileWrite(cuFileHandle, address, length, fileOffset, devOffset);
241+
HandlePool<int, CUfileHandle_t>::Instance().Put(fd, [](CUfileHandle_t h) {
242+
if (h != nullptr) {
243+
cuFileHandleDeregister(h);
244+
}
245+
});
246+
247+
if (bytesWrite < 0 || (size_t)bytesWrite != length) {
248+
UC_ERROR("cuFileWrite failed for fd {}: expected {}, got {}", fd, length, bytesWrite);
249+
return Status::Error();
250+
}
251+
return Status::OK();
252+
}
168253
Status AppendCallback(std::function<void(bool)> cb) override
169254
{
170255
auto* c = new (std::nothrow) Closure(cb);
@@ -226,6 +311,14 @@ private:
226311
cudaStream_t stream_;
227312
};
228313

314+
Status DeviceFactory::Setup(bool useDirect)
315+
{
316+
if (useDirect) {
317+
return CudaDevice::InitGdsOnce();
318+
}
319+
return Status::OK();
320+
}
321+
229322
std::unique_ptr<IDevice> DeviceFactory::Make(const int32_t deviceId, const size_t bufferSize,
230323
const size_t bufferNumber)
231324
{
@@ -237,5 +330,20 @@ std::unique_ptr<IDevice> DeviceFactory::Make(const int32_t deviceId, const size_
237330
return nullptr;
238331
}
239332
}
333+
std::once_flag CudaDevice::gdsOnce_{};
334+
Status CudaDevice::InitGdsOnce()
335+
{
336+
Status result = Status::OK();
337+
std::call_once(gdsOnce_, [&result]() {
338+
CUfileError_t ret = cuFileDriverOpen();
339+
if (ret.err == CU_FILE_SUCCESS) {
340+
UC_INFO("GDS driver initialized successfully");
341+
} else {
342+
UC_ERROR("GDS driver initialization failed with error code: {}", static_cast<int>(ret.err));
343+
result = Status::Error();
344+
}
345+
});
346+
return result;
347+
}
240348

241349
} // namespace UC

ucm/store/device/idevice.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ class IDevice {
4949
const size_t count) = 0;
5050
virtual Status D2HBatchSync(std::byte* hArr[], const std::byte* dArr[], const size_t number,
5151
const size_t count) = 0;
52+
virtual Status S2DSync(int fd, void* address, const size_t length, const size_t fileOffset, const size_t devOffset) = 0;
53+
virtual Status D2SSync(int fd, void* address, const size_t length, const size_t fileOffset, const size_t devOffset) = 0;
5254

5355
protected:
5456
virtual std::shared_ptr<std::byte> MakeBuffer(const size_t size) = 0;
@@ -59,6 +61,7 @@ class IDevice {
5961

6062
class DeviceFactory {
6163
public:
64+
static Status Setup(bool useDirect = false);
6265
static std::unique_ptr<IDevice> Make(const int32_t deviceId, const size_t bufferSize,
6366
const size_t bufferNumber);
6467
};

ucm/store/device/simu/simu_device.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,14 @@ class SimuDevice : public IBufferedDevice {
7272
this->backend_.Push([=] { std::copy(src, src + count, dst); });
7373
return Status::OK();
7474
}
75+
Status S2DSync(int fd, void* address, const size_t length, const size_t fileOffset, const size_t devOffset) override
76+
{
77+
return Status::Unsupported();
78+
}
79+
Status D2SSync(int fd, void* address, const size_t length, const size_t fileOffset, const size_t devOffset) override
80+
{
81+
return Status::Unsupported();
82+
}
7583
Status AppendCallback(std::function<void(bool)> cb) override
7684
{
7785
this->backend_.Push([=] { cb(true); });

ucm/store/infra/file/file.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,20 @@ Status File::Write(const std::string& path, const size_t offset, const size_t le
8787
return status;
8888
}
8989

90+
Status File::OpenForDirectIO(const std::string& path, uint32_t flags, int& fd)
91+
{
92+
auto file = std::make_unique<FileImpl>(path);
93+
auto status = file->Open(flags);
94+
if (status.Failure()) {
95+
UC_ERROR("Failed to open file({}) with flags({}).", path, flags);
96+
fd = -1;
97+
return status;
98+
}
99+
fd = file->GetHandle();
100+
file.release();
101+
return Status::OK();
102+
}
103+
90104
void File::MUnmap(void* addr, size_t size) { FileImpl{{}}.MUnmap(addr, size); }
91105

92106
void File::ShmUnlink(const std::string& path) { FileImpl{path}.ShmUnlink(); }

ucm/store/infra/file/file.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ class File {
4141
uintptr_t address, const bool directIo = false);
4242
static Status Write(const std::string& path, const size_t offset, const size_t length,
4343
const uintptr_t address, const bool directIo = false);
44+
static Status OpenForDirectIO(const std::string& path, uint32_t flags, int& fd);
4445
static void MUnmap(void* addr, size_t size);
4546
static void ShmUnlink(const std::string& path);
4647
static void Remove(const std::string& path);

ucm/store/infra/file/ifile.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ class IFile {
5555
IFile(const std::string& path) : path_{path} {}
5656
virtual ~IFile() = default;
5757
const std::string& Path() const { return this->path_; }
58+
virtual int32_t GetHandle() const = 0;
5859
virtual Status MkDir() = 0;
5960
virtual Status RmDir() = 0;
6061
virtual Status Rename(const std::string& newName) = 0;

ucm/store/infra/file/posix_file.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class PosixFile : public IFile {
3232
public:
3333
PosixFile(const std::string& path) : IFile{path}, handle_{-1} {}
3434
~PosixFile() override;
35+
int32_t GetHandle() const override { return handle_; }
3536
Status MkDir() override;
3637
Status RmDir() override;
3738
Status Rename(const std::string& newName) override;
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
#ifndef UC_INFRA_HANDLE_POOL_H
2+
#define UC_INFRA_HANDLE_POOL_H
3+
4+
#include <functional>
5+
#include "status/status.h"
6+
#include "hashmap.h"
7+
8+
namespace UC {
9+
10+
template <typename KeyType, typename HandleType>
11+
class HandlePool {
12+
private:
13+
struct PoolEntry {
14+
HandleType handle;
15+
uint64_t refCount;
16+
};
17+
using PoolMap = HashMap<KeyType, PoolEntry, std::hash<KeyType>, 10>;
18+
PoolMap pool_;
19+
20+
public:
21+
HandlePool() = default;
22+
HandlePool(const HandlePool&) = delete;
23+
HandlePool& operator=(const HandlePool&) = delete;
24+
25+
static HandlePool& Instance()
26+
{
27+
static HandlePool instance;
28+
return instance;
29+
}
30+
31+
Status Get(const KeyType& key, HandleType& handle,
32+
std::function<Status(HandleType&)> instantiate)
33+
{
34+
auto result = pool_.GetOrCreate(key, [&instantiate](PoolEntry& entry) -> bool {
35+
HandleType h{};
36+
37+
auto status = instantiate(h);
38+
if (status.Failure()) {
39+
return false;
40+
}
41+
42+
entry.handle = h;
43+
entry.refCount = 1;
44+
return true;
45+
});
46+
47+
if (!result.has_value()) {
48+
return Status::Error();
49+
}
50+
51+
auto& entry = result.value().get();
52+
entry.refCount++;
53+
handle = entry.handle;
54+
return Status::OK();
55+
}
56+
57+
void Put(const KeyType& key,
58+
std::function<void(HandleType)> cleanup)
59+
{
60+
pool_.Upsert(key, [&cleanup](PoolEntry& entry) -> bool {
61+
entry.refCount--;
62+
if (entry.refCount > 0) {
63+
return false;
64+
}
65+
cleanup(entry.handle);
66+
return true;
67+
});
68+
}
69+
70+
void ClearAll(std::function<void(HandleType)> cleanup)
71+
{
72+
pool_.ForEach([&cleanup](const KeyType& key, PoolEntry& entry) {
73+
(void)key;
74+
cleanup(entry.handle);
75+
});
76+
pool_.Clear();
77+
}
78+
};
79+
80+
} // namespace UC
81+
82+
#endif
83+

0 commit comments

Comments
 (0)