Skip to content

Commit

Permalink
Add TestScheduler to support testing time-based coroutines without …
Browse files Browse the repository at this point in the history
…waiting for timeouts (nv-morpheus#453)

Adds a manually driven TestScheduler that can fast-forward through delayed coroutines.

Required for nv-morpheus/Morpheus#1548

Authors:
  - Christopher Harris (https://github.com/cwharris)

Approvers:
  - Michael Demoret (https://github.com/mdemoret-nv)

URL: nv-morpheus#453
  • Loading branch information
cwharris authored Mar 25, 2024
1 parent 9cf1ebc commit bd7955e
Show file tree
Hide file tree
Showing 3 changed files with 208 additions and 0 deletions.
1 change: 1 addition & 0 deletions cpp/mrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ add_library(libmrc
src/public/coroutines/io_scheduler.cpp
src/public/coroutines/sync_wait.cpp
src/public/coroutines/task_container.cpp
src/public/coroutines/test_scheduler.cpp
src/public/coroutines/thread_local_context.cpp
src/public/coroutines/thread_pool.cpp
src/public/cuda/device_guard.cpp
Expand Down
105 changes: 105 additions & 0 deletions cpp/mrc/include/mrc/coroutines/test_scheduler.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "mrc/coroutines/scheduler.hpp"
#include "mrc/coroutines/task.hpp"

#include <chrono>
#include <coroutine>
#include <queue>
#include <utility>
#include <vector>

#pragma once

namespace mrc::coroutines {

class TestScheduler : public Scheduler
{
private:
struct Operation
{
public:
Operation(TestScheduler* self, std::chrono::time_point<std::chrono::steady_clock> time);

static constexpr bool await_ready()
{
return false;
}

void await_suspend(std::coroutine_handle<> handle);

void await_resume() {}

private:
TestScheduler* m_self;
std::chrono::time_point<std::chrono::steady_clock> m_time;
};

using item_t = std::pair<std::coroutine_handle<>, std::chrono::time_point<std::chrono::steady_clock>>;
struct ItemCompare
{
bool operator()(item_t& lhs, item_t& rhs);
};

std::priority_queue<item_t, std::vector<item_t>, ItemCompare> m_queue;
std::chrono::time_point<std::chrono::steady_clock> m_time = std::chrono::steady_clock::now();

public:
/**
* @brief Enqueue's the coroutine handle to be resumed at the current logical time.
*/
void resume(std::coroutine_handle<> handle) noexcept override;

/**
* Suspends the current function and enqueue's it to be resumed at the current logical time.
*/
mrc::coroutines::Task<> yield() override;

/**
* Suspends the current function and enqueue's it to be resumed at the current logica time + the given duration.
*/
mrc::coroutines::Task<> yield_for(std::chrono::milliseconds time) override;

/**
* Suspends the current function and enqueue's it to be resumed at the given logical time.
*/
mrc::coroutines::Task<> yield_until(std::chrono::time_point<std::chrono::steady_clock> time) override;

/**
* Immediately resumes the next-in-queue coroutine handle.
*
* @return true if more coroutines exist in the queue after resuming, false otherwise.
*/
bool resume_next();

/**
* Immediately resumes next-in-queue coroutines up to the current logical time + the given duration, in-order.
*
* @return true if more coroutines exist in the queue after resuming, false otherwise.
*/
bool resume_for(std::chrono::milliseconds time);

/**
* Immediately resumes next-in-queue coroutines up to the given logical time.
*
* @return true if more coroutines exist in the queue after resuming, false otherwise.
*/
bool resume_until(std::chrono::time_point<std::chrono::steady_clock> time);
};

} // namespace mrc::coroutines
102 changes: 102 additions & 0 deletions cpp/mrc/src/public/coroutines/test_scheduler.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "mrc/coroutines/test_scheduler.hpp"

#include <compare>

namespace mrc::coroutines {

TestScheduler::Operation::Operation(TestScheduler* self, std::chrono::time_point<std::chrono::steady_clock> time) :
m_self(self),
m_time(time)
{}

bool TestScheduler::ItemCompare::operator()(item_t& lhs, item_t& rhs)
{
return lhs.second > rhs.second;
}

void TestScheduler::Operation::await_suspend(std::coroutine_handle<> handle)
{
m_self->m_queue.emplace(std::move(handle), m_time);
}

void TestScheduler::resume(std::coroutine_handle<> handle) noexcept
{
m_queue.emplace(std::move(handle), std::chrono::steady_clock::now());
}

mrc::coroutines::Task<> TestScheduler::yield()
{
co_return co_await TestScheduler::Operation{this, m_time};
}

mrc::coroutines::Task<> TestScheduler::yield_for(std::chrono::milliseconds time)
{
co_return co_await TestScheduler::Operation{this, m_time + time};
}

mrc::coroutines::Task<> TestScheduler::yield_until(std::chrono::time_point<std::chrono::steady_clock> time)
{
co_return co_await TestScheduler::Operation{this, time};
}

bool TestScheduler::resume_next()
{
if (m_queue.empty())
{
return false;
}

auto handle = m_queue.top();

m_queue.pop();

m_time = handle.second;

handle.first.resume();

return true;
}

bool TestScheduler::resume_for(std::chrono::milliseconds time)
{
return resume_until(m_time + time);
}

bool TestScheduler::resume_until(std::chrono::time_point<std::chrono::steady_clock> time)
{
m_time = time;

while (not m_queue.empty())
{
if (m_queue.top().second <= m_time)
{
m_queue.top().first.resume();
m_queue.pop();
}
else
{
return true;
}
}

return false;
}

} // namespace mrc::coroutines

0 comments on commit bd7955e

Please sign in to comment.