Skip to content

Commit

Permalink
Allow MessageQueue to update 2+ tags at once
Browse files Browse the repository at this point in the history
  • Loading branch information
knelli2 authored and nilsdeppe committed Feb 6, 2025
1 parent bb4d647 commit 9f42a47
Show file tree
Hide file tree
Showing 12 changed files with 117 additions and 72 deletions.
19 changes: 9 additions & 10 deletions src/ControlSystem/Systems/Expansion.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "ControlSystem/UpdateControlSystem.hpp"
#include "DataStructures/DataBox/DataBox.hpp"
#include "DataStructures/DataBox/Tag.hpp"
#include "DataStructures/DataVector.hpp"
#include "DataStructures/LinkedMessageId.hpp"
#include "DataStructures/LinkedMessageQueue.hpp"
#include "Domain/Structure/ObjectLabel.hpp"
Expand Down Expand Up @@ -117,9 +118,9 @@ struct Expansion : tt::ConformsTo<protocols::ControlSystem> {
DataVector center(horizon_strahlkorper.physical_center());

Parallel::simple_action<::Actions::UpdateMessageQueue<
QueueTags::Center<Horizon>, MeasurementQueue,
UpdateControlSystem<Expansion>>>(control_sys_proxy, measurement_id,
std::move(center));
MeasurementQueue, UpdateControlSystem<Expansion>,
QueueTags::Center<Horizon>>>(control_sys_proxy, measurement_id,
std::move(center));

if (Parallel::get<Tags::Verbosity>(cache) >= ::Verbosity::Verbose) {
Parallel::printf("%s, time = %.16f: Received measurement '%s'.\n",
Expand All @@ -140,13 +141,11 @@ struct Expansion : tt::ConformsTo<protocols::ControlSystem> {
cache);

Parallel::simple_action<::Actions::UpdateMessageQueue<
QueueTags::Center<::domain::ObjectLabel::A>, MeasurementQueue,
UpdateControlSystem<Expansion>>>(control_sys_proxy, measurement_id,
DataVector(center_a));
Parallel::simple_action<::Actions::UpdateMessageQueue<
QueueTags::Center<::domain::ObjectLabel::B>, MeasurementQueue,
UpdateControlSystem<Expansion>>>(control_sys_proxy, measurement_id,
DataVector(center_b));
MeasurementQueue, UpdateControlSystem<Expansion>,
QueueTags::Center<::domain::ObjectLabel::A>,
QueueTags::Center<::domain::ObjectLabel::B>>>(
control_sys_proxy, measurement_id, DataVector(center_a),
DataVector(center_b));

if (Parallel::get<Tags::Verbosity>(cache) >= ::Verbosity::Verbose) {
Parallel::printf("%s, time = %.16f: Received measurement '%s'.\n",
Expand Down
18 changes: 8 additions & 10 deletions src/ControlSystem/Systems/Rotation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,9 @@ struct Rotation : tt::ConformsTo<protocols::ControlSystem> {
DataVector center(strahlkorper.physical_center());

Parallel::simple_action<::Actions::UpdateMessageQueue<
QueueTags::Center<Horizon>, MeasurementQueue,
UpdateControlSystem<Rotation>>>(control_sys_proxy, measurement_id,
std::move(center));
MeasurementQueue, UpdateControlSystem<Rotation>,
QueueTags::Center<Horizon>>>(control_sys_proxy, measurement_id,
std::move(center));

if (Parallel::get<Tags::Verbosity>(cache) >= ::Verbosity::Verbose) {
Parallel::printf("%s, time = %.16f: Received measurement '%s'.\n",
Expand All @@ -146,13 +146,11 @@ struct Rotation : tt::ConformsTo<protocols::ControlSystem> {
cache);

Parallel::simple_action<::Actions::UpdateMessageQueue<
QueueTags::Center<::domain::ObjectLabel::A>, MeasurementQueue,
UpdateControlSystem<Rotation>>>(control_sys_proxy, measurement_id,
DataVector(center_a));
Parallel::simple_action<::Actions::UpdateMessageQueue<
QueueTags::Center<::domain::ObjectLabel::B>, MeasurementQueue,
UpdateControlSystem<Rotation>>>(control_sys_proxy, measurement_id,
DataVector(center_b));
MeasurementQueue, UpdateControlSystem<Rotation>,
QueueTags::Center<::domain::ObjectLabel::A>,
QueueTags::Center<::domain::ObjectLabel::B>>>(
control_sys_proxy, measurement_id, DataVector(center_a),
DataVector(center_b));

if (Parallel::get<Tags::Verbosity>(cache) >= ::Verbosity::Verbose) {
Parallel::printf("%s, time = %.16f: Received measurement '%s'.\n",
Expand Down
12 changes: 6 additions & 6 deletions src/ControlSystem/Systems/Shape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,9 @@ struct Shape : tt::ConformsTo<protocols::ControlSystem> {
ControlComponent<Metavariables, Shape>>(cache);

Parallel::simple_action<::Actions::UpdateMessageQueue<
QueueTags::Horizon<Frame::Distorted>, MeasurementQueue,
UpdateControlSystem<Shape>>>(control_sys_proxy, measurement_id,
strahlkorper);
MeasurementQueue, UpdateControlSystem<Shape>,
QueueTags::Horizon<Frame::Distorted>>>(control_sys_proxy,
measurement_id, strahlkorper);

if (Parallel::get<Tags::Verbosity>(cache) >= ::Verbosity::Verbose) {
Parallel::printf("%s, time = %.16f: Received measurement '%s'.\n",
Expand All @@ -145,9 +145,9 @@ struct Shape : tt::ConformsTo<protocols::ControlSystem> {
ControlComponent<Metavariables, Shape>>(cache);

Parallel::simple_action<::Actions::UpdateMessageQueue<
QueueTags::Horizon<Frame::Distorted>, MeasurementQueue,
UpdateControlSystem<Shape>>>(control_sys_proxy, measurement_id,
strahlkorper);
MeasurementQueue, UpdateControlSystem<Shape>,
QueueTags::Horizon<Frame::Distorted>>>(
control_sys_proxy, measurement_id, strahlkorper);

if (Parallel::get<Tags::Verbosity>(cache) >= ::Verbosity::Verbose) {
Parallel::printf("%s, time = %.16f: Received measurement '%s'.\n",
Expand Down
8 changes: 4 additions & 4 deletions src/ControlSystem/Systems/Size.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ struct Size : tt::ConformsTo<protocols::ControlSystem> {
measurement_id.id);

Parallel::simple_action<::Actions::UpdateMessageQueue<
QueueTags::SizeExcisionQuantities<Frame::Distorted>, MeasurementQueue,
UpdateControlSystem<Size>>>(
MeasurementQueue, UpdateControlSystem<Size>,
QueueTags::SizeExcisionQuantities<Frame::Distorted>>>(
control_sys_proxy, measurement_id,
QueueTags::SizeExcisionQuantities<Frame::Distorted>::type{
std::move(distorted_excision_surface), lapse, shifty_quantity,
Expand All @@ -154,8 +154,8 @@ struct Size : tt::ConformsTo<protocols::ControlSystem> {
}

Parallel::simple_action<::Actions::UpdateMessageQueue<
QueueTags::SizeHorizonQuantities<Frame::Distorted>, MeasurementQueue,
UpdateControlSystem<Size>>>(
MeasurementQueue, UpdateControlSystem<Size>,
QueueTags::SizeHorizonQuantities<Frame::Distorted>>>(
control_sys_proxy, measurement_id,
QueueTags::SizeHorizonQuantities<Frame::Distorted>::type{
horizon, time_deriv_horizon});
Expand Down
22 changes: 10 additions & 12 deletions src/ControlSystem/Systems/Translation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ struct Translation : tt::ConformsTo<protocols::ControlSystem> {
ControlComponent<Metavariables, Translation>>(cache);

Parallel::simple_action<::Actions::UpdateMessageQueue<
QueueTags::Center<::domain::ObjectLabel::None>, MeasurementQueue,
UpdateControlSystem<Translation>>>(
MeasurementQueue, UpdateControlSystem<Translation>,
QueueTags::Center<::domain::ObjectLabel::None>>>(
control_sys_proxy, measurement_id,
DataVector{strahlkorper.physical_center()});

Expand All @@ -137,9 +137,9 @@ struct Translation : tt::ConformsTo<protocols::ControlSystem> {
DataVector center(strahlkorper.physical_center());

Parallel::simple_action<::Actions::UpdateMessageQueue<
QueueTags::Center<Horizon>, MeasurementQueue,
UpdateControlSystem<Translation>>>(control_sys_proxy, measurement_id,
std::move(center));
MeasurementQueue, UpdateControlSystem<Translation>,
QueueTags::Center<Horizon>>>(control_sys_proxy, measurement_id,
std::move(center));

if (Parallel::get<Tags::Verbosity>(cache) >= ::Verbosity::Verbose) {
Parallel::printf("%s, time = %.16f: Received measurement '%s'.\n",
Expand All @@ -159,13 +159,11 @@ struct Translation : tt::ConformsTo<protocols::ControlSystem> {
ControlComponent<Metavariables, Translation>>(cache);

Parallel::simple_action<::Actions::UpdateMessageQueue<
QueueTags::Center<::domain::ObjectLabel::A>, MeasurementQueue,
UpdateControlSystem<Translation>>>(control_sys_proxy, measurement_id,
DataVector(center_a));
Parallel::simple_action<::Actions::UpdateMessageQueue<
QueueTags::Center<::domain::ObjectLabel::B>, MeasurementQueue,
UpdateControlSystem<Translation>>>(control_sys_proxy, measurement_id,
DataVector(center_b));
MeasurementQueue, UpdateControlSystem<Translation>,
QueueTags::Center<::domain::ObjectLabel::A>,
QueueTags::Center<::domain::ObjectLabel::B>>>(
control_sys_proxy, measurement_id, DataVector(center_a),
DataVector(center_b));

if (Parallel::get<Tags::Verbosity>(cache) >= ::Verbosity::Verbose) {
Parallel::printf("%s, time = %.16f: Received measurement '%s'.\n",
Expand Down
47 changes: 41 additions & 6 deletions src/DataStructures/LinkedMessageQueue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <optional>
#include <ostream>
#include <pup.h>
#include <type_traits>
#include <unordered_map>
#include <utility>

Expand Down Expand Up @@ -43,6 +44,17 @@ class LinkedMessageQueue<Id, tmpl::list<QueueTags...>> {
void insert(const LinkedMessageId<Id>& id_and_previous,
typename Tag::type message);

/// Insert multiple data at once into a given queue at a given ID. All queues
/// must receive data with the same collection of \p id_and_previous, but are
/// not required to receive them in the same order.
///
/// \details Tags are inserted in the order they are passed in. Duplicate tags
/// are not allowed.
template <typename Tag1, typename Tag2, typename... Tags>
void insert(const LinkedMessageId<Id>& id_and_previous,
typename Tag1::type message_1, typename Tag2::type message_2,
typename Tags::type... messages);

/// The next ID in the received sequence, if all queues have
/// received messages at that ID.
std::optional<Id> next_ready_id() const;
Expand Down Expand Up @@ -84,16 +96,39 @@ void LinkedMessageQueue<Id, tmpl::list<QueueTags...>>::insert(
std::pair{id_and_previous.id, OptionalTuple{}}})
.first;
auto& [id, tuple] = entry->second;
ASSERT(id_and_previous.id == id,
"Received messages with different ids (" << id << " and "
<< id_and_previous.id << ") but the same previous id ("
<< id_and_previous.previous << ").");
ASSERT(id_and_previous.id == id, "Received messages with different ids ("
<< id << " and " << id_and_previous.id
<< ") but the same previous id ("
<< id_and_previous.previous << ").");
ASSERT(not tuples::get<Optional<Tag>>(tuple).has_value(),
"Received duplicate messages at id " << id << " and previous id "
<< id_and_previous.previous << ".");
"Received duplicate messages at id "
<< id << " and previous id " << id_and_previous.previous << ".");
tuples::get<Optional<Tag>>(tuple) = std::move(message);
}

template <typename Id, typename... QueueTags>
template <typename Tag1, typename Tag2, typename... Tags>
void LinkedMessageQueue<Id, tmpl::list<QueueTags...>>::insert(
const LinkedMessageId<Id>& id_and_previous, typename Tag1::type message_1,
typename Tag2::type message_2, typename Tags::type... messages) {
static_assert(
tmpl::size<
tmpl::remove_duplicates<tmpl::list<Tag1, Tag2, Tags...>>>::value ==
sizeof...(Tags) + 2,
"Must have unique tags in LinkedMessageQueue insert.");
insert<Tag1>(id_and_previous, std::move(message_1));
insert<Tag2>(id_and_previous, std::move(message_2));

[[maybe_unused]] const auto insert_pack =
[this, &id_and_previous](const auto& tag_v, auto message) {
(void)this;
using tag = std::decay_t<decltype(tag_v)>;
insert<tag>(id_and_previous, std::move(message));
};

EXPAND_PACK_LEFT_TO_RIGHT(insert_pack(Tags{}, std::move(messages)));
}

template <typename Id, typename... QueueTags>
std::optional<Id>
LinkedMessageQueue<Id, tmpl::list<QueueTags...>>::next_ready_id() const {
Expand Down
11 changes: 7 additions & 4 deletions src/ParallelAlgorithms/Actions/UpdateMessageQueue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ namespace Actions {
/// `Queue1` and `Queue2` with ID type `int` is:
///
/// \snippet Test_UpdateMessageQueue.cpp Processor::apply
template <typename QueueTag, typename LinkedMessageQueueTag, typename Processor>
template <typename LinkedMessageQueueTag, typename Processor,
typename... QueueTags>
struct UpdateMessageQueue {
template <typename ParallelComponent, typename DbTags, typename Metavariables,
typename ArrayIndex>
Expand All @@ -43,16 +44,18 @@ struct UpdateMessageQueue {
const ArrayIndex& array_index,
const LinkedMessageId<typename LinkedMessageQueueTag::type::IdType>&
id_and_previous,
typename QueueTag::type message) {
typename QueueTags::type... messages) {
if (not domain::functions_of_time_are_ready_simple_action_callback<
domain::Tags::FunctionsOfTime, UpdateMessageQueue>(
cache, array_index, std::add_pointer_t<ParallelComponent>{nullptr},
id_and_previous.id, std::nullopt, id_and_previous, message)) {
id_and_previous.id, std::nullopt, id_and_previous,
std::move(messages)...)) {
return;
}
auto& queue =
db::get_mutable_reference<LinkedMessageQueueTag>(make_not_null(&box));
queue.template insert<QueueTag>(id_and_previous, std::move(message));
queue.template insert<QueueTags...>(id_and_previous,
std::move(messages)...);
while (auto id = queue.next_ready_id()) {
Processor::apply(make_not_null(&box), cache, array_index, *id,
queue.extract());
Expand Down
15 changes: 8 additions & 7 deletions tests/Unit/ControlSystem/Systems/Test_Size.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ using all_tags = measurement_queue::type::queue_tags_list;

size_t message_queue_call_count = 0;

template <typename QueueTag, typename LinkedMessageQueueTag, typename Processor>
template <typename LinkedMessageQueueTag, typename Processor,
typename... QueueTags>
struct MockUpdateMessageQueue {
template <typename ParallelComponent, typename DbTags, typename Metavariables,
typename ArrayIndex>
Expand All @@ -65,21 +66,21 @@ struct MockUpdateMessageQueue {
const ArrayIndex& /*array_index*/,
const LinkedMessageId<typename LinkedMessageQueueTag::type::IdType>&
/*id_and_previous*/,
typename QueueTag::type /*message*/) {
typename QueueTags::type... /*message*/) {
++message_queue_call_count;
}
};

// The Nvidia compiler crashes if we define these lists inside the MockComponent
// struct.
using replace_these_simple_actions_mock_component = tmpl::transform<
all_tags, tmpl::bind<::Actions::UpdateMessageQueue, tmpl::_1,
tmpl::pin<measurement_queue>,
tmpl::pin<control_system::UpdateControlSystem<size>>>>;
all_tags,
tmpl::bind<::Actions::UpdateMessageQueue, tmpl::pin<measurement_queue>,
tmpl::pin<control_system::UpdateControlSystem<size>>, tmpl::_1>>;
using with_these_simple_actions_mock_component = tmpl::transform<
all_tags,
tmpl::bind<MockUpdateMessageQueue, tmpl::_1, tmpl::pin<measurement_queue>,
tmpl::pin<control_system::UpdateControlSystem<size>>>>;
tmpl::bind<MockUpdateMessageQueue, tmpl::pin<measurement_queue>,
tmpl::pin<control_system::UpdateControlSystem<size>>, tmpl::_1>>;

template <typename Metavariables>
struct MockComponent
Expand Down
7 changes: 4 additions & 3 deletions tests/Unit/ControlSystem/Test_RunCallbacks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,10 @@ struct System : tt::ConformsTo<control_system::protocols::ControlSystem> {
auto& control_system_proxy = Parallel::get_parallel_component<
ControlComponent<Metavariables, System>>(cache);
Parallel::simple_action<::Actions::UpdateMessageQueue<
SubmeasurementQueueTag, MeasurementQueue,
control_system::TestHelpers::SomeControlSystemUpdater>>(
control_system_proxy, measurement_id, measurement_result);
MeasurementQueue,
control_system::TestHelpers::SomeControlSystemUpdater,
SubmeasurementQueueTag>>(control_system_proxy, measurement_id,
measurement_result);
}
};
};
Expand Down
5 changes: 2 additions & 3 deletions tests/Unit/DataStructures/Test_LinkedMessageQueue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,8 @@ void test_queue() {
LinkedMessageQueue<int, tmpl::list<MyQueue<Label1>, MyQueue<Label2>>> queue{};
CHECK(not queue.next_ready_id().has_value());

queue.insert<MyQueue<Label1>>({1, {}}, std::make_unique<double>(1.1));
CHECK(not queue.next_ready_id().has_value());
queue.insert<MyQueue<Label2>>({1, {}}, std::make_unique<double>(-1.1));
queue.insert<MyQueue<Label1>, MyQueue<Label2>>(
{1, {}}, std::make_unique<double>(1.1), std::make_unique<double>(-1.1));

CHECK(queue.next_ready_id() == std::optional{1});
{
Expand Down
6 changes: 3 additions & 3 deletions tests/Unit/Helpers/ControlSystem/Examples.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,9 @@ struct ExampleControlSystem
auto& control_system_proxy = Parallel::get_parallel_component<
ControlComponent<Metavariables, ExampleControlSystem>>(cache);
Parallel::simple_action<::Actions::UpdateMessageQueue<
ExampleSubmeasurementQueueTag, MeasurementQueue,
SomeControlSystemUpdater>>(control_system_proxy, measurement_id,
measurement_result);
MeasurementQueue, SomeControlSystemUpdater,
ExampleSubmeasurementQueueTag>>(control_system_proxy, measurement_id,
measurement_result);
}
};
};
Expand Down
19 changes: 15 additions & 4 deletions tests/Unit/ParallelAlgorithms/Actions/Test_UpdateMessageQueue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ SPECTRE_TEST_CASE("Unit.Actions.UpdateMessageQueue", "[Unit][Actions]") {
const LinkedMessageId<int>& id,
auto data) -> decltype(auto) {
ActionTesting::simple_action<
component, Actions::UpdateMessageQueue<
decltype(queue_v), LinkedMessageQueueTag, Processor>>(
component, Actions::UpdateMessageQueue<LinkedMessageQueueTag, Processor,
decltype(queue_v)>>(
make_not_null(&runner), 0, id, std::move(data));
return db::mutate<ProcessorCalls>(
[](const gsl::not_null<ProcessorCalls::type*> calls) {
Expand All @@ -131,9 +131,20 @@ SPECTRE_TEST_CASE("Unit.Actions.UpdateMessageQueue", "[Unit][Actions]") {
&ActionTesting::get_databox<component>(make_not_null(&runner), 0)));
};

CHECK(processed_by_call(Queue1{}, {0, {}}, 1.23).empty());
{
const auto processed = processed_by_call(Queue2{}, {0, {}}, 2.34);
// Test two tags at once
ActionTesting::simple_action<
component, Actions::UpdateMessageQueue<LinkedMessageQueueTag, Processor,
Queue1, Queue2>>(
make_not_null(&runner), 0, LinkedMessageId<int>{0, {}}, 1.23, 2.34);
const auto processed = db::mutate<ProcessorCalls>(
[](const gsl::not_null<ProcessorCalls::type*> calls) {
auto ret = std::move(*calls);
calls->clear();
return ret;
},
make_not_null(
&ActionTesting::get_databox<component>(make_not_null(&runner), 0)));
CHECK(processed.size() == 1);

CHECK(processed[0].first == 0);
Expand Down

0 comments on commit 9f42a47

Please sign in to comment.