Skip to content

Commit

Permalink
[CELEBORN-1636] Client supports dynamic update of Worker resources on…
Browse files Browse the repository at this point in the history
… the server

### What changes were proposed in this pull request?
Currently, the ChangePartitionManager retrieves workers from the LifeCycleManager's workerSnapshot. However, during the revival process in reallocateChangePartitionRequestSlotsFromCandidates, it does not account for newly added available workers resulting from elastic contraction and expansion. This PR addresses this issue by updating the candidate workers in the ChangePartitionManager to use the available workers reported in the heartbeat from the master.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
UT

Closes #2835 from zaynt4606/clbdev.

Authored-by: szt <zaynt4606@163.com>
Signed-off-by: Shuang <lvshuang.xjs@alibaba-inc.com>
  • Loading branch information
zaynt4606 authored and RexXiong committed Oct 28, 2024
1 parent 59029a0 commit 7685fa7
Show file tree
Hide file tree
Showing 16 changed files with 589 additions and 77 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class ApplicationHeartbeater(
// Use independent app heartbeat threads to avoid being blocked by other operations.
private val appHeartbeatIntervalMs = conf.appHeartbeatIntervalMs
private val applicationUnregisterEnabled = conf.applicationUnregisterEnabled
private val clientShuffleDynamicResourceEnabled = conf.clientShuffleDynamicResourceEnabled
private val appHeartbeatHandlerThread =
ThreadUtils.newDaemonSingleThreadScheduledExecutor(
"celeborn-client-lifecycle-manager-app-heartbeater")
Expand All @@ -69,6 +70,7 @@ class ApplicationHeartbeater(
tmpTotalWritten,
tmpTotalFileCount,
workerStatusTracker.getNeedCheckedWorkers().toList.asJava,
clientShuffleDynamicResourceEnabled,
ZERO_UUID,
true)
val response = requestHeartbeat(appHeartbeat)
Expand Down Expand Up @@ -129,6 +131,7 @@ class ApplicationHeartbeater(
List.empty.asJava,
List.empty.asJava,
List.empty.asJava,
List.empty.asJava,
List.empty.asJava)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@
package org.apache.celeborn.client

import java.util
import java.util.{Set => JSet}
import java.util.{function, Set => JSet}
import java.util.concurrent.{ConcurrentHashMap, ScheduledExecutorService, ScheduledFuture, TimeUnit}

import scala.collection.JavaConverters._

import org.apache.celeborn.client.LifecycleManager.ShuffleFailedWorkers
import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.internal.Logging
import org.apache.celeborn.common.meta.WorkerInfo
import org.apache.celeborn.common.meta.{ShufflePartitionLocationInfo, WorkerInfo}
import org.apache.celeborn.common.protocol.PartitionLocation
import org.apache.celeborn.common.protocol.message.ControlMessages.WorkerResource
import org.apache.celeborn.common.protocol.message.StatusCode
Expand All @@ -45,7 +46,8 @@ class ChangePartitionManager(

private val pushReplicateEnabled = conf.clientPushReplicateEnabled
// shuffleId -> (partitionId -> set of ChangePartition)
private val changePartitionRequests =
val changePartitionRequests
: ConcurrentHashMap[Int, ConcurrentHashMap[Integer, JSet[ChangePartitionRequest]]] =
JavaUtils.newConcurrentHashMap[Int, ConcurrentHashMap[Integer, JSet[ChangePartitionRequest]]]()

// shuffleId -> locks
Expand Down Expand Up @@ -74,6 +76,8 @@ class ChangePartitionManager(

private val testRetryRevive = conf.testRetryRevive

private val clientShuffleDynamicResourceEnabled = conf.clientShuffleDynamicResourceEnabled

def start(): Unit = {
batchHandleChangePartition = batchHandleChangePartitionSchedulerThread.map {
// noinspection ConvertExpressionToSAM
Expand Down Expand Up @@ -128,7 +132,8 @@ class ChangePartitionManager(
batchHandleChangePartitionSchedulerThread.foreach(ThreadUtils.shutdown(_))
}

private val rpcContextRegisterFunc =
val rpcContextRegisterFunc
: function.Function[Int, ConcurrentHashMap[Integer, JSet[ChangePartitionRequest]]] =
new util.function.Function[
Int,
ConcurrentHashMap[Integer, util.Set[ChangePartitionRequest]]]() {
Expand All @@ -148,6 +153,13 @@ class ChangePartitionManager(
}
}

private val updateWorkerSnapshotsFunc =
new util.function.Function[WorkerInfo, ShufflePartitionLocationInfo] {
override def apply(w: WorkerInfo): ShufflePartitionLocationInfo = {
new ShufflePartitionLocationInfo()
}
}

def handleRequestPartitionLocation(
context: RequestLocationCallContext,
shuffleId: Int,
Expand Down Expand Up @@ -186,7 +198,7 @@ class ChangePartitionManager(
partitionId,
StatusCode.SUCCESS,
Some(latestLoc),
lifecycleManager.workerStatusTracker.workerAvailable(oldPartition))
lifecycleManager.workerStatusTracker.workerAvailableByLocation(oldPartition))
logDebug(s"[handleRequestPartitionLocation]: For shuffle: $shuffleId," +
s" old partition: $partitionId-$oldEpoch, new partition: $latestLoc found, return it")
return
Expand Down Expand Up @@ -254,7 +266,7 @@ class ChangePartitionManager(
req.partitionId,
StatusCode.SUCCESS,
Option(newLocation),
lifecycleManager.workerStatusTracker.workerAvailable(req.oldPartition))))
lifecycleManager.workerStatusTracker.workerAvailableByLocation(req.oldPartition))))
}
}

Expand All @@ -274,18 +286,49 @@ class ChangePartitionManager(
req.partitionId,
status,
None,
lifecycleManager.workerStatusTracker.workerAvailable(req.oldPartition))))
lifecycleManager.workerStatusTracker.workerAvailableByLocation(req.oldPartition))))
}
}

// Get candidate worker that not in excluded worker list of shuffleId
val candidates =
lifecycleManager
.workerSnapshots(shuffleId)
.keySet()
val candidates = new util.HashSet[WorkerInfo]()
if (clientShuffleDynamicResourceEnabled) {
// availableWorkers wont filter excludedWorkers in heartBeat So have to do filtering.
candidates.addAll(lifecycleManager
.workerStatusTracker
.availableWorkersWithEndpoint
.values()
.asScala
.toSet
.filter(lifecycleManager.workerStatusTracker.workerAvailable)
.toList
.asJava)

// SetupEndpoint for those availableWorkers without endpoint
val workersRequireEndpoints = new util.HashSet[WorkerInfo](
lifecycleManager.workerStatusTracker.availableWorkersWithoutEndpoint.asScala.filter(
lifecycleManager.workerStatusTracker.workerAvailable).asJava)
val connectFailedWorkers = new ShuffleFailedWorkers()
lifecycleManager.setupEndpoints(
workersRequireEndpoints,
shuffleId,
connectFailedWorkers)
workersRequireEndpoints.removeAll(connectFailedWorkers.asScala.keys.toList.asJava)
candidates.addAll(workersRequireEndpoints)

// Update worker status
lifecycleManager.workerStatusTracker.addWorkersWithEndpoint(workersRequireEndpoints)
lifecycleManager.workerStatusTracker.recordWorkerFailure(connectFailedWorkers)
lifecycleManager.workerStatusTracker.removeFromExcludedWorkers(candidates)
} else {
val snapshotCandidates =
lifecycleManager
.workerSnapshots(shuffleId)
.keySet()
.asScala
.filter(lifecycleManager.workerStatusTracker.workerAvailable)
.asJava
candidates.addAll(snapshotCandidates)
}

if (candidates.size < 1 || (pushReplicateEnabled && candidates.size < 2)) {
logError("[Update partition] failed for not enough candidates for revive.")
replyFailure(StatusCode.SLOT_NOT_AVAILABLE)
Expand All @@ -294,45 +337,46 @@ class ChangePartitionManager(

// PartitionSplit all contains oldPartition
val newlyAllocatedLocations =
reallocateChangePartitionRequestSlotsFromCandidates(changePartitions.toList, candidates)
reallocateChangePartitionRequestSlotsFromCandidates(
changePartitions.toList,
candidates.asScala.toList)

if (!lifecycleManager.reserveSlotsWithRetry(
shuffleId,
new util.HashSet(candidates.toSet.asJava),
candidates,
newlyAllocatedLocations,
isSegmentGranularityVisible = isSegmentGranularityVisible)) {
logError(s"[Update partition] failed for $shuffleId.")
replyFailure(StatusCode.RESERVE_SLOTS_FAILED)
return
}

val newPrimaryLocations =
newlyAllocatedLocations.asScala.flatMap {
case (workInfo, (primaryLocations, replicaLocations)) =>
// Add all re-allocated slots to worker snapshots.
lifecycleManager.workerSnapshots(shuffleId).asScala
.get(workInfo)
.foreach { partitionLocationInfo =>
partitionLocationInfo.addPrimaryPartitions(primaryLocations)
lifecycleManager.updateLatestPartitionLocations(shuffleId, primaryLocations)
partitionLocationInfo.addReplicaPartitions(replicaLocations)
}
// partition location can be null when call reserveSlotsWithRetry().
val locations = (primaryLocations.asScala ++ replicaLocations.asScala.map(_.getPeer))
.distinct.filter(_ != null)
if (locations.nonEmpty) {
val changes = locations.map { partition =>
s"(partition ${partition.getId} epoch from ${partition.getEpoch - 1} to ${partition.getEpoch})"
}.mkString("[", ", ", "]")
logInfo(s"[Update partition] success for " +
s"shuffle $shuffleId, succeed partitions: " +
s"$changes.")
}
val newPrimaryLocations = newlyAllocatedLocations.asScala.flatMap {
case (workInfo, (primaryLocations, replicaLocations)) =>
// Add all re-allocated slots to worker snapshots.
val partitionLocationInfo = lifecycleManager.workerSnapshots(shuffleId).computeIfAbsent(
workInfo,
updateWorkerSnapshotsFunc)
partitionLocationInfo.addPrimaryPartitions(primaryLocations)
partitionLocationInfo.addReplicaPartitions(replicaLocations)
lifecycleManager.updateLatestPartitionLocations(shuffleId, primaryLocations)

// partition location can be null when call reserveSlotsWithRetry().
val locations = (primaryLocations.asScala ++ replicaLocations.asScala.map(_.getPeer))
.distinct.filter(_ != null)
if (locations.nonEmpty) {
val changes = locations.map { partition =>
s"(partition ${partition.getId} epoch from ${partition.getEpoch - 1} to ${partition.getEpoch})"
}.mkString("[", ", ", "]")
logInfo(s"[Update partition] success for " +
s"shuffle $shuffleId, succeed partitions: " +
s"$changes.")
}

// TODO: should record the new partition locations and acknowledge the new partitionLocations to downstream task,
// in scenario the downstream task start early before the upstream task.
locations
}
// TODO: should record the new partition locations and acknowledge the new partitionLocations to downstream task,
// in scenario the downstream task start early before the upstream task.
locations
}
replySuccess(newPrimaryLocations.toArray)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -449,11 +449,11 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
}

def setupEndpoints(
slots: WorkerResource,
workers: util.Set[WorkerInfo],
shuffleId: Int,
connectFailedWorkers: ShuffleFailedWorkers): Unit = {
val futures = new util.LinkedList[(Future[RpcEndpointRef], WorkerInfo)]()
slots.asScala foreach { case (workerInfo, _) =>
workers.asScala foreach { workerInfo =>
val future = workerRpcEnvInUse.asyncSetupEndpointRefByAddr(RpcEndpointAddress(
RpcAddress.apply(workerInfo.host, workerInfo.rpcPort),
WORKER_EP))
Expand Down Expand Up @@ -676,8 +676,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
val connectFailedWorkers = new ShuffleFailedWorkers()

// Second, for each worker, try to initialize the endpoint.
setupEndpoints(slots, shuffleId, connectFailedWorkers)

setupEndpoints(slots.keySet(), shuffleId, connectFailedWorkers)
candidatesWorkers.removeAll(connectFailedWorkers.asScala.keys.toList.asJava)
workerStatusTracker.recordWorkerFailure(connectFailedWorkers)
// If newly allocated from primary and can setup endpoint success, LifecycleManager should remove worker from
Expand Down Expand Up @@ -713,6 +712,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
allocatedWorkers.put(workerInfo, partitionLocationInfo)
}
shuffleAllocatedWorkers.put(shuffleId, allocatedWorkers)
workerStatusTracker.addWorkersWithEndpoint(candidatesWorkers)
registeredShuffle.add(shuffleId)
commitManager.registerShuffle(
shuffleId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,27 @@ import org.apache.celeborn.common.meta.WorkerInfo
import org.apache.celeborn.common.protocol.PartitionLocation
import org.apache.celeborn.common.protocol.message.ControlMessages.HeartbeatFromApplicationResponse
import org.apache.celeborn.common.protocol.message.StatusCode
import org.apache.celeborn.common.util.Utils
import org.apache.celeborn.common.util.{JavaUtils, Utils}

class WorkerStatusTracker(
conf: CelebornConf,
lifecycleManager: LifecycleManager) extends Logging {
private val excludedWorkerExpireTimeout = conf.clientExcludedWorkerExpireTimeout
private val workerStatusListeners = ConcurrentHashMap.newKeySet[WorkerStatusListener]()
private val clientShuffleDynamicResourceEnabled = conf.clientShuffleDynamicResourceEnabled

val excludedWorkers = new ShuffleFailedWorkers()
val shuttingWorkers: JSet[WorkerInfo] = new JHashSet[WorkerInfo]()

// Workers that have already set an endpoint can skip the setupEndpoint process in changePartition when reviving
// key: WorkerInfo.toUniqueId value: WorkerInfo
val availableWorkersWithEndpoint: ConcurrentHashMap[String, WorkerInfo] =
JavaUtils.newConcurrentHashMap[String, WorkerInfo]()

// Workers that may be available but have not been used(without endpoint)
// availableWorkersWithoutEndpoint is empty until appHeartbeatWithAvailableWorkers set to true
val availableWorkersWithoutEndpoint = ConcurrentHashMap.newKeySet[WorkerInfo]()

def registerWorkerStatusListener(workerStatusListener: WorkerStatusListener): Unit = {
workerStatusListeners.add(workerStatusListener)
}
Expand All @@ -61,7 +71,7 @@ class WorkerStatusTracker(
!excludedWorkers.containsKey(worker) && !shuttingWorkers.contains(worker)
}

def workerAvailable(loc: PartitionLocation): Boolean = {
def workerAvailableByLocation(loc: PartitionLocation): Boolean = {
if (loc == null) {
false
} else {
Expand Down Expand Up @@ -131,13 +141,16 @@ class WorkerStatusTracker(
failedWorkers.asScala.foreach {
case (worker, (StatusCode.WORKER_SHUTDOWN, _)) =>
shuttingWorkers.add(worker)
removeFromAvailableWorkers(worker)
case (worker, (statusCode, registerTime)) if !excludedWorkers.containsKey(worker) =>
excludedWorkers.put(worker, (statusCode, registerTime))
removeFromAvailableWorkers(worker)
case (worker, (statusCode, _))
if statusCode == StatusCode.NO_AVAILABLE_WORKING_DIR ||
statusCode == StatusCode.RESERVE_SLOTS_FAILED ||
statusCode == StatusCode.WORKER_UNKNOWN =>
excludedWorkers.put(worker, (statusCode, excludedWorkers.get(worker)._2))
removeFromAvailableWorkers(worker)
case _ => // Not cover
}
}
Expand All @@ -147,10 +160,22 @@ class WorkerStatusTracker(
excludedWorkers.keySet.removeAll(workers)
}

private def removeFromAvailableWorkers(worker: WorkerInfo): Unit = {
availableWorkersWithEndpoint.remove(worker.toUniqueId())
availableWorkersWithoutEndpoint.remove(worker)
}

def addWorkersWithEndpoint(workers: JHashSet[WorkerInfo]): Unit = {
availableWorkersWithoutEndpoint.removeAll(workers)
workers.asScala.foreach { workerInfo =>
availableWorkersWithEndpoint.put(workerInfo.toUniqueId(), workerInfo)
}
}

def handleHeartbeatResponse(res: HeartbeatFromApplicationResponse): Unit = {
if (res.statusCode == StatusCode.SUCCESS) {
logDebug(s"Received Worker status from Primary, excluded workers: ${res.excludedWorkers} " +
s"unknown workers: ${res.unknownWorkers}, shutdown workers: ${res.shuttingWorkers}")
s"unknown workers: ${res.unknownWorkers}, shutdown workers: ${res.shuttingWorkers}, available workers from heartbeat: ${res.availableWorkers}")
val current = System.currentTimeMillis()
var statusChanged = false

Expand Down Expand Up @@ -188,9 +213,33 @@ class WorkerStatusTracker(
statusChanged = true
}
}
val retainResult = shuttingWorkers.retainAll(res.shuttingWorkers)
val addResult = shuttingWorkers.addAll(res.shuttingWorkers)
statusChanged = statusChanged || retainResult || addResult

val retainShuttingWorkersResult = shuttingWorkers.retainAll(res.shuttingWorkers)
val addShuttingWorkersResult = shuttingWorkers.addAll(res.shuttingWorkers)

if (clientShuffleDynamicResourceEnabled) {
// AvailableWorkers filter Client excludedWorkers and shuttingWorkers.
// AvailableWorkers already filtered res.excludedWorkers and res.shuttingWorkers.
val resAvailableWorkers: JSet[WorkerInfo] = new JHashSet[WorkerInfo](res.availableWorkers)
// update availableWorkers
// availableWorkers wont filter excludedWorkers.
// So before using them we hava to filter excludedWorkers.
availableWorkersWithoutEndpoint.retainAll(resAvailableWorkers)
availableWorkersWithEndpoint.keySet().retainAll(
resAvailableWorkers.asScala.map(_.toUniqueId()).asJava)
resAvailableWorkers.asScala.foreach { workerInfo: WorkerInfo =>
if (!availableWorkersWithEndpoint.keySet.contains(workerInfo.toUniqueId())) {
availableWorkersWithoutEndpoint.add(workerInfo)
} else {
if (availableWorkersWithoutEndpoint.contains(workerInfo)) {
availableWorkersWithoutEndpoint.remove(workerInfo)
}
}
}
}

statusChanged =
statusChanged || retainShuttingWorkersResult || addShuttingWorkersResult
// Always trigger commit files for shutting down workers from HeartbeatFromApplicationResponse
// See details in CELEBORN-696
if (!res.unknownWorkers.isEmpty || !res.shuttingWorkers.isEmpty) {
Expand Down
Loading

0 comments on commit 7685fa7

Please sign in to comment.