Skip to content

Commit

Permalink
Genericalize schema utils to support non-struct root level access (#3716
Browse files Browse the repository at this point in the history
)

<!--
Thanks for sending a pull request!  Here are some tips for you:
1. If this is your first time, please read our contributor guidelines:
https://github.com/delta-io/delta/blob/master/CONTRIBUTING.md
2. If the PR is unfinished, add '[WIP]' in your PR title, e.g., '[WIP]
Your PR title ...'.
  3. Be sure to keep the PR description updated to reflect all changes.
  4. Please write your PR title to summarize what this PR proposes.
5. If possible, provide a concise example to reproduce the issue for a
faster review.
6. If applicable, include the corresponding issue number in the PR title
and link it in the body.
-->

#### Which Delta project/connector is this regarding?
<!--
Please add the component selected below to the beginning of the pull
request title
For example: [Spark] Title of my pull request
-->

- [x] Spark
- [ ] Standalone
- [ ] Flink
- [ ] Kernel
- [ ] Other (fill in here)

## Description
Improving some schema utils to allow path index into non-struct root
data structures.

## How was this patch tested?
New UT.

## Does this PR introduce _any_ user-facing changes?
No.
  • Loading branch information
jackierwzhang authored Sep 24, 2024
1 parent 36995d9 commit 19374e2
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 63 deletions.
29 changes: 19 additions & 10 deletions spark/src/main/scala/org/apache/spark/sql/delta/DeltaErrors.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1634,14 +1634,14 @@ trait DeltaErrorsBase
messageParameters = Array(option, operation))
}

def foundMapTypeColumnException(key: String, value: String, schema: StructType): Throwable = {
def foundMapTypeColumnException(key: String, value: String, schema: DataType): Throwable = {
new DeltaAnalysisException(
errorClass = "DELTA_FOUND_MAP_TYPE_COLUMN",
messageParameters = Array(key, value, schema.treeString)
messageParameters = Array(key, value, dataTypeToString(schema))
)
}
def columnNotInSchemaException(column: String, schema: StructType): Throwable = {
nonExistentColumnInSchema(column, schema.treeString)
def columnNotInSchemaException(column: String, schema: DataType): Throwable = {
nonExistentColumnInSchema(column, dataTypeToString(schema))
}

def metadataAbsentException(): Throwable = {
Expand Down Expand Up @@ -2690,25 +2690,29 @@ trait DeltaErrorsBase
def incorrectArrayAccessByName(
rightName: String,
wrongName: String,
schema: StructType): Throwable = {
schema: DataType): Throwable = {
new DeltaAnalysisException(
errorClass = "DELTA_INCORRECT_ARRAY_ACCESS_BY_NAME",
messageParameters = Array(rightName, wrongName, schema.treeString)
messageParameters = Array(
rightName,
wrongName,
dataTypeToString(schema)
)
)
}

def columnPathNotNested(
columnPath: String,
other: DataType,
column: Seq[String],
schema: StructType): Throwable = {
schema: DataType): Throwable = {
new DeltaAnalysisException(
errorClass = "DELTA_COLUMN_PATH_NOT_NESTED",
messageParameters = Array(
s"$columnPath",
s"$other",
s"${SchemaUtils.prettyFieldName(column)}",
schema.treeString
dataTypeToString(schema)
)
)
}
Expand Down Expand Up @@ -3445,11 +3449,11 @@ trait DeltaErrorsBase
}

def errorFindingColumnPosition(
columnPath: Seq[String], schema: StructType, extraErrMsg: String): Throwable = {
columnPath: Seq[String], schema: DataType, extraErrMsg: String): Throwable = {
new DeltaAnalysisException(
errorClass = "_LEGACY_ERROR_TEMP_DELTA_0008",
messageParameters = Array(
UnresolvedAttribute(columnPath).name, schema.treeString, extraErrMsg))
UnresolvedAttribute(columnPath).name, dataTypeToString(schema), extraErrMsg))
}

def alterTableClusterByOnPartitionedTableException(): Throwable = {
Expand Down Expand Up @@ -3481,6 +3485,11 @@ trait DeltaErrorsBase
errorClass = "DELTA_UNSUPPORTED_WRITES_WITHOUT_COORDINATOR",
messageParameters = Array(coordinatorName))
}

