Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 54 additions & 57 deletions src/runtime_src/core/common/api/xrt_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -694,9 +694,14 @@ class kernel_command : public xrt_core::command
public:
using execbuf_type = xrt_core::bo_cache::cmd_bo<ert_start_kernel_cmd>;
using callback_function_type = std::function<void(ert_cmd_state)>;
using callback_list = std::vector<callback_function_type>;

private:
struct callback_properties {
callback_function_type fn;
mutable bool called;
};
using callback_list = std::vector<callback_properties>;

// Return state of underlying exec buffer packet This is an
// asynchronous call, the command object may not be in the same
// state as reflected by the return value.
Expand Down Expand Up @@ -758,7 +763,9 @@ class kernel_command : public xrt_core::command
bool
is_done() const
{
std::lock_guard<std::mutex> lk(m_mutex);
if (!m_done && !m_managed) {
m_done = get_state() >= ERT_CMD_STATE_COMPLETED;
}
return m_done;
}

Expand All @@ -772,9 +779,7 @@ class kernel_command : public xrt_core::command
// is a no-op on platforms where command state is live.
m_hwqueue.poll(this);

auto state = get_state_raw();
notify(state); // update command state accordingly
return state;
return get_state_raw();
}

// Return kernel return code from command object for PS kernels
Expand Down Expand Up @@ -814,52 +819,25 @@ class kernel_command : public xrt_core::command
std::lock_guard<std::mutex> lk(m_mutex);
if (!m_managed && !m_done)
throw xrt_core::error(ENOTSUP, "Cannot add callback to running unmanaged command");
if (!m_callbacks)
m_callbacks = std::make_unique<callback_list>();
m_callbacks->emplace_back(std::move(fcn));
auto pkt = get_ert_packet();
state = static_cast<ert_cmd_state>(pkt->state);
complete = m_done && state >= ERT_CMD_STATE_COMPLETED;
m_callbacks.emplace_back(callback_properties({std::move(fcn), complete}));
}

// lock must not be helt while calling callback function
// lock must not be held while calling callback function
if (complete)
m_callbacks.get()->back()(state);
fcn(state);
}

// Remove last added callback
void
pop_callback()
{
if (m_callbacks && m_callbacks->size())
m_callbacks->pop_back();
}

// Run registered callbacks.
void
run_callbacks(ert_cmd_state state) const
{
{
std::lock_guard<std::mutex> lk(m_mutex);
if (!m_callbacks)
return;
}

// cannot lock mutex while calling the callbacks
// so copy address of callbacks while holding the lock
// then execute callbacks without lock
std::vector<callback_function_type*> copy;
copy.reserve(m_callbacks->size());

{
std::lock_guard<std::mutex> lk(m_mutex);
std::transform(m_callbacks->begin(),m_callbacks->end()
,std::back_inserter(copy)
,[](callback_function_type& cb) { return &cb; });
std::lock_guard<std::mutex> lk(m_mutex);
if (!m_callbacks.empty()) {
m_callbacks.pop_back();
}

for (auto cb : copy)
(*cb)(state);
}

// Submit the command for execution.
Expand All @@ -870,7 +848,9 @@ class kernel_command : public xrt_core::command
std::lock_guard<std::mutex> lk(m_mutex);
if (!m_done)
throw std::runtime_error("bad command state, can't launch");
m_managed = (m_callbacks && !m_callbacks->empty());
m_managed = !m_callbacks.empty();
for (auto& cb : m_callbacks)
cb.called = false;
m_done = false;
}

Expand All @@ -883,7 +863,6 @@ class kernel_command : public xrt_core::command
catch (...) {
// Start failed, m_done remains true
// command can be retried if needed
std::lock_guard<std::mutex> lk(m_mutex);
m_done = true;
throw;
}
Expand All @@ -902,7 +881,12 @@ class kernel_command : public xrt_core::command
m_hwqueue.wait(this);
}

return get_state_raw(); // state wont change after wait
auto state = get_state_raw();
if (!m_done) {
m_done = state >= ERT_CMD_STATE_COMPLETED;
}

return state; // state wont change after wait
}

std::pair<ert_cmd_state, std::cv_status>
Expand All @@ -919,7 +903,11 @@ class kernel_command : public xrt_core::command
return {get_state_raw(), std::cv_status::timeout};
}

return {get_state_raw(), std::cv_status::no_timeout};
auto state = get_state_raw();
if (!m_done) {
m_done = state >= ERT_CMD_STATE_COMPLETED;
}
return {state, std::cv_status::no_timeout};
}

////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -954,27 +942,36 @@ class kernel_command : public xrt_core::command
void
notify(ert_cmd_state s) const override
{
bool complete = false;
bool callbacks = false;
if (s >= ERT_CMD_STATE_COMPLETED) {

if (s < ERT_CMD_STATE_COMPLETED)
throw std::runtime_error("bad command state, notify() is only expected to be called after command completes");

std::vector<callback_function_type> copy;
{
std::lock_guard<std::mutex> lk(m_mutex);

// Handle potential race if multiple threads end up here. This
// condition is by design because there are multiple paths into
// this function and first conditional check should not be locked
if (m_done)
return;

XRT_DEBUGF("kernel_command::notify() m_uid(%d) m_state(%d)\n", m_uid, s);
complete = m_done = true;
callbacks = (m_callbacks && !m_callbacks->empty());
m_done = true;

// cannot lock mutex while calling the callbacks
// so copy address of callbacks while holding the lock
// then execute callbacks without lock
if (!m_callbacks.empty()) {
copy.reserve(m_callbacks.size());
for (auto& cb : m_callbacks) {
if (!cb.called) {
cb.called = true;
copy.emplace_back(cb.fn);
}
}
}
}

if (complete) {
m_exec_done.notify_all();
if (callbacks)
run_callbacks(s);
}
m_exec_done.notify_all();
for (const auto& cb : copy)
cb(s);
}

void
Expand All @@ -998,7 +995,7 @@ class kernel_command : public xrt_core::command
mutable std::mutex m_mutex;
mutable std::condition_variable m_exec_done;

std::unique_ptr<callback_list> m_callbacks;
callback_list m_callbacks; // don't see any reason this was a pointer?
};

// class argument - get argument value from va_arg
Expand Down
Loading