Skip to content

Commit

Permalink
[SPARK-51263][CORE][SQL][TESTS] Clean up unnecessary invokePrivate
Browse files Browse the repository at this point in the history
…method calls in test code

### What changes were proposed in this pull request?
This pr cleans up unnecessary calls to the `org.scalatest.PrivateMethodTester.Invoker#invokePrivate` method in the test code, replacing those cases with direct function calls.

### Why are the changes needed?
Due to changes in the function's access scope, some cases in the original tests that used `invokePrivate` to call private methods are no longer necessary.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
- Pass GitHub Actions

### Was this patch authored or co-authored using generative AI tooling?
No

Closes apache#50012 from LuciferYang/PrivateMethod-cleanup.

Lead-authored-by: yangjie01 <yangjie01@baidu.com>
Co-authored-by: YangJie <yangjie01@baidu.com>
Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
  • Loading branch information
LuciferYang authored and dongjoon-hyun committed Feb 21, 2025
1 parent 140a69b commit e1842c7
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1911,14 +1911,6 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester {
| Helper methods for accessing private methods and fields |
* ------------------------------------------------------- */

private val _numExecutorsToAddPerResourceProfileId =
PrivateMethod[mutable.HashMap[Int, Int]](
Symbol("numExecutorsToAddPerResourceProfileId"))
private val _numExecutorsTargetPerResourceProfileId =
PrivateMethod[mutable.HashMap[Int, Int]](
Symbol("numExecutorsTargetPerResourceProfileId"))
private val _maxNumExecutorsNeededPerResourceProfile =
PrivateMethod[Int](Symbol("maxNumExecutorsNeededPerResourceProfile"))
private val _addTime = PrivateMethod[Long](Symbol("addTime"))
private val _schedule = PrivateMethod[Unit](Symbol("schedule"))
private val _doUpdateRequest = PrivateMethod[Unit](Symbol("doUpdateRequest"))
Expand Down Expand Up @@ -1946,7 +1938,7 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester {
private def numExecutorsToAdd(
manager: ExecutorAllocationManager,
rp: ResourceProfile): Int = {
val nmap = manager invokePrivate _numExecutorsToAddPerResourceProfileId()
val nmap = manager.numExecutorsToAddPerResourceProfileId
nmap(rp.id)
}

Expand All @@ -1963,7 +1955,7 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester {
private def numExecutorsTarget(
manager: ExecutorAllocationManager,
rpId: Int): Int = {
val numMap = manager invokePrivate _numExecutorsTargetPerResourceProfileId()
val numMap = manager.numExecutorsTargetPerResourceProfileId
numMap(rpId)
}

Expand All @@ -1982,7 +1974,7 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester {
rp: ResourceProfile
): Int = {
val maxNumExecutorsNeeded =
manager invokePrivate _maxNumExecutorsNeededPerResourceProfile(rp.id)
manager.maxNumExecutorsNeededPerResourceProfile(rp.id)
manager invokePrivate
_addExecutorsToTarget(maxNumExecutorsNeeded, rp.id, updatesNeeded)
}
Expand All @@ -2005,7 +1997,7 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester {
private def maxNumExecutorsNeededPerResourceProfile(
manager: ExecutorAllocationManager,
rp: ResourceProfile): Int = {
manager invokePrivate _maxNumExecutorsNeededPerResourceProfile(rp.id)
manager.maxNumExecutorsNeededPerResourceProfile(rp.id)
}

private def adjustRequestedExecutors(manager: ExecutorAllocationManager): Int = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -440,13 +440,11 @@ trait MasterSuiteBase extends SparkFunSuite
private val _drivers = PrivateMethod[HashSet[DriverInfo]](Symbol("drivers"))
protected val _waitingDrivers =
PrivateMethod[mutable.ArrayBuffer[DriverInfo]](Symbol("waitingDrivers"))
private val _state = PrivateMethod[RecoveryState.Value](Symbol("state"))
protected val _newDriverId = PrivateMethod[String](Symbol("newDriverId"))
protected val _newApplicationId = PrivateMethod[String](Symbol("newApplicationId"))
protected val _maybeUpdateAppName =
PrivateMethod[DriverDescription](Symbol("maybeUpdateAppName"))
protected val _createApplication = PrivateMethod[ApplicationInfo](Symbol("createApplication"))
protected val _persistenceEngine = PrivateMethod[PersistenceEngine](Symbol("persistenceEngine"))

protected val workerInfo = makeWorkerInfo(512, 10)
private val workerInfos = Array(workerInfo, workerInfo, workerInfo)
Expand Down Expand Up @@ -567,7 +565,7 @@ trait MasterSuiteBase extends SparkFunSuite
}

protected def getState(master: Master): RecoveryState.Value = {
master.invokePrivate(_state())
master.state
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ class RecoverySuite extends MasterSuiteBase {
var master: Master = null
try {
master = makeAliveMaster(conf)
val e = master.invokePrivate(_persistenceEngine()).asInstanceOf[FileSystemPersistenceEngine]
val e = master.persistenceEngine.asInstanceOf[FileSystemPersistenceEngine]
assert(e.codec.isEmpty)
} finally {
if (master != null) {
Expand All @@ -502,7 +502,7 @@ class RecoverySuite extends MasterSuiteBase {
var master: Master = null
try {
master = makeAliveMaster(conf)
val e = master.invokePrivate(_persistenceEngine()).asInstanceOf[FileSystemPersistenceEngine]
val e = master.persistenceEngine.asInstanceOf[FileSystemPersistenceEngine]
assert(e.codec.get.isInstanceOf[LZ4CompressionCodec])
} finally {
if (master != null) {
Expand All @@ -521,7 +521,7 @@ class RecoverySuite extends MasterSuiteBase {
var master: Master = null
try {
master = makeAliveMaster(conf)
val e = master.invokePrivate(_persistenceEngine()).asInstanceOf[RocksDBPersistenceEngine]
val e = master.persistenceEngine.asInstanceOf[RocksDBPersistenceEngine]
assert(e.serializer.isInstanceOf[JavaSerializer])
} finally {
if (master != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -851,8 +851,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with PrivateMethodTe
when(bmMaster.getLocations(mc.any[BlockId])).thenReturn(Seq(bmId1, bmId2, bmId3))

val blockManager = makeBlockManager(128, "exec", bmMaster)
val sortLocations = PrivateMethod[Seq[BlockManagerId]](Symbol("sortLocations"))
val locations = blockManager invokePrivate sortLocations(bmMaster.getLocations("test"))
val locations = blockManager.sortLocations(bmMaster.getLocations("test"))
assert(locations.map(_.host) === Seq(localHost, localHost, otherHost))
}

Expand All @@ -874,8 +873,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with PrivateMethodTe
val blockManager = makeBlockManager(128, "exec", bmMaster)
blockManager.blockManagerId =
BlockManagerId(SparkContext.DRIVER_IDENTIFIER, localHost, 1, Some(localRack))
val sortLocations = PrivateMethod[Seq[BlockManagerId]](Symbol("sortLocations"))
val locations = blockManager invokePrivate sortLocations(bmMaster.getLocations("test"))
val locations = blockManager.sortLocations(bmMaster.getLocations("test"))
assert(locations.map(_.host) === Seq(localHost, localHost, otherHost, otherHost, otherHost))
assert(locations.flatMap(_.topologyInfo)
=== Seq(localRack, localRack, localRack, otherRack, otherRack))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ import org.mockito.Mockito.{doThrow, mock, times, verify, when}
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
import org.roaringbitmap.RoaringBitmap
import org.scalatest.PrivateMethodTester

import org.apache.spark.{MapOutputTracker, SparkFunSuite, TaskContext}
import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
Expand All @@ -51,7 +50,7 @@ import org.apache.spark.storage.ShuffleBlockFetcherIterator._
import org.apache.spark.util.Utils


class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodTester {
class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite {

private var transfer: BlockTransferService = _
private var mapOutputTracker: MapOutputTracker = _
Expand Down Expand Up @@ -159,8 +158,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
// Note: ShuffleBlockFetcherIterator wraps input streams in a BufferReleasingInputStream
val wrappedInputStream = inputStream.asInstanceOf[BufferReleasingInputStream]
verify(buffer, times(0)).release()
val delegateAccess = PrivateMethod[InputStream](Symbol("delegate"))
var in = wrappedInputStream.invokePrivate(delegateAccess())
var in = wrappedInputStream.delegate
in match {
case stream: CheckedInputStream =>
val underlyingInputFiled = classOf[CheckedInputStream].getSuperclass.getDeclaredField("in")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,13 @@ package org.apache.spark.sql.catalyst.expressions
import scala.jdk.CollectionConverters._
import scala.reflect.ClassTag

import org.scalatest.PrivateMethodTester

import org.apache.spark.{SparkException, SparkFunSuite, SparkIllegalArgumentException, SparkThrowable}
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{LongType, StructField, StructType, TimestampNTZType, TimestampType}

class TimeWindowSuite extends SparkFunSuite with ExpressionEvalHelper with PrivateMethodTester {
class TimeWindowSuite extends SparkFunSuite with ExpressionEvalHelper {

test("time window is unevaluable") {
intercept[SparkException] {
Expand Down Expand Up @@ -96,35 +94,33 @@ class TimeWindowSuite extends SparkFunSuite with ExpressionEvalHelper with Priva
}
}

private val parseExpression = PrivateMethod[Long](Symbol("parseExpression"))

test("parse sql expression for duration in microseconds - string") {
val dur = TimeWindow.invokePrivate(parseExpression(Literal("5 seconds")))
val dur = TimeWindow.parseExpression(Literal("5 seconds"))
assert(dur.isInstanceOf[Long])
assert(dur === 5000000)
}

test("parse sql expression for duration in microseconds - integer") {
val dur = TimeWindow.invokePrivate(parseExpression(Literal(100)))
val dur = TimeWindow.parseExpression(Literal(100))
assert(dur.isInstanceOf[Long])
assert(dur === 100)
}

test("parse sql expression for duration in microseconds - long") {
val dur = TimeWindow.invokePrivate(parseExpression(Literal.create(2L << 52, LongType)))
val dur = TimeWindow.parseExpression(Literal.create(2L << 52, LongType))
assert(dur.isInstanceOf[Long])
assert(dur === (2L << 52))
}

test("parse sql expression for duration in microseconds - invalid interval") {
intercept[AnalysisException] {
TimeWindow.invokePrivate(parseExpression(Literal("2 apples")))
TimeWindow.parseExpression(Literal("2 apples"))
}
}

test("parse sql expression for duration in microseconds - invalid expression") {
intercept[AnalysisException] {
TimeWindow.invokePrivate(parseExpression(Rand(123)))
TimeWindow.parseExpression(Rand(123))
}
}

Expand Down

0 comments on commit e1842c7

Please sign in to comment.