diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml
new file mode 100644
index 0000000..55629e8
--- /dev/null
+++ b/.github/workflows/ci.yaml
@@ -0,0 +1,50 @@
+name: Java CI
+
+on: [push]
+
+permissions:
+ id-token: write
+ contents: read
+ checks: write
+ pull-requests: write
+
+concurrency:
+ group: ${{ github.ref }}
+ cancel-in-progress: ${{ !contains(github.ref, 'main') }}
+
+jobs:
+ build-and-run-unit-tests:
+ runs-on: arc-4-cores-ondemand-staging-arm
+ timeout-minutes: 30
+
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v4
+ with:
+ fetch-depth: 0
+
+ - name: Set up JDK 11
+ uses: actions/setup-java@v4
+ with:
+ java-version: '11'
+ distribution: 'temurin'
+ cache: maven
+
+ - name: Set up Maven
+ run: |
+ sudo apt-get update
+ sudo apt-get install -y maven
+ mvn --version
+
+ - name: Build and Test Maven Artifacts
+ run: mvn -B -ntp clean package
+
+ - name: Publish Unit Test Results
+ uses: dorny/test-reporter@v2
+ if: success() || failure()
+ with:
+ name: Unit Test Results
+ path: 'target/surefire-reports/TEST-*.xml'
+ reporter: java-junit
+ fail-on-error: false
+ fail-on-empty: false
diff --git a/pom.xml b/pom.xml
index 4463f7d..17c5142 100644
--- a/pom.xml
+++ b/pom.xml
@@ -134,6 +134,35 @@
+
+ org.apache.maven.plugins
+ maven-surefire-plugin
+ ${maven-surefire-plugin.version}
+
+
+ **/*Test.class
+ **/*Suite.class
+
+
+
+
+ org.scalatest
+ scalatest-maven-plugin
+ 2.2.0
+
+ ${project.build.directory}/surefire-reports
+ .
+ TestSuite.txt
+
+
+
+ test
+
+ test
+
+
+
+
@@ -205,5 +234,23 @@
jcommander
1.72
+
+ org.scalatest
+ scalatest_2.12
+ 3.2.15
+ test
+
+
+ com.fasterxml.jackson.module
+ jackson-module-scala_2.12
+ 2.15.2
+ test
+
+
+ com.fasterxml.jackson.core
+ jackson-databind
+ 2.15.2
+ test
+
diff --git a/src/main/scala/ai/onehouse/lakeloader/ChangeDataGenerator.scala b/src/main/scala/ai/onehouse/lakeloader/ChangeDataGenerator.scala
index badef94..76ef21c 100644
--- a/src/main/scala/ai/onehouse/lakeloader/ChangeDataGenerator.scala
+++ b/src/main/scala/ai/onehouse/lakeloader/ChangeDataGenerator.scala
@@ -452,17 +452,11 @@ object ChangeDataGenerator {
}
}
- private def genParallelRDD(
+ private[lakeloader] def genParallelRDD(
spark: SparkSession,
targetParallelism: Int,
start: Long,
end: Long): RDD[Long] = {
- val partitionSize = (end - start) / targetParallelism
- spark.sparkContext
- .parallelize(0 to targetParallelism, targetParallelism)
- .mapPartitions { it =>
- val partitionStart = it.next() * partitionSize
- (partitionStart to partitionStart + partitionSize).iterator
- }
+ spark.sparkContext.range(start, end, numSlices = targetParallelism)
}
}
diff --git a/src/test/scala/ai/onehouse/lakeloader/UnitTestSuite.scala b/src/test/scala/ai/onehouse/lakeloader/UnitTestSuite.scala
new file mode 100644
index 0000000..ea4b017
--- /dev/null
+++ b/src/test/scala/ai/onehouse/lakeloader/UnitTestSuite.scala
@@ -0,0 +1,132 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package ai.onehouse.lakeloader
+
+import org.apache.spark.sql.SparkSession
+import org.scalatest.BeforeAndAfterAll
+import org.scalatest.funsuite.AnyFunSuite
+import org.scalatest.matchers.should.Matchers
+
+class UnitTestSuite extends AnyFunSuite with Matchers with BeforeAndAfterAll {
+
+ var spark: SparkSession = _
+
+ override def beforeAll(): Unit = {
+ spark = SparkSession
+ .builder()
+ .appName("UnitTestSuite")
+ .master("local[*]")
+ .config("spark.sql.shuffle.partitions", "4")
+ .config("spark.default.parallelism", "4")
+ .getOrCreate()
+ }
+
+ override def afterAll(): Unit = {
+ if (spark != null) {
+ spark.stop()
+ }
+ }
+
+ test("genParallelRDD should generate expected number of elements for range 0-100") {
+ val end = 100L
+ val targetParallelism = 4
+ val rdd = ChangeDataGenerator.genParallelRDD(spark, targetParallelism, 0, end)
+ val collected = rdd.collect()
+
+ collected.length shouldBe end
+ collected.distinct.length shouldBe collected.length
+ collected.sorted shouldBe (0L until end).toArray
+
+ // Each partition should have data
+ val partitionSizes = rdd.mapPartitions(iter => Iterator(iter.size)).collect()
+ partitionSizes.foreach(_ should be > 0)
+ }
+
+ test("genParallelRDD should work with small ranges") {
+ val end = 10L
+ val targetParallelism = 2
+ val rdd = ChangeDataGenerator.genParallelRDD(spark, targetParallelism, 0, end)
+ val collected = rdd.collect()
+
+ collected.length shouldBe end
+ rdd.getNumPartitions shouldBe targetParallelism
+ collected.distinct.length shouldBe collected.length
+ collected.sorted shouldBe (0L until end).toArray
+
+ // Each partition should have data
+ val partitionSizes = rdd.mapPartitions(iter => Iterator(iter.size)).collect()
+ partitionSizes.foreach(_ should be > 0)
+ }
+
+ test("genParallelRDD should work with large ranges") {
+ val end = 10000L
+ val targetParallelism = 8
+ val rdd = ChangeDataGenerator.genParallelRDD(spark, targetParallelism, 0, end)
+ val collected = rdd.collect()
+
+ collected.length shouldBe end
+ rdd.getNumPartitions shouldBe targetParallelism
+ collected.distinct.length shouldBe collected.length
+ collected.sorted shouldBe (0L until end).toArray
+
+ // Each partition should have data
+ val partitionSizes = rdd.mapPartitions(iter => Iterator(iter.size)).collect()
+ partitionSizes.foreach(_ should be > 0)
+ }
+
+ test("genParallelRDD should generate exactly targetParallelism partitions") {
+ val targetParallelism = 5
+ val end = 100L
+ val rdd = ChangeDataGenerator.genParallelRDD(spark, targetParallelism, 0, end)
+
+ // Should have exactly targetParallelism partitions
+ rdd.getNumPartitions shouldBe targetParallelism
+ }
+
+ test("genParallelRDD should handle non-divisible end values correctly") {
+ val end = 97L
+ val targetParallelism = 7
+ val rdd = ChangeDataGenerator.genParallelRDD(spark, targetParallelism, 0, end)
+ val collected = rdd.collect().sorted
+
+ collected.length shouldBe end
+ collected shouldBe (0L until end).toArray
+ collected.distinct.length shouldBe collected.length
+ val partitionSizes = rdd.mapPartitions(iter => Iterator(iter.size)).collect()
+ partitionSizes.foreach(_ should be > 0)
+ }
+
+ test("genParallelRDD should work with non-zero start and non-divisible range") {
+ val start = 50L
+ val end = 147L // 97 elements total
+ val targetParallelism = 7
+ val rdd = ChangeDataGenerator.genParallelRDD(spark, targetParallelism, start, end)
+ val collected = rdd.collect().sorted
+
+ // Verify correctness of collected
+ val expectedCount = end - start
+ collected.length shouldBe expectedCount
+ collected shouldBe (start until end).toArray
+ collected.distinct.length shouldBe collected.length
+
+ // Check partition size distribution
+ // First partition should have 13 elements and last 6 partitions should have 14 elements
+ rdd.getNumPartitions shouldBe targetParallelism
+ val partitionSizes = rdd.mapPartitions(iter => Iterator(iter.size)).collect()
+ partitionSizes.length shouldBe targetParallelism
+ val expectedSizes = Array(13, 14, 14, 14, 14, 14, 14)
+ partitionSizes shouldBe expectedSizes
+ }
+}