Skip to content

Commit

Permalink
[CELEBORN-1071] Support stage rerun for shuffle data lost
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
If shuffle data is lost and enabled throw fetch failures, triggered stage rerun.

### Why are the changes needed?
Rerun stage for shuffle lost scenarios.

### Does this PR introduce _any_ user-facing change?
NO.

### How was this patch tested?
GA.

Closes #2894 from FMX/b1701.

Authored-by: mingji <fengmingxiao.fmx@alibaba-inc.com>
Signed-off-by: Shuang <lvshuang.xjs@alibaba-inc.com>
(cherry picked from commit 42d5d42)
Signed-off-by: Shuang <lvshuang.xjs@alibaba-inc.com>
  • Loading branch information
FMX authored and RexXiong committed Nov 12, 2024
1 parent caa060b commit 321b4e3
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.collection.ExternalSorter

import org.apache.celeborn.client.ShuffleClient
import org.apache.celeborn.client.ShuffleClientImpl.ReduceFileGroups
import org.apache.celeborn.client.read.{CelebornInputStream, MetricsCallback}
import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.exception.{CelebornIOException, PartitionUnRetryAbleException}
Expand Down Expand Up @@ -104,8 +105,16 @@ class CelebornShuffleReader[K, C](
val localFetchEnabled = conf.enableReadLocalShuffleFile
val localHostAddress = Utils.localHostName(conf)
val shuffleKey = Utils.makeShuffleKey(handle.appUniqueId, shuffleId)
// startPartition is irrelevant
val fileGroups = shuffleClient.updateFileGroup(shuffleId, startPartition)
var fileGroups: ReduceFileGroups = null
try {
// startPartition is irrelevant
fileGroups = shuffleClient.updateFileGroup(shuffleId, startPartition)
} catch {
case ce @ (_: CelebornIOException | _: PartitionUnRetryAbleException) =>
handleFetchExceptions(shuffleId, 0, ce)
case e: Throwable => throw e
}

// host-port -> (TransportClient, PartitionLocation Array, PbOpenStreamList)
val workerRequestMap = new util.HashMap[
String,
Expand Down Expand Up @@ -245,18 +254,7 @@ class CelebornShuffleReader[K, C](
if (exceptionRef.get() != null) {
exceptionRef.get() match {
case ce @ (_: CelebornIOException | _: PartitionUnRetryAbleException) =>
if (throwsFetchFailure &&
shuffleClient.reportShuffleFetchFailure(handle.shuffleId, shuffleId)) {
throw new FetchFailedException(
null,
handle.shuffleId,
-1,
-1,
partitionId,
SparkUtils.FETCH_FAILURE_ERROR_MSG + handle.shuffleId + "/" + shuffleId,
ce)
} else
throw ce
handleFetchExceptions(handle.shuffleId, partitionId, ce)
case e => throw e
}
}
Expand Down Expand Up @@ -289,18 +287,7 @@ class CelebornShuffleReader[K, C](
iter
} catch {
case e @ (_: CelebornIOException | _: PartitionUnRetryAbleException) =>
if (throwsFetchFailure &&
shuffleClient.reportShuffleFetchFailure(handle.shuffleId, shuffleId)) {
throw new FetchFailedException(
null,
handle.shuffleId,
-1,
-1,
partitionId,
SparkUtils.FETCH_FAILURE_ERROR_MSG + handle.shuffleId + "/" + shuffleId,
e)
} else
throw e
handleFetchExceptions(handle.shuffleId, partitionId, e)
}
}

Expand Down Expand Up @@ -380,6 +367,22 @@ class CelebornShuffleReader[K, C](
}
}

private def handleFetchExceptions(shuffleId: Int, partitionId: Int, ce: Throwable) = {
if (throwsFetchFailure &&
shuffleClient.reportShuffleFetchFailure(handle.shuffleId, shuffleId)) {
logWarning(s"Handle fetch exceptions for ${shuffleId}-${partitionId}", ce)
throw new FetchFailedException(
null,
handle.shuffleId,
-1,
-1,
partitionId,
SparkUtils.FETCH_FAILURE_ERROR_MSG + handle.shuffleId + "/" + shuffleId,
ce)
} else
throw ce
}

def newSerializerInstance(dep: ShuffleDependency[K, _, C]): SerializerInstance = {
dep.serializer.newInstance()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ protected Compressor initialValue() {

private final ReviveManager reviveManager;

protected static class ReduceFileGroups {
public static class ReduceFileGroups {
public Map<Integer, Set<PartitionLocation>> partitionGroups;
public int[] mapAttempts;
public Set<Integer> partitionIds;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ class ReducePartitionCommitHandler(
private val inProcessStageEndShuffleSet = ConcurrentHashMap.newKeySet[Int]()
private val shuffleMapperAttempts = JavaUtils.newConcurrentHashMap[Int, Array[Int]]()
private val stageEndTimeout = conf.clientPushStageEndTimeout
private val mockShuffleLost = conf.testMockShuffleLost
private val mockShuffleLostShuffle = conf.testMockShuffleLostShuffle

private val rpcCacheSize = conf.clientRpcCacheSize
private val rpcCacheConcurrencyLevel = conf.clientRpcCacheConcurrencyLevel
Expand All @@ -94,7 +96,11 @@ class ReducePartitionCommitHandler(
}

override def isStageDataLost(shuffleId: Int): Boolean = {
dataLostShuffleSet.contains(shuffleId)
if (mockShuffleLost) {
mockShuffleLostShuffle == shuffleId
} else {
dataLostShuffleSet.contains(shuffleId)
}
}

override def isPartitionInProcess(shuffleId: Int, partitionId: Int): Boolean = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1245,6 +1245,8 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se
def testFetchFailure: Boolean = get(TEST_CLIENT_FETCH_FAILURE)
def testMockDestroySlotsFailure: Boolean = get(TEST_CLIENT_MOCK_DESTROY_SLOTS_FAILURE)
def testMockCommitFilesFailure: Boolean = get(TEST_CLIENT_MOCK_COMMIT_FILES_FAILURE)
def testMockShuffleLost: Boolean = get(TEST_CLIENT_MOCK_SHUFFLE_LOST)
def testMockShuffleLostShuffle: Int = get(TEST_CLIENT_MOCK_SHUFFLE_LOST_SHUFFLE)
def testPushPrimaryDataTimeout: Boolean = get(TEST_CLIENT_PUSH_PRIMARY_DATA_TIMEOUT)
def testPushReplicaDataTimeout: Boolean = get(TEST_WORKER_PUSH_REPLICA_DATA_TIMEOUT)
def testRetryRevive: Boolean = get(TEST_CLIENT_RETRY_REVIVE)
Expand Down Expand Up @@ -3716,6 +3718,26 @@ object CelebornConf extends Logging {
.booleanConf
.createWithDefault(false)

val TEST_CLIENT_MOCK_SHUFFLE_LOST: ConfigEntry[Boolean] =
buildConf("celeborn.test.client.mockShuffleLost")
.internal
.categories("test", "client")
.doc("Mock shuffle lost.")
.version("0.5.2")
.internal
.booleanConf
.createWithDefault(false)

val TEST_CLIENT_MOCK_SHUFFLE_LOST_SHUFFLE: ConfigEntry[Int] =
buildConf("celeborn.test.client.mockShuffleLostShuffle")
.internal
.categories("test", "client")
.doc("Mock shuffle lost for shuffle")
.version("0.5.2")
.internal
.intConf
.createWithDefault(0)

val CLIENT_PUSH_REPLICATE_ENABLED: ConfigEntry[Boolean] =
buildConf("celeborn.client.push.replicate.enabled")
.withAlternative("celeborn.push.replicate.enabled")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.celeborn.tests.spark

import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
import org.scalatest.BeforeAndAfterEach
import org.scalatest.funsuite.AnyFunSuite

import org.apache.celeborn.client.ShuffleClient
import org.apache.celeborn.common.protocol.ShuffleMode

class CelebornShuffleLostSuite extends AnyFunSuite
with SparkTestBase
with BeforeAndAfterEach {

override def beforeEach(): Unit = {
ShuffleClient.reset()
}

override def afterEach(): Unit = {
System.gc()
}

test("celeborn shuffle data lost - hash") {
val sparkConf = new SparkConf().setAppName("celeborn-demo").setMaster("local[2]")
val sparkSession = SparkSession.builder().config(sparkConf).getOrCreate()
val combineResult = combine(sparkSession)
val groupbyResult = groupBy(sparkSession)
val repartitionResult = repartition(sparkSession)
val sqlResult = runsql(sparkSession)

Thread.sleep(3000L)
sparkSession.stop()

val conf = updateSparkConf(sparkConf, ShuffleMode.HASH)
conf.set("spark.celeborn.client.spark.fetch.throwsFetchFailure", "true")
conf.set("spark.celeborn.test.client.mockShuffleLost", "true")

val celebornSparkSession = SparkSession.builder()
.config(conf)
.getOrCreate()
val celebornCombineResult = combine(celebornSparkSession)
val celebornGroupbyResult = groupBy(celebornSparkSession)
val celebornRepartitionResult = repartition(celebornSparkSession)
val celebornSqlResult = runsql(celebornSparkSession)

assert(combineResult.equals(celebornCombineResult))
assert(groupbyResult.equals(celebornGroupbyResult))
assert(repartitionResult.equals(celebornRepartitionResult))
assert(combineResult.equals(celebornCombineResult))
assert(sqlResult.equals(celebornSqlResult))

celebornSparkSession.stop()
}
}

0 comments on commit 321b4e3

Please sign in to comment.