Skip to content

Commit 5b7667e

Browse files
committed
mpi: #379
1 parent d9810b0 commit 5b7667e

File tree

13 files changed

+364
-207
lines changed

13 files changed

+364
-207
lines changed

include/faabric/mpi/MpiMessage.h

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#pragma once
2+
3+
#include <cstdint>
4+
#include <vector>
5+
6+
namespace faabric::mpi {
7+
8+
enum MpiMessageType : int32_t
9+
{
10+
NORMAL = 0,
11+
BARRIER_JOIN = 1,
12+
BARRIER_DONE = 2,
13+
SCATTER = 3,
14+
GATHER = 4,
15+
ALLGATHER = 5,
16+
REDUCE = 6,
17+
SCAN = 7,
18+
ALLREDUCE = 8,
19+
ALLTOALL = 9,
20+
SENDRECV = 10,
21+
BROADCAST = 11,
22+
};
23+
24+
struct MpiMessage
25+
{
26+
int32_t id;
27+
int32_t worldId;
28+
int32_t sendRank;
29+
int32_t recvRank;
30+
int32_t typeSize;
31+
int32_t count;
32+
MpiMessageType messageType;
33+
void* buffer;
34+
};
35+
36+
inline size_t payloadSize(const MpiMessage& msg)
37+
{
38+
return msg.typeSize * msg.count;
39+
}
40+
41+
inline size_t msgSize(const MpiMessage& msg)
42+
{
43+
return sizeof(MpiMessage) + payloadSize(msg);
44+
}
45+
46+
void serializeMpiMsg(std::vector<uint8_t>& buffer, const MpiMessage& msg);
47+
48+
void parseMpiMsg(const std::vector<uint8_t>& bytes, MpiMessage* msg);
49+
}

include/faabric/mpi/MpiMessageBuffer.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
#include <faabric/mpi/MpiMessage.h>
12
#include <faabric/mpi/mpi.h>
2-
#include <faabric/mpi/mpi.pb.h>
33

44
#include <iterator>
55
#include <list>
6+
#include <memory>
67

78
namespace faabric::mpi {
89
/* The MPI message buffer (MMB) keeps track of the asyncrhonous
@@ -25,17 +26,20 @@ class MpiMessageBuffer
2526
{
2627
public:
2728
int requestId = -1;
28-
std::shared_ptr<MPIMessage> msg = nullptr;
29+
std::shared_ptr<MpiMessage> msg = nullptr;
2930
int sendRank = -1;
3031
int recvRank = -1;
3132
uint8_t* buffer = nullptr;
3233
faabric_datatype_t* dataType = nullptr;
3334
int count = -1;
34-
MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL;
35+
MpiMessageType messageType = MpiMessageType::NORMAL;
3536

3637
bool isAcknowledged() { return msg != nullptr; }
3738

38-
void acknowledge(std::shared_ptr<MPIMessage> msgIn) { msg = msgIn; }
39+
void acknowledge(const MpiMessage& msgIn)
40+
{
41+
msg = std::make_shared<MpiMessage>(msgIn);
42+
}
3943
};
4044

4145
/* Interface to query the buffer size */

include/faabric/mpi/MpiWorld.h

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#pragma once
22

3+
#include <faabric/mpi/MpiMessage.h>
34
#include <faabric/mpi/MpiMessageBuffer.h>
45
#include <faabric/mpi/mpi.h>
5-
#include <faabric/mpi/mpi.pb.h>
66
#include <faabric/proto/faabric.pb.h>
77
#include <faabric/scheduler/InMemoryMessageQueue.h>
88
#include <faabric/transport/PointToPointBroker.h>
@@ -26,10 +26,9 @@ namespace faabric::mpi {
2626
// -----------------------------------
2727
// MPITOPTP - mocking at the MPI level won't be needed when using the PTP broker
2828
// as the broker already has mocking capabilities
29-
std::vector<std::shared_ptr<MPIMessage>> getMpiMockedMessages(int sendRank);
29+
std::vector<MpiMessage> getMpiMockedMessages(int sendRank);
3030

31-
typedef faabric::util::FixedCapacityQueue<std::shared_ptr<MPIMessage>>
32-
InMemoryMpiQueue;
31+
typedef faabric::util::FixedCapacityQueue<MpiMessage> InMemoryMpiQueue;
3332

3433
class MpiWorld
3534
{
@@ -73,36 +72,36 @@ class MpiWorld
7372
const uint8_t* buffer,
7473
faabric_datatype_t* dataType,
7574
int count,
76-
MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL);
75+
MpiMessageType messageType = MpiMessageType::NORMAL);
7776

7877
int isend(int sendRank,
7978
int recvRank,
8079
const uint8_t* buffer,
8180
faabric_datatype_t* dataType,
8281
int count,
83-
MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL);
82+
MpiMessageType messageType = MpiMessageType::NORMAL);
8483

8584
void broadcast(int rootRank,
8685
int thisRank,
8786
uint8_t* buffer,
8887
faabric_datatype_t* dataType,
8988
int count,
90-
MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL);
89+
MpiMessageType messageType = MpiMessageType::NORMAL);
9190

9291
void recv(int sendRank,
9392
int recvRank,
9493
uint8_t* buffer,
9594
faabric_datatype_t* dataType,
9695
int count,
9796
MPI_Status* status,
98-
MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL);
97+
MpiMessageType messageType = MpiMessageType::NORMAL);
9998

10099
int irecv(int sendRank,
101100
int recvRank,
102101
uint8_t* buffer,
103102
faabric_datatype_t* dataType,
104103
int count,
105-
MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL);
104+
MpiMessageType messageType = MpiMessageType::NORMAL);
106105

107106
void awaitAsyncRequest(int requestId);
108107

@@ -240,29 +239,36 @@ class MpiWorld
240239
void sendRemoteMpiMessage(std::string dstHost,
241240
int sendRank,
242241
int recvRank,
243-
const std::shared_ptr<MPIMessage>& msg);
242+
const MpiMessage& msg);
244243

245-
std::shared_ptr<MPIMessage> recvRemoteMpiMessage(int sendRank,
246-
int recvRank);
244+
MpiMessage recvRemoteMpiMessage(int sendRank, int recvRank);
247245

248246
// Support for asyncrhonous communications
249247
std::shared_ptr<MpiMessageBuffer> getUnackedMessageBuffer(int sendRank,
250248
int recvRank);
251249

252-
std::shared_ptr<MPIMessage> recvBatchReturnLast(int sendRank,
253-
int recvRank,
254-
int batchSize = 0);
250+
MpiMessage recvBatchReturnLast(int sendRank,
251+
int recvRank,
252+
int batchSize = 0);
255253

256254
/* Helper methods */
257255

258256
void checkRanksRange(int sendRank, int recvRank);
259257

260258
// Abstraction of the bulk of the recv work, shared among various functions
261-
void doRecv(std::shared_ptr<MPIMessage>& m,
259+
void doRecv(const MpiMessage& m,
262260
uint8_t* buffer,
263261
faabric_datatype_t* dataType,
264262
int count,
265263
MPI_Status* status,
266-
MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL);
264+
MpiMessageType messageType = MpiMessageType::NORMAL);
265+
266+
// Abstraction of the bulk of the recv work, shared among various functions
267+
void doRecv(std::unique_ptr<MpiMessage> m,
268+
uint8_t* buffer,
269+
faabric_datatype_t* dataType,
270+
int count,
271+
MPI_Status* status,
272+
MpiMessageType messageType = MpiMessageType::NORMAL);
267273
};
268274
}

src/mpi/CMakeLists.txt

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -38,32 +38,12 @@ endif()
3838
# -----------------------------------------------
3939

4040
if (NOT ("${CMAKE_PROJECT_NAME}" STREQUAL "faabricmpi"))
41-
# Generate protobuf headers
42-
set(MPI_PB_HEADER_COPIED "${FAABRIC_INCLUDE_DIR}/faabric/mpi/mpi.pb.h")
43-
44-
protobuf_generate_cpp(MPI_PB_SRC MPI_PB_HEADER mpi.proto)
45-
46-
# Copy the generated headers into place
47-
add_custom_command(
48-
OUTPUT "${MPI_PB_HEADER_COPIED}"
49-
DEPENDS "${MPI_PB_HEADER}"
50-
COMMAND ${CMAKE_COMMAND}
51-
ARGS -E copy ${MPI_PB_HEADER} ${FAABRIC_INCLUDE_DIR}/faabric/mpi/
52-
)
53-
54-
add_custom_target(
55-
mpi_pbh_copied
56-
DEPENDS ${MPI_PB_HEADER_COPIED}
57-
)
58-
59-
add_dependencies(faabric_common_dependencies mpi_pbh_copied)
60-
6141
faabric_lib(mpi
6242
MpiContext.cpp
43+
MpiMessage.cpp
6344
MpiMessageBuffer.cpp
6445
MpiWorld.cpp
6546
MpiWorldRegistry.cpp
66-
${MPI_PB_SRC}
6747
)
6848

6949
target_link_libraries(mpi PRIVATE

src/mpi/MpiMessage.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#include <faabric/mpi/MpiMessage.h>
2+
#include <faabric/util/memory.h>
3+
4+
#include <cassert>
5+
#include <cstdint>
6+
#include <cstring>
7+
8+
namespace faabric::mpi {
9+
10+
void parseMpiMsg(const std::vector<uint8_t>& bytes, MpiMessage* msg)
11+
{
12+
assert(msg != nullptr);
13+
assert(bytes.size() >= sizeof(MpiMessage));
14+
std::memcpy(msg, bytes.data(), sizeof(MpiMessage));
15+
size_t thisPayloadSize = bytes.size() - sizeof(MpiMessage);
16+
assert(thisPayloadSize == payloadSize(*msg));
17+
18+
if (thisPayloadSize == 0) {
19+
msg->buffer = nullptr;
20+
return;
21+
}
22+
23+
msg->buffer = faabric::util::malloc(thisPayloadSize);
24+
std::memcpy(
25+
msg->buffer, bytes.data() + sizeof(MpiMessage), thisPayloadSize);
26+
}
27+
28+
void serializeMpiMsg(std::vector<uint8_t>& buffer, const MpiMessage& msg)
29+
{
30+
std::memcpy(buffer.data(), &msg, sizeof(MpiMessage));
31+
size_t payloadSz = payloadSize(msg);
32+
if (payloadSz > 0 && msg.buffer != nullptr) {
33+
std::memcpy(buffer.data() + sizeof(MpiMessage), msg.buffer, payloadSz);
34+
}
35+
}
36+
}

0 commit comments

Comments
 (0)