private def dataTypeToString(dt: DataType): String = dt match {
case s: StructType => s.treeString
case other => other.simpleString
}
}

object DeltaErrors extends DeltaErrorsBase
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,9 +296,9 @@ object SchemaMergingUtils {
* @param tf function to apply.
* @return the transformed schema.
*/
def transformColumns(
schema: StructType)(
tf: (Seq[String], StructField, Resolver) => StructField): StructType = {
def transformColumns[T <: DataType](
schema: T)(
tf: (Seq[String], StructField, Resolver) => StructField): T = {
def transform[E <: DataType](path: Seq[String], dt: E): E = {
val newDt = dt match {
case StructType(fields) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ object SchemaUtils extends DeltaLogging {
* defines whether we should recurse into ArrayType and MapType.
*/
def filterRecursively(
schema: StructType,
schema: DataType,
checkComplexTypes: Boolean)(f: StructField => Boolean): Seq[(Seq[String], StructField)] = {
def recurseIntoComplexTypes(
complexType: DataType,
Expand Down Expand Up @@ -699,7 +699,7 @@ def normalizeColumnNamesInDataType(
*/
def findColumnPosition(
column: Seq[String],
schema: StructType,
schema: DataType,
resolver: Resolver = DELTA_COL_RESOLVER): Seq[Int] = {
def findRecursively(
searchPath: Seq[String],
Expand Down Expand Up @@ -803,7 +803,7 @@ def normalizeColumnNamesInDataType(
* @param position A list of ordinals (0-based) representing the path to the nested field in
* `parent`.
*/
def getNestedTypeFromPosition(schema: StructType, position: Seq[Int]): DataType =
def getNestedTypeFromPosition(schema: DataType, position: Seq[Int]): DataType =
getNestedFieldFromPosition(StructField("schema", schema), position).dataType

/**
Expand All @@ -814,7 +814,34 @@ def normalizeColumnNamesInDataType(
}

/**
* Add `column` to the specified `position` in `schema`.
* Add a column to its child.
* @param parent The parent data type.
* @param column The column to add.
* @param position The position to add the column.
*/
def addColumn[T <: DataType](parent: T, column: StructField, position: Seq[Int]): T = {
if (position.isEmpty) {
throw DeltaErrors.addColumnParentNotStructException(column, parent)
}
parent match {
case struct: StructType =>
addColumnToStruct(struct, column, position).asInstanceOf[T]
case map: MapType if position.head == MAP_KEY_INDEX =>
map.copy(keyType = addColumn(map.keyType, column, position.tail)).asInstanceOf[T]
case map: MapType if position.head == MAP_VALUE_INDEX =>
map.copy(valueType = addColumn(map.valueType, column, position.tail)).asInstanceOf[T]
case array: ArrayType if position.head == ARRAY_ELEMENT_INDEX =>
array.copy(elementType =
addColumn(array.elementType, column, position.tail)).asInstanceOf[T]
case _: ArrayType =>
throw DeltaErrors.incorrectArrayAccess()
case other =>
throw DeltaErrors.addColumnParentNotStructException(column, other)
}
}

/**
* Add `column` to the specified `position` in a struct `schema`.
* @param position A Seq of ordinals on where this column should go. It is a Seq to denote
* positions in nested columns (0-based). For example:
*
Expand All @@ -824,26 +851,10 @@ def normalizeColumnNamesInDataType(
* will return
* result: <a:STRUCT<a1,a2,a3>, b,c:STRUCT<c1,**c2**,c3>>
*/
def addColumn(schema: StructType, column: StructField, position: Seq[Int]): StructType = {
def addColumnInChild(parent: DataType, column: StructField, position: Seq[Int]): DataType = {
if (position.isEmpty) {
throw DeltaErrors.addColumnParentNotStructException(column, parent)
}
parent match {
case struct: StructType =>
addColumn(struct, column, position)
case map: MapType if position.head == MAP_KEY_INDEX =>
map.copy(keyType = addColumnInChild(map.keyType, column, position.tail))
case map: MapType if position.head == MAP_VALUE_INDEX =>
map.copy(valueType = addColumnInChild(map.valueType, column, position.tail))
case array: ArrayType if position.head == ARRAY_ELEMENT_INDEX =>
array.copy(elementType = addColumnInChild(array.elementType, column, position.tail))
case _: ArrayType =>
throw DeltaErrors.incorrectArrayAccess()
case other =>
throw DeltaErrors.addColumnParentNotStructException(column, other)
}
}
private def addColumnToStruct(
schema: StructType,
column: StructField,
position: Seq[Int]): StructType = {
// If the proposed new column includes a default value, return a specific "not supported" error.
// The rationale is that such operations require the data source scan operator to implement
// support for filling in the specified default value when the corresponding field is not
Expand Down Expand Up @@ -877,13 +888,42 @@ def normalizeColumnNamesInDataType(
if (!column.nullable && field.nullable) {
throw DeltaErrors.nullableParentWithNotNullNestedField
}
val mid = field.copy(dataType = addColumnInChild(field.dataType, column, position.tail))
val mid = field.copy(dataType = addColumn(field.dataType, column, position.tail))
StructType(pre ++ Seq(mid) ++ post.tail)
} else {
StructType(pre ++ Seq(column) ++ post)
}
}

/**
* Drop a column from its child.
* @param parent The parent data type.
* @param position The position to drop the column.
*/
def dropColumn[T <: DataType](parent: T, position: Seq[Int]): (T, StructField) = {
if (position.isEmpty) {
throw DeltaErrors.dropNestedColumnsFromNonStructTypeException(parent)
}
parent match {
case struct: StructType =>
val (t, s) = dropColumnInStruct(struct, position)
(t.asInstanceOf[T], s)
case map: MapType if position.head == MAP_KEY_INDEX =>
val (newKeyType, droppedColumn) = dropColumn(map.keyType, position.tail)
map.copy(keyType = newKeyType).asInstanceOf[T] -> droppedColumn
case map: MapType if position.head == MAP_VALUE_INDEX =>
val (newValueType, droppedColumn) = dropColumn(map.valueType, position.tail)
map.copy(valueType = newValueType).asInstanceOf[T] -> droppedColumn
case array: ArrayType if position.head == ARRAY_ELEMENT_INDEX =>
val (newElementType, droppedColumn) = dropColumn(array.elementType, position.tail)
array.copy(elementType = newElementType).asInstanceOf[T] -> droppedColumn
case _: ArrayType =>
throw DeltaErrors.incorrectArrayAccess()
case other =>
throw DeltaErrors.dropNestedColumnsFromNonStructTypeException(other)
}
}

/**
* Drop from the specified `position` in `schema` and return with the original column.
* @param position A Seq of ordinals on where this column should go. It is a Seq to denote
Expand All @@ -894,30 +934,9 @@ def normalizeColumnNamesInDataType(
* will return
* result: <a:STRUCT<a1,a2,a3>, b,c:STRUCT<c1,c3>>
*/
def dropColumn(schema: StructType, position: Seq[Int]): (StructType, StructField) = {
def dropColumnInChild(parent: DataType, position: Seq[Int]): (DataType, StructField) = {
if (position.isEmpty) {
throw DeltaErrors.dropNestedColumnsFromNonStructTypeException(parent)
}
parent match {
case struct: StructType =>
dropColumn(struct, position)
case map: MapType if position.head == MAP_KEY_INDEX =>
val (newKeyType, droppedColumn) = dropColumnInChild(map.keyType, position.tail)
map.copy(keyType = newKeyType) -> droppedColumn
case map: MapType if position.head == MAP_VALUE_INDEX =>
val (newValueType, droppedColumn) = dropColumnInChild(map.valueType, position.tail)
map.copy(valueType = newValueType) -> droppedColumn
case array: ArrayType if position.head == ARRAY_ELEMENT_INDEX =>
val (newElementType, droppedColumn) = dropColumnInChild(array.elementType, position.tail)
array.copy(elementType = newElementType) -> droppedColumn
case _: ArrayType =>
throw DeltaErrors.incorrectArrayAccess()
case other =>
throw DeltaErrors.dropNestedColumnsFromNonStructTypeException(other)
}
}

private def dropColumnInStruct(
schema: StructType,
position: Seq[Int]): (StructType, StructField) = {
require(position.nonEmpty, "Don't know where to drop the column")
val slicePosition = position.head
if (slicePosition < 0) {
Expand All @@ -930,7 +949,7 @@ def normalizeColumnNamesInDataType(
val (pre, post) = schema.splitAt(slicePosition)
val field = post.head
if (position.length > 1) {
val (newType, droppedColumn) = dropColumnInChild(field.dataType, position.tail)
val (newType, droppedColumn) = dropColumn(field.dataType, position.tail)
val mid = field.copy(dataType = newType)

StructType(pre ++ Seq(mid) ++ post.tail) -> droppedColumn
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1258,6 +1258,35 @@ class SchemaUtilsSuite extends QueryTest
}
}

test("addColumn - top level array") {
val a = StructField("a", IntegerType)
val b = StructField("b", StringType)
val schema = ArrayType(new StructType().add(a).add(b))

val x = StructField("x", LongType)
assert(SchemaUtils.addColumn(schema, x, Seq(0, 1)) ===
ArrayType(new StructType().add(a).add(x).add(b)))
}

test("addColumn - top level map") {
val k = StructField("k", IntegerType)
val v = StructField("v", StringType)
val schema = MapType(
keyType = new StructType().add(k),
valueType = new StructType().add(v))

val x = StructField("x", LongType)
assert(SchemaUtils.addColumn(schema, x, Seq(0, 1)) ===
MapType(
keyType = new StructType().add(k).add(x),
valueType = new StructType().add(v)))

assert(SchemaUtils.addColumn(schema, x, Seq(1, 1)) ===
MapType(
keyType = new StructType().add(k),
valueType = new StructType().add(v).add(x)))
}

////////////////////////////
// dropColumn
////////////////////////////
Expand Down Expand Up @@ -1511,6 +1540,29 @@ class SchemaUtilsSuite extends QueryTest
}
}

test("dropColumn - top level array") {
val schema = ArrayType(new StructType().add("a", IntegerType).add("b", StringType))

assert(SchemaUtils.dropColumn(schema, Seq(0, 0))._1 ===
ArrayType(new StructType().add("b", StringType)))
}

test("dropColumn - top level map") {
val schema = MapType(
keyType = new StructType().add("k", IntegerType).add("k2", StringType),
valueType = new StructType().add("v", StringType).add("v2", StringType))

assert(SchemaUtils.dropColumn(schema, Seq(0, 0))._1 ===
MapType(
keyType = new StructType().add("k2", StringType),
valueType = new StructType().add("v", StringType).add("v2", StringType)))

assert(SchemaUtils.dropColumn(schema, Seq(1, 0))._1 ===
MapType(
keyType = new StructType().add("k", IntegerType).add("k2", StringType),
valueType = new StructType().add("v2", StringType)))
}

/////////////////////////////////
// normalizeColumnNamesInDataType
/////////////////////////////////
Expand Down Expand Up @@ -2584,6 +2636,45 @@ class SchemaUtilsSuite extends QueryTest
assert(update === res3)
}

test("transform top level array type") {
val at = ArrayType(
new StructType()
.add("s1", IntegerType)
)

var visitedFields = 0
val updated = SchemaMergingUtils.transformColumns(at) {
case (_, field, _) =>
visitedFields += 1
field.copy(name = "s1_1", dataType = StringType)
}

assert(visitedFields === 1)
assert(updated === ArrayType(new StructType().add("s1_1", StringType)))
}

test("transform top level map type") {
val mt = MapType(
new StructType()
.add("k1", IntegerType),
new StructType()
.add("v1", IntegerType)
)

var visitedFields = 0
val updated = SchemaMergingUtils.transformColumns(mt) {
case (_, field, _) =>
visitedFields += 1
field.copy(name = field.name + "_1", dataType = StringType)
}

assert(visitedFields === 2)
assert(updated === MapType(
new StructType().add("k1_1", StringType),
new StructType().add("v1_1", StringType)
))
}

////////////////////////////
// pruneEmptyStructs
////////////////////////////
Expand Down

0 comments on commit 19374e2

Please sign in to comment.