-
Notifications
You must be signed in to change notification settings - Fork 360
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CELEBORN-1388] Use finer grained locks in changePartitionManager #2462
Changes from 9 commits
263a4c3
24fbd9c
24c6905
f7e0999
f220fcd
d96bd8e
c307b71
a6fa04b
d6763b7
73de215
93d4eb6
ac43719
f3c8f0e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -47,8 +47,11 @@ class ChangePartitionManager( | |
// shuffleId -> (partitionId -> set of ChangePartition) | ||
private val changePartitionRequests = | ||
JavaUtils.newConcurrentHashMap[Int, ConcurrentHashMap[Integer, JSet[ChangePartitionRequest]]]() | ||
private val locks = Array.fill(conf.batchHandleChangePartitionBuckets)(new AnyRef()) | ||
|
||
// shuffleId -> set of partition id | ||
private val inBatchPartitions = JavaUtils.newConcurrentHashMap[Int, JSet[Integer]]() | ||
private val inBatchPartitions = | ||
JavaUtils.newConcurrentHashMap[Int, ConcurrentHashMap[Integer, Unit]]() | ||
|
||
private val batchHandleChangePartitionEnabled = conf.batchHandleChangePartitionEnabled | ||
private val batchHandleChangePartitionExecutors = ThreadUtils.newDaemonCachedThreadPool( | ||
|
@@ -79,14 +82,18 @@ class ChangePartitionManager( | |
batchHandleChangePartitionExecutors.submit { | ||
new Runnable { | ||
override def run(): Unit = { | ||
val distinctPartitions = requests.synchronized { | ||
// For each partition only need handle one request | ||
requests.asScala.filter { case (partitionId, _) => | ||
!inBatchPartitions.get(shuffleId).contains(partitionId) | ||
}.map { case (partitionId, request) => | ||
inBatchPartitions.get(shuffleId).add(partitionId) | ||
request.asScala.toArray.maxBy(_.epoch) | ||
}.toArray | ||
val distinctPartitions = { | ||
val requestSet = inBatchPartitions.get(shuffleId) | ||
requests.asScala.map { case (partitionId, request) => | ||
locks(partitionId % locks.length).synchronized { | ||
if (!requestSet.contains(partitionId)) { | ||
requestSet.put(partitionId, ()) | ||
Some(request.asScala.toArray.maxBy(_.epoch)) | ||
} else { | ||
None | ||
} | ||
} | ||
}.filter(_.isDefined).map(_.get).toArray | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not so sure of this change - it feels like it will have a lot more contention - as each entry in the If you have a test bed to validate perf, how about this ?
Essentially minimize the time within the synchronized block itself by removing unnecessary costs. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. based on our observation in our production system, it won't bring more competition ... if we use after this change, with a huge spark application of 300TB shuffle data, I don't see such intensive locking competition anymore There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fair enough. Thanks for sharing the stack trace. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
To reduce the frequency of acquiring locks, I think we can calculate the lock buckets for each partition ids first, then group the partition ids by the lock bucket, then acquire lock and process each group (in random order). Though I'm not sure how beneficial this will be. |
||
} | ||
if (distinctPartitions.nonEmpty) { | ||
handleRequestPartitions( | ||
|
@@ -123,9 +130,11 @@ class ChangePartitionManager( | |
JavaUtils.newConcurrentHashMap() | ||
} | ||
|
||
private val inBatchShuffleIdRegisterFunc = new util.function.Function[Int, util.Set[Integer]]() { | ||
override def apply(s: Int): util.Set[Integer] = new util.HashSet[Integer]() | ||
} | ||
private val inBatchShuffleIdRegisterFunc = | ||
new util.function.Function[Int, ConcurrentHashMap[Integer, Unit]]() { | ||
override def apply(s: Int): ConcurrentHashMap[Integer, Unit] = | ||
new ConcurrentHashMap[Integer, Unit]() | ||
} | ||
|
||
def handleRequestPartitionLocation( | ||
context: RequestLocationCallContext, | ||
|
@@ -151,15 +160,21 @@ class ChangePartitionManager( | |
oldPartition, | ||
cause) | ||
|
||
requests.synchronized { | ||
if (requests.containsKey(partitionId)) { | ||
requests.get(partitionId).add(changePartition) | ||
locks(partitionId % locks.length).synchronized { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @CodingCat I think the "partition Id" of different shuffles can be repeated. The lock is for the same "shuffleId" in the previous implementation but the lock can be contended by the same "partition Id" of different stages in your new implementation. Although a spark application won't run too many stages concurrently, but the spark thrift server might run many stages. The locks variable can be changed to avoid the lock contention of different stages. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. excellent point! just changed the code |
||
var newEntry = false | ||
val set = requests.computeIfAbsent( | ||
partitionId, | ||
new java.util.function.Function[Integer, util.Set[ChangePartitionRequest]] { | ||
override def apply(t: Integer): util.Set[ChangePartitionRequest] = { | ||
newEntry = true | ||
new util.HashSet[ChangePartitionRequest]() | ||
} | ||
}) | ||
|
||
if (newEntry) { | ||
logTrace(s"[handleRequestPartitionLocation] For $shuffleId, request for same partition" + | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could replace this with a
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if I understand the suggest code correctly, you essentially create a set in requests for each partition and keep adding a request to it, I thought the same when iterating on the PR, however it turns out we cannot do it .... basically it is not what the original code was doing... the original code always add a new set containing a single request to the hash map, i.e. line 178 - 179 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The original code is doing the same.
It is probing the map multiple times though, which is something we can avoid. (the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. requests.putIfAbsent(partitionId, set)
requests.get(partitionId).synchronized {
getLatestPartition(shuffleId, partitionId, oldEpoch).foreach { latestLoc =>
context.reply(
partitionId,
StatusCode.SUCCESS,
Some(latestLoc),
lifecycleManager.workerStatusTracker.workerAvailable(oldPartition))
logDebug(s"New partition found, old partition $partitionId-$oldEpoch return it." +
s" shuffleId: $shuffleId $latestLoc")
return
}
requests.get(partitionId).add(changePartition)
} this was my original code, somehow this makes the application stuck , that's why I feel somehow this putIfAbsent approach changed the original semantics in a stealthy way There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is strictly not the same as what exists in main branch - I have not analyzed it greater detail, but the critical sections are different. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i have updated the code , will run more test in our env |
||
s"$partitionId-$oldEpoch exists, register context.") | ||
return | ||
} else { | ||
// If new slot for the partition has been allocated, reply and return. | ||
// Else register and allocate for it. | ||
getLatestPartition(shuffleId, partitionId, oldEpoch).foreach { latestLoc => | ||
context.reply( | ||
partitionId, | ||
|
@@ -170,10 +185,8 @@ class ChangePartitionManager( | |
s" shuffleId: $shuffleId $latestLoc") | ||
return | ||
} | ||
val set = new util.HashSet[ChangePartitionRequest]() | ||
set.add(changePartition) | ||
requests.put(partitionId, set) | ||
} | ||
set.add(changePartition) | ||
} | ||
if (!batchHandleChangePartitionEnabled) { | ||
handleRequestPartitions(shuffleId, Array(changePartition)) | ||
|
@@ -216,14 +229,15 @@ class ChangePartitionManager( | |
|
||
// remove together to reduce lock time | ||
def replySuccess(locations: Array[PartitionLocation]): Unit = { | ||
requestsMap.synchronized { | ||
locations.map { location => | ||
locations.map { location => | ||
locks(location.getId % locks.length).synchronized { | ||
val ret = requestsMap.remove(location.getId) | ||
if (batchHandleChangePartitionEnabled) { | ||
inBatchPartitions.get(shuffleId).remove(location.getId) | ||
} | ||
// Here one partition id can be remove more than once, | ||
// so need to filter null result before reply. | ||
location -> Option(requestsMap.remove(location.getId)) | ||
location -> Option(ret) | ||
} | ||
}.foreach { case (newLocation, requests) => | ||
requests.map(_.asScala.toList.foreach(req => | ||
|
@@ -237,12 +251,13 @@ class ChangePartitionManager( | |
|
||
// remove together to reduce lock time | ||
def replyFailure(status: StatusCode): Unit = { | ||
requestsMap.synchronized { | ||
changePartitions.map { changePartition => | ||
changePartitions.map { changePartition => | ||
locks(changePartition.partitionId % locks.length).synchronized { | ||
val r = requestsMap.remove(changePartition.partitionId) | ||
if (batchHandleChangePartitionEnabled) { | ||
inBatchPartitions.get(shuffleId).remove(changePartition.partitionId) | ||
} | ||
Option(requestsMap.remove(changePartition.partitionId)) | ||
Option(r) | ||
} | ||
}.foreach { requests => | ||
requests.map(_.asScala.toList.foreach(req => | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
contains
forConcurrentHashMap
is actuallycontainsValue
. It's better to useConcurrentHashMap.newKeySet()
instead ofConcurrentHashMap[Integer, Unit]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oops, fixed