Skip to content

Commit

Permalink
[spark] Add the fields in reservedFilters into the estimation of stats (
Browse files Browse the repository at this point in the history
  • Loading branch information
Zouxxyy authored Apr 24, 2024
1 parent 4b878fe commit 9f9e46a
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ abstract class PaimonBaseScan(

lazy val statistics: Optional[stats.Statistics] = table.statistics()

lazy val requiredStatsSchema: StructType = {
val fieldNames = requiredSchema.fieldNames ++ reservedFilters.flatMap(_.references)
StructType(tableSchema.filter(field => fieldNames.contains(field.name)))
}

lazy val readBuilder: ReadBuilder = {
val _readBuilder = table.newReadBuilder()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ case class PaimonStatistics[T <: PaimonBaseScan](scan: T) extends Statistics {
if (paimonStats.isPresent) paimonStats.get().mergedRecordCount() else OptionalLong.of(rowCount)

override def columnStats(): java.util.Map[NamedReference, ColumnStatistics] = {
val requiredFields = scan.readSchema().fieldNames.toList.asJava
val requiredFields = scan.requiredStatsSchema.fieldNames.toList.asJava
val resultMap = new java.util.HashMap[NamedReference, ColumnStatistics]()
if (paimonStats.isPresent) {
val paimonColStats = paimonStats.get().colStats()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,17 @@ import org.apache.spark.sql.connector.expressions.NamedReference
import org.apache.spark.sql.connector.read.Statistics
import org.apache.spark.sql.connector.read.colstats.ColumnStatistics
import org.apache.spark.sql.sources.{And, Filter}
import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.spark.sql.types.StructType

import java.util.OptionalLong

trait StatisticsHelperBase extends SQLConfHelper {

val requiredSchema: StructType
val requiredStatsSchema: StructType

def filterStatistics(v2Stats: Statistics, filters: Seq[Filter]): Statistics = {
val attrs: Seq[AttributeReference] =
requiredSchema.map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)())
requiredStatsSchema.map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)())
val condition = filterToCondition(filters, attrs)

if (condition.isDefined && v2Stats.numRows().isPresent) {
Expand All @@ -56,14 +56,15 @@ trait StatisticsHelperBase extends SQLConfHelper {
StructFilters.filterToExpression(filters.reduce(And), toRef).map {
expression =>
expression.transform {
case ref: BoundReference => attrs.find(_.name == requiredSchema(ref.ordinal).name).get
case ref: BoundReference =>
attrs.find(_.name == requiredStatsSchema(ref.ordinal).name).get
}
}
}

private def toRef(attr: String): Option[BoundReference] = {
val index = requiredSchema.fieldIndex(attr)
val field = requiredSchema(index)
val index = requiredStatsSchema.fieldIndex(attr)
val field = requiredStatsSchema(index)
Option.apply(BoundReference(index, field.dataType, field.nullable))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,13 @@ abstract class AnalyzeTableTestBase extends PaimonSparkTestBase {
getScanStatistic(sql).rowCount.get.longValue())
checkAnswer(spark.sql(sql), Nil)

// partition push down hit and select without it
sql = "SELECT id FROM T WHERE pt < 1"
Assertions.assertEquals(
if (supportsColStats()) 0L else 4L,
getScanStatistic(sql).rowCount.get.longValue())
checkAnswer(spark.sql(sql), Nil)

// partition push down not hit
sql = "SELECT * FROM T WHERE id < 1"
Assertions.assertEquals(4L, getScanStatistic(sql).rowCount.get.longValue())
Expand Down

0 comments on commit 9f9e46a

Please sign in to comment.