Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CELEBORN-1626] group mapTask #2771

Draft
wants to merge 17 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import java.io.IOException
import java.util.concurrent.{ThreadPoolExecutor, TimeUnit}
import java.util.concurrent.atomic.AtomicReference

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.{Aggregator, InterruptibleIterator, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.shuffle.{FetchFailedException, ShuffleReader}
Expand Down Expand Up @@ -88,8 +90,29 @@ class CelebornShuffleReader[K, C](
}
}

var partitionIdList = new ArrayBuffer[Int]()
if (!conf.groupMapTaskEnabled) {
partitionIdList = ArrayBuffer[Int]() ++ (startPartition until endPartition)
} else {
val numPartitions = dep.partitioner.numPartitions
val numMappers = handle.numMaps
val partitionGroupCnt =
if (conf.groupMapTaskEnabled)
math.ceil(numMappers.toDouble / conf.groupMapTaskGroupSize).toInt
else 1
val groupNumPartitions = numPartitions * partitionGroupCnt
(startPartition until endPartition).foreach { originalPartitionId =>
(0 until partitionGroupCnt).foreach { groupCnt =>
val tmpPartitionId = {
originalPartitionId + groupCnt * (groupNumPartitions / partitionGroupCnt)
}
partitionIdList += tmpPartitionId
}
}
}

val streams = JavaUtils.newConcurrentHashMap[Integer, CelebornInputStream]()
(startPartition until endPartition).map(partitionId => {
partitionIdList.foreach(partitionId => {
streamCreatorPool.submit(new Runnable {
override def run(): Unit = {
if (exceptionRef.get() == null) {
Expand All @@ -115,7 +138,7 @@ class CelebornShuffleReader[K, C](
})
})

val recordIter = (startPartition until endPartition).iterator.map(partitionId => {
val recordIter = partitionIdList.iterator.map(partitionId => {
if (handle.numMaps > 0) {
val startFetchWait = System.nanoTime()
var inputStream: CelebornInputStream = streams.get(partitionId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import java.util.concurrent.{ConcurrentHashMap, ThreadPoolExecutor, TimeUnit}
import java.util.concurrent.atomic.AtomicReference

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.{Aggregator, InterruptibleIterator, ShuffleDependency, TaskContext}
import org.apache.spark.celeborn.ExceptionMakerHelper
Expand Down Expand Up @@ -104,16 +105,36 @@ class CelebornShuffleReader[K, C](
val localFetchEnabled = conf.enableReadLocalShuffleFile
val localHostAddress = Utils.localHostName(conf)
val shuffleKey = Utils.makeShuffleKey(handle.appUniqueId, shuffleId)
// startPartition is irrelevant
// startPartition is irrelevant, for error log print
val fileGroups = shuffleClient.updateFileGroup(shuffleId, startPartition)
// host-port -> (TransportClient, PartitionLocation Array, PbOpenStreamList)
val workerRequestMap = new util.HashMap[
String,
(TransportClient, util.ArrayList[PartitionLocation], PbOpenStreamList.Builder)]()

var partCnt = 0
var groupPartitionIdList = new ArrayBuffer[Int]()
if (!conf.groupMapTaskEnabled) {
groupPartitionIdList = ArrayBuffer[Int]() ++ (startPartition until endPartition)
} else {
val numPartitions = dep.partitioner.numPartitions
val numMappers = handle.numMappers
val partitionGroupCnt =
if (conf.groupMapTaskEnabled)
math.ceil(numMappers.toDouble / conf.groupMapTaskGroupSize).toInt
else 1
val groupNumPartitions = numPartitions * partitionGroupCnt
(startPartition until endPartition).foreach { originalPartitionId =>
(0 until partitionGroupCnt).foreach { groupCnt =>
val tmpPartitionId = {
originalPartitionId + groupCnt * (groupNumPartitions / partitionGroupCnt)
}
groupPartitionIdList += tmpPartitionId
}
}
}

(startPartition until endPartition).foreach { partitionId =>
groupPartitionIdList.foreach { partitionId =>
if (fileGroups.partitionGroups.containsKey(partitionId)) {
fileGroups.partitionGroups.get(partitionId).asScala.foreach { location =>
partCnt += 1
Expand Down Expand Up @@ -227,20 +248,20 @@ class CelebornShuffleReader[K, C](
}

val inputStreamCreationWindow = conf.clientInputStreamCreationWindow
(startPartition until Math.min(
startPartition + inputStreamCreationWindow,
endPartition)).foreach(partitionId => {

(0 until Math.min(inputStreamCreationWindow, groupPartitionIdList.size)).foreach(listIndex => {
streamCreatorPool.submit(new Runnable {
override def run(): Unit = {
createInputStream(partitionId)
createInputStream(groupPartitionIdList(listIndex))
}
})
})

val recordIter = (startPartition until endPartition).iterator.map(partitionId => {
val recordIter = groupPartitionIdList.iterator.map(partitionId => {
if (handle.numMappers > 0) {
val startFetchWait = System.nanoTime()
var inputStream: CelebornInputStream = streams.get(partitionId)
// todo bug fix: inputStream keep null when revive happened
while (inputStream == null) {
if (exceptionRef.get() != null) {
exceptionRef.get() match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -965,11 +965,19 @@ public int pushOrMergeData(
return 0;
}

final PartitionLocation loc = map.get(partitionId);
int tmpgroupTaskPartitionId = partitionId;
if (conf.groupMapTaskEnabled()) {
int mapTaskGroupId = mapId / conf.groupMapTaskGroupSize();
tmpgroupTaskPartitionId = partitionId + numPartitions * mapTaskGroupId;
}
final int groupTaskPartitionId = tmpgroupTaskPartitionId;
PartitionLocation loc = map.get(groupTaskPartitionId);

if (loc == null) {
throw new CelebornIOException(
String.format(
"Partition location for shuffle %s partition %d is NULL!", shuffleId, partitionId));
"Partition location for shuffle %s partition %d groupPartition %d is NULL!",
shuffleId, partitionId, groupTaskPartitionId));
}

PushState pushState = getPushState(mapKey);
Expand Down Expand Up @@ -1017,21 +1025,28 @@ public void onSuccess(ByteBuffer response) {
.add(mapId);
}
logger.debug(
"Push data to {} success for shuffle {} map {} attempt {} partition {} batch {}.",
"Push data to {} success for shuffle {} map {} attempt {} partition {} groupPartition {} batch {}.",
loc.hostAndPushPort(),
shuffleId,
mapId,
attemptId,
partitionId,
groupTaskPartitionId,
nextBatchId);
}

@Override
public void onFailure(Throwable e) {
String errorMsg =
String.format(
"Push data to %s failed for shuffle %d map %d attempt %d partition %d batch %d.",
loc, shuffleId, mapId, attemptId, partitionId, nextBatchId);
"Push data to %s failed for shuffle %d map %d attempt %d partition %d groupPartition %d batch %d.",
loc,
shuffleId,
mapId,
attemptId,
partitionId,
groupTaskPartitionId,
nextBatchId);
pushState.exception.compareAndSet(null, new CelebornIOException(errorMsg, e));
}
};
Expand All @@ -1054,21 +1069,25 @@ public void onSuccess(ByteBuffer response) {
byte reason = response.get();
if (reason == StatusCode.SOFT_SPLIT.getValue()) {
logger.debug(
"Push data to {} soft split required for shuffle {} map {} attempt {} partition {} batch {}.",
"Push data to {} soft split required for shuffle {} map {} attempt {} partition {} groupPartition {} batch {}.",
latest.hostAndPushPort(),
shuffleId,
mapId,
attemptId,
partitionId,
groupTaskPartitionId,
nextBatchId);
if (!newerPartitionLocationExists(
reducePartitionMap.get(shuffleId), partitionId, latest.getEpoch(), false)) {
reducePartitionMap.get(shuffleId),
groupTaskPartitionId,
latest.getEpoch(),
false)) {
ReviveRequest reviveRequest =
new ReviveRequest(
shuffleId,
mapId,
attemptId,
partitionId,
groupTaskPartitionId,
latest.getEpoch(),
latest,
StatusCode.SOFT_SPLIT);
Expand All @@ -1079,19 +1098,20 @@ public void onSuccess(ByteBuffer response) {
callback.onSuccess(response);
} else if (reason == StatusCode.HARD_SPLIT.getValue()) {
logger.debug(
"Push data to {} hard split required for shuffle {} map {} attempt {} partition {} batch {}.",
"Push data to {} hard split required for shuffle {} map {} attempt {} partition {} groupPartition {} batch {}.",
latest.hostAndPushPort(),
shuffleId,
mapId,
attemptId,
partitionId,
groupTaskPartitionId,
nextBatchId);
ReviveRequest reviveRequest =
new ReviveRequest(
shuffleId,
mapId,
attemptId,
partitionId,
groupTaskPartitionId,
latest.getEpoch(),
latest,
StatusCode.HARD_SPLIT);
Expand All @@ -1114,24 +1134,26 @@ public void onSuccess(ByteBuffer response) {
dueTime));
} else if (reason == StatusCode.PUSH_DATA_SUCCESS_PRIMARY_CONGESTED.getValue()) {
logger.debug(
"Push data to {} primary congestion required for shuffle {} map {} attempt {} partition {} batch {}.",
"Push data to {} primary congestion required for shuffle {} map {} attempt {} partition {} groupPartition {} batch {}.",
latest.hostAndPushPort(),
shuffleId,
mapId,
attemptId,
partitionId,
groupTaskPartitionId,
nextBatchId);
pushState.onCongestControl(latest.hostAndPushPort());
pushState.removeBatch(nextBatchId, latest.hostAndPushPort());
callback.onSuccess(response);
} else if (reason == StatusCode.PUSH_DATA_SUCCESS_REPLICA_CONGESTED.getValue()) {
logger.debug(
"Push data to {} replica congestion required for shuffle {} map {} attempt {} partition {} batch {}.",
"Push data to {} replica congestion required for shuffle {} map {} attempt {} partition {} groupPartition {} batch {}.",
latest.hostAndPushPort(),
shuffleId,
mapId,
attemptId,
partitionId,
groupTaskPartitionId,
nextBatchId);
pushState.onCongestControl(latest.hostAndPushPort());
pushState.removeBatch(nextBatchId, latest.hostAndPushPort());
Expand Down Expand Up @@ -1166,12 +1188,13 @@ public void onFailure(Throwable e) {
}

logger.error(
"Push data to {} failed for shuffle {} map {} attempt {} partition {} batch {}, remain revive times {}.",
"Push data to {} failed for shuffle {} map {} attempt {} partition {} groupPartition {} batch {}, remain revive times {}.",
latest.hostAndPushPort(),
shuffleId,
mapId,
attemptId,
partitionId,
groupTaskPartitionId,
nextBatchId,
remainReviveTimes,
e);
Expand All @@ -1180,7 +1203,13 @@ public void onFailure(Throwable e) {
remainReviveTimes = remainReviveTimes - 1;
ReviveRequest reviveRequest =
new ReviveRequest(
shuffleId, mapId, attemptId, partitionId, latest.getEpoch(), latest, cause);
shuffleId,
mapId,
attemptId,
groupTaskPartitionId,
latest.getEpoch(),
latest,
cause);
reviveManager.addRequest(reviveRequest);
long dueTime =
System.currentTimeMillis()
Expand Down Expand Up @@ -1217,7 +1246,8 @@ public void onFailure(Throwable e) {
if (!testRetryRevive) {
assert dataClientFactory != null;
TransportClient client =
dataClientFactory.createClient(loc.getHost(), loc.getPushPort(), partitionId);
dataClientFactory.createClient(
loc.getHost(), loc.getPushPort(), groupTaskPartitionId);
client.pushData(pushData, pushDataTimeout, wrappedCallback);
} else {
wrappedCallback.onFailure(
Expand All @@ -1228,11 +1258,12 @@ public void onFailure(Throwable e) {
}
} catch (Exception e) {
logger.error(
"Exception raised while pushing data for shuffle {} map {} attempt {} partition {} batch {} location {}.",
"Exception raised while pushing data for shuffle {} map {} attempt {} partition {} groupPartition {} batch {} location {}.",
shuffleId,
mapId,
attemptId,
partitionId,
groupTaskPartitionId,
nextBatchId,
loc,
e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
private val slotsAssignMaxWorkers = conf.clientSlotAssignMaxWorkers
private val pushReplicateEnabled = conf.clientPushReplicateEnabled
private val pushRackAwareEnabled = conf.clientReserveSlotsRackAwareEnabled
private val groupMapTaskEnabled = conf.groupMapTaskEnabled
private val groupMapTaskGroupSize = conf.groupMapTaskGroupSize
private val partitionSplitThreshold = conf.shufflePartitionSplitThreshold
private val partitionSplitMode = conf.shufflePartitionSplitMode
// shuffle id -> partition type
Expand All @@ -98,7 +100,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
scala.collection.mutable.LinkedHashMap[String, (Int, Boolean)]]()
private val shuffleIdGenerator = new AtomicInteger(0)
// app shuffle id -> whether shuffle is determinate, rerun of a indeterminate shuffle gets different result
private val appShuffleDeterminateMap = JavaUtils.newConcurrentHashMap[Int, Boolean]();
private val appShuffleDeterminateMap = JavaUtils.newConcurrentHashMap[Int, Boolean]()

private val rpcCacheSize = conf.clientRpcCacheSize
private val rpcCacheConcurrencyLevel = conf.clientRpcCacheConcurrencyLevel
Expand Down Expand Up @@ -648,9 +650,16 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
}

// First, request to get allocated slots from Primary
val ids = new util.ArrayList[Integer](numPartitions)
(0 until numPartitions).foreach(idx => ids.add(Integer.valueOf(idx)))
val res = requestMasterRequestSlotsWithRetry(shuffleId, ids)
var numGroupTask = 1
var groupNumPartitions = numPartitions
if (partitionType.getValue.equals(PartitionType.REDUCE.getValue) && groupMapTaskEnabled) {
numGroupTask = math.ceil(numMappers.toDouble / groupMapTaskGroupSize).toInt
groupNumPartitions = numPartitions * numGroupTask
}

val ids = new util.ArrayList[Integer](groupNumPartitions)
(0 until groupNumPartitions).foreach(idx => ids.add(Integer.valueOf(idx)))
val res = requestMasterRequestSlotsWithRetry(shuffleId, ids, numGroupTask)

res.status match {
case StatusCode.REQUEST_FAILED =>
Expand Down Expand Up @@ -1624,7 +1633,8 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends

def requestMasterRequestSlotsWithRetry(
shuffleId: Int,
ids: util.ArrayList[Integer]): RequestSlotsResponse = {
ids: util.ArrayList[Integer],
numGroupTask: Int = 1): RequestSlotsResponse = {
val excludedWorkerSet =
if (excludedWorkersFilter) {
workerStatusTracker.excludedWorkers.asScala.keys.toSet
Expand All @@ -1646,7 +1656,8 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
availableStorageTypes,
excludedWorkerSet,
true,
clientTagsExpr)
clientTagsExpr,
numGroupTask)
val res = requestMasterRequestSlots(req)
if (res.status != StatusCode.SUCCESS) {
requestMasterRequestSlots(req)
Expand Down
1 change: 1 addition & 0 deletions common/src/main/proto/TransportMessages.proto
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ message PbRequestSlots {
repeated PbWorkerInfo excludedWorkerSet = 12;
bool packed = 13;
string tagsExpr = 14;
int32 numGroupTask = 15;
}

message PbSlotInfo {
Expand Down
Loading
Loading