Skip to content

Commit

Permalink
apacheGH-37364: [C++][GPU] Add CUDA impl of Device Event/Stream (apac…
Browse files Browse the repository at this point in the history
…he#37365)

### What changes are included in this PR?
Adding `CudaDevice::SyncEvent` and `CudaDevice::Stream` implementations which provide more idiomatic handling of Events and Streams.

### Are these changes tested?
Simple SyncEvent test added. More stream tests still being added.

* Closes: apache#37364

Authored-by: Matt Topol <zotthewizard@gmail.com>
Signed-off-by: Matt Topol <zotthewizard@gmail.com>
  • Loading branch information
zeroshade committed Aug 30, 2023
1 parent 602083b commit 3b8ab8e
Show file tree
Hide file tree
Showing 5 changed files with 299 additions and 13 deletions.
4 changes: 3 additions & 1 deletion cpp/src/arrow/c/bridge_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1222,7 +1222,9 @@ class MyDevice : public Device {

virtual ~MySyncEvent() = default;
Status Wait() override { return Status::OK(); }
Status Record(const Device::Stream&) override { return Status::OK(); }
Status Record(const Device::Stream&, const unsigned int) override {
return Status::OK();
}
};

protected:
Expand Down
46 changes: 40 additions & 6 deletions cpp/src/arrow/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#pragma once

#include <cstdint>
#include <functional>
#include <memory>
#include <string>

Expand Down Expand Up @@ -109,23 +110,54 @@ class ARROW_EXPORT Device : public std::enable_shared_from_this<Device>,
/// should be trivially constructible from it's device-specific counterparts.
class ARROW_EXPORT Stream {
public:
virtual const void* get_raw() const { return NULLPTR; }
using release_fn_t = std::function<void(void*)>;

virtual ~Stream() = default;

virtual const void* get_raw() const { return stream_.get(); }

/// \brief Make the stream wait on the provided event.
///
/// Tells the stream that it should wait until the synchronization
/// event is completed without blocking the CPU.
virtual Status WaitEvent(const SyncEvent&) = 0;

/// \brief Blocks the current thread until a stream's remaining tasks are completed
virtual Status Synchronize() const = 0;

protected:
Stream() = default;
virtual ~Stream() = default;
explicit Stream(void* stream, release_fn_t release_stream)
: stream_{stream, release_stream} {}

std::unique_ptr<void, release_fn_t> stream_;
};

virtual Result<std::shared_ptr<Stream>> MakeStream() { return NULLPTR; }

/// \brief Create a new device stream
///
/// This should create the appropriate stream type for the device,
/// derived from Device::Stream to allow for stream ordered events
/// and memory allocations.
virtual Result<std::shared_ptr<Stream>> MakeStream(unsigned int flags) {
return NULLPTR;
}

/// @brief Wrap an existing device stream alongside a release function
///
/// @param device_stream a pointer to the stream to wrap
/// @param release_fn a function to call during destruction, `nullptr` or
/// a no-op function can be passed to indicate ownership is maintained
/// externally
virtual Result<std::shared_ptr<Stream>> WrapStream(void* device_stream,
Stream::release_fn_t release_fn) {
return NULLPTR;
}

/// \brief EXPERIMENTAL: An object that provides event/stream sync primitives
class ARROW_EXPORT SyncEvent {
public:
using release_fn_t = void (*)(void*);
using release_fn_t = std::function<void(void*)>;

virtual ~SyncEvent() = default;

Expand All @@ -134,9 +166,11 @@ class ARROW_EXPORT Device : public std::enable_shared_from_this<Device>,
/// @brief Block until sync event is completed.
virtual Status Wait() = 0;

inline Status Record(const Stream& st) { return Record(st, 0); }

/// @brief Record the wrapped event on the stream so it triggers
/// the event when the stream gets to that point in its queue.
virtual Status Record(const Stream&) = 0;
virtual Status Record(const Stream&, const unsigned int flags) = 0;

protected:
/// If creating this with a passed in event, the caller must ensure
Expand Down Expand Up @@ -225,7 +259,7 @@ class ARROW_EXPORT MemoryManager : public std::enable_shared_from_this<MemoryMan

/// \brief Wrap an event into a SyncEvent.
///
/// @param sync_event passed in sync_event from the imported device array.
/// @param sync_event passed in sync_event (should be a pointer to the appropriate type)
/// @param release_sync_event destructor to free sync_event. `nullptr` may be
/// passed to indicate that no destruction/freeing is necessary
virtual Result<std::shared_ptr<Device::SyncEvent>> WrapDeviceSyncEvent(
Expand Down
106 changes: 100 additions & 6 deletions cpp/src/arrow/gpu/cuda_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,10 @@
#include <utility>
#include <vector>

#include <cuda.h>

#include "arrow/gpu/cuda_internal.h"
#include "arrow/gpu/cuda_memory.h"
#include "arrow/util/checked_cast.h"
#include "arrow/util/logging.h"

namespace arrow {

Expand Down Expand Up @@ -273,6 +272,35 @@ bool IsCudaDevice(const Device& device) {
return device.type_name() == kCudaDeviceTypeName;
}

Result<std::shared_ptr<Device::Stream>> CudaDevice::MakeStream(unsigned int flags) {
ARROW_ASSIGN_OR_RAISE(auto context, GetContext());
ContextSaver set_temporary(reinterpret_cast<CUcontext>(context.get()->handle()));

CUstream stream;
CU_RETURN_NOT_OK("cuStreamCreate", cuStreamCreate(&stream, flags));
return std::shared_ptr<Device::Stream>(
new CudaDevice::Stream(context, new CUstream(stream), [](void* st) {
auto typed_stream = reinterpret_cast<CUstream*>(st);
// DCHECK_OK still evaluates its argument in release mode
// but in debug mode it'll also throw if it fails
DCHECK_OK(
internal::StatusFromCuda(cuStreamDestroy(*typed_stream), "cuStreamDestroy"));
delete typed_stream;
}));
}

Result<std::shared_ptr<Device::Stream>> CudaDevice::WrapStream(
void* stream, Device::Stream::release_fn_t release_fn) {
if (!release_fn) {
release_fn = [](void*) {};
}

auto cu_stream = reinterpret_cast<CUstream*>(stream);
ARROW_ASSIGN_OR_RAISE(auto context, GetContext());
return std::shared_ptr<Device::Stream>(
new CudaDevice::Stream(context, cu_stream, release_fn));
}

Result<std::shared_ptr<CudaDevice>> AsCudaDevice(const std::shared_ptr<Device>& device) {
if (IsCudaDevice(*device)) {
return checked_pointer_cast<CudaDevice>(device);
Expand All @@ -281,6 +309,48 @@ Result<std::shared_ptr<CudaDevice>> AsCudaDevice(const std::shared_ptr<Device>&
}
}

Status CudaDevice::Stream::WaitEvent(const Device::SyncEvent& event) {
auto cuda_event =
checked_cast<const CudaDevice::SyncEvent*, const Device::SyncEvent*>(&event);
if (!cuda_event) {
return Status::Invalid("CudaDevice::Stream cannot Wait on non-cuda event");
}

auto cu_event = cuda_event->value();
if (!cu_event) {
return Status::Invalid("Cuda Stream cannot wait on null event");
}

ContextSaver set_temporary(reinterpret_cast<CUcontext>(context_.get()->handle()));
CU_RETURN_NOT_OK("cuStreamWaitEvent",
cuStreamWaitEvent(value(), cu_event, CU_EVENT_WAIT_DEFAULT));
return Status::OK();
}

Status CudaDevice::Stream::Synchronize() const {
ContextSaver set_temporary(reinterpret_cast<CUcontext>(context_.get()->handle()));
CU_RETURN_NOT_OK("cuStreamSynchronize", cuStreamSynchronize(value()));
return Status::OK();
}

Status CudaDevice::SyncEvent::Wait() {
ContextSaver set_temporary(reinterpret_cast<CUcontext>(context_.get()->handle()));
CU_RETURN_NOT_OK("cuEventSynchronize", cuEventSynchronize(value()));
return Status::OK();
}

Status CudaDevice::SyncEvent::Record(const Device::Stream& st, const unsigned int flags) {
auto cuda_stream = checked_cast<const CudaDevice::Stream*, const Device::Stream*>(&st);
if (!cuda_stream) {
return Status::Invalid("CudaDevice::Event cannot record on non-cuda stream");
}

ContextSaver set_temporary(reinterpret_cast<CUcontext>(context_.get()->handle()));
CU_RETURN_NOT_OK("cuEventRecordWithFlags",
cuEventRecordWithFlags(value(), cuda_stream->value(), flags));
return Status::OK();
}

// ----------------------------------------------------------------------
// CudaMemoryManager implementation

Expand All @@ -293,11 +363,35 @@ std::shared_ptr<CudaDevice> CudaMemoryManager::cuda_device() const {
return checked_pointer_cast<CudaDevice>(device_);
}

Result<std::shared_ptr<Device::SyncEvent>> CudaMemoryManager::MakeDeviceSyncEvent() {
ARROW_ASSIGN_OR_RAISE(auto context, cuda_device()->GetContext());
ContextSaver set_temporary(reinterpret_cast<CUcontext>(context.get()->handle()));

// TODO: event creation flags
CUevent ev;
CU_RETURN_NOT_OK("cuEventCreate", cuEventCreate(&ev, CU_EVENT_DEFAULT));

return std::shared_ptr<Device::SyncEvent>(
new CudaDevice::SyncEvent(context, new CUevent(ev), [](void* ev) {
auto typed_event = reinterpret_cast<CUevent*>(ev);
// DCHECK_OK still evaluates its argument in release mode
// but in debug mode it'll also throw if it fails
DCHECK_OK(
internal::StatusFromCuda(cuEventDestroy(*typed_event), "cuEventDestroy"));
delete typed_event;
}));
}

Result<std::shared_ptr<Device::SyncEvent>> CudaMemoryManager::WrapDeviceSyncEvent(
void* sync_event, Device::SyncEvent::release_fn_t release_sync_event) {
return nullptr;
// auto ev = reinterpret_cast<CUstream*>(sync_event);
// return std::make_shared<CudaDeviceSync>(ev);
if (!release_sync_event) {
release_sync_event = [](void*) {};
}

auto ev = reinterpret_cast<CUevent*>(sync_event);
ARROW_ASSIGN_OR_RAISE(auto context, cuda_device()->GetContext());
return std::shared_ptr<Device::SyncEvent>(
new CudaDevice::SyncEvent(context, ev, release_sync_event));
}

Result<std::shared_ptr<io::RandomAccessFile>> CudaMemoryManager::GetBufferReader(
Expand Down Expand Up @@ -440,7 +534,7 @@ class CudaDeviceManager::Impl {
Status AllocateHost(int device_number, int64_t nbytes, uint8_t** out) {
RETURN_NOT_OK(CheckDeviceNum(device_number));
ARROW_ASSIGN_OR_RAISE(auto ctx, GetContext(device_number));
ContextSaver set_temporary((CUcontext)(ctx.get()->handle()));
ContextSaver set_temporary(reinterpret_cast<CUcontext>(ctx.get()->handle()));
CU_RETURN_NOT_OK("cuMemHostAlloc", cuMemHostAlloc(reinterpret_cast<void**>(out),
static_cast<size_t>(nbytes),
CU_MEMHOSTALLOC_PORTABLE));
Expand Down
97 changes: 97 additions & 0 deletions cpp/src/arrow/gpu/cuda_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
#include <memory>
#include <string>

#include <cuda.h>

#include "arrow/device.h"
#include "arrow/result.h"
#include "arrow/util/visibility.h"
Expand Down Expand Up @@ -140,6 +142,90 @@ class ARROW_EXPORT CudaDevice : public Device {
/// \param[in] size The buffer size in bytes
Result<std::shared_ptr<CudaHostBuffer>> AllocateHostBuffer(int64_t size);

/// \brief EXPERIMENTAL: Wrapper for CUstreams
///
/// Does not *own* the CUstream object which must be separately constructed
/// and freed using cuStreamCreate and cuStreamDestroy (or equivalent).
/// Default construction will use the cuda default stream, and does not allow
/// construction from literal 0 or nullptr.
class ARROW_EXPORT Stream : public Device::Stream {
public:
~Stream() = default;

[[nodiscard]] inline CUstream value() const noexcept {
if (!stream_) {
return CUstream{};
}
return *reinterpret_cast<CUstream*>(stream_.get());
}
operator CUstream() const noexcept { return value(); }

const void* get_raw() const noexcept override { return stream_.get(); }
Status WaitEvent(const Device::SyncEvent&) override;
Status Synchronize() const override;

protected:
friend class CudaDevice;

explicit Stream(std::shared_ptr<CudaContext> ctx, CUstream* st,
Device::Stream::release_fn_t release_fn)
: Device::Stream(reinterpret_cast<void*>(st), release_fn),
context_{std::move(ctx)} {}

// disable construction from literal 0
explicit Stream(std::shared_ptr<CudaContext>, int,
Device::Stream::release_fn_t) = delete; // Prevent cast from 0
explicit Stream(std::shared_ptr<CudaContext>, std::nullptr_t,
Device::Stream::release_fn_t) = delete; // Prevent cast from nullptr

private:
std::shared_ptr<CudaContext> context_;
};

Result<std::shared_ptr<Device::Stream>> MakeStream() override { return MakeStream(0); }

/// \brief Create a CUstream wrapper in the current context
Result<std::shared_ptr<Device::Stream>> MakeStream(unsigned int flags) override;

/// @brief Wrap a pointer to an existing stream
///
/// @param device_stream passed in stream (should be a CUstream*)
/// @param release_fn destructor to free the stream. `nullptr` may be passed
/// to indicate there is no destruction/freeing necessary.
Result<std::shared_ptr<Device::Stream>> WrapStream(
void* device_stream, Stream::release_fn_t release_fn) override;

class ARROW_EXPORT SyncEvent : public Device::SyncEvent {
public:
[[nodiscard]] CUevent value() const {
if (sync_event_) {
return *static_cast<CUevent*>(sync_event_.get());
}
return CUevent{};
}
operator CUevent() const noexcept { return value(); }

/// @brief Block until the sync event is marked completed
Status Wait() override;

/// @brief Record the wrapped event on the stream
///
/// Once the stream completes the tasks previously added to it,
/// it will trigger the event.
Status Record(const Device::Stream&, const unsigned int) override;

protected:
friend class CudaMemoryManager;

explicit SyncEvent(std::shared_ptr<CudaContext> ctx, CUevent* ev,
Device::SyncEvent::release_fn_t release_ev)
: Device::SyncEvent(reinterpret_cast<void*>(ev), release_ev),
context_{std::move(ctx)} {}

private:
std::shared_ptr<CudaContext> context_;
};

protected:
struct Impl;

Expand Down Expand Up @@ -179,6 +265,17 @@ class ARROW_EXPORT CudaMemoryManager : public MemoryManager {
/// having to cast the `device()` result.
std::shared_ptr<CudaDevice> cuda_device() const;

/// \brief Creates a wrapped CUevent.
///
/// Will call cuEventCreate and it will call cuEventDestroy internally
/// when the event is destructed.
Result<std::shared_ptr<Device::SyncEvent>> MakeDeviceSyncEvent() override;

/// \brief Wraps an existing event into a sync event.
///
/// @param sync_event the event to wrap, must be a CUevent*
/// @param release_sync_event a function to call during destruction, `nullptr` or
/// a no-op function can be passed to indicate ownership is maintained externally
Result<std::shared_ptr<Device::SyncEvent>> WrapDeviceSyncEvent(
void* sync_event, Device::SyncEvent::release_fn_t release_sync_event) override;

Expand Down
Loading

0 comments on commit 3b8ab8e

Please sign in to comment.