From d6fe6f2e33dec7c4c6020f68649fb10893992925 Mon Sep 17 00:00:00 2001 From: jiang13021 Date: Fri, 19 Apr 2024 11:28:03 +0800 Subject: [PATCH] [CELEBORN-1391] Retry when MasterClient receiving a RpcTimeoutException ### What changes were proposed in this pull request? Retry when MasterClient receiving a RpcTimeoutException ### Why are the changes needed? When the MasterClient encounters an RpcTimeoutException, it may indicate that the current master is either busy or unavailable. In such cases, retrying with an alternative master endpoint could work. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Unit test: org.apache.celeborn.common.client.MasterClientSuiteJ#testOneMasterTimeoutInHA Closes #2466 from jiang13021/celeborn-1391. Authored-by: jiang13021 Signed-off-by: zky.zhoukeyong --- .../celeborn/common/client/MasterClient.java | 7 +-- .../common/client/MasterClientSuiteJ.java | 62 +++++++++++++++++++ 2 files changed, 64 insertions(+), 5 deletions(-) diff --git a/common/src/main/java/org/apache/celeborn/common/client/MasterClient.java b/common/src/main/java/org/apache/celeborn/common/client/MasterClient.java index 94e4201a99..24fcb69a2e 100644 --- a/common/src/main/java/org/apache/celeborn/common/client/MasterClient.java +++ b/common/src/main/java/org/apache/celeborn/common/client/MasterClient.java @@ -41,10 +41,7 @@ import org.apache.celeborn.common.protocol.message.ControlMessages.OneWayMessageResponse$; import org.apache.celeborn.common.protocol.message.MasterRequestMessage; import org.apache.celeborn.common.protocol.message.Message; -import org.apache.celeborn.common.rpc.RpcAddress; -import org.apache.celeborn.common.rpc.RpcEndpointRef; -import org.apache.celeborn.common.rpc.RpcEnv; -import org.apache.celeborn.common.rpc.RpcTimeout; +import org.apache.celeborn.common.rpc.*; import org.apache.celeborn.common.util.ThreadUtils; public class MasterClient { @@ -185,7 +182,7 @@ private boolean shouldRetry(@Nullable RpcEndpointRef oldRef, Throwable e) { LOG.warn("Master leader is not present currently, please check masters' status!"); } return true; - } else if (e.getCause() instanceof IOException) { + } else if (e.getCause() instanceof IOException || e instanceof RpcTimeoutException) { resetRpcEndpointRef(oldRef); return true; } diff --git a/common/src/test/java/org/apache/celeborn/common/client/MasterClientSuiteJ.java b/common/src/test/java/org/apache/celeborn/common/client/MasterClientSuiteJ.java index dacb32c414..5d02fd62d3 100644 --- a/common/src/test/java/org/apache/celeborn/common/client/MasterClientSuiteJ.java +++ b/common/src/test/java/org/apache/celeborn/common/client/MasterClientSuiteJ.java @@ -23,6 +23,7 @@ import java.io.IOException; import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Supplier; @@ -45,6 +46,7 @@ import org.apache.celeborn.common.rpc.RpcAddress; import org.apache.celeborn.common.rpc.RpcEndpointRef; import org.apache.celeborn.common.rpc.RpcEnv; +import org.apache.celeborn.common.rpc.RpcTimeoutException; public class MasterClientSuiteJ { private static final Logger LOG = LoggerFactory.getLogger(MasterClientSuiteJ.class); @@ -219,6 +221,11 @@ public void testOneMasterDownCausedByRuntimeExceptionInHA() { checkOneMasterDownInHA(new RuntimeException("test")); } + @Test + public void testOneMasterTimeoutInHA() { + checkOneMasterAskFailedInHA(new RpcTimeoutException("test", new TimeoutException("test"))); + } + private void checkOneMasterDownInHA(Exception causedByException) { final CelebornConf conf = prepareForCelebornConfWithHA(); @@ -268,6 +275,61 @@ private void checkOneMasterDownInHA(Exception causedByException) { assertEquals(mockResponse, response); } + private void checkOneMasterAskFailedInHA(Exception exception) { + final CelebornConf conf = prepareForCelebornConfWithHA(); + + final RpcEndpointRef master1 = Mockito.mock(RpcEndpointRef.class); + final RpcEndpointRef master2 = Mockito.mock(RpcEndpointRef.class); + final RpcEndpointRef master3 = Mockito.mock(RpcEndpointRef.class); + + // master leader switch to host2 + Mockito.doReturn( + Future$.MODULE$.failed(new MasterNotLeaderException("host1:9097", "host2:9097", null))) + .when(master1) + .ask(Mockito.any(), Mockito.any(), Mockito.any()); + + // Assume master2 get exception. + Mockito.doReturn(Future$.MODULE$.failed(exception)) + .when(master2) + .ask(Mockito.any(), Mockito.any(), Mockito.any()); + + Mockito.doReturn(Future$.MODULE$.successful(mockResponse)) + .when(master3) + .ask(Mockito.any(), Mockito.any(), Mockito.any()); + + Mockito.doAnswer( + (invocation) -> { + RpcAddress address = invocation.getArgument(0, RpcAddress.class); + switch (address.host()) { + case "host1": + return master1; + case "host2": + return master2; + case "host3": + return master3; + default: + fail( + "Should use master host1/host2/host3:" + masterPort + ", but use " + address); + } + return null; + }) + .when(rpcEnv) + .setupEndpointRef(Mockito.any(RpcAddress.class), Mockito.anyString()); + + MasterClient client = new MasterClient(rpcEnv, conf, false); + HeartbeatFromWorker message = Mockito.mock(HeartbeatFromWorker.class); + + HeartbeatFromWorkerResponse response = null; + try { + response = client.askSync(message, HeartbeatFromWorkerResponse.class); + } catch (Throwable t) { + LOG.error("It should be no exceptions when sending one-way message.", t); + fail("It should be no exceptions when sending one-way message."); + } + + assertEquals(mockResponse, response); + } + private void prepareForRpcEnvWithHA(final Supplier> supplier) { final RpcEndpointRef ref1 = Mockito.mock(RpcEndpointRef.class); final RpcEndpointRef ref2 = Mockito.mock(RpcEndpointRef.class);