diff --git a/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/FlinkShuffleClientImplSuiteJ.java b/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/FlinkShuffleClientImplSuiteJ.java index 6931247474..60a843f4a9 100644 --- a/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/FlinkShuffleClientImplSuiteJ.java +++ b/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/FlinkShuffleClientImplSuiteJ.java @@ -111,7 +111,7 @@ public void testPushDataByteBufHardSplit() throws IOException { @Test public void testPushDataByteBufFail() throws IOException { ByteBuf byteBuf = Unpooled.wrappedBuffer(TEST_BUF1); - when(client.pushData(any(), anyLong(), any(), any())) + when(client.pushData(any(), anyLong(), any(), any(), any())) .thenAnswer( t -> { RpcResponseCallback rpcResponseCallback = t.getArgument(1, RpcResponseCallback.class); diff --git a/common/src/main/java/org/apache/celeborn/common/network/client/TransportClient.java b/common/src/main/java/org/apache/celeborn/common/network/client/TransportClient.java index fd64b6bd03..2c335b350e 100644 --- a/common/src/main/java/org/apache/celeborn/common/network/client/TransportClient.java +++ b/common/src/main/java/org/apache/celeborn/common/network/client/TransportClient.java @@ -216,6 +216,23 @@ public ChannelFuture pushData( long pushDataTimeout, RpcResponseCallback callback, Runnable rpcSendoutCallback) { + Runnable rpcFailureCallback = + () -> { + try { + pushData.body().release(); + } catch (Throwable e) { + logger.error("Error release buffer for PUSH_DATA request {}", pushData.requestId, e); + } + }; + return pushData(pushData, pushDataTimeout, callback, rpcSendoutCallback, rpcFailureCallback); + } + + public ChannelFuture pushData( + PushData pushData, + long pushDataTimeout, + RpcResponseCallback callback, + Runnable rpcSendoutCallback, + Runnable rpcFailureCallback) { if (logger.isTraceEnabled()) { logger.trace("Pushing data to {}", NettyUtils.getRemoteAddress(channel)); } @@ -225,7 +242,8 @@ public ChannelFuture pushData( PushRequestInfo info = new PushRequestInfo(dueTime, callback); handler.addPushRequest(requestId, info); pushData.requestId = requestId; - PushChannelListener listener = new PushChannelListener(requestId, rpcSendoutCallback); + PushChannelListener listener = + new PushChannelListener(requestId, rpcSendoutCallback, rpcFailureCallback); ChannelFuture channelFuture = channel.writeAndFlush(pushData).addListener(listener); info.setChannelFuture(channelFuture); return channelFuture; @@ -233,6 +251,26 @@ public ChannelFuture pushData( public ChannelFuture pushMergedData( PushMergedData pushMergedData, long pushDataTimeout, RpcResponseCallback callback) { + Runnable rpcFailureCallback = + () -> { + try { + pushMergedData.body().release(); + } catch (Throwable e) { + logger.error( + "Error release buffer for PUSH_MERGED_DATA request {}", + pushMergedData.requestId, + e); + } + }; + return pushMergedData(pushMergedData, pushDataTimeout, callback, null, rpcFailureCallback); + } + + public ChannelFuture pushMergedData( + PushMergedData pushMergedData, + long pushDataTimeout, + RpcResponseCallback callback, + Runnable rpcSendoutCallback, + Runnable rpcFailureCallback) { if (logger.isTraceEnabled()) { logger.trace("Pushing merged data to {}", NettyUtils.getRemoteAddress(channel)); } @@ -243,7 +281,8 @@ public ChannelFuture pushMergedData( handler.addPushRequest(requestId, info); pushMergedData.requestId = requestId; - PushChannelListener listener = new PushChannelListener(requestId); + PushChannelListener listener = + new PushChannelListener(requestId, rpcSendoutCallback, rpcFailureCallback); ChannelFuture channelFuture = channel.writeAndFlush(pushMergedData).addListener(listener); info.setChannelFuture(channelFuture); return channelFuture; @@ -417,14 +456,18 @@ private class PushChannelListener extends StdChannelListener { final long pushRequestId; Runnable rpcSendOutCallback; + Runnable rpcFailureCallback; + PushChannelListener(long pushRequestId) { - this(pushRequestId, null); + this(pushRequestId, null, null); } - PushChannelListener(long pushRequestId, Runnable rpcSendOutCallback) { + PushChannelListener( + long pushRequestId, Runnable rpcSendOutCallback, Runnable rpcFailureCallback) { super("PUSH " + pushRequestId); this.pushRequestId = pushRequestId; this.rpcSendOutCallback = rpcSendOutCallback; + this.rpcFailureCallback = rpcFailureCallback; } @Override @@ -438,6 +481,9 @@ public void operationComplete(Future future) throws Exception { @Override protected void handleFailure(String errorMsg, Throwable cause) { handler.handlePushFailure(pushRequestId, errorMsg, cause); + if (rpcFailureCallback != null) { + rpcFailureCallback.run(); + } } } }