Skip to content

Commit

Permalink
Produces partial results and also fills corrupt record (#370)
Browse files Browse the repository at this point in the history
Followup for #368.

This PR proposes to produces partial results and also fills corrupt record. Some codes changes were borrowed from the resent change in Apache Spark (apache/spark@4e1d859). That fix allow partial fields in JSON.

This PR fixes:
- Don't allow partial results within array and map (to match with Apache side)
- Partially parse and convert each value in each row. If it fails to parse or convert, it becomes `null`.
  - Partial results are only allowed in the top-level row. For instance,  nested partial result case like `Row(1, Row(1, 2, null))` is disallowed - it becomes `Row(1, null)`.
- If any exception is detected, the whole XML text is placed in `columnNameOfCorruptRecord` record
  • Loading branch information
HyukjinKwon authored Dec 29, 2018
1 parent 98f9be7 commit 1a712ea
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 76 deletions.
138 changes: 71 additions & 67 deletions src/main/scala/com/databricks/spark/xml/parsers/StaxXmlParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import javax.xml.stream._
import scala.collection.mutable.ArrayBuffer
import scala.collection.JavaConverters._
import scala.util.control.NonFatal
import scala.util.Try

import org.slf4j.LoggerFactory

Expand All @@ -44,7 +45,24 @@ private[xml] object StaxXmlParser extends Serializable {
xml: RDD[String],
schema: StructType,
options: XmlOptions): RDD[Row] = {
def failedRecord(record: String, cause: Throwable = null): Option[Row] = {

// The logic below is borrowed from Apache Spark's FailureSafeParser.
val corruptFieldIndex = Try(schema.fieldIndex(options.columnNameOfCorruptRecord)).toOption
val actualSchema = StructType(schema.filterNot(_.name == options.columnNameOfCorruptRecord))
val resultRow = new Array[Any](schema.length)
val toResultRow: (Option[Row], String) => Row = (row, badRecord) => {
var i = 0
while (i < actualSchema.length) {
val from = actualSchema(i)
resultRow(schema.fieldIndex(from.name)) = row.map(_.get(i)).orNull
i += 1
}
corruptFieldIndex.foreach(index => resultRow(index) = badRecord)
Row.fromSeq(resultRow)
}

def failedRecord(
record: String, cause: Throwable = null, partialResult: Option[Row] = None): Option[Row] = {
// create a row even if no corrupt record column is present
options.parseMode match {
case FailFastMode =>
Expand All @@ -55,13 +73,7 @@ private[xml] object StaxXmlParser extends Serializable {
logger.warn(s"Dropping malformed line: ${record.replaceAll("\n", "")}. $reason")
None
case PermissiveMode =>
val row = new Array[Any](schema.length)
val nameToIndex = schema.map(_.name).zipWithIndex.toMap
nameToIndex.get(options.columnNameOfCorruptRecord).foreach { corruptIndex =>
require(schema(corruptIndex).dataType == StringType)
row.update(corruptIndex, record)
}
Some(Row.fromSeq(row))
Some(toResultRow(partialResult, record))
}
}

Expand All @@ -88,6 +100,8 @@ private[xml] object StaxXmlParser extends Serializable {
Some(convertObject(parser, schema, options, rootAttributes))
.orElse(failedRecord(xml))
} catch {
case e: PartialResultException =>
failedRecord(xml, e.cause, Some(e.partialResult))
case NonFatal(e) =>
failedRecord(xml, e)
}
Expand Down Expand Up @@ -159,12 +173,8 @@ private[xml] object StaxXmlParser extends Serializable {
while (!shouldStop) {
parser.nextEvent match {
case e: StartElement =>
try {
keys += e.getName.getLocalPart
values += convertField(parser, valueType, options)
} catch {
case NonFatal(_) if options.parseMode == PermissiveMode => // do nothing
}
keys += e.getName.getLocalPart
values += convertField(parser, valueType, options)
case _: EndElement =>
shouldStop = StaxXmlParserUtils.checkEndElement(parser)
case _ => // do nothing
Expand All @@ -183,13 +193,9 @@ private[xml] object StaxXmlParser extends Serializable {
val convertedValuesMap = collection.mutable.Map.empty[String, Any]
val valuesMap = StaxXmlParserUtils.convertAttributesToValuesMap(attributes, options)
valuesMap.foreach { case (f, v) =>
try {
val nameToIndex = schema.map(_.name).zipWithIndex.toMap
nameToIndex.get(f).foreach { i =>
convertedValuesMap(f) = convertTo(v, schema(i).dataType, options)
}
} catch {
case NonFatal(_) if options.parseMode == PermissiveMode => // do nothing
val nameToIndex = schema.map(_.name).zipWithIndex.toMap
nameToIndex.get(f).foreach { i =>
convertedValuesMap(f) = convertTo(v, schema(i).dataType, options)
}
}
convertedValuesMap.toMap
Expand All @@ -212,23 +218,17 @@ private[xml] object StaxXmlParser extends Serializable {
val attributesMap = convertAttributes(attributes, schema, options)

// Then, we read elements here.
val fieldsMap =
try {
convertField(parser, schema, options) match {
case row: Row =>
Map(schema.map(_.name).zip(row.toSeq): _*)
case v if schema.fieldNames.contains(options.valueTag) =>
// If this is the element having no children, then it wraps attributes
// with a row So, we first need to find the field name that has the real
// value and then push the value.
val valuesMap = schema.fieldNames.map((_, null)).toMap
valuesMap + (options.valueTag -> v)
case _ => Map.empty
}
} catch {
case NonFatal(_) if options.parseMode == PermissiveMode =>
Map.empty
}
val fieldsMap = convertField(parser, schema, options) match {
case row: Row =>
Map(schema.map(_.name).zip(row.toSeq): _*)
case v if schema.fieldNames.contains(options.valueTag) =>
// If this is the element having no children, then it wraps attributes
// with a row So, we first need to find the field name that has the real
// value and then push the value.
val valuesMap = schema.fieldNames.map((_, null)).toMap
valuesMap + (options.valueTag -> v)
case _ => Map.empty
}

// Here we merge both to a row.
val valuesMap = fieldsMap ++ attributesMap
Expand Down Expand Up @@ -260,51 +260,55 @@ private[xml] object StaxXmlParser extends Serializable {
convertAttributes(rootAttributes, schema, options).toSeq.foreach { case (f, v) =>
nameToIndex.get(f).foreach { row(_) = v }
}
var badRecordException: Option[Throwable] = None

var shouldStop = false
while (!shouldStop) {
parser.nextEvent match {
case e: StartElement =>
try {
val attributes = e.getAttributes.asScala.map(_.asInstanceOf[Attribute]).toArray
val field = e.asStartElement.getName.getLocalPart

nameToIndex.get(field) match {
case Some(index) =>
schema(index).dataType match {
case st: StructType =>
row(index) = convertObjectWithAttributes(parser, st, options, attributes)
case e: StartElement => try {
val attributes = e.getAttributes.asScala.map(_.asInstanceOf[Attribute]).toArray
val field = e.asStartElement.getName.getLocalPart

case ArrayType(dt: DataType, _) =>
val values = Option(row(index))
.map(_.asInstanceOf[ArrayBuffer[Any]])
.getOrElse(ArrayBuffer.empty[Any])
val newValue = {
dt match {
case st: StructType =>
convertObjectWithAttributes(parser, st, options, attributes)
case dt: DataType =>
convertField(parser, dt, options)
}
}
row(index) = values :+ newValue
nameToIndex.get(field) match {
case Some(index) => schema(index).dataType match {
case st: StructType =>
row(index) = convertObjectWithAttributes(parser, st, options, attributes)

case ArrayType(dt: DataType, _) =>
val values = Option(row(index))
.map(_.asInstanceOf[ArrayBuffer[Any]])
.getOrElse(ArrayBuffer.empty[Any])
val newValue = dt match {
case st: StructType =>
convertObjectWithAttributes(parser, st, options, attributes)
case dt: DataType =>
row(index) = convertField(parser, dt, options)
convertField(parser, dt, options)
}
row(index) = values :+ newValue

case None =>
StaxXmlParserUtils.skipChildren(parser)
case dt: DataType =>
row(index) = convertField(parser, dt, options)
}
} catch {
case NonFatal(_) if options.parseMode == PermissiveMode => // do nothing

case None =>
StaxXmlParserUtils.skipChildren(parser)
}
} catch {
case NonFatal(exception) if options.parseMode == PermissiveMode =>
badRecordException = badRecordException.orElse(Some(exception))
}

case _: EndElement =>
shouldStop = StaxXmlParserUtils.checkEndElement(parser)

case _ => // do nothing
}
}
Row.fromSeq(row)

if (badRecordException.isEmpty) {
Row.fromSeq(row)
} else {
throw PartialResultException(Row.fromSeq(row), badRecordException.get)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Copyright 2014 Databricks
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.databricks.spark.xml.util

import org.apache.spark.sql.Row

/**
* Exception thrown when the underlying parser returns a partial result of parsing.
* @param partialResult the partial result of parsing a bad record.
* @param cause the actual exception about why the parser cannot return full result.
*/
case class PartialResultException(
partialResult: Row,
cause: Throwable)
extends Exception(cause)
24 changes: 15 additions & 9 deletions src/test/scala/com/databricks/spark/xml/XmlSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -980,27 +980,33 @@ final class XmlSuite extends FunSuite with BeforeAndAfterAll {
field("_int", IntegerType)),
struct("long_value",
field("_VALUE", LongType),
field("_int", IntegerType)),
field("_int", StringType)),
field("float_value", FloatType),
field("double_value", DoubleType),
field("boolean_value", BooleanType),
field("string_value"),
array("integer_array", IntegerType),
StructField("integer_map", MapType(StringType, IntegerType)))
field("string_value"), array("integer_array", IntegerType),
field("integer_map", MapType(StringType, IntegerType)),
field("_malformed_records", StringType))
val results = spark.read
.option("mode", "PERMISSIVE")
.option("columnNameOfCorruptRecord", "_malformed_records")
.schema(schema)
.xml(dataTypesValidAndInvalid)

assert(results.schema === schema)

val Array(valid, invalid) = results.take(2)
assert(valid.toSeq.toArray ===
Array(Row(10, 10), Row(10L, null), 10.0, 10.0, true,

assert(valid.toSeq.toArray.take(schema.length - 1) ===
Array(Row(10, 10), Row(10, "Ten"), 10.0, 10.0, true,
"Ten", Array(1, 2), Map("a" -> 123, "b" -> 345)))
assert(invalid.toSeq.toArray ===
Array(Row(null, null), Row(null, 10), null, null, null,
"Ten", Array(2), Map("b" -> 345)))
assert(invalid.toSeq.toArray.take(schema.length - 1) ===
Array(null, null, null, null, null,
"Ten", Array(2), null))

assert(valid.toSeq.toArray.last === null)
assert(invalid.toSeq.toArray.last.toString.contains(
<integer_value int="Ten">Ten</integer_value>.toString))
}

}

0 comments on commit 1a712ea

Please sign in to comment.