22#include < faabric/mpi/MpiMessage.h>
33#include < faabric/mpi/MpiWorld.h>
44#include < faabric/planner/PlannerClient.h>
5+ #include < faabric/transport/PointToPointMessage.h>
56#include < faabric/transport/macros.h>
67#include < faabric/util/ExecGraph.h>
78#include < faabric/util/batch.h>
@@ -59,14 +60,16 @@ void MpiWorld::sendRemoteMpiMessage(std::string dstHost,
5960 serializeMpiMsg (serialisedBuffer, msg);
6061
6162 try {
62- broker.sendMessage (
63- thisRankMsg->groupid (),
64- sendRank,
65- recvRank,
66- reinterpret_cast <const uint8_t *>(serialisedBuffer.data ()),
67- serialisedBuffer.size (),
68- dstHost,
69- true );
63+ // It is safe to send a pointer to a stack-allocated object
64+ // because the broker will make an additional copy (and so will NNG!)
65+ faabric::transport::PointToPointMessage msg (
66+ { .groupId = thisRankMsg->groupid (),
67+ .sendIdx = sendRank,
68+ .recvIdx = recvRank,
69+ .dataSize = serialisedBuffer.size (),
70+ .dataPtr = (void *)serialisedBuffer.data () });
71+
72+ broker.sendMessage (msg, dstHost, true );
7073 } catch (std::runtime_error& e) {
7174 SPDLOG_ERROR (" {}:{}:{} Timed out with: MPI - send {} -> {}" ,
7275 thisRankMsg->appid (),
@@ -80,10 +83,12 @@ void MpiWorld::sendRemoteMpiMessage(std::string dstHost,
8083
8184MpiMessage MpiWorld::recvRemoteMpiMessage (int sendRank, int recvRank)
8285{
83- std::vector<uint8_t > msg;
86+ faabric::transport::PointToPointMessage msg (
87+ { .groupId = thisRankMsg->groupid (),
88+ .sendIdx = sendRank,
89+ .recvIdx = recvRank });
8490 try {
85- msg =
86- broker.recvMessage (thisRankMsg->groupid (), sendRank, recvRank, true );
91+ broker.recvMessage (msg, true );
8792 } catch (std::runtime_error& e) {
8893 SPDLOG_ERROR (" {}:{}:{} Timed out with: MPI - recv (remote) {} -> {}" ,
8994 thisRankMsg->appid (),
@@ -96,7 +101,8 @@ MpiMessage MpiWorld::recvRemoteMpiMessage(int sendRank, int recvRank)
96101
97102 // TODO(mpi-opt): make sure we minimze copies here
98103 MpiMessage parsedMsg;
99- parseMpiMsg (msg, &parsedMsg);
104+ std::vector<uint8_t > msgBytes ((uint8_t *) msg.dataPtr , (uint8_t *) msg.dataPtr + msg.dataSize );
105+ parseMpiMsg (msgBytes, &parsedMsg);
100106
101107 return parsedMsg;
102108}
0 commit comments