Skip to content
This repository has been archived by the owner on Jan 26, 2024. It is now read-only.

Commit

Permalink
SWDEV-380024 - Fix performance drop in TF-RCCL models
Browse files Browse the repository at this point in the history
Change-Id: Idc845bb0dab858b94b9d2720cae8308cac2e7328
  • Loading branch information
Anusha GodavarthySurya authored and Anusha Godavarthy Surya committed Feb 15, 2023
1 parent 6d8fc0b commit 1cf8f19
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 8 deletions.
11 changes: 8 additions & 3 deletions device/device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1796,9 +1796,14 @@ class Device : public RuntimeObject {

// Returns the status of HW event, associated with amd::Event
virtual bool IsHwEventReady(
const amd::Event& event, //!< AMD event for HW status validation
bool wait = false //!< If true then forces the event completion
) const {
const amd::Event& event, //!< AMD event for HW status validation
bool wait = false) const { //!< If true then forces the event completion
return false;
};

// Returns the status of HW event, associated with amd::Event
virtual bool IsHwEventReadyForcedWait(
const amd::Event& event) const { //!< AMD event for HW status validation
return false;
};

Expand Down
16 changes: 14 additions & 2 deletions device/rocm/rocdevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2729,10 +2729,22 @@ bool Device::SetClockMode(const cl_set_device_clock_mode_input_amd setClockModeI
return result;
}

// ================================================================================================
bool Device::IsHwEventReadyForcedWait(const amd::Event& event) const {
void* hw_event =
(event.NotifyEvent() != nullptr) ? event.NotifyEvent()->HwEvent() : event.HwEvent();
if (hw_event == nullptr) {
ClPrint(amd::LOG_INFO, amd::LOG_SIG, "No HW event");
return false;
}
static constexpr bool Timeout = true;
return WaitForSignal<Timeout>(reinterpret_cast<ProfilingSignal*>(hw_event)->signal_, false, true);
}

// ================================================================================================
bool Device::IsHwEventReady(const amd::Event& event, bool wait) const {
void* hw_event = (event.NotifyEvent() != nullptr) ?
event.NotifyEvent()->HwEvent() : event.HwEvent();
void* hw_event =
(event.NotifyEvent() != nullptr) ? event.NotifyEvent()->HwEvent() : event.HwEvent();
if (hw_event == nullptr) {
ClPrint(amd::LOG_INFO, amd::LOG_SIG, "No HW event");
return false;
Expand Down
2 changes: 2 additions & 0 deletions device/rocm/rocdevice.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ class NullDevice : public amd::Device {
cl_set_device_clock_mode_output_amd* pSetClockModeOutput) { return true; }

virtual bool IsHwEventReady(const amd::Event& event, bool wait = false) const { return false; }
virtual bool IsHwEventReadyForcedWait(const amd::Event& event) const { return false; }
virtual void getHwEventTime(const amd::Event& event, uint64_t* start, uint64_t* end) const {};
virtual void ReleaseGlobalSignal(void* signal) const {}

Expand Down Expand Up @@ -443,6 +444,7 @@ class Device : public NullDevice {
cl_set_device_clock_mode_output_amd* pSetClockModeOutput);

virtual bool IsHwEventReady(const amd::Event& event, bool wait = false) const;
virtual bool IsHwEventReadyForcedWait(const amd::Event& event) const;
virtual void getHwEventTime(const amd::Event& event, uint64_t* start, uint64_t* end) const;
virtual void ReleaseGlobalSignal(void* signal) const;

Expand Down
6 changes: 3 additions & 3 deletions device/rocm/rocvirtual.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,18 @@ constexpr static uint64_t kUnlimitedWait = std::numeric_limits<uint64_t>::max();

// Active wait time out incase same sdma engine is used again,
// then just wait instead of adding dependency wait signal.
constexpr static uint64_t kSDMAEngineTimeout = 10;
constexpr static uint64_t kForcedTimeout = 10;

template <bool active_wait_timeout = false>
inline bool WaitForSignal(hsa_signal_t signal, bool active_wait = false, bool sdma_wait = false) {
inline bool WaitForSignal(hsa_signal_t signal, bool active_wait = false, bool forced_wait = false) {
if (hsa_signal_load_relaxed(signal) > 0) {
uint64_t timeout = kTimeout100us;
if (active_wait) {
timeout = kUnlimitedWait;
}
if (active_wait_timeout) {
// If diff engine, wait to 10 ms. Otherwise no wait
timeout = (sdma_wait ? kSDMAEngineTimeout : ROC_ACTIVE_WAIT_TIMEOUT) * K;
timeout = (forced_wait ? kForcedTimeout : ROC_ACTIVE_WAIT_TIMEOUT) * K;
if (timeout == 0) {
return false;
}
Expand Down

0 comments on commit 1cf8f19

Please sign in to comment.