@@ -23,6 +23,7 @@ import scala.collection.JavaConverters
2323import scala .concurrent .{Future , Promise }
2424import scala .concurrent .duration .Duration
2525
26+ import org .apache .spark .sql .catalyst .plans .logical .CollectMetrics
2627import org .apache .spark .sql .execution .QueryExecution
2728import org .apache .spark .sql .util .QueryExecutionListener
2829import org .apache .spark .util .SparkThreadUtils
@@ -61,6 +62,8 @@ class Observation(val name: String) {
6162
6263 @ volatile private var sparkSession : Option [SparkSession ] = None
6364
65+ @ volatile private var dataframeId : Option [Long ] = None
66+
6467 private val promise = Promise [Row ]()
6568
6669 /**
@@ -78,11 +81,11 @@ class Observation(val name: String) {
7881 * @return observed dataset
7982 * @throws IllegalArgumentException If this is a streaming Dataset (ds.isStreaming == true)
8083 */
81- private [spark] def on [T ](ds : Dataset [T ], expr : Column , exprs : Column * ): Dataset [T ] = {
84+ private [spark] def on [T ](ds : Dataset [T ], dataframeId : Long , expr : Column , exprs : Column * ): Dataset [T ] = {
8285 if (ds.isStreaming) {
8386 throw new IllegalArgumentException (" Observation does not support streaming Datasets" )
8487 }
85- register(ds.sparkSession)
88+ register(ds.sparkSession, dataframeId )
8689 ds.observe(name, expr, exprs : _* )
8790 }
8891
@@ -91,20 +94,31 @@ class Observation(val name: String) {
9194 * its first action. Only the result of the first action is available. Subsequent actions do not
9295 * modify the result.
9396 *
97+ *
98+ * Note that if no metrics were recorded, an empty map is probably returned. It possibly happens
99+ * when the operators used for observation are optimized away.
100+ *
94101 * @return the observed metrics as a `Map[String, Any]`
95102 * @throws InterruptedException interrupted while waiting
96103 */
97104 @ throws[InterruptedException ]
98105 def get : Map [String , _] = {
99- val row = getRow
100- row.getValuesMap(row.schema.map(_.name))
106+ val row = getRow]
107+ if (row == null || row.schema == null ) {
108+ Map .empty
109+ } else {
110+ row.getValuesMap(row.schema.map(_.name))
111+ }
101112 }
102113
103114 /**
104115 * (Java-specific) Get the observed metrics. This waits for the observed dataset to finish
105116 * its first action. Only the result of the first action is available. Subsequent actions do not
106117 * modify the result.
107118 *
119+ * Note that if no metrics were recorded, an empty map is probably returned. It possibly happens
120+ * when the operators used for observation are optimized away.
121+ *
108122 * @return the observed metrics as a `java.util.Map[String, Object]`
109123 * @throws InterruptedException interrupted while waiting
110124 */
@@ -115,7 +129,7 @@ class Observation(val name: String) {
115129 )
116130 }
117131
118- private def register (sparkSession : SparkSession ): Unit = {
132+ private def register (sparkSession : SparkSession , dataframeId : Long ): Unit = {
119133 // makes this class thread-safe:
120134 // only the first thread entering this block can set sparkSession
121135 // all other threads will see the exception, as it is only allowed to do this once
@@ -124,6 +138,7 @@ class Observation(val name: String) {
124138 throw new IllegalArgumentException (" An Observation can be used with a Dataset only once" )
125139 }
126140 this .sparkSession = Some (sparkSession)
141+ this .dataframeId = Some (dataframeId)
127142 }
128143
129144 sparkSession.listenerManager.register(this .listener)
@@ -134,9 +149,24 @@ class Observation(val name: String) {
134149 }
135150
136151 private [spark] def onFinish (qe : QueryExecution ): Unit = {
137- qe.observedMetrics.get(name).foreach { metrics =>
138- promise.trySuccess(metrics)
139- unregister()
152+ if (qe.logical.exists {
153+ case CollectMetrics (name, _, _, dataframeId) =>
154+ name == this .name && dataframeId == this .dataframeId.get
155+ case _ => false
156+ }) {
157+ val metrics = qe.observedMetrics.get(name)
158+ if (metrics.isEmpty) {
159+ // If the key exists but no metrics were collected, it means for some reason the metrics
160+ // could not be collected. This can happen e.g., if the CollectMetricsExec was optimized
161+ // away.
162+ promise.trySuccess(Row .empty)
163+ unregister()
164+ } else {
165+ metrics.foreach { metrics =>
166+ promise.trySuccess(metrics)
167+ unregister()
168+ }
169+ }
140170 }
141171 }
142172
0 commit comments