diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/util/DatasetRefCache.scala b/spark/src/main/scala/org/apache/spark/sql/delta/util/DatasetRefCache.scala index 3d9bdbdbeb..879a101bb3 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/util/DatasetRefCache.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/util/DatasetRefCache.scala @@ -42,10 +42,12 @@ import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} * * @param creator a function to create [[Dataset]]. */ -class DatasetRefCache[T](creator: () => Dataset[T]) { +class DatasetRefCache[T] private[util](creator: () => Dataset[T]) { private val holder = new AtomicReference[Dataset[T]] + private[delta] def invalidate() = holder.set(null) + def get: Dataset[T] = Option(holder.get()) .filter(_.sparkSession eq SparkSession.active) .getOrElse { @@ -54,4 +56,3 @@ class DatasetRefCache[T](creator: () => Dataset[T]) { df } } - diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/util/StateCache.scala b/spark/src/main/scala/org/apache/spark/sql/delta/util/StateCache.scala index 391cf56710..86784bfde3 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/util/StateCache.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/util/StateCache.scala @@ -40,6 +40,7 @@ trait StateCache extends DeltaLogging { private var _isCached = true /** A list of RDDs that we need to uncache when we are done with this snapshot. */ private val cached = ArrayBuffer[RDD[_]]() + private val cached_refs = ArrayBuffer[DatasetRefCache[_]]() /** Method to expose the value of _isCached for testing. */ private[delta] def isCached: Boolean = _isCached @@ -47,7 +48,7 @@ trait StateCache extends DeltaLogging { private val storageLevel = StorageLevel.fromString( spark.sessionState.conf.getConf(DeltaSQLConf.DELTA_SNAPSHOT_CACHE_STORAGE_LEVEL)) - class CachedDS[A](ds: Dataset[A], name: String) { + class CachedDS[A] private[StateCache](ds: Dataset[A], name: String) { // While we cache RDD to avoid re-computation in different spark sessions, `Dataset` can only be // reused by the session that created it to avoid session pollution. So we use `DatasetRefCache` // to re-create a new `Dataset` when the active session is changed. This is an optimization for @@ -64,10 +65,10 @@ trait StateCache extends DeltaLogging { rdd.persist(storageLevel) } cached += rdd - val dsCache = new DatasetRefCache(() => { + val dsCache = datasetRefCache { () => val logicalRdd = LogicalRDD(qe.analyzed.output, rdd)(spark) Dataset.ofRows(spark, logicalRdd) - }) + } Some(dsCache) } else { None @@ -110,11 +111,18 @@ trait StateCache extends DeltaLogging { new CachedDS[A](ds, name) } + def datasetRefCache[A](creator: () => Dataset[A]): DatasetRefCache[A] = { + val dsCache = new DatasetRefCache(creator) + cached_refs += dsCache + dsCache + } + /** Drop any cached data for this [[Snapshot]]. */ def uncache(): Unit = cached.synchronized { if (isCached) { _isCached = false cached.foreach(_.unpersist(blocking = false)) + cached_refs.foreach(_.invalidate()) } } }