Skip to content

Commit 3ef18f4

Browse files
committed
fix
1 parent 09d7385 commit 3ef18f4

File tree

2 files changed

+39
-9
lines changed

2 files changed

+39
-9
lines changed

sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2216,7 +2216,7 @@ class Dataset[T] private[sql](
22162216
*/
22172217
@varargs
22182218
def observe(observation: Observation, expr: Column, exprs: Column*): Dataset[T] = {
2219-
observation.on(this, expr, exprs: _*)
2219+
observation.on(this, id, expr, exprs: _*)
22202220
}
22212221

22222222
/**

sql/core/src/main/scala/org/apache/spark/sql/Observation.scala

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import scala.collection.JavaConverters
2323
import scala.concurrent.{Future, Promise}
2424
import scala.concurrent.duration.Duration
2525

26+
import org.apache.spark.sql.catalyst.plans.logical.CollectMetrics
2627
import org.apache.spark.sql.execution.QueryExecution
2728
import org.apache.spark.sql.util.QueryExecutionListener
2829
import org.apache.spark.util.SparkThreadUtils
@@ -61,6 +62,8 @@ class Observation(val name: String) {
6162

6263
@volatile private var sparkSession: Option[SparkSession] = None
6364

65+
@volatile private var dataframeId: Option[Long] = None
66+
6467
private val promise = Promise[Row]()
6568

6669
/**
@@ -78,11 +81,11 @@ class Observation(val name: String) {
7881
* @return observed dataset
7982
* @throws IllegalArgumentException If this is a streaming Dataset (ds.isStreaming == true)
8083
*/
81-
private[spark] def on[T](ds: Dataset[T], expr: Column, exprs: Column*): Dataset[T] = {
84+
private[spark] def on[T](ds: Dataset[T], dataframeId: Long, expr: Column, exprs: Column*): Dataset[T] = {
8285
if (ds.isStreaming) {
8386
throw new IllegalArgumentException("Observation does not support streaming Datasets")
8487
}
85-
register(ds.sparkSession)
88+
register(ds.sparkSession, dataframeId)
8689
ds.observe(name, expr, exprs: _*)
8790
}
8891

@@ -91,20 +94,31 @@ class Observation(val name: String) {
9194
* its first action. Only the result of the first action is available. Subsequent actions do not
9295
* modify the result.
9396
*
97+
*
98+
* Note that if no metrics were recorded, an empty map is probably returned. It possibly happens
99+
* when the operators used for observation are optimized away.
100+
*
94101
* @return the observed metrics as a `Map[String, Any]`
95102
* @throws InterruptedException interrupted while waiting
96103
*/
97104
@throws[InterruptedException]
98105
def get: Map[String, _] = {
99-
val row = getRow
100-
row.getValuesMap(row.schema.map(_.name))
106+
val row = getRow]
107+
if (row == null || row.schema == null) {
108+
Map.empty
109+
} else {
110+
row.getValuesMap(row.schema.map(_.name))
111+
}
101112
}
102113

103114
/**
104115
* (Java-specific) Get the observed metrics. This waits for the observed dataset to finish
105116
* its first action. Only the result of the first action is available. Subsequent actions do not
106117
* modify the result.
107118
*
119+
* Note that if no metrics were recorded, an empty map is probably returned. It possibly happens
120+
* when the operators used for observation are optimized away.
121+
*
108122
* @return the observed metrics as a `java.util.Map[String, Object]`
109123
* @throws InterruptedException interrupted while waiting
110124
*/
@@ -115,7 +129,7 @@ class Observation(val name: String) {
115129
)
116130
}
117131

118-
private def register(sparkSession: SparkSession): Unit = {
132+
private def register(sparkSession: SparkSession, dataframeId: Long): Unit = {
119133
// makes this class thread-safe:
120134
// only the first thread entering this block can set sparkSession
121135
// all other threads will see the exception, as it is only allowed to do this once
@@ -124,6 +138,7 @@ class Observation(val name: String) {
124138
throw new IllegalArgumentException("An Observation can be used with a Dataset only once")
125139
}
126140
this.sparkSession = Some(sparkSession)
141+
this.dataframeId = Some(dataframeId)
127142
}
128143

129144
sparkSession.listenerManager.register(this.listener)
@@ -134,9 +149,24 @@ class Observation(val name: String) {
134149
}
135150

136151
private[spark] def onFinish(qe: QueryExecution): Unit = {
137-
qe.observedMetrics.get(name).foreach { metrics =>
138-
promise.trySuccess(metrics)
139-
unregister()
152+
if (qe.logical.exists {
153+
case CollectMetrics(name, _, _, dataframeId) =>
154+
name == this.name && dataframeId == this.dataframeId.get
155+
case _ => false
156+
}) {
157+
val metrics = qe.observedMetrics.get(name)
158+
if (metrics.isEmpty) {
159+
// If the key exists but no metrics were collected, it means for some reason the metrics
160+
// could not be collected. This can happen e.g., if the CollectMetricsExec was optimized
161+
// away.
162+
promise.trySuccess(Row.empty)
163+
unregister()
164+
} else {
165+
metrics.foreach { metrics =>
166+
promise.trySuccess(metrics)
167+
unregister()
168+
}
169+
}
140170
}
141171
}
142172

0 commit comments

Comments
 (0)