Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 126 additions & 24 deletions mooncake-transfer-engine/src/transport/tcp_transport/tcp_transport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,80 @@ struct SessionHeader {
uint8_t opcode;
};

class GpuRuntime {
public:
static GpuRuntime &instance() {
static GpuRuntime inst;
return inst;
}

bool isAvailable() const { return handle_ != nullptr; }

bool isDevicePtr(const void *addr) const {
if (!isAvailable() || !pGetAttr_) return false;
Attr attr{};
int status = pGetAttr_(&attr, addr);
if (status != 0) return false;
return attr.type == kMemoryTypeDevice;
}

bool copy(void *dst, const void *src, size_t bytes, int kind) const {
if (!isAvailable() || !pMemcpy_) return false;
int status = pMemcpy_(dst, src, bytes, kind);
return (status == 0);
}

private:
GpuRuntime() { init(); }
~GpuRuntime() {
if (handle_) dlclose(handle_);
}
GpuRuntime(const GpuRuntime &) = delete;
GpuRuntime &operator=(const GpuRuntime &) = delete;

void init() {
const char *libs[] = {
"libcudart.so", // CUDA
"libmusa_runtime.so" // MUSA
};

for (auto lib : libs) {
handle_ = dlopen(lib, RTLD_LAZY);
if (!handle_) continue;

pGetAttr_ = reinterpret_cast<GetAttrFn>(
dlsym(handle_, "cudaPointerGetAttributes"));
if (!pGetAttr_)
pGetAttr_ = reinterpret_cast<GetAttrFn>(
dlsym(handle_, "musaPointerGetAttributes"));

pMemcpy_ = reinterpret_cast<MemcpyFn>(dlsym(handle_, "cudaMemcpy"));
if (!pMemcpy_)
pMemcpy_ =
reinterpret_cast<MemcpyFn>(dlsym(handle_, "musaMemcpy"));

if (pGetAttr_ && pMemcpy_) {
std::cout << "[GpuRuntime] Loaded GPU runtime: " << lib << "\n";

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This debug message uses std::cout. For consistency with the rest of the codebase which uses glog, it's better to use LOG(INFO) for logging.

Suggested change
std::cout << "[GpuRuntime] Loaded GPU runtime: " << lib << "\n";
LOG(INFO) << "[GpuRuntime] Loaded GPU runtime: " << lib;

return;
}

dlclose(handle_);
handle_ = nullptr;
}
}

private:
void *handle_ = nullptr;
struct Attr {
int type;
};
using GetAttrFn = int (*)(Attr *, const void *);
using MemcpyFn = int (*)(void *, const void *, size_t, int);
GetAttrFn pGetAttr_ = nullptr;
MemcpyFn pMemcpy_ = nullptr;
static constexpr int kMemoryTypeDevice = 2;
};

#if defined(USE_CUDA) || defined(USE_MUSA)
static bool isCudaMemory(void *addr) {
cudaPointerAttributes attributes;
Expand All @@ -57,6 +131,12 @@ static bool isCudaMemory(void *addr) {
if (attributes.type == cudaMemoryTypeDevice) return true;
return false;
}
#else
static bool isCudaMemory(void *addr) {
auto &gpu = GpuRuntime::instance();
if (!gpu.isAvailable()) return false;
return gpu.isDevicePtr(addr);
}
#endif

struct Session : public std::enable_shared_from_this<Session> {
Expand Down Expand Up @@ -156,26 +236,39 @@ struct Session : public std::enable_shared_from_this<Session> {
cudaMemcpy(dram_buffer, addr + total_transferred_bytes_,
buffer_size, cudaMemcpyDefault);
}
#else
if (isCudaMemory(addr)) {
dram_buffer = new char[buffer_size];
auto &gpu = GpuRuntime::instance();
gpu.copy(dram_buffer, addr + total_transferred_bytes_, buffer_size,
4);
Comment on lines +243 to +244

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The magic number 4 corresponds to cudaMemcpyDefault. To improve readability and maintainability, it's better to define this as a named constant in the GpuRuntime class and use it here. A similar change should be applied in readBody as well.

For example, you can add this to GpuRuntime:

static constexpr int kMemcpyDefault = 4; // cudaMemcpyDefault
Suggested change
gpu.copy(dram_buffer, addr + total_transferred_bytes_, buffer_size,
4);
gpu.copy(dram_buffer, addr + total_transferred_bytes_, buffer_size,
GpuRuntime::kMemcpyDefault);

}
#endif

asio::async_write(
socket_, asio::buffer(dram_buffer, buffer_size),
[this, addr, dram_buffer, self](const asio::error_code &ec,
std::size_t transferred_bytes) {
#if defined(USE_CUDA) || defined(USE_MUSA)
if (isCudaMemory(addr)) {
delete[] dram_buffer;
}
#endif
if (ec) {
LOG(ERROR)
<< "Session::writeBody failed. "
<< "Attempt to write data " << addr << " using buffer "
<< dram_buffer << ". Error: " << ec.message()
<< " (value: " << ec.value() << ")"
<< ", total_transferred_bytes_: "
<< total_transferred_bytes_
<< ", current transferred_bytes: " << transferred_bytes;
if (ec.value() == 14 /* Bad address */) {
LOG(FATAL) << "Unable to transfer GPU memory vis TCP "
"transport without CUDA support. "
"Please rebuild the Python wheel with "
"-DUSE_CUDA=ON";
} else {
LOG(ERROR) << "Session::writeBody failed. "
<< "Attempt to write data " << addr
<< " using buffer " << dram_buffer
<< ". Error: " << ec.message()
<< " (value: " << ec.value() << ")"
<< ", total_transferred_bytes_: "
<< total_transferred_bytes_
<< ", current transferred_bytes: "
<< transferred_bytes;
}
if (on_finalize_) on_finalize_(TransferStatusEnum::FAILED);
session_mutex_.unlock();
return;
Expand All @@ -201,39 +294,48 @@ struct Session : public std::enable_shared_from_this<Session> {

char *dram_buffer = addr + total_transferred_bytes_;

#if defined(USE_CUDA) || defined(USE_MUSA)
bool is_cuda_memory = isCudaMemory(addr);
if (is_cuda_memory) {
dram_buffer = new char[buffer_size];
}
#else
bool is_cuda_memory = false;
#endif

asio::async_read(
socket_, asio::buffer(dram_buffer, buffer_size),
[this, addr, dram_buffer, is_cuda_memory, self](
const asio::error_code &ec, std::size_t transferred_bytes) {
if (ec) {
LOG(ERROR)
<< "Session::readBody failed. "
<< "Attempt to read data " << addr << " using buffer "
<< dram_buffer << ". Error: " << ec.message()
<< " (value: " << ec.value() << ")"
<< ", total_transferred_bytes_: "
<< total_transferred_bytes_
<< ", current transferred_bytes: " << transferred_bytes;
if (ec.value() == 14 /* Bad address */) {
LOG(FATAL) << "Unable to transfer GPU memory vis TCP "
"transport without CUDA support. "
"Please rebuild the Python wheel with "
"-DUSE_CUDA=ON";
} else {
LOG(ERROR) << "Session::readBody failed. "
<< "Attempt to read data " << addr
<< " using buffer " << dram_buffer
<< ". Error: " << ec.message()
<< " (value: " << ec.value() << ")"
<< ", total_transferred_bytes_: "
<< total_transferred_bytes_
<< ", current transferred_bytes: "
<< transferred_bytes;
}
if (on_finalize_) on_finalize_(TransferStatusEnum::FAILED);
#if defined(USE_CUDA) || defined(USE_MUSA)
if (is_cuda_memory) delete[] dram_buffer;
#endif
session_mutex_.unlock();
return;
}
#if defined(USE_CUDA) || defined(USE_MUSA)
cudaMemcpy(addr + total_transferred_bytes_, dram_buffer,
transferred_bytes, cudaMemcpyDefault);
if (is_cuda_memory) delete[] dram_buffer;
#else
if (isCudaMemory(addr)) {
auto &gpu = GpuRuntime::instance();
gpu.copy(addr + total_transferred_bytes_, dram_buffer,
transferred_bytes, 4);
if (is_cuda_memory) delete[] dram_buffer;
}
Comment on lines +333 to +338

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic in this #else block can be simplified. The is_cuda_memory variable is already captured by the lambda and holds the result of isCudaMemory(addr), so you can use it in the if condition to avoid a redundant function call. Additionally, the inner if (is_cuda_memory) check before delete[] dram_buffer is redundant.

I'd also recommend using a named constant for the magic number 4 as mentioned in another comment.

Suggested change
if (isCudaMemory(addr)) {
auto &gpu = GpuRuntime::instance();
gpu.copy(addr + total_transferred_bytes_, dram_buffer,
transferred_bytes, 4);
if (is_cuda_memory) delete[] dram_buffer;
}
if (is_cuda_memory) {
auto &gpu = GpuRuntime::instance();
gpu.copy(addr + total_transferred_bytes_, dram_buffer,
transferred_bytes, 4 /* kMemcpyDefault */);
delete[] dram_buffer;
}

#endif
total_transferred_bytes_ += transferred_bytes;
readBody();
Expand Down
Loading