Skip to content

Commit 25382fd

Browse files
committed
Add priorReport
1 parent 656353c commit 25382fd

File tree

3 files changed

+36
-26
lines changed

3 files changed

+36
-26
lines changed

src/main/scala/sbttestshards/ShardingAlgorithm.scala

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
package sbttestshards
22

3-
import sbttestshards.parsers.JUnitReportParser
3+
import sbttestshards.parsers.{FullTestReport, JUnitReportParser}
44

55
import java.nio.charset.StandardCharsets
66
import java.nio.file.Path
@@ -10,6 +10,8 @@ import scala.util.hashing.MurmurHash3
1010
// This trait is open so that users can implement a custom `ShardingAlgorithm` if they'd like
1111
trait ShardingAlgorithm {
1212

13+
def priorReport: Option[FullTestReport]
14+
1315
def check(suiteName: String, shardContext: ShardContext): ShardResult
1416

1517
/** Determines whether the specified suite will run on this shard or not. */
@@ -31,6 +33,8 @@ object ShardingAlgorithm {
3133

3234
ShardResult(Some(testShard))
3335
}
36+
37+
def priorReport: Option[FullTestReport] = None
3438
}
3539

3640
/** Will always mark the test to run on this shard. Useful for debugging or
@@ -40,6 +44,8 @@ object ShardingAlgorithm {
4044

4145
def check(suiteName: String, shardContext: ShardContext): ShardResult =
4246
ShardResult(Some(shardContext.testShard))
47+
48+
def priorReport: Option[FullTestReport] = None
4349
}
4450

4551
/** Will never mark the test to run on this shard. Useful for debugging or for
@@ -49,6 +55,8 @@ object ShardingAlgorithm {
4955

5056
def check(suiteName: String, shardContext: ShardContext): ShardResult =
5157
ShardResult(None)
58+
59+
def priorReport: Option[FullTestReport] = None
5260
}
5361

5462
object Balance {
@@ -57,14 +65,18 @@ object ShardingAlgorithm {
5765
reportDirectories: Seq[Path],
5866
shardsInfo: ShardingInfo,
5967
fallbackShardingAlgorithm: ShardingAlgorithm = ShardingAlgorithm.SuiteName
60-
): Balance =
68+
): Balance = {
69+
val priorReport = JUnitReportParser.parseDirectoriesRecursively(reportDirectories)
70+
6171
ShardingAlgorithm.Balance(
62-
JUnitReportParser.parseDirectoriesRecursively(reportDirectories).testReports.map { r =>
72+
priorReport.testReports.map { r =>
6373
SuiteInfo(r.name, Some(Duration.ofMillis((r.timeTaken * 1000).toLong)))
6474
},
6575
shardsInfo,
66-
fallbackShardingAlgorithm
76+
fallbackShardingAlgorithm,
77+
Some(priorReport)
6778
)
79+
}
6880
}
6981

7082
/** Attempts to balance the shards by execution time so that no one shard
@@ -73,7 +85,8 @@ object ShardingAlgorithm {
7385
final case class Balance(
7486
suites: Seq[SuiteInfo],
7587
shardsInfo: ShardingInfo,
76-
fallbackShardingAlgorithm: ShardingAlgorithm = ShardingAlgorithm.SuiteName
88+
fallbackShardingAlgorithm: ShardingAlgorithm = ShardingAlgorithm.SuiteName,
89+
priorReport: Option[FullTestReport] = None
7790
) extends ShardingAlgorithm {
7891

7992
// TODO: Median might be better here?

src/main/scala/sbttestshards/TestShardsPlugin.scala

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,7 @@ package sbttestshards
22

33
import sbt.*
44
import sbt.Keys.*
5-
import sbttestshards.parsers.JUnitReportParser
6-
7-
import java.nio.file.Paths
5+
import sbttestshards.parsers.FullTestReport
86

97
object TestShardsPlugin extends AutoPlugin {
108

@@ -20,11 +18,6 @@ object TestShardsPlugin extends AutoPlugin {
2018

2119
override def trigger = allRequirements
2220

23-
def stringConfig(key: String, default: String): String = {
24-
val propertyKey = key.replace('_', '.').toLowerCase
25-
sys.props.get(propertyKey).orElse(sys.env.get(key)).getOrElse(default)
26-
}
27-
2821
override lazy val projectSettings: Seq[Def.Setting[?]] =
2922
Seq(
3023
testShard := stringConfig("TEST_SHARD", "0").toInt,
@@ -50,24 +43,19 @@ object TestShardsPlugin extends AutoPlugin {
5043
val shardContext = ShardContext(testShard.value, testShardCount.value, sLog.value)
5144
val logger = shardContext.logger
5245
val algorithm = shardingAlgorithm.value
53-
54-
// TODO:: Make path customizable
55-
val fullTestReport = JUnitReportParser.parseDirectoriesRecursively(
56-
Seq(Paths.get(s"test-reports/main/resources/test-reports", moduleName.value))
57-
)
58-
46+
val priorReport = algorithm.priorReport.getOrElse(FullTestReport.empty)
5947
val sbtSuiteNames = (Test / definedTestNames).value.toSet
60-
val missingSuiteNames = sbtSuiteNames diff fullTestReport.testReports.map(_.name).toSet
48+
val missingSuiteNames = sbtSuiteNames diff priorReport.testReports.map(_.name).toSet
6149

62-
val results = fullTestReport.testReports.map { suiteReport =>
50+
val results = priorReport.testReports.map { suiteReport =>
6351
val shardResult = algorithm.check(suiteReport.name, shardContext)
6452

6553
shardResult.testShard -> suiteReport
6654
}.collect { case (Some(shard), report) => shard -> report }
6755
.groupBy(_._1)
6856

6957
results.toSeq.sortBy(_._1).foreach { case (k, v) =>
70-
val totalTime = v.map(_._2.timeTaken).sum
58+
val totalTime = BigDecimal(v.map(_._2.timeTaken).sum).setScale(3, BigDecimal.RoundingMode.HALF_UP)
7159

7260
logger.info(s"[${moduleName.value}] Shard $k expected to take $totalTime s")
7361

@@ -76,7 +64,7 @@ object TestShardsPlugin extends AutoPlugin {
7664
}
7765
}
7866

79-
if(missingSuiteNames.nonEmpty) {
67+
if (missingSuiteNames.nonEmpty) {
8068
logger.warn(s"Detected ${missingSuiteNames.size} suites that don't have a test report")
8169

8270
missingSuiteNames.foreach { s =>
@@ -86,4 +74,9 @@ object TestShardsPlugin extends AutoPlugin {
8674
}
8775
)
8876

77+
private def stringConfig(key: String, default: String): String = {
78+
val propertyKey = key.replace('_', '.').toLowerCase
79+
sys.props.get(propertyKey).orElse(sys.env.get(key)).getOrElse(default)
80+
}
81+
8982
}

src/main/scala/sbttestshards/parsers/JUnitReportParser.scala

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package sbttestshards.parsers
22

3-
import java.nio.file.{Files, Path, Paths}
3+
import java.nio.file.Files
4+
import java.nio.file.Path
45
import scala.jdk.CollectionConverters.*
56
import scala.xml.XML
67

@@ -18,6 +19,10 @@ final case class FullTestReport(testReports: Seq[SuiteReport]) {
1819
def ++(other: FullTestReport): FullTestReport = FullTestReport(testReports ++ other.testReports)
1920
}
2021

22+
object FullTestReport {
23+
def empty: FullTestReport = FullTestReport(Seq.empty)
24+
}
25+
2126
final case class SuiteReport(
2227
name: String,
2328
testCount: Int,
@@ -83,9 +88,8 @@ object JUnitReportParser {
8388
val testName = (node \@ "name").trim
8489

8590
Some(testName)
86-
} else {
91+
} else
8792
None
88-
}
8993
}.collect { case Some(testName) =>
9094
testName
9195
}

0 commit comments

Comments
 (0)