Skip to content

Commit ce61833

Browse files
authored
Merge pull request #6476 from nilsdeppe/bns_improvements_multiple_messages
Allow MessageQueue to update 2+ tags at once
2 parents f6cb742 + 9f42a47 commit ce61833

File tree

12 files changed

+117
-72
lines changed

12 files changed

+117
-72
lines changed

src/ControlSystem/Systems/Expansion.hpp

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "ControlSystem/UpdateControlSystem.hpp"
2121
#include "DataStructures/DataBox/DataBox.hpp"
2222
#include "DataStructures/DataBox/Tag.hpp"
23+
#include "DataStructures/DataVector.hpp"
2324
#include "DataStructures/LinkedMessageId.hpp"
2425
#include "DataStructures/LinkedMessageQueue.hpp"
2526
#include "Domain/Structure/ObjectLabel.hpp"
@@ -117,9 +118,9 @@ struct Expansion : tt::ConformsTo<protocols::ControlSystem> {
117118
DataVector center(horizon_strahlkorper.physical_center());
118119

119120
Parallel::simple_action<::Actions::UpdateMessageQueue<
120-
QueueTags::Center<Horizon>, MeasurementQueue,
121-
UpdateControlSystem<Expansion>>>(control_sys_proxy, measurement_id,
122-
std::move(center));
121+
MeasurementQueue, UpdateControlSystem<Expansion>,
122+
QueueTags::Center<Horizon>>>(control_sys_proxy, measurement_id,
123+
std::move(center));
123124

124125
if (Parallel::get<Tags::Verbosity>(cache) >= ::Verbosity::Verbose) {
125126
Parallel::printf("%s, time = %.16f: Received measurement '%s'.\n",
@@ -140,13 +141,11 @@ struct Expansion : tt::ConformsTo<protocols::ControlSystem> {
140141
cache);
141142

142143
Parallel::simple_action<::Actions::UpdateMessageQueue<
143-
QueueTags::Center<::domain::ObjectLabel::A>, MeasurementQueue,
144-
UpdateControlSystem<Expansion>>>(control_sys_proxy, measurement_id,
145-
DataVector(center_a));
146-
Parallel::simple_action<::Actions::UpdateMessageQueue<
147-
QueueTags::Center<::domain::ObjectLabel::B>, MeasurementQueue,
148-
UpdateControlSystem<Expansion>>>(control_sys_proxy, measurement_id,
149-
DataVector(center_b));
144+
MeasurementQueue, UpdateControlSystem<Expansion>,
145+
QueueTags::Center<::domain::ObjectLabel::A>,
146+
QueueTags::Center<::domain::ObjectLabel::B>>>(
147+
control_sys_proxy, measurement_id, DataVector(center_a),
148+
DataVector(center_b));
150149

151150
if (Parallel::get<Tags::Verbosity>(cache) >= ::Verbosity::Verbose) {
152151
Parallel::printf("%s, time = %.16f: Received measurement '%s'.\n",

src/ControlSystem/Systems/Rotation.hpp

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,9 @@ struct Rotation : tt::ConformsTo<protocols::ControlSystem> {
123123
DataVector center(strahlkorper.physical_center());
124124

125125
Parallel::simple_action<::Actions::UpdateMessageQueue<
126-
QueueTags::Center<Horizon>, MeasurementQueue,
127-
UpdateControlSystem<Rotation>>>(control_sys_proxy, measurement_id,
128-
std::move(center));
126+
MeasurementQueue, UpdateControlSystem<Rotation>,
127+
QueueTags::Center<Horizon>>>(control_sys_proxy, measurement_id,
128+
std::move(center));
129129

130130
if (Parallel::get<Tags::Verbosity>(cache) >= ::Verbosity::Verbose) {
131131
Parallel::printf("%s, time = %.16f: Received measurement '%s'.\n",
@@ -146,13 +146,11 @@ struct Rotation : tt::ConformsTo<protocols::ControlSystem> {
146146
cache);
147147

148148
Parallel::simple_action<::Actions::UpdateMessageQueue<
149-
QueueTags::Center<::domain::ObjectLabel::A>, MeasurementQueue,
150-
UpdateControlSystem<Rotation>>>(control_sys_proxy, measurement_id,
151-
DataVector(center_a));
152-
Parallel::simple_action<::Actions::UpdateMessageQueue<
153-
QueueTags::Center<::domain::ObjectLabel::B>, MeasurementQueue,
154-
UpdateControlSystem<Rotation>>>(control_sys_proxy, measurement_id,
155-
DataVector(center_b));
149+
MeasurementQueue, UpdateControlSystem<Rotation>,
150+
QueueTags::Center<::domain::ObjectLabel::A>,
151+
QueueTags::Center<::domain::ObjectLabel::B>>>(
152+
control_sys_proxy, measurement_id, DataVector(center_a),
153+
DataVector(center_b));
156154

157155
if (Parallel::get<Tags::Verbosity>(cache) >= ::Verbosity::Verbose) {
158156
Parallel::printf("%s, time = %.16f: Received measurement '%s'.\n",

src/ControlSystem/Systems/Shape.hpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,9 @@ struct Shape : tt::ConformsTo<protocols::ControlSystem> {
120120
ControlComponent<Metavariables, Shape>>(cache);
121121

122122
Parallel::simple_action<::Actions::UpdateMessageQueue<
123-
QueueTags::Horizon<Frame::Distorted>, MeasurementQueue,
124-
UpdateControlSystem<Shape>>>(control_sys_proxy, measurement_id,
125-
strahlkorper);
123+
MeasurementQueue, UpdateControlSystem<Shape>,
124+
QueueTags::Horizon<Frame::Distorted>>>(control_sys_proxy,
125+
measurement_id, strahlkorper);
126126

127127
if (Parallel::get<Tags::Verbosity>(cache) >= ::Verbosity::Verbose) {
128128
Parallel::printf("%s, time = %.16f: Received measurement '%s'.\n",
@@ -145,9 +145,9 @@ struct Shape : tt::ConformsTo<protocols::ControlSystem> {
145145
ControlComponent<Metavariables, Shape>>(cache);
146146

147147
Parallel::simple_action<::Actions::UpdateMessageQueue<
148-
QueueTags::Horizon<Frame::Distorted>, MeasurementQueue,
149-
UpdateControlSystem<Shape>>>(control_sys_proxy, measurement_id,
150-
strahlkorper);
148+
MeasurementQueue, UpdateControlSystem<Shape>,
149+
QueueTags::Horizon<Frame::Distorted>>>(
150+
control_sys_proxy, measurement_id, strahlkorper);
151151

152152
if (Parallel::get<Tags::Verbosity>(cache) >= ::Verbosity::Verbose) {
153153
Parallel::printf("%s, time = %.16f: Received measurement '%s'.\n",

src/ControlSystem/Systems/Size.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,8 @@ struct Size : tt::ConformsTo<protocols::ControlSystem> {
129129
measurement_id.id);
130130

131131
Parallel::simple_action<::Actions::UpdateMessageQueue<
132-
QueueTags::SizeExcisionQuantities<Frame::Distorted>, MeasurementQueue,
133-
UpdateControlSystem<Size>>>(
132+
MeasurementQueue, UpdateControlSystem<Size>,
133+
QueueTags::SizeExcisionQuantities<Frame::Distorted>>>(
134134
control_sys_proxy, measurement_id,
135135
QueueTags::SizeExcisionQuantities<Frame::Distorted>::type{
136136
std::move(distorted_excision_surface), lapse, shifty_quantity,
@@ -154,8 +154,8 @@ struct Size : tt::ConformsTo<protocols::ControlSystem> {
154154
}
155155

156156
Parallel::simple_action<::Actions::UpdateMessageQueue<
157-
QueueTags::SizeHorizonQuantities<Frame::Distorted>, MeasurementQueue,
158-
UpdateControlSystem<Size>>>(
157+
MeasurementQueue, UpdateControlSystem<Size>,
158+
QueueTags::SizeHorizonQuantities<Frame::Distorted>>>(
159159
control_sys_proxy, measurement_id,
160160
QueueTags::SizeHorizonQuantities<Frame::Distorted>::type{
161161
horizon, time_deriv_horizon});

src/ControlSystem/Systems/Translation.hpp

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,8 @@ struct Translation : tt::ConformsTo<protocols::ControlSystem> {
113113
ControlComponent<Metavariables, Translation>>(cache);
114114

115115
Parallel::simple_action<::Actions::UpdateMessageQueue<
116-
QueueTags::Center<::domain::ObjectLabel::None>, MeasurementQueue,
117-
UpdateControlSystem<Translation>>>(
116+
MeasurementQueue, UpdateControlSystem<Translation>,
117+
QueueTags::Center<::domain::ObjectLabel::None>>>(
118118
control_sys_proxy, measurement_id,
119119
DataVector{strahlkorper.physical_center()});
120120

@@ -137,9 +137,9 @@ struct Translation : tt::ConformsTo<protocols::ControlSystem> {
137137
DataVector center(strahlkorper.physical_center());
138138

139139
Parallel::simple_action<::Actions::UpdateMessageQueue<
140-
QueueTags::Center<Horizon>, MeasurementQueue,
141-
UpdateControlSystem<Translation>>>(control_sys_proxy, measurement_id,
142-
std::move(center));
140+
MeasurementQueue, UpdateControlSystem<Translation>,
141+
QueueTags::Center<Horizon>>>(control_sys_proxy, measurement_id,
142+
std::move(center));
143143

144144
if (Parallel::get<Tags::Verbosity>(cache) >= ::Verbosity::Verbose) {
145145
Parallel::printf("%s, time = %.16f: Received measurement '%s'.\n",
@@ -159,13 +159,11 @@ struct Translation : tt::ConformsTo<protocols::ControlSystem> {
159159
ControlComponent<Metavariables, Translation>>(cache);
160160

161161
Parallel::simple_action<::Actions::UpdateMessageQueue<
162-
QueueTags::Center<::domain::ObjectLabel::A>, MeasurementQueue,
163-
UpdateControlSystem<Translation>>>(control_sys_proxy, measurement_id,
164-
DataVector(center_a));
165-
Parallel::simple_action<::Actions::UpdateMessageQueue<
166-
QueueTags::Center<::domain::ObjectLabel::B>, MeasurementQueue,
167-
UpdateControlSystem<Translation>>>(control_sys_proxy, measurement_id,
168-
DataVector(center_b));
162+
MeasurementQueue, UpdateControlSystem<Translation>,
163+
QueueTags::Center<::domain::ObjectLabel::A>,
164+
QueueTags::Center<::domain::ObjectLabel::B>>>(
165+
control_sys_proxy, measurement_id, DataVector(center_a),
166+
DataVector(center_b));
169167

170168
if (Parallel::get<Tags::Verbosity>(cache) >= ::Verbosity::Verbose) {
171169
Parallel::printf("%s, time = %.16f: Received measurement '%s'.\n",

src/DataStructures/LinkedMessageQueue.hpp

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <optional>
77
#include <ostream>
88
#include <pup.h>
9+
#include <type_traits>
910
#include <unordered_map>
1011
#include <utility>
1112

@@ -43,6 +44,17 @@ class LinkedMessageQueue<Id, tmpl::list<QueueTags...>> {
4344
void insert(const LinkedMessageId<Id>& id_and_previous,
4445
typename Tag::type message);
4546

47+
/// Insert multiple data at once into a given queue at a given ID. All queues
48+
/// must receive data with the same collection of \p id_and_previous, but are
49+
/// not required to receive them in the same order.
50+
///
51+
/// \details Tags are inserted in the order they are passed in. Duplicate tags
52+
/// are not allowed.
53+
template <typename Tag1, typename Tag2, typename... Tags>
54+
void insert(const LinkedMessageId<Id>& id_and_previous,
55+
typename Tag1::type message_1, typename Tag2::type message_2,
56+
typename Tags::type... messages);
57+
4658
/// The next ID in the received sequence, if all queues have
4759
/// received messages at that ID.
4860
std::optional<Id> next_ready_id() const;
@@ -84,16 +96,39 @@ void LinkedMessageQueue<Id, tmpl::list<QueueTags...>>::insert(
8496
std::pair{id_and_previous.id, OptionalTuple{}}})
8597
.first;
8698
auto& [id, tuple] = entry->second;
87-
ASSERT(id_and_previous.id == id,
88-
"Received messages with different ids (" << id << " and "
89-
<< id_and_previous.id << ") but the same previous id ("
90-
<< id_and_previous.previous << ").");
99+
ASSERT(id_and_previous.id == id, "Received messages with different ids ("
100+
<< id << " and " << id_and_previous.id
101+
<< ") but the same previous id ("
102+
<< id_and_previous.previous << ").");
91103
ASSERT(not tuples::get<Optional<Tag>>(tuple).has_value(),
92-
"Received duplicate messages at id " << id << " and previous id "
93-
<< id_and_previous.previous << ".");
104+
"Received duplicate messages at id "
105+
<< id << " and previous id " << id_and_previous.previous << ".");
94106
tuples::get<Optional<Tag>>(tuple) = std::move(message);
95107
}
96108

109+
template <typename Id, typename... QueueTags>
110+
template <typename Tag1, typename Tag2, typename... Tags>
111+
void LinkedMessageQueue<Id, tmpl::list<QueueTags...>>::insert(
112+
const LinkedMessageId<Id>& id_and_previous, typename Tag1::type message_1,
113+
typename Tag2::type message_2, typename Tags::type... messages) {
114+
static_assert(
115+
tmpl::size<
116+
tmpl::remove_duplicates<tmpl::list<Tag1, Tag2, Tags...>>>::value ==
117+
sizeof...(Tags) + 2,
118+
"Must have unique tags in LinkedMessageQueue insert.");
119+
insert<Tag1>(id_and_previous, std::move(message_1));
120+
insert<Tag2>(id_and_previous, std::move(message_2));
121+
122+
[[maybe_unused]] const auto insert_pack =
123+
[this, &id_and_previous](const auto& tag_v, auto message) {
124+
(void)this;
125+
using tag = std::decay_t<decltype(tag_v)>;
126+
insert<tag>(id_and_previous, std::move(message));
127+
};
128+
129+
EXPAND_PACK_LEFT_TO_RIGHT(insert_pack(Tags{}, std::move(messages)));
130+
}
131+
97132
template <typename Id, typename... QueueTags>
98133
std::optional<Id>
99134
LinkedMessageQueue<Id, tmpl::list<QueueTags...>>::next_ready_id() const {

src/ParallelAlgorithms/Actions/UpdateMessageQueue.hpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ namespace Actions {
3434
/// `Queue1` and `Queue2` with ID type `int` is:
3535
///
3636
/// \snippet Test_UpdateMessageQueue.cpp Processor::apply
37-
template <typename QueueTag, typename LinkedMessageQueueTag, typename Processor>
37+
template <typename LinkedMessageQueueTag, typename Processor,
38+
typename... QueueTags>
3839
struct UpdateMessageQueue {
3940
template <typename ParallelComponent, typename DbTags, typename Metavariables,
4041
typename ArrayIndex>
@@ -43,16 +44,18 @@ struct UpdateMessageQueue {
4344
const ArrayIndex& array_index,
4445
const LinkedMessageId<typename LinkedMessageQueueTag::type::IdType>&
4546
id_and_previous,
46-
typename QueueTag::type message) {
47+
typename QueueTags::type... messages) {
4748
if (not domain::functions_of_time_are_ready_simple_action_callback<
4849
domain::Tags::FunctionsOfTime, UpdateMessageQueue>(
4950
cache, array_index, std::add_pointer_t<ParallelComponent>{nullptr},
50-
id_and_previous.id, std::nullopt, id_and_previous, message)) {
51+
id_and_previous.id, std::nullopt, id_and_previous,
52+
std::move(messages)...)) {
5153
return;
5254
}
5355
auto& queue =
5456
db::get_mutable_reference<LinkedMessageQueueTag>(make_not_null(&box));
55-
queue.template insert<QueueTag>(id_and_previous, std::move(message));
57+
queue.template insert<QueueTags...>(id_and_previous,
58+
std::move(messages)...);
5659
while (auto id = queue.next_ready_id()) {
5760
Processor::apply(make_not_null(&box), cache, array_index, *id,
5861
queue.extract());

tests/Unit/ControlSystem/Systems/Test_Size.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ using all_tags = measurement_queue::type::queue_tags_list;
5555

5656
size_t message_queue_call_count = 0;
5757

58-
template <typename QueueTag, typename LinkedMessageQueueTag, typename Processor>
58+
template <typename LinkedMessageQueueTag, typename Processor,
59+
typename... QueueTags>
5960
struct MockUpdateMessageQueue {
6061
template <typename ParallelComponent, typename DbTags, typename Metavariables,
6162
typename ArrayIndex>
@@ -65,21 +66,21 @@ struct MockUpdateMessageQueue {
6566
const ArrayIndex& /*array_index*/,
6667
const LinkedMessageId<typename LinkedMessageQueueTag::type::IdType>&
6768
/*id_and_previous*/,
68-
typename QueueTag::type /*message*/) {
69+
typename QueueTags::type... /*message*/) {
6970
++message_queue_call_count;
7071
}
7172
};
7273

7374
// The Nvidia compiler crashes if we define these lists inside the MockComponent
7475
// struct.
7576
using replace_these_simple_actions_mock_component = tmpl::transform<
76-
all_tags, tmpl::bind<::Actions::UpdateMessageQueue, tmpl::_1,
77-
tmpl::pin<measurement_queue>,
78-
tmpl::pin<control_system::UpdateControlSystem<size>>>>;
77+
all_tags,
78+
tmpl::bind<::Actions::UpdateMessageQueue, tmpl::pin<measurement_queue>,
79+
tmpl::pin<control_system::UpdateControlSystem<size>>, tmpl::_1>>;
7980
using with_these_simple_actions_mock_component = tmpl::transform<
8081
all_tags,
81-
tmpl::bind<MockUpdateMessageQueue, tmpl::_1, tmpl::pin<measurement_queue>,
82-
tmpl::pin<control_system::UpdateControlSystem<size>>>>;
82+
tmpl::bind<MockUpdateMessageQueue, tmpl::pin<measurement_queue>,
83+
tmpl::pin<control_system::UpdateControlSystem<size>>, tmpl::_1>>;
8384

8485
template <typename Metavariables>
8586
struct MockComponent

tests/Unit/ControlSystem/Test_RunCallbacks.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,10 @@ struct System : tt::ConformsTo<control_system::protocols::ControlSystem> {
6464
auto& control_system_proxy = Parallel::get_parallel_component<
6565
ControlComponent<Metavariables, System>>(cache);
6666
Parallel::simple_action<::Actions::UpdateMessageQueue<
67-
SubmeasurementQueueTag, MeasurementQueue,
68-
control_system::TestHelpers::SomeControlSystemUpdater>>(
69-
control_system_proxy, measurement_id, measurement_result);
67+
MeasurementQueue,
68+
control_system::TestHelpers::SomeControlSystemUpdater,
69+
SubmeasurementQueueTag>>(control_system_proxy, measurement_id,
70+
measurement_result);
7071
}
7172
};
7273
};

tests/Unit/DataStructures/Test_LinkedMessageQueue.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,8 @@ void test_queue() {
2727
LinkedMessageQueue<int, tmpl::list<MyQueue<Label1>, MyQueue<Label2>>> queue{};
2828
CHECK(not queue.next_ready_id().has_value());
2929

30-
queue.insert<MyQueue<Label1>>({1, {}}, std::make_unique<double>(1.1));
31-
CHECK(not queue.next_ready_id().has_value());
32-
queue.insert<MyQueue<Label2>>({1, {}}, std::make_unique<double>(-1.1));
30+
queue.insert<MyQueue<Label1>, MyQueue<Label2>>(
31+
{1, {}}, std::make_unique<double>(1.1), std::make_unique<double>(-1.1));
3332

3433
CHECK(queue.next_ready_id() == std::optional{1});
3534
{

tests/Unit/Helpers/ControlSystem/Examples.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -170,9 +170,9 @@ struct ExampleControlSystem
170170
auto& control_system_proxy = Parallel::get_parallel_component<
171171
ControlComponent<Metavariables, ExampleControlSystem>>(cache);
172172
Parallel::simple_action<::Actions::UpdateMessageQueue<
173-
ExampleSubmeasurementQueueTag, MeasurementQueue,
174-
SomeControlSystemUpdater>>(control_system_proxy, measurement_id,
175-
measurement_result);
173+
MeasurementQueue, SomeControlSystemUpdater,
174+
ExampleSubmeasurementQueueTag>>(control_system_proxy, measurement_id,
175+
measurement_result);
176176
}
177177
};
178178
};

tests/Unit/ParallelAlgorithms/Actions/Test_UpdateMessageQueue.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,8 @@ SPECTRE_TEST_CASE("Unit.Actions.UpdateMessageQueue", "[Unit][Actions]") {
118118
const LinkedMessageId<int>& id,
119119
auto data) -> decltype(auto) {
120120
ActionTesting::simple_action<
121-
component, Actions::UpdateMessageQueue<
122-
decltype(queue_v), LinkedMessageQueueTag, Processor>>(
121+
component, Actions::UpdateMessageQueue<LinkedMessageQueueTag, Processor,
122+
decltype(queue_v)>>(
123123
make_not_null(&runner), 0, id, std::move(data));
124124
return db::mutate<ProcessorCalls>(
125125
[](const gsl::not_null<ProcessorCalls::type*> calls) {
@@ -131,9 +131,20 @@ SPECTRE_TEST_CASE("Unit.Actions.UpdateMessageQueue", "[Unit][Actions]") {
131131
&ActionTesting::get_databox<component>(make_not_null(&runner), 0)));
132132
};
133133

134-
CHECK(processed_by_call(Queue1{}, {0, {}}, 1.23).empty());
135134
{
136-
const auto processed = processed_by_call(Queue2{}, {0, {}}, 2.34);
135+
// Test two tags at once
136+
ActionTesting::simple_action<
137+
component, Actions::UpdateMessageQueue<LinkedMessageQueueTag, Processor,
138+
Queue1, Queue2>>(
139+
make_not_null(&runner), 0, LinkedMessageId<int>{0, {}}, 1.23, 2.34);
140+
const auto processed = db::mutate<ProcessorCalls>(
141+
[](const gsl::not_null<ProcessorCalls::type*> calls) {
142+
auto ret = std::move(*calls);
143+
calls->clear();
144+
return ret;
145+
},
146+
make_not_null(
147+
&ActionTesting::get_databox<component>(make_not_null(&runner), 0)));
137148
CHECK(processed.size() == 1);
138149

139150
CHECK(processed[0].first == 0);

0 commit comments

Comments
 (0)