11#pragma once
22
3+ #include < faabric/mpi/MpiMessage.h>
34#include < faabric/mpi/MpiMessageBuffer.h>
45#include < faabric/mpi/mpi.h>
5- #include < faabric/mpi/mpi.pb.h>
66#include < faabric/proto/faabric.pb.h>
77#include < faabric/scheduler/InMemoryMessageQueue.h>
88#include < faabric/transport/PointToPointBroker.h>
@@ -26,10 +26,9 @@ namespace faabric::mpi {
2626// -----------------------------------
2727// MPITOPTP - mocking at the MPI level won't be needed when using the PTP broker
2828// as the broker already has mocking capabilities
29- std::vector<std::shared_ptr<MPIMessage> > getMpiMockedMessages (int sendRank);
29+ std::vector<MpiMessage > getMpiMockedMessages (int sendRank);
3030
31- typedef faabric::util::FixedCapacityQueue<std::shared_ptr<MPIMessage>>
32- InMemoryMpiQueue;
31+ typedef faabric::util::FixedCapacityQueue<MpiMessage> InMemoryMpiQueue;
3332
3433class MpiWorld
3534{
@@ -73,36 +72,36 @@ class MpiWorld
7372 const uint8_t * buffer,
7473 faabric_datatype_t * dataType,
7574 int count,
76- MPIMessage::MPIMessageType messageType = MPIMessage ::NORMAL);
75+ MpiMessageType messageType = MpiMessageType ::NORMAL);
7776
7877 int isend (int sendRank,
7978 int recvRank,
8079 const uint8_t * buffer,
8180 faabric_datatype_t * dataType,
8281 int count,
83- MPIMessage::MPIMessageType messageType = MPIMessage ::NORMAL);
82+ MpiMessageType messageType = MpiMessageType ::NORMAL);
8483
8584 void broadcast (int rootRank,
8685 int thisRank,
8786 uint8_t * buffer,
8887 faabric_datatype_t * dataType,
8988 int count,
90- MPIMessage::MPIMessageType messageType = MPIMessage ::NORMAL);
89+ MpiMessageType messageType = MpiMessageType ::NORMAL);
9190
9291 void recv (int sendRank,
9392 int recvRank,
9493 uint8_t * buffer,
9594 faabric_datatype_t * dataType,
9695 int count,
9796 MPI_Status* status,
98- MPIMessage::MPIMessageType messageType = MPIMessage ::NORMAL);
97+ MpiMessageType messageType = MpiMessageType ::NORMAL);
9998
10099 int irecv (int sendRank,
101100 int recvRank,
102101 uint8_t * buffer,
103102 faabric_datatype_t * dataType,
104103 int count,
105- MPIMessage::MPIMessageType messageType = MPIMessage ::NORMAL);
104+ MpiMessageType messageType = MpiMessageType ::NORMAL);
106105
107106 void awaitAsyncRequest (int requestId);
108107
@@ -240,29 +239,36 @@ class MpiWorld
240239 void sendRemoteMpiMessage (std::string dstHost,
241240 int sendRank,
242241 int recvRank,
243- const std::shared_ptr<MPIMessage> & msg);
242+ const MpiMessage & msg);
244243
245- std::shared_ptr<MPIMessage> recvRemoteMpiMessage (int sendRank,
246- int recvRank);
244+ MpiMessage recvRemoteMpiMessage (int sendRank, int recvRank);
247245
248246 // Support for asyncrhonous communications
249247 std::shared_ptr<MpiMessageBuffer> getUnackedMessageBuffer (int sendRank,
250248 int recvRank);
251249
252- std::shared_ptr<MPIMessage> recvBatchReturnLast (int sendRank,
253- int recvRank,
254- int batchSize = 0 );
250+ MpiMessage recvBatchReturnLast (int sendRank,
251+ int recvRank,
252+ int batchSize = 0 );
255253
256254 /* Helper methods */
257255
258256 void checkRanksRange (int sendRank, int recvRank);
259257
260258 // Abstraction of the bulk of the recv work, shared among various functions
261- void doRecv (std::shared_ptr<MPIMessage> & m,
259+ void doRecv (const MpiMessage & m,
262260 uint8_t * buffer,
263261 faabric_datatype_t * dataType,
264262 int count,
265263 MPI_Status* status,
266- MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL);
264+ MpiMessageType messageType = MpiMessageType::NORMAL);
265+
266+ // Abstraction of the bulk of the recv work, shared among various functions
267+ void doRecv (std::unique_ptr<MpiMessage> m,
268+ uint8_t * buffer,
269+ faabric_datatype_t * dataType,
270+ int count,
271+ MPI_Status* status,
272+ MpiMessageType messageType = MpiMessageType::NORMAL);
267273};
268274}
0 commit comments