diff --git a/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala b/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala index 9cf1dca528..573e4d94e0 100644 --- a/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala +++ b/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala @@ -47,6 +47,8 @@ class ChangePartitionManager( // shuffleId -> (partitionId -> set of ChangePartition) private val changePartitionRequests = JavaUtils.newConcurrentHashMap[Int, ConcurrentHashMap[Integer, JSet[ChangePartitionRequest]]]() + private val locks = Array.fill(conf.batchHandleChangePartitionParallelism)(new AnyRef()) + // shuffleId -> set of partition id private val inBatchPartitions = JavaUtils.newConcurrentHashMap[Int, JSet[Integer]]() @@ -79,14 +81,17 @@ class ChangePartitionManager( batchHandleChangePartitionExecutors.submit { new Runnable { override def run(): Unit = { - val distinctPartitions = requests.synchronized { - // For each partition only need handle one request - requests.asScala.filter { case (partitionId, _) => - !inBatchPartitions.get(shuffleId).contains(partitionId) - }.map { case (partitionId, request) => - inBatchPartitions.get(shuffleId).add(partitionId) - request.asScala.toArray.maxBy(_.epoch) - }.toArray + val distinctPartitions = { + requests.asScala.map { case (partitionId, request) => + locks(partitionId % locks.length).synchronized { + if (!inBatchPartitions.contains(partitionId)) { + inBatchPartitions.get(shuffleId).add(partitionId) + Some(request.asScala.toArray.maxBy(_.epoch)) + } else { + None + } + } + }.filter(_.isDefined).map(_.get).toArray } if (distinctPartitions.nonEmpty) { handleRequestPartitions( @@ -151,7 +156,7 @@ class ChangePartitionManager( oldPartition, cause) - requests.synchronized { + locks(partitionId % locks.length).synchronized { if (requests.containsKey(partitionId)) { requests.get(partitionId).add(changePartition) logTrace(s"[handleRequestPartitionLocation] For $shuffleId, request for same partition" + diff --git a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala index bb6b96e93a..6ad7dab240 100644 --- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala +++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala @@ -914,6 +914,8 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se PartitionSplitMode.valueOf(get(SHUFFLE_PARTITION_SPLIT_MODE)) def shufflePartitionSplitThreshold: Long = get(SHUFFLE_PARTITION_SPLIT_THRESHOLD) def batchHandleChangePartitionEnabled: Boolean = get(CLIENT_BATCH_HANDLE_CHANGE_PARTITION_ENABLED) + def batchHandleChangePartitionParallelism: Int = + get(CLIENT_BATCH_HANDLE_CHANGE_PARTITION_PARALLELISM) def batchHandleChangePartitionNumThreads: Int = get(CLIENT_BATCH_HANDLE_CHANGE_PARTITION_THREADS) def batchHandleChangePartitionRequestInterval: Long = get(CLIENT_BATCH_HANDLE_CHANGE_PARTITION_INTERVAL) @@ -3899,6 +3901,15 @@ object CelebornConf extends Logging { .booleanConf .createWithDefault(true) + val CLIENT_BATCH_HANDLE_CHANGE_PARTITION_PARALLELISM: ConfigEntry[Int] = + buildConf("celeborn.client.shuffle.batchHandleChangePartition.parallelism") + .categories("client") + .internal + .doc("max number of change partition requests which can be concurrently processed ") + .version("0.5.0") + .intConf + .createWithDefault(256) + val CLIENT_BATCH_HANDLE_CHANGE_PARTITION_THREADS: ConfigEntry[Int] = buildConf("celeborn.client.shuffle.batchHandleChangePartition.threads") .withAlternative("celeborn.shuffle.batchHandleChangePartition.threads")