Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CELEBORN-1388] Use finer grained locks in changePartitionManager #2462

Closed
wants to merge 13 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,11 @@ class ChangePartitionManager(
// shuffleId -> (partitionId -> set of ChangePartition)
private val changePartitionRequests =
JavaUtils.newConcurrentHashMap[Int, ConcurrentHashMap[Integer, JSet[ChangePartitionRequest]]]()
private val locks = Array.fill(conf.batchHandleChangePartitionBuckets)(new AnyRef())

// shuffleId -> set of partition id
private val inBatchPartitions = JavaUtils.newConcurrentHashMap[Int, JSet[Integer]]()
private val inBatchPartitions =
JavaUtils.newConcurrentHashMap[Int, ConcurrentHashMap[Integer, Unit]]()

private val batchHandleChangePartitionEnabled = conf.batchHandleChangePartitionEnabled
private val batchHandleChangePartitionExecutors = ThreadUtils.newDaemonCachedThreadPool(
Expand Down Expand Up @@ -79,14 +82,18 @@ class ChangePartitionManager(
batchHandleChangePartitionExecutors.submit {
new Runnable {
override def run(): Unit = {
val distinctPartitions = requests.synchronized {
// For each partition only need handle one request
requests.asScala.filter { case (partitionId, _) =>
!inBatchPartitions.get(shuffleId).contains(partitionId)
}.map { case (partitionId, request) =>
inBatchPartitions.get(shuffleId).add(partitionId)
request.asScala.toArray.maxBy(_.epoch)
}.toArray
val distinctPartitions = {
val requestSet = inBatchPartitions.get(shuffleId)
requests.asScala.map { case (partitionId, request) =>
locks(partitionId % locks.length).synchronized {
if (!requestSet.contains(partitionId)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

contains for ConcurrentHashMap is actually containsValue. It's better to use ConcurrentHashMap.newKeySet() instead of ConcurrentHashMap[Integer, Unit]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops, fixed

requestSet.put(partitionId, ())
Some(request.asScala.toArray.maxBy(_.epoch))
} else {
None
}
}
}.filter(_.isDefined).map(_.get).toArray
Copy link
Contributor

@mridulm mridulm Apr 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not so sure of this change - it feels like it will have a lot more contention - as each entry in the requests map will need to acquire a lock - while uncontented locks are cheap to acquire, it is not zero cost still.

If you have a test bed to validate perf, how about this ?

                      val batchPartitions = inBatchPartitions.get(shuffleId)
                      val distinctPartitions = requests.synchronized {
                        // For each partition only need handle one request
                        requests.asScala.filter { case (partitionId, _) =>
                          !batchPartitions.contains(partitionId)
                        }.map { case (partitionId, request) =>
                          batchPartitions.add(partitionId)
                          request.asScala.maxBy(_.epoch)
                        }.toArray
                      }

Essentially minimize the time within the synchronized block itself by removing unnecessary costs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

based on our observation in our production system, it won't bring more competition ...

if we use requests.synchronized, all celeborn-dispatcher threads + all celeborn-client-life-cycle-manager-change-partition-executor will all compete for the same object for locking, even they are likely working on different partitions ... check the following screenshots

image

image

after this change, with a huge spark application of 300TB shuffle data, I don't see such intensive locking competition anymore

Copy link
Contributor

@mridulm mridulm Apr 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough. Thanks for sharing the stack trace.
I would suggest that the changes I gave are relevant irrespective of the locking strategy - as it will minimize the time within a critical section.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it feels like it will have a lot more contention - as each entry in the requests map will need to acquire a lock

To reduce the frequency of acquiring locks, I think we can calculate the lock buckets for each partition ids first, then group the partition ids by the lock bucket, then acquire lock and process each group (in random order). Though I'm not sure how beneficial this will be.

}
if (distinctPartitions.nonEmpty) {
handleRequestPartitions(
Expand Down Expand Up @@ -123,9 +130,11 @@ class ChangePartitionManager(
JavaUtils.newConcurrentHashMap()
}

private val inBatchShuffleIdRegisterFunc = new util.function.Function[Int, util.Set[Integer]]() {
override def apply(s: Int): util.Set[Integer] = new util.HashSet[Integer]()
}
private val inBatchShuffleIdRegisterFunc =
new util.function.Function[Int, ConcurrentHashMap[Integer, Unit]]() {
override def apply(s: Int): ConcurrentHashMap[Integer, Unit] =
new ConcurrentHashMap[Integer, Unit]()
}

def handleRequestPartitionLocation(
context: RequestLocationCallContext,
Expand All @@ -151,15 +160,21 @@ class ChangePartitionManager(
oldPartition,
cause)

requests.synchronized {
if (requests.containsKey(partitionId)) {
requests.get(partitionId).add(changePartition)
locks(partitionId % locks.length).synchronized {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@CodingCat I think the "partition Id" of different shuffles can be repeated. The lock is for the same "shuffleId" in the previous implementation but the lock can be contended by the same "partition Id" of different stages in your new implementation. Although a spark application won't run too many stages concurrently, but the spark thrift server might run many stages.

The locks variable can be changed to avoid the lock contention of different stages.
private val locks = JavaUtils.newConcurrentHashMap[Int,Array[AnyRef]]()
I think creating an array of AnyRef won't cost more than the contended locks. 256 AnyRef objects will consume 2 kb of memory, this suggestion won't introduce memory pressure.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

excellent point! just changed the code

var newEntry = false
val set = requests.computeIfAbsent(
partitionId,
new java.util.function.Function[Integer, util.Set[ChangePartitionRequest]] {
override def apply(t: Integer): util.Set[ChangePartitionRequest] = {
newEntry = true
new util.HashSet[ChangePartitionRequest]()
}
})

if (newEntry) {
logTrace(s"[handleRequestPartitionLocation] For $shuffleId, request for same partition" +
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could replace this with a computeIfAbsent ?
Something like:

    requests.synchronized {
      var newEntry = false
      val set = requests.computeIfAbsent(partitionId, (v1: Integer) => {
        // If new slot for the partition has been allocated, reply and return.
        // Else register and allocate for it.
        getLatestPartition(shuffleId, partitionId, oldEpoch).foreach { latestLoc =>
          context.reply(
            partitionId,
            StatusCode.SUCCESS,
            Some(latestLoc),
            lifecycleManager.workerStatusTracker.workerAvailable(oldPartition))
          logDebug(s"New partition found, old partition $partitionId-$oldEpoch return it." +
            s" shuffleId: $shuffleId $latestLoc")
          return
        }
        newEntry = true
        new util.HashSet[ChangePartitionRequest]()
      })

      set.add(changePartition)
      if (!newEntry) {
        logTrace(s"[handleRequestPartitionLocation] For $shuffleId, request for same partition" +
          s"$partitionId-$oldEpoch exists, register context.")
      }
    }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if I understand the suggest code correctly, you essentially create a set in requests for each partition and keep adding a request to it,

I thought the same when iterating on the PR, however it turns out we cannot do it ....

basically it is not what the original code was doing... the original code always add a new set containing a single request to the hash map, i.e. line 178 - 179

Copy link
Contributor

@mridulm mridulm Apr 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original code is doing the same.
If partition exists, it will add to the set - else create new and add with the entry.
(Removing other parts of the code, it is essentially)

 if (requests.containsKey(partitionId)) {
     requests.get(partitionId).add(changePartition)
 } else {
    // an early exit condition, followed by:
    val set = new util.HashSet[ChangePartitionRequest]()
    set.add(changePartition)
    requests.put(partitionId, set)
 }

It is probing the map multiple times though, which is something we can avoid.

(the return in the getLatestPartition case I suggested looks wrong though - we should return null and exit if set is null)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

requests.putIfAbsent(partitionId, set)
requests.get(partitionId).synchronized {
  getLatestPartition(shuffleId, partitionId, oldEpoch).foreach { latestLoc =>
    context.reply(
      partitionId,
      StatusCode.SUCCESS,
      Some(latestLoc),
      lifecycleManager.workerStatusTracker.workerAvailable(oldPartition))
    logDebug(s"New partition found, old partition $partitionId-$oldEpoch return it." +
      s" shuffleId: $shuffleId $latestLoc")
    return
  }
  requests.get(partitionId).add(changePartition)
}

this was my original code, somehow this makes the application stuck , that's why I feel somehow this putIfAbsent approach changed the original semantics in a stealthy way

Copy link
Contributor

@mridulm mridulm Apr 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is strictly not the same as what exists in main branch - I have not analyzed it greater detail, but the critical sections are different.
Note that the changes I proposed above are to ensure we remove avoidable probes into the map, and improve performance while not changing the critical sections ... but if the version I proposed does cause deadlocks/hangs, I would be very curious to know why ! (stack trace would definitely help) thanks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i have updated the code , will run more test in our env

s"$partitionId-$oldEpoch exists, register context.")
return
} else {
// If new slot for the partition has been allocated, reply and return.
// Else register and allocate for it.
getLatestPartition(shuffleId, partitionId, oldEpoch).foreach { latestLoc =>
context.reply(
partitionId,
Expand All @@ -170,10 +185,8 @@ class ChangePartitionManager(
s" shuffleId: $shuffleId $latestLoc")
return
}
val set = new util.HashSet[ChangePartitionRequest]()
set.add(changePartition)
requests.put(partitionId, set)
}
set.add(changePartition)
}
if (!batchHandleChangePartitionEnabled) {
handleRequestPartitions(shuffleId, Array(changePartition))
Expand Down Expand Up @@ -216,14 +229,15 @@ class ChangePartitionManager(

// remove together to reduce lock time
def replySuccess(locations: Array[PartitionLocation]): Unit = {
requestsMap.synchronized {
locations.map { location =>
locations.map { location =>
locks(location.getId % locks.length).synchronized {
val ret = requestsMap.remove(location.getId)
if (batchHandleChangePartitionEnabled) {
inBatchPartitions.get(shuffleId).remove(location.getId)
}
// Here one partition id can be remove more than once,
// so need to filter null result before reply.
location -> Option(requestsMap.remove(location.getId))
location -> Option(ret)
}
}.foreach { case (newLocation, requests) =>
requests.map(_.asScala.toList.foreach(req =>
Expand All @@ -237,12 +251,13 @@ class ChangePartitionManager(

// remove together to reduce lock time
def replyFailure(status: StatusCode): Unit = {
requestsMap.synchronized {
changePartitions.map { changePartition =>
changePartitions.map { changePartition =>
locks(changePartition.partitionId % locks.length).synchronized {
val r = requestsMap.remove(changePartition.partitionId)
if (batchHandleChangePartitionEnabled) {
inBatchPartitions.get(shuffleId).remove(changePartition.partitionId)
}
Option(requestsMap.remove(changePartition.partitionId))
Option(r)
}
}.foreach { requests =>
requests.map(_.asScala.toList.foreach(req =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -914,6 +914,8 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se
PartitionSplitMode.valueOf(get(SHUFFLE_PARTITION_SPLIT_MODE))
def shufflePartitionSplitThreshold: Long = get(SHUFFLE_PARTITION_SPLIT_THRESHOLD)
def batchHandleChangePartitionEnabled: Boolean = get(CLIENT_BATCH_HANDLE_CHANGE_PARTITION_ENABLED)
def batchHandleChangePartitionBuckets: Int =
get(CLIENT_BATCH_HANDLE_CHANGE_PARTITION_BUCKETS)
def batchHandleChangePartitionNumThreads: Int = get(CLIENT_BATCH_HANDLE_CHANGE_PARTITION_THREADS)
def batchHandleChangePartitionRequestInterval: Long =
get(CLIENT_BATCH_HANDLE_CHANGE_PARTITION_INTERVAL)
Expand Down Expand Up @@ -3899,6 +3901,14 @@ object CelebornConf extends Logging {
.booleanConf
.createWithDefault(true)

val CLIENT_BATCH_HANDLE_CHANGE_PARTITION_BUCKETS: ConfigEntry[Int] =
buildConf("celeborn.client.shuffle.batchHandleChangePartition.partitionBuckets")
.categories("client")
.doc("Max number of change partition requests which can be concurrently processed ")
.version("0.5.0")
.intConf
.createWithDefault(256)

val CLIENT_BATCH_HANDLE_CHANGE_PARTITION_THREADS: ConfigEntry[Int] =
buildConf("celeborn.client.shuffle.batchHandleChangePartition.threads")
.withAlternative("celeborn.shuffle.batchHandleChangePartition.threads")
Expand Down
1 change: 1 addition & 0 deletions docs/configuration/client.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ license: |
| celeborn.client.rpc.reserveSlots.askTimeout | <value of celeborn.rpc.askTimeout> | false | Timeout for LifecycleManager request reserve slots. | 0.3.0 | |
| celeborn.client.rpc.shared.threads | 16 | false | Number of shared rpc threads in LifecycleManager. | 0.3.2 | |
| celeborn.client.shuffle.batchHandleChangePartition.interval | 100ms | false | Interval for LifecycleManager to schedule handling change partition requests in batch. | 0.3.0 | celeborn.shuffle.batchHandleChangePartition.interval |
| celeborn.client.shuffle.batchHandleChangePartition.partitionBuckets | 256 | false | Max number of change partition requests which can be concurrently processed | 0.5.0 | |
| celeborn.client.shuffle.batchHandleChangePartition.threads | 8 | false | Threads number for LifecycleManager to handle change partition request in batch. | 0.3.0 | celeborn.shuffle.batchHandleChangePartition.threads |
| celeborn.client.shuffle.batchHandleCommitPartition.interval | 5s | false | Interval for LifecycleManager to schedule handling commit partition requests in batch. | 0.3.0 | celeborn.shuffle.batchHandleCommitPartition.interval |
| celeborn.client.shuffle.batchHandleCommitPartition.threads | 8 | false | Threads number for LifecycleManager to handle commit partition request in batch. | 0.3.0 | celeborn.shuffle.batchHandleCommitPartition.threads |
Expand Down
Loading