Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Spark] Pass sparkSession to commitOwnerBuilder #3112

Merged
merged 1 commit into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1522,7 +1522,7 @@ trait OptimisticTransactionImpl extends TransactionalWrite
var newManagedCommitTableConf: Option[Map[String, String]] = None
if (finalMetadata.configuration != snapshot.metadata.configuration || snapshot.version == -1L) {
val newCommitOwnerClientOpt =
ManagedCommitUtils.getCommitOwnerClient(finalMetadata, finalProtocol)
ManagedCommitUtils.getCommitOwnerClient(spark, finalMetadata, finalProtocol)
(newCommitOwnerClientOpt, readSnapshotTableCommitOwnerClientOpt) match {
case (Some(newCommitOwnerClient), None) =>
// FS -> MC conversion
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ class Snapshot(
*/
val tableCommitOwnerClientOpt: Option[TableCommitOwnerClient] = initializeTableCommitOwner()
protected def initializeTableCommitOwner(): Option[TableCommitOwnerClient] = {
ManagedCommitUtils.getTableCommitOwner(this)
ManagedCommitUtils.getTableCommitOwner(spark, this)
}

/** Number of columns to collect stats on for data skipping */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import org.apache.spark.sql.delta.storage.LogStore
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}

import org.apache.spark.sql.SparkSession

/** Representation of a commit file */
case class Commit(
private val version: Long,
Expand Down Expand Up @@ -199,7 +201,7 @@ trait CommitOwnerBuilder {
def getName: String

/** Returns a commit-owner client based on the given conf */
def build(conf: Map[String, String]): CommitOwnerClient
def build(spark: SparkSession, conf: Map[String, String]): CommitOwnerClient
}

/** Factory to get the correct [[CommitOwnerClient]] for a table */
Expand All @@ -218,10 +220,12 @@ object CommitOwnerProvider {
}
}

/** Returns a [[CommitOwnerClient]] for the given `name` and `conf` */
/** Returns a [[CommitOwnerClient]] for the given `name`, `conf`, and `spark` */
def getCommitOwnerClient(
name: String, conf: Map[String, String]): CommitOwnerClient = synchronized {
nameToBuilderMapping.get(name).map(_.build(conf)).getOrElse {
name: String,
conf: Map[String, String],
spark: SparkSession): CommitOwnerClient = synchronized {
nameToBuilderMapping.get(name).map(_.build(spark, conf)).getOrElse {
throw new IllegalArgumentException(s"Unknown commit-owner: $name")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import org.apache.spark.sql.delta.storage.LogStore
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}

import org.apache.spark.sql.SparkSession

class InMemoryCommitOwner(val batchSize: Long)
extends AbstractBatchBackfillingCommitOwnerClient {

Expand Down Expand Up @@ -206,7 +208,7 @@ case class InMemoryCommitOwnerBuilder(batchSize: Long) extends CommitOwnerBuilde
def getName: String = "in-memory"

/** Returns a commit-owner based on the given conf */
def build(conf: Map[String, String]): CommitOwnerClient = {
def build(spark: SparkSession, conf: Map[String, String]): CommitOwnerClient = {
inMemoryStore
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ import org.apache.spark.sql.delta.util.FileNames.{DeltaFile, UnbackfilledDeltaFi
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}

import org.apache.spark.sql.SparkSession

object ManagedCommitUtils extends DeltaLogging {

/**
Expand Down Expand Up @@ -111,16 +113,19 @@ object ManagedCommitUtils extends DeltaLogging {
*/
def getTablePath(logPath: Path): Path = logPath.getParent

def getCommitOwnerClient(metadata: Metadata, protocol: Protocol): Option[CommitOwnerClient] = {
def getCommitOwnerClient(
spark: SparkSession, metadata: Metadata, protocol: Protocol): Option[CommitOwnerClient] = {
metadata.managedCommitOwnerName.map { commitOwnerStr =>
assert(protocol.isFeatureSupported(ManagedCommitTableFeature))
CommitOwnerProvider.getCommitOwnerClient(commitOwnerStr, metadata.managedCommitOwnerConf)
CommitOwnerProvider.getCommitOwnerClient(
commitOwnerStr, metadata.managedCommitOwnerConf, spark)
}
}

def getTableCommitOwner(
spark: SparkSession,
snapshotDescriptor: SnapshotDescriptor): Option[TableCommitOwnerClient] = {
getCommitOwnerClient(snapshotDescriptor.metadata, snapshotDescriptor.protocol).map {
getCommitOwnerClient(spark, snapshotDescriptor.metadata, snapshotDescriptor.protocol).map {
commitOwner =>
TableCommitOwnerClient(
commitOwner,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import org.scalatest.Tag

import org.apache.spark.{DebugFilesystem, SparkException, TaskFailedReason}
import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, QueryTest, Row}
import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, QueryTest, Row, SparkSession}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.NoSuchDatabaseException
import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ class DeltaLogSuite extends QueryTest
// For Managed Commit table with a commit that is not backfilled, we can't use
// 00000000002.json yet. Contact commit store to get uuid file path to malform json file.
val oc = CommitOwnerProvider.getCommitOwnerClient(
"tracking-in-memory", Map.empty[String, String])
"tracking-in-memory", Map.empty[String, String], spark)
val commitResponse = oc.getCommits(deltaLog.logPath, Map.empty, Some(2))
if (!commitResponse.getCommits.isEmpty) {
val path = commitResponse.getCommits.last.getFileStatus.getPath
Expand Down Expand Up @@ -602,7 +602,7 @@ class DeltaLogSuite extends QueryTest
// For Managed Commit table with a commit that is not backfilled, we can't use
// 00000000001.json yet. Contact commit store to get uuid file path to malform json file.
val oc = CommitOwnerProvider.getCommitOwnerClient(
"tracking-in-memory", Map.empty[String, String])
"tracking-in-memory", Map.empty[String, String], spark)
val commitResponse = oc.getCommits(log.logPath, Map.empty, Some(1))
if (!commitResponse.getCommits.isEmpty) {
commitFilePath = commitResponse.getCommits.head.getFileStatus.getPath
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ import org.apache.spark.sql.delta.util.{FileNames, JsonUtils}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path

import org.apache.spark.sql.Row
import org.apache.spark.sql.SaveMode
import org.apache.spark.sql.{Row, SaveMode, SparkSession}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.{EqualTo, Literal}
import org.apache.spark.sql.functions.lit
Expand Down Expand Up @@ -520,7 +519,8 @@ class OptimisticTransactionSuite
}
}
}
override def build(conf: Map[String, String]): CommitOwnerClient = commitOwnerClient
override def build(
spark: SparkSession, conf: Map[String, String]): CommitOwnerClient = commitOwnerClient
}

CommitOwnerProvider.registerBuilder(RetryableNonConflictCommitOwnerBuilder$)
Expand Down Expand Up @@ -569,7 +569,8 @@ class OptimisticTransactionSuite
}
}
}
override def build(conf: Map[String, String]): CommitOwnerClient = commitOwnerClient
override def build(
spark: SparkSession, conf: Map[String, String]): CommitOwnerClient = commitOwnerClient
}

CommitOwnerProvider.registerBuilder(FileAlreadyExistsCommitOwnerBuilder)
Expand Down Expand Up @@ -878,7 +879,8 @@ class OptimisticTransactionSuite
object RetryableConflictCommitOwnerBuilder$ extends CommitOwnerBuilder {
lazy val commitOwnerClient = new RetryableConflictCommitOwnerClient()
override def getName: String = commitOwnerName
override def build(conf: Map[String, String]): CommitOwnerClient = commitOwnerClient
override def build(
spark: SparkSession, conf: Map[String, String]): CommitOwnerClient = commitOwnerClient
}
CommitOwnerProvider.registerBuilder(RetryableConflictCommitOwnerBuilder$)
val conf = Map(DeltaConfigs.MANAGED_COMMIT_OWNER_NAME.key -> commitOwnerName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.SparkConf
import org.apache.spark.SparkException
import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.storage.StorageLevel

Expand Down Expand Up @@ -587,7 +588,7 @@ object ConcurrentBackfillCommitOwnerBuilder extends CommitOwnerBuilder {
private lazy val concurrentBackfillCommitOwnerClient =
ConcurrentBackfillCommitOwnerClient(synchronousBackfillThreshold = 2, batchSize)
override def getName: String = "awaiting-commit-owner"
override def build(conf: Map[String, String]): CommitOwnerClient = {
override def build(spark: SparkSession, conf: Map[String, String]): CommitOwnerClient = {
concurrentBackfillCommitOwnerClient
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.sql.delta.test.DeltaSQLTestUtils
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path

import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.{QueryTest, SparkSession}
import org.apache.spark.sql.test.SharedSparkSession

class CommitOwnerClientSuite extends QueryTest with DeltaSQLTestUtils with SharedSparkSession
Expand Down Expand Up @@ -72,15 +72,15 @@ class CommitOwnerClientSuite extends QueryTest with DeltaSQLTestUtils with Share

test("registering multiple commit-owner builders with same name") {
object Builder1 extends CommitOwnerBuilder {
override def build(conf: Map[String, String]): CommitOwnerClient = null
override def build(spark: SparkSession, conf: Map[String, String]): CommitOwnerClient = null
override def getName: String = "builder-1"
}
object BuilderWithSameName extends CommitOwnerBuilder {
override def build(conf: Map[String, String]): CommitOwnerClient = null
override def build(spark: SparkSession, conf: Map[String, String]): CommitOwnerClient = null
override def getName: String = "builder-1"
}
object Builder3 extends CommitOwnerBuilder {
override def build(conf: Map[String, String]): CommitOwnerClient = null
override def build(spark: SparkSession, conf: Map[String, String]): CommitOwnerClient = null
override def getName: String = "builder-3"
}
CommitOwnerProvider.registerBuilder(Builder1)
Expand All @@ -94,7 +94,7 @@ class CommitOwnerClientSuite extends QueryTest with DeltaSQLTestUtils with Share
object Builder1 extends CommitOwnerBuilder {
val cs1 = new TestCommitOwnerClient1()
val cs2 = new TestCommitOwnerClient2()
override def build(conf: Map[String, String]): CommitOwnerClient = {
override def build(spark: SparkSession, conf: Map[String, String]): CommitOwnerClient = {
conf.getOrElse("url", "") match {
case "url1" => cs1
case "url2" => cs2
Expand All @@ -104,21 +104,22 @@ class CommitOwnerClientSuite extends QueryTest with DeltaSQLTestUtils with Share
override def getName: String = "cs-x"
}
CommitOwnerProvider.registerBuilder(Builder1)
val cs1 = CommitOwnerProvider.getCommitOwnerClient("cs-x", Map("url" -> "url1"))
val cs1 = CommitOwnerProvider.getCommitOwnerClient("cs-x", Map("url" -> "url1"), spark)
assert(cs1.isInstanceOf[TestCommitOwnerClient1])
val cs1_again = CommitOwnerProvider.getCommitOwnerClient("cs-x", Map("url" -> "url1"))
val cs1_again = CommitOwnerProvider.getCommitOwnerClient("cs-x", Map("url" -> "url1"), spark)
assert(cs1 eq cs1_again)
val cs2 = CommitOwnerProvider.getCommitOwnerClient("cs-x", Map("url" -> "url2", "a" -> "b"))
val cs2 =
CommitOwnerProvider.getCommitOwnerClient("cs-x", Map("url" -> "url2", "a" -> "b"), spark)
assert(cs2.isInstanceOf[TestCommitOwnerClient2])
// If builder receives a config which doesn't have expected params, then it can throw exception.
intercept[IllegalArgumentException] {
CommitOwnerProvider.getCommitOwnerClient("cs-x", Map("url" -> "url3"))
CommitOwnerProvider.getCommitOwnerClient("cs-x", Map("url" -> "url3"), spark)
}
}

test("getCommitOwnerClient - builder returns new object each time") {
object Builder1 extends CommitOwnerBuilder {
override def build(conf: Map[String, String]): CommitOwnerClient = {
override def build(spark: SparkSession, conf: Map[String, String]): CommitOwnerClient = {
conf.getOrElse("url", "") match {
case "url1" => new TestCommitOwnerClient1()
case _ => throw new IllegalArgumentException("Invalid url")
Expand All @@ -127,9 +128,9 @@ class CommitOwnerClientSuite extends QueryTest with DeltaSQLTestUtils with Share
override def getName: String = "cs-name"
}
CommitOwnerProvider.registerBuilder(Builder1)
val cs1 = CommitOwnerProvider.getCommitOwnerClient("cs-name", Map("url" -> "url1"))
val cs1 = CommitOwnerProvider.getCommitOwnerClient("cs-name", Map("url" -> "url1"), spark)
assert(cs1.isInstanceOf[TestCommitOwnerClient1])
val cs1_again = CommitOwnerProvider.getCommitOwnerClient("cs-name", Map("url" -> "url1"))
val cs1_again = CommitOwnerProvider.getCommitOwnerClient("cs-name", Map("url" -> "url1"), spark)
assert(cs1 ne cs1_again)
}

Expand Down Expand Up @@ -202,21 +203,21 @@ class CommitOwnerClientSuite extends QueryTest with DeltaSQLTestUtils with Share
other.asInstanceOf[TestCommitOwnerClient].key == key
}
object Builder1 extends CommitOwnerBuilder {
override def build(conf: Map[String, String]): CommitOwnerClient = {
override def build(spark: SparkSession, conf: Map[String, String]): CommitOwnerClient = {
new TestCommitOwnerClient(conf("key"))
}
override def getName: String = "cs-name"
}
CommitOwnerProvider.registerBuilder(Builder1)

// Different CommitOwner with same keys should be semantically equal.
val obj1 = CommitOwnerProvider.getCommitOwnerClient("cs-name", Map("key" -> "url1"))
val obj2 = CommitOwnerProvider.getCommitOwnerClient("cs-name", Map("key" -> "url1"))
val obj1 = CommitOwnerProvider.getCommitOwnerClient("cs-name", Map("key" -> "url1"), spark)
val obj2 = CommitOwnerProvider.getCommitOwnerClient("cs-name", Map("key" -> "url1"), spark)
assert(obj1 != obj2)
assert(obj1.semanticEquals(obj2))

// Different CommitOwner with different keys should be semantically unequal.
val obj3 = CommitOwnerProvider.getCommitOwnerClient("cs-name", Map("key" -> "url2"))
val obj3 = CommitOwnerProvider.getCommitOwnerClient("cs-name", Map("key" -> "url2"), spark)
assert(obj1 != obj3)
assert(!obj1.semanticEquals(obj3))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ abstract class InMemoryCommitOwnerSuite(batchSize: Int) extends CommitOwnerClien

override protected def createTableCommitOwnerClient(
deltaLog: DeltaLog): TableCommitOwnerClient = {
val cs = InMemoryCommitOwnerBuilder(batchSize).build(Map.empty)
val cs = InMemoryCommitOwnerBuilder(batchSize).build(spark, Map.empty)
TableCommitOwnerClient(cs, deltaLog, Map.empty[String, String])
}

Expand Down Expand Up @@ -65,22 +65,22 @@ abstract class InMemoryCommitOwnerSuite(batchSize: Int) extends CommitOwnerClien

test("InMemoryCommitOwnerBuilder works as expected") {
val builder1 = InMemoryCommitOwnerBuilder(5)
val cs1 = builder1.build(Map.empty)
val cs1 = builder1.build(spark, Map.empty)
assert(cs1.isInstanceOf[InMemoryCommitOwner])
assert(cs1.asInstanceOf[InMemoryCommitOwner].batchSize == 5)

val cs1_again = builder1.build(Map.empty)
val cs1_again = builder1.build(spark, Map.empty)
assert(cs1_again.isInstanceOf[InMemoryCommitOwner])
assert(cs1 == cs1_again)

val builder2 = InMemoryCommitOwnerBuilder(10)
val cs2 = builder2.build(Map.empty)
val cs2 = builder2.build(spark, Map.empty)
assert(cs2.isInstanceOf[InMemoryCommitOwner])
assert(cs2.asInstanceOf[InMemoryCommitOwner].batchSize == 10)
assert(cs2 ne cs1)

val builder3 = InMemoryCommitOwnerBuilder(10)
val cs3 = builder3.build(Map.empty)
val cs3 = builder3.build(spark, Map.empty)
assert(cs3.isInstanceOf[InMemoryCommitOwner])
assert(cs3.asInstanceOf[InMemoryCommitOwner].batchSize == 10)
assert(cs3 ne cs2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}

import org.apache.spark.SparkConf
import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.{QueryTest, Row, SparkSession}
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.util.ManualClock

Expand Down Expand Up @@ -71,7 +71,7 @@ class ManagedCommitSuite

override def getName: String = commitOwnerName

override def build(conf: Map[String, String]): CommitOwnerClient =
override def build(spark: SparkSession, conf: Map[String, String]): CommitOwnerClient =
new InMemoryCommitOwner(batchSize = 5) {
override def commit(
logStore: LogStore,
Expand Down Expand Up @@ -125,7 +125,7 @@ class ManagedCommitSuite

test("cold snapshot initialization") {
val builder = TrackingInMemoryCommitOwnerBuilder(batchSize = 10)
val commitOwnerClient = builder.build(Map.empty).asInstanceOf[TrackingCommitOwnerClient]
val commitOwnerClient = builder.build(spark, Map.empty).asInstanceOf[TrackingCommitOwnerClient]
CommitOwnerProvider.registerBuilder(builder)
withTempDir { tempDir =>
val tablePath = tempDir.getAbsolutePath
Expand Down Expand Up @@ -221,7 +221,7 @@ class ManagedCommitSuite
name: String,
commitOwnerClient: CommitOwnerClient) extends CommitOwnerBuilder {
var numBuildCalled = 0
override def build(conf: Map[String, String]): CommitOwnerClient = {
override def build(spark: SparkSession, conf: Map[String, String]): CommitOwnerClient = {
numBuildCalled += 1
commitOwnerClient
}
Expand Down Expand Up @@ -361,7 +361,8 @@ class ManagedCommitSuite
case class TrackingInMemoryCommitOwnerClientBuilder(
name: String,
commitOwnerClient: CommitOwnerClient) extends CommitOwnerBuilder {
override def build(conf: Map[String, String]): CommitOwnerClient = commitOwnerClient
override def build(
spark: SparkSession, conf: Map[String, String]): CommitOwnerClient = commitOwnerClient
override def getName: String = name
}
val builder1 = TrackingInMemoryCommitOwnerClientBuilder(name = "in-memory-1", cs1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.test.SharedSparkSession

trait ManagedCommitTestUtils
Expand Down Expand Up @@ -116,7 +117,7 @@ case class TrackingInMemoryCommitOwnerBuilder(
}

override def getName: String = "tracking-in-memory"
override def build(conf: Map[String, String]): CommitOwnerClient = {
override def build(spark: SparkSession, conf: Map[String, String]): CommitOwnerClient = {
trackingInMemoryCommitOwnerClient
}
}
Expand Down
Loading