diff --git a/device/device.hpp b/device/device.hpp index 64d67465..ee9ac0ee 100644 --- a/device/device.hpp +++ b/device/device.hpp @@ -1796,9 +1796,14 @@ class Device : public RuntimeObject { // Returns the status of HW event, associated with amd::Event virtual bool IsHwEventReady( - const amd::Event& event, //!< AMD event for HW status validation - bool wait = false //!< If true then forces the event completion - ) const { + const amd::Event& event, //!< AMD event for HW status validation + bool wait = false) const { //!< If true then forces the event completion + return false; + }; + + // Returns the status of HW event, associated with amd::Event + virtual bool IsHwEventReadyForcedWait( + const amd::Event& event) const { //!< AMD event for HW status validation return false; }; diff --git a/device/rocm/rocdevice.cpp b/device/rocm/rocdevice.cpp index f5d0f965..34395a24 100644 --- a/device/rocm/rocdevice.cpp +++ b/device/rocm/rocdevice.cpp @@ -2729,10 +2729,22 @@ bool Device::SetClockMode(const cl_set_device_clock_mode_input_amd setClockModeI return result; } +// ================================================================================================ +bool Device::IsHwEventReadyForcedWait(const amd::Event& event) const { + void* hw_event = + (event.NotifyEvent() != nullptr) ? event.NotifyEvent()->HwEvent() : event.HwEvent(); + if (hw_event == nullptr) { + ClPrint(amd::LOG_INFO, amd::LOG_SIG, "No HW event"); + return false; + } + static constexpr bool Timeout = true; + return WaitForSignal(reinterpret_cast(hw_event)->signal_, false, true); +} + // ================================================================================================ bool Device::IsHwEventReady(const amd::Event& event, bool wait) const { - void* hw_event = (event.NotifyEvent() != nullptr) ? - event.NotifyEvent()->HwEvent() : event.HwEvent(); + void* hw_event = + (event.NotifyEvent() != nullptr) ? event.NotifyEvent()->HwEvent() : event.HwEvent(); if (hw_event == nullptr) { ClPrint(amd::LOG_INFO, amd::LOG_SIG, "No HW event"); return false; diff --git a/device/rocm/rocdevice.hpp b/device/rocm/rocdevice.hpp index b3da3783..9619abe6 100644 --- a/device/rocm/rocdevice.hpp +++ b/device/rocm/rocdevice.hpp @@ -258,6 +258,7 @@ class NullDevice : public amd::Device { cl_set_device_clock_mode_output_amd* pSetClockModeOutput) { return true; } virtual bool IsHwEventReady(const amd::Event& event, bool wait = false) const { return false; } + virtual bool IsHwEventReadyForcedWait(const amd::Event& event) const { return false; } virtual void getHwEventTime(const amd::Event& event, uint64_t* start, uint64_t* end) const {}; virtual void ReleaseGlobalSignal(void* signal) const {} @@ -443,6 +444,7 @@ class Device : public NullDevice { cl_set_device_clock_mode_output_amd* pSetClockModeOutput); virtual bool IsHwEventReady(const amd::Event& event, bool wait = false) const; + virtual bool IsHwEventReadyForcedWait(const amd::Event& event) const; virtual void getHwEventTime(const amd::Event& event, uint64_t* start, uint64_t* end) const; virtual void ReleaseGlobalSignal(void* signal) const; diff --git a/device/rocm/rocvirtual.hpp b/device/rocm/rocvirtual.hpp index 18cc34ec..af597ef2 100644 --- a/device/rocm/rocvirtual.hpp +++ b/device/rocm/rocvirtual.hpp @@ -46,10 +46,10 @@ constexpr static uint64_t kUnlimitedWait = std::numeric_limits::max(); // Active wait time out incase same sdma engine is used again, // then just wait instead of adding dependency wait signal. -constexpr static uint64_t kSDMAEngineTimeout = 10; +constexpr static uint64_t kForcedTimeout = 10; template -inline bool WaitForSignal(hsa_signal_t signal, bool active_wait = false, bool sdma_wait = false) { +inline bool WaitForSignal(hsa_signal_t signal, bool active_wait = false, bool forced_wait = false) { if (hsa_signal_load_relaxed(signal) > 0) { uint64_t timeout = kTimeout100us; if (active_wait) { @@ -57,7 +57,7 @@ inline bool WaitForSignal(hsa_signal_t signal, bool active_wait = false, bool sd } if (active_wait_timeout) { // If diff engine, wait to 10 ms. Otherwise no wait - timeout = (sdma_wait ? kSDMAEngineTimeout : ROC_ACTIVE_WAIT_TIMEOUT) * K; + timeout = (forced_wait ? kForcedTimeout : ROC_ACTIVE_WAIT_TIMEOUT) * K; if (timeout == 0) { return false; }