Skip to content

Commit

Permalink
[SPARK] Use StateCache to manage DatasetRefCache instances (#3682)
Browse files Browse the repository at this point in the history
## Description

`DatasetRefCache` instances are currently untracked, making it hard to
discard or invalidate them when no longer needed. We start to address
that issue by using `StateCache` to track them, so that
`Snapshot.uncache()` can clean them up -- the same way `CachedDS`
instances are already tracked and cleaned up.

## How was this patch tested?
Existing unit tests exercise `Snapshot.uncache` path.

## Does this PR introduce _any_ user-facing changes?
No

Co-authored-by: Ryan Johnson <ryan.johnson@databricks.com>
  • Loading branch information
scovich and ryan-johnson-databricks committed Sep 20, 2024
1 parent 80457b0 commit 2547e91
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,12 @@ import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
*
* @param creator a function to create [[Dataset]].
*/
class DatasetRefCache[T](creator: () => Dataset[T]) {
class DatasetRefCache[T] private[util](creator: () => Dataset[T]) {

private val holder = new AtomicReference[Dataset[T]]

private[delta] def invalidate() = holder.set(null)

def get: Dataset[T] = Option(holder.get())
.filter(_.sparkSession eq SparkSession.active)
.getOrElse {
Expand All @@ -54,4 +56,3 @@ class DatasetRefCache[T](creator: () => Dataset[T]) {
df
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,15 @@ trait StateCache extends DeltaLogging {
private var _isCached = true
/** A list of RDDs that we need to uncache when we are done with this snapshot. */
private val cached = ArrayBuffer[RDD[_]]()
private val cached_refs = ArrayBuffer[DatasetRefCache[_]]()

/** Method to expose the value of _isCached for testing. */
private[delta] def isCached: Boolean = _isCached

private val storageLevel = StorageLevel.fromString(
spark.sessionState.conf.getConf(DeltaSQLConf.DELTA_SNAPSHOT_CACHE_STORAGE_LEVEL))

class CachedDS[A](ds: Dataset[A], name: String) {
class CachedDS[A] private[StateCache](ds: Dataset[A], name: String) {
// While we cache RDD to avoid re-computation in different spark sessions, `Dataset` can only be
// reused by the session that created it to avoid session pollution. So we use `DatasetRefCache`
// to re-create a new `Dataset` when the active session is changed. This is an optimization for
Expand All @@ -64,10 +65,10 @@ trait StateCache extends DeltaLogging {
rdd.persist(storageLevel)
}
cached += rdd
val dsCache = new DatasetRefCache(() => {
val dsCache = datasetRefCache { () =>
val logicalRdd = LogicalRDD(qe.analyzed.output, rdd)(spark)
Dataset.ofRows(spark, logicalRdd)
})
}
Some(dsCache)
} else {
None
Expand Down Expand Up @@ -110,11 +111,18 @@ trait StateCache extends DeltaLogging {
new CachedDS[A](ds, name)
}

def datasetRefCache[A](creator: () => Dataset[A]): DatasetRefCache[A] = {
val dsCache = new DatasetRefCache(creator)
cached_refs += dsCache
dsCache
}

/** Drop any cached data for this [[Snapshot]]. */
def uncache(): Unit = cached.synchronized {
if (isCached) {
_isCached = false
cached.foreach(_.unpersist(blocking = false))
cached_refs.foreach(_.invalidate())
}
}
}

0 comments on commit 2547e91

Please sign in to comment.