Skip to content

Commit

Permalink
[CELEBORN-1388] Use finer grained locks in changePartitionManager
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

this PR proposes to use finer grained lock in  changePartitionManager when handling requests for different partitions

### Why are the changes needed?

we observed the intensive competition of locks when there are many partition got split. most of  change-partition-executor threads are competing for the concurrenthashmap used in ChangePartitionManager...this concurrentHashMap is holding request per partition but we are lock at the whole map instead of per partition level,

with this change, the driver memory footprint is significantly reduced due to the increased processing throughput...

### Does this PR introduce _any_ user-facing change?

one more configs

### How was this patch tested?

prod

Closes #2462 from CodingCat/finer_grained_locks.

Authored-by: CodingCat <zhunansjtu@gmail.com>
Signed-off-by: zky.zhoukeyong <zky.zhoukeyong@alibaba-inc.com>
  • Loading branch information
CodingCat authored and waitinfuture committed Apr 30, 2024
1 parent 29b5586 commit 9f30479
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,14 @@ class ChangePartitionManager(
// shuffleId -> (partitionId -> set of ChangePartition)
private val changePartitionRequests =
JavaUtils.newConcurrentHashMap[Int, ConcurrentHashMap[Integer, JSet[ChangePartitionRequest]]]()

// shuffleId -> locks
private val locks = JavaUtils.newConcurrentHashMap[Int, Array[AnyRef]]()
private val lockBucketSize = conf.batchHandleChangePartitionBuckets

// shuffleId -> set of partition id
private val inBatchPartitions = JavaUtils.newConcurrentHashMap[Int, JSet[Integer]]()
private val inBatchPartitions =
JavaUtils.newConcurrentHashMap[Int, ConcurrentHashMap.KeySetView[Int, java.lang.Boolean]]()

private val batchHandleChangePartitionEnabled = conf.batchHandleChangePartitionEnabled
private val batchHandleChangePartitionExecutors = ThreadUtils.newDaemonCachedThreadPool(
Expand Down Expand Up @@ -79,14 +85,19 @@ class ChangePartitionManager(
batchHandleChangePartitionExecutors.submit {
new Runnable {
override def run(): Unit = {
val distinctPartitions = requests.synchronized {
// For each partition only need handle one request
requests.asScala.filter { case (partitionId, _) =>
!inBatchPartitions.get(shuffleId).contains(partitionId)
}.map { case (partitionId, request) =>
inBatchPartitions.get(shuffleId).add(partitionId)
request.asScala.toArray.maxBy(_.epoch)
}.toArray
val distinctPartitions = {
val requestSet = inBatchPartitions.get(shuffleId)
val locksForShuffle = locks.computeIfAbsent(shuffleId, locksRegisterFunc)
requests.asScala.map { case (partitionId, request) =>
locksForShuffle(partitionId % locksForShuffle.length).synchronized {
if (!requestSet.contains(partitionId)) {
requestSet.add(partitionId)
Some(request.asScala.toArray.maxBy(_.epoch))
} else {
None
}
}
}.filter(_.isDefined).map(_.get).toArray
}
if (distinctPartitions.nonEmpty) {
handleRequestPartitions(
Expand Down Expand Up @@ -123,8 +134,16 @@ class ChangePartitionManager(
JavaUtils.newConcurrentHashMap()
}

private val inBatchShuffleIdRegisterFunc = new util.function.Function[Int, util.Set[Integer]]() {
override def apply(s: Int): util.Set[Integer] = new util.HashSet[Integer]()
private val inBatchShuffleIdRegisterFunc =
new util.function.Function[Int, ConcurrentHashMap.KeySetView[Int, java.lang.Boolean]]() {
override def apply(s: Int): ConcurrentHashMap.KeySetView[Int, java.lang.Boolean] =
ConcurrentHashMap.newKeySet[Int]()
}

private val locksRegisterFunc = new util.function.Function[Int, Array[AnyRef]] {
override def apply(t: Int): Array[AnyRef] = {
Array.fill(lockBucketSize)(new AnyRef())
}
}

def handleRequestPartitionLocation(
Expand All @@ -151,15 +170,22 @@ class ChangePartitionManager(
oldPartition,
cause)

requests.synchronized {
if (requests.containsKey(partitionId)) {
requests.get(partitionId).add(changePartition)
val locksForShuffle = locks.computeIfAbsent(shuffleId, locksRegisterFunc)
locksForShuffle(partitionId % locksForShuffle.length).synchronized {
var newEntry = false
val set = requests.computeIfAbsent(
partitionId,
new java.util.function.Function[Integer, util.Set[ChangePartitionRequest]] {
override def apply(t: Integer): util.Set[ChangePartitionRequest] = {
newEntry = true
new util.HashSet[ChangePartitionRequest]()
}
})

if (newEntry) {
logTrace(s"[handleRequestPartitionLocation] For $shuffleId, request for same partition" +
s"$partitionId-$oldEpoch exists, register context.")
return
} else {
// If new slot for the partition has been allocated, reply and return.
// Else register and allocate for it.
getLatestPartition(shuffleId, partitionId, oldEpoch).foreach { latestLoc =>
context.reply(
partitionId,
Expand All @@ -170,10 +196,8 @@ class ChangePartitionManager(
s" shuffleId: $shuffleId $latestLoc")
return
}
val set = new util.HashSet[ChangePartitionRequest]()
set.add(changePartition)
requests.put(partitionId, set)
}
set.add(changePartition)
}
if (!batchHandleChangePartitionEnabled) {
handleRequestPartitions(shuffleId, Array(changePartition))
Expand Down Expand Up @@ -216,14 +240,16 @@ class ChangePartitionManager(

// remove together to reduce lock time
def replySuccess(locations: Array[PartitionLocation]): Unit = {
requestsMap.synchronized {
locations.map { location =>
val locksForShuffle = locks.computeIfAbsent(shuffleId, locksRegisterFunc)
locations.map { location =>
locksForShuffle(location.getId % locksForShuffle.length).synchronized {
val ret = requestsMap.remove(location.getId)
if (batchHandleChangePartitionEnabled) {
inBatchPartitions.get(shuffleId).remove(location.getId)
}
// Here one partition id can be remove more than once,
// so need to filter null result before reply.
location -> Option(requestsMap.remove(location.getId))
location -> Option(ret)
}
}.foreach { case (newLocation, requests) =>
requests.map(_.asScala.toList.foreach(req =>
Expand All @@ -237,12 +263,14 @@ class ChangePartitionManager(

// remove together to reduce lock time
def replyFailure(status: StatusCode): Unit = {
requestsMap.synchronized {
changePartitions.map { changePartition =>
changePartitions.map { changePartition =>
val locksForShuffle = locks.computeIfAbsent(shuffleId, locksRegisterFunc)
locksForShuffle(changePartition.partitionId % locksForShuffle.length).synchronized {
val r = requestsMap.remove(changePartition.partitionId)
if (batchHandleChangePartitionEnabled) {
inBatchPartitions.get(shuffleId).remove(changePartition.partitionId)
}
Option(requestsMap.remove(changePartition.partitionId))
Option(r)
}
}.foreach { requests =>
requests.map(_.asScala.toList.foreach(req =>
Expand Down Expand Up @@ -325,5 +353,6 @@ class ChangePartitionManager(
def removeExpiredShuffle(shuffleId: Int): Unit = {
changePartitionRequests.remove(shuffleId)
inBatchPartitions.remove(shuffleId)
locks.remove(shuffleId)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1008,6 +1008,8 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se
PartitionSplitMode.valueOf(get(SHUFFLE_PARTITION_SPLIT_MODE))
def shufflePartitionSplitThreshold: Long = get(SHUFFLE_PARTITION_SPLIT_THRESHOLD)
def batchHandleChangePartitionEnabled: Boolean = get(CLIENT_BATCH_HANDLE_CHANGE_PARTITION_ENABLED)
def batchHandleChangePartitionBuckets: Int =
get(CLIENT_BATCH_HANDLE_CHANGE_PARTITION_BUCKETS)
def batchHandleChangePartitionNumThreads: Int = get(CLIENT_BATCH_HANDLE_CHANGE_PARTITION_THREADS)
def batchHandleChangePartitionRequestInterval: Long =
get(CLIENT_BATCH_HANDLE_CHANGE_PARTITION_INTERVAL)
Expand Down Expand Up @@ -4038,6 +4040,14 @@ object CelebornConf extends Logging {
.booleanConf
.createWithDefault(true)

val CLIENT_BATCH_HANDLE_CHANGE_PARTITION_BUCKETS: ConfigEntry[Int] =
buildConf("celeborn.client.shuffle.batchHandleChangePartition.partitionBuckets")
.categories("client")
.doc("Max number of change partition requests which can be concurrently processed ")
.version("0.5.0")
.intConf
.createWithDefault(256)

val CLIENT_BATCH_HANDLE_CHANGE_PARTITION_THREADS: ConfigEntry[Int] =
buildConf("celeborn.client.shuffle.batchHandleChangePartition.threads")
.withAlternative("celeborn.shuffle.batchHandleChangePartition.threads")
Expand Down
1 change: 1 addition & 0 deletions docs/configuration/client.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ license: |
| celeborn.client.rpc.reserveSlots.askTimeout | &lt;value of celeborn.rpc.askTimeout&gt; | false | Timeout for LifecycleManager request reserve slots. | 0.3.0 | |
| celeborn.client.rpc.shared.threads | 16 | false | Number of shared rpc threads in LifecycleManager. | 0.3.2 | |
| celeborn.client.shuffle.batchHandleChangePartition.interval | 100ms | false | Interval for LifecycleManager to schedule handling change partition requests in batch. | 0.3.0 | celeborn.shuffle.batchHandleChangePartition.interval |
| celeborn.client.shuffle.batchHandleChangePartition.partitionBuckets | 256 | false | Max number of change partition requests which can be concurrently processed | 0.5.0 | |
| celeborn.client.shuffle.batchHandleChangePartition.threads | 8 | false | Threads number for LifecycleManager to handle change partition request in batch. | 0.3.0 | celeborn.shuffle.batchHandleChangePartition.threads |
| celeborn.client.shuffle.batchHandleCommitPartition.interval | 5s | false | Interval for LifecycleManager to schedule handling commit partition requests in batch. | 0.3.0 | celeborn.shuffle.batchHandleCommitPartition.interval |
| celeborn.client.shuffle.batchHandleCommitPartition.threads | 8 | false | Threads number for LifecycleManager to handle commit partition request in batch. | 0.3.0 | celeborn.shuffle.batchHandleCommitPartition.threads |
Expand Down

0 comments on commit 9f30479

Please sign in to comment.