Skip to content

Commit

Permalink
MPI GPU
Browse files Browse the repository at this point in the history
  • Loading branch information
OsKnoth committed Oct 15, 2023
1 parent e632f4b commit 2bc6f01
Showing 1 changed file with 47 additions and 2 deletions.
49 changes: 47 additions & 2 deletions src/Parallel/Exchange.jl
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,51 @@ function ExchangeData3DSend(U,Exchange)
@views MPI.Isend(SendBuffer3[iP][1:nz,:,1:nT], iP - 1, tag, MPI.COMM_WORLD, sreq[i])
end
end
function ExchangeData3DSendGPU(U,Exchange)

IndSendBuffer = Exchange.IndSendBuffer
IndRecvBuffer = Exchange.IndRecvBuffer
NeiProc = Exchange.NeiProc
Proc = Exchange.Proc
ProcNumber = Exchange.ProcNumber
nz = size(U,1)
nT = size(U,3)
RecvBuffer3 = Exchange.RecvBuffer3
SendBuffer3 = Exchange.SendBuffer3
rreq = Exchange.rreq
sreq = Exchange.sreq

group = (Nz,5,1)
KExchangeData3DSendKernel! = ExchangeData3DSendKernel!(group)
@inbounds for iP in NeiProc
ndrange = (Nz,length(IndSendBuffer[iP]),nT)
KExchangeData3DSendKernel!(U,SendBuffer3[iP],IndSendBuffer[iP],ndrange)
end

i = 0
@inbounds for iP in NeiProc
tag = Proc + ProcNumber*iP
i += 1
@views MPI.Irecv!(RecvBuffer3[iP][1:nz,:,1:nT], iP - 1, tag, MPI.COMM_WORLD, rreq[i])
end
i = 0
@inbounds for iP in NeiProc
tag = iP + ProcNumber*Proc
i += 1
@views MPI.Isend(SendBuffer3[iP][1:nz,:,1:nT], iP - 1, tag, MPI.COMM_WORLD, sreq[i])
end
end

@kernel function ExchangeData3DSendKernel!(U,SendBuffer,IndSendBuffer)

Iz,I,IT = @index(Global, NTuple)
NumInd = @uniform @ndrange()[2]
NT = @uniform @ndrange()[2]
if I <= NumInd && IT <= NT
@inbounds Ind = IndSendBuffer[I]
@inbounds SendBuffer[Iz,I,IT] = U[Iz,Ind,IT]
end
end

function ExchangeData3DRecv!(U,p,Exchange)

Expand Down Expand Up @@ -829,14 +874,14 @@ function ExchangeData3DRecvGPU!(U,Exchange)
end
end

@kernel function ExchangeData3DRecvKernel!(U,Exchange)
@kernel function ExchangeData3DRecvKernel!(U,RecvBuffer,IndRecvBuffer)

Iz,I,IT = @index(Global, NTuple)
NumInd = @uniform @ndrange()[2]
NT = @uniform @ndrange()[2]
if I <= NumInd && IT <= NT
@inbounds Ind = IndRecvBuffer[I]
U[Iz,Ind,IT] += RecvBuffer3[iP][Iz,I,IT]
@inbounds @atomic U[Iz,Ind,IT] += RecvBuffer[Iz,I,IT]
end
end

Expand Down

0 comments on commit 2bc6f01

Please sign in to comment.