From 771a8cc2555de0036d7bacc6bd38ebe4ad3a640c Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 15 Feb 2016 09:53:05 +0900 Subject: [PATCH] XML factory per partition instead of each row and reduce peek() call https://github.com/databricks/spark-xml/issues/84 Firstly, the XML factory was being created for each row. Although this factory cannot be reused across the tasks because it is not serializable, it still can create the factory for each partition. Secondly, It looks `peek()` is called multiple times for reading a complete data within an element. However, `IS_COALESCING` option for XML parser allows this to read only once. Also, While looking through this library, I found some stylistic corrections and useless logics. I removed them. 1. I used `isIgnorableWhiteSpace()` so that it can differentiate some spaces between elements but it looks this function does not detect anything due to DTD is not accessible during parsing each row. 2. Corrected the indentations of parameters in function across multiple lines. 3. Removed duplicated logics. 4. Corrected `StaxXmlGenerator` to write attributes first always for safety (to prevent to try to write some data within an element first and then write attributes, which emits an exception) Author: hyukjinkwon Closes #85 from HyukjinKwon/ISSUE-84-performance. --- .../spark/xml/parsers/StaxXmlGenerator.scala | 21 +++-- .../spark/xml/parsers/StaxXmlParser.scala | 88 +++++++++---------- .../xml/parsers/StaxXmlParserUtils.scala | 38 ++------ .../spark/xml/util/InferSchema.scala | 68 +++++++------- 4 files changed, 100 insertions(+), 115 deletions(-) diff --git a/src/main/scala/com/databricks/spark/xml/parsers/StaxXmlGenerator.scala b/src/main/scala/com/databricks/spark/xml/parsers/StaxXmlGenerator.scala index 5645a367..6ecbab3d 100644 --- a/src/main/scala/com/databricks/spark/xml/parsers/StaxXmlGenerator.scala +++ b/src/main/scala/com/databricks/spark/xml/parsers/StaxXmlGenerator.scala @@ -34,8 +34,8 @@ private[xml] object StaxXmlGenerator { * @param row The row to convert */ def apply(schema: StructType, - writer: IndentingXMLStreamWriter, - options: XmlOptions)(row: Row): Unit = { + writer: IndentingXMLStreamWriter, + options: XmlOptions)(row: Row): Unit = { def writeChildElement: (String, DataType, Any) => Unit = { // If this is meant to be value but in no child, write only a value case (_, _, null) |(_, NullType, _) if options.nullValue == null => @@ -59,9 +59,12 @@ private[xml] object StaxXmlGenerator { } case _ if name.startsWith(options.attributePrefix) => writer.writeAttribute(name.substring(options.attributePrefix.length), v.toString) + // For ArrayType, we just need to write each as XML element. case (ArrayType(ty, _), v: Seq[_]) => - v.foreach(e => writeChildElement(name, ty, e)) + v.foreach { e => + writeChildElement(name, ty, e) + } // For other datatypes, we just write normal elements. case _ => writeChildElement(name, dt, v) @@ -94,13 +97,21 @@ private[xml] object StaxXmlGenerator { } case (MapType(kv, vt, _), mv: Map[_, _]) => - mv.foreach { + val (attributes, elements) = mv.toSeq.partition { + case (f, _) => f.toString.startsWith(options.attributePrefix) + } + // We need to write attributes first before the value. + (attributes ++ elements).foreach { case (k, v) => writeChild(k.toString, vt, v) } case (StructType(ty), r: Row) => - ty.zip(r.toSeq).foreach { + val (attributes, elements) = ty.zip(r.toSeq).partition { + case (f, _) => f.name.startsWith(options.attributePrefix) + } + // We need to write attributes first before the value. + (attributes ++ elements).foreach { case (field, v) => writeChild(field.name, field.dataType, v) } diff --git a/src/main/scala/com/databricks/spark/xml/parsers/StaxXmlParser.scala b/src/main/scala/com/databricks/spark/xml/parsers/StaxXmlParser.scala index 74240ac3..6602271c 100644 --- a/src/main/scala/com/databricks/spark/xml/parsers/StaxXmlParser.scala +++ b/src/main/scala/com/databricks/spark/xml/parsers/StaxXmlParser.scala @@ -39,23 +39,23 @@ private[xml] object StaxXmlParser { private val logger = LoggerFactory.getLogger(StaxXmlParser.getClass) def parse(xml: RDD[String], - schema: StructType, - options: XmlOptions): RDD[Row] = { + schema: StructType, + options: XmlOptions): RDD[Row] = { val failFast = options.failFastFlag xml.mapPartitions { iter => + val factory = XMLInputFactory.newInstance() + factory.setProperty(XMLInputFactory.IS_NAMESPACE_AWARE, false) + factory.setProperty(XMLInputFactory.IS_COALESCING, true) iter.flatMap { xml => // It does not have to skip for white space, since `XmlInputFormat` // always finds the root tag without a heading space. - val factory = XMLInputFactory.newInstance() - factory.setProperty(XMLInputFactory.IS_NAMESPACE_AWARE, false) val reader = new ByteArrayInputStream(xml.getBytes) val parser = factory.createXMLEventReader(reader) try { - val rootAttributes = { - val rootEvent = StaxXmlParserUtils.skipUntil(parser, XMLStreamConstants.START_ELEMENT) - rootEvent.asStartElement.getAttributes - .map(_.asInstanceOf[Attribute]).toArray - } + StaxXmlParserUtils.skipUntil(parser, XMLStreamConstants.START_ELEMENT) + val rootEvent = parser.nextEvent() + val rootAttributes = + rootEvent.asStartElement.getAttributes.map(_.asInstanceOf[Attribute]).toArray Some(convertObject(parser, schema, options, rootAttributes)) } catch { case _: java.lang.NumberFormatException if !failFast => @@ -80,8 +80,8 @@ private[xml] object StaxXmlParser { * Parse the current token (and related children) according to a desired schema */ private[xml] def convertField(parser: XMLEventReader, - dataType: DataType, - options: XmlOptions): Any = { + dataType: DataType, + options: XmlOptions): Any = { def convertComplicatedType: DataType => Any = { case dt: StructType => convertObject(parser, dt, options) case MapType(StringType, vt, _) => convertMap(parser, vt, options) @@ -89,38 +89,31 @@ private[xml] object StaxXmlParser { case udt: UserDefinedType[_] => convertField(parser, udt.sqlType, options) } - val current = parser.peek - (current, dataType) match { + (parser.peek, dataType) match { case (_: StartElement, dt: DataType) => convertComplicatedType(dt) case (_: EndElement, _: DataType) => null - case (c: Characters, dt: DataType) if !c.isIgnorableWhiteSpace && c.isWhiteSpace => + case (c: Characters, dt: DataType) if c.isWhiteSpace => // When `Characters` is found, we need to look further to decide // if this is really data or space between other elements. - val next = { - parser.nextEvent - parser.peek - } - val data = c.asCharacters().getData - (next, dataType) match { - case (_: EndElement, _) => if (options.treatEmptyValuesAsNulls) null else data + val data = c.getData + parser.nextEvent() + (parser.peek, dataType) match { case (_: StartElement, dt: DataType) => convertComplicatedType(dt) - case (_: Characters, st: StructType) => - // This case can be happen when current data type is inferred as `StructType` - // due to `valueTag` for elements having attributes but no child. - val dt = st.filter(_.name == options.valueTag).head.dataType - convertStringTo(StaxXmlParserUtils.readDataFully(parser), dt) - case (_: Characters, dt: DataType) => - convertStringTo(StaxXmlParserUtils.readDataFully(parser), dt) + case (_: EndElement, _) if data.isEmpty => null + case (_: EndElement, _) if options.treatEmptyValuesAsNulls => null + case (_: EndElement, _: DataType) => data } - case (c: Characters, ArrayType(st, _)) if !c.isIgnorableWhiteSpace && !c.isWhiteSpace => - convertStringTo(StaxXmlParserUtils.readDataFully(parser), st) - case (c: Characters, st: StructType) if !c.isIgnorableWhiteSpace && !c.isWhiteSpace => + + case (c: Characters, ArrayType(st, _)) => + // For `ArrayType`, it needs to return the type of element. The values are merged later. + convertStringTo(c.getData, st) + case (c: Characters, st: StructType) => // This case can be happen when current data type is inferred as `StructType` // due to `valueTag` for elements having attributes but no child. val dt = st.filter(_.name == options.valueTag).head.dataType - convertStringTo(StaxXmlParserUtils.readDataFully(parser), dt) - case (c: Characters, dt: DataType) if !c.isIgnorableWhiteSpace && !c.isWhiteSpace => - convertStringTo(StaxXmlParserUtils.readDataFully(parser), dt) + convertStringTo(c.getData, dt) + case (c: Characters, dt: DataType) => + convertStringTo(c.getData, dt) case (e: XMLEvent, dt: DataType) => sys.error(s"Failed to parse a value for data type $dt with event ${e.toString}") } @@ -147,8 +140,8 @@ private[xml] object StaxXmlParser { * Parse an object as map. */ private def convertMap(parser: XMLEventReader, - valueType: DataType, - options: XmlOptions): Map[String, Any] = { + valueType: DataType, + options: XmlOptions): Map[String, Any] = { val keys = ArrayBuffer.empty[String] val values = ArrayBuffer.empty[Any] var shouldStop = false @@ -170,7 +163,7 @@ private[xml] object StaxXmlParser { * Convert string values to required data type. */ private def convertValues(valuesMap: Map[String, String], - schema: StructType): Map[String, Any] = { + schema: StructType): Map[String, Any] = { val convertedValuesMap = collection.mutable.Map.empty[String, Any] valuesMap.foreach { case (f, v) => @@ -188,13 +181,13 @@ private[xml] object StaxXmlParser { * Fields in the xml that are not defined in the requested schema will be dropped. */ private def convertObject(parser: XMLEventReader, - schema: StructType, - options: XmlOptions, - rootAttributes: Array[Attribute] = Array()): Row = { + schema: StructType, + options: XmlOptions, + rootAttributes: Array[Attribute] = Array()): Row = { val row = new Array[Any](schema.length) var shouldStop = false while (!shouldStop) { - parser.nextEvent match { + parser.nextEvent match { case e: StartElement => val nameToIndex = schema.map(_.name).zipWithIndex.toMap // If there are attributes, then we process them first. @@ -214,19 +207,21 @@ private[xml] object StaxXmlParser { val dataType = schema(index).dataType row(index) = dataType match { case st: StructType => - // The fields are sorted so `TreeMap` is used. val fields = convertField(parser, st, options) match { case row: Row => - TreeMap(st.map(_.name).zip(row.toSeq): _*) + Map(st.map(_.name).zip(row.toSeq): _*) case v if st.exists(_.name == options.valueTag) => // If this is the element having no children, then it wraps attributes // with a row So, we first need to find the field name that has the real // value and then push the value. - TreeMap(options.valueTag -> v) + Map(options.valueTag -> v) + case null => Map.empty } + // The fields are sorted so `TreeMap` is used. val convertedValuesMap = convertValues(valuesMap, st) - val row = (fields ++ convertedValuesMap).values.toSeq + val row = TreeMap((fields ++ convertedValuesMap).toSeq : _*).values.toSeq Row.fromSeq(row) + case ArrayType(dt: DataType, _) => val values = Option(row(index)) .map(_.asInstanceOf[ArrayBuffer[Any]]) @@ -243,12 +238,15 @@ private[xml] object StaxXmlParser { } } values :+ newValue + case _ => convertField(parser, dataType, options) } } + case _: EndElement => shouldStop = StaxXmlParserUtils.checkEndElement(parser, options) + case _ => shouldStop = shouldStop && parser.hasNext } diff --git a/src/main/scala/com/databricks/spark/xml/parsers/StaxXmlParserUtils.scala b/src/main/scala/com/databricks/spark/xml/parsers/StaxXmlParserUtils.scala index 06713058..b82a17cf 100644 --- a/src/main/scala/com/databricks/spark/xml/parsers/StaxXmlParserUtils.scala +++ b/src/main/scala/com/databricks/spark/xml/parsers/StaxXmlParserUtils.scala @@ -10,47 +10,27 @@ private[xml] object StaxXmlParserUtils { * Skips elements until this meets the given type of a element */ def skipUntil(parser: XMLEventReader, eventType: Int): XMLEvent = { - var event = parser.nextEvent - while(parser.hasNext && event.getEventType != eventType) { - event = parser.nextEvent - } - event - } - - /** - * Reads the data for all continuous character events within an element. - */ - def readDataFully(parser: XMLEventReader): String = { var event = parser.peek - var data: String = if (event.isCharacters) "" else null - while(event.isCharacters) { - data += event.asCharacters.getData + while(parser.hasNext && event.getEventType != eventType) { parser.nextEvent event = parser.peek } - data + event } /** * Checks if current event points the EndElement. */ def checkEndElement(parser: XMLEventReader, options: XmlOptions): Boolean = { - val current = parser.peek - current match { + parser.peek match { case _: EndElement => true case _: StartElement => false - case _: Characters => - // When `Characters` is found here, we need to look further to decide - // if this is really `EndElement` because this can be whitespace between - // `EndElement` and `StartElement`. - val next = { - parser.nextEvent - parser.peek - } - next match { - case _: EndElement => true - case _: XMLEvent => false - } + case _: XMLEvent => + // When other events are found here rather than `EndElement` or `StartElement` + // , we need to look further to decide if this is the end because this can be + // whitespace between `EndElement` and `StartElement`. + parser.nextEvent + checkEndElement(parser, options) } } diff --git a/src/main/scala/com/databricks/spark/xml/util/InferSchema.scala b/src/main/scala/com/databricks/spark/xml/util/InferSchema.scala index fbcbe9b7..3e0c35d3 100644 --- a/src/main/scala/com/databricks/spark/xml/util/InferSchema.scala +++ b/src/main/scala/com/databricks/spark/xml/util/InferSchema.scala @@ -66,8 +66,7 @@ private[xml] object InferSchema { * 2. Merge types by choosing the lowest type necessary to cover equal keys * 3. Replace any remaining null fields with string, the top type */ - def infer(xml: RDD[String], - options: XmlOptions): StructType = { + def infer(xml: RDD[String], options: XmlOptions): StructType = { require(options.samplingRatio > 0, s"samplingRatio ($options.samplingRatio) should be greater than 0") val schemaData = if (options.samplingRatio > 0.99) { @@ -78,19 +77,19 @@ private[xml] object InferSchema { val failFast = options.failFastFlag // perform schema inference on each row and merge afterwards val rootType = schemaData.mapPartitions { iter => + val factory = XMLInputFactory.newInstance() + factory.setProperty(XMLInputFactory.IS_NAMESPACE_AWARE, false) + factory.setProperty(XMLInputFactory.IS_COALESCING, true) iter.flatMap { xml => // It does not have to skip for white space, since [[XmlInputFormat]] // always finds the root tag without a heading space. - val factory = XMLInputFactory.newInstance() - factory.setProperty(XMLInputFactory.IS_NAMESPACE_AWARE, false) val reader = new ByteArrayInputStream(xml.getBytes) val parser = factory.createXMLEventReader(reader) try { - val rootAttributes = { - val rootEvent = StaxXmlParserUtils.skipUntil(parser, XMLStreamConstants.START_ELEMENT) - rootEvent.asStartElement.getAttributes - .map(_.asInstanceOf[Attribute]).toArray - } + StaxXmlParserUtils.skipUntil(parser, XMLStreamConstants.START_ELEMENT) + val rootEvent = parser.nextEvent() + val rootAttributes = + rootEvent.asStartElement.getAttributes.map(_.asInstanceOf[Attribute]).toArray Some(inferObject(parser, options, rootAttributes)) } catch { case _: XMLStreamException if !failFast => @@ -111,41 +110,35 @@ private[xml] object InferSchema { } } - private def inferTypeFromString(value: String): DataType = { - Option(value) match { - case Some(v) if v.isEmpty => NullType - case Some(v) if isLong(v) => LongType - case Some(v) if isInteger(v) => IntegerType - case Some(v) if isDouble(v) => DoubleType - case Some(v) if isBoolean(v) => BooleanType - case Some(v) if isTimestamp(v) => TimestampType - case Some(v) => StringType - case None => NullType - } + private def inferTypeFromString: String => DataType = { + case null => NullType + case v if v.isEmpty => NullType + case v if isLong(v) => LongType + case v if isInteger(v) => IntegerType + case v if isDouble(v) => DoubleType + case v if isBoolean(v) => BooleanType + case v if isTimestamp(v) => TimestampType + case v => StringType } private def inferField(parser: XMLEventReader, options: XmlOptions): DataType = { - val current = parser.peek - current match { + parser.peek match { case _: EndElement => NullType case _: StartElement => inferObject(parser, options) - case c: Characters if !c.isIgnorableWhiteSpace && c.isWhiteSpace => + case c: Characters if c.isWhiteSpace => // When `Characters` is found, we need to look further to decide // if this is really data or space between other elements. - val next = { - parser.nextEvent - parser.peek - } - next match { + val data = c.getData + parser.nextEvent() + parser.peek match { + case _: StartElement => inferObject(parser, options) + case _: EndElement if data.isEmpty => NullType case _: EndElement if options.treatEmptyValuesAsNulls => NullType case _: EndElement => StringType - case _: StartElement => inferObject(parser, options) - case _: Characters => inferTypeFromString(StaxXmlParserUtils.readDataFully(parser)) } - case c: Characters if !c.isIgnorableWhiteSpace && !c.isWhiteSpace => + case c: Characters if !c.isWhiteSpace => // This means data exists - inferTypeFromString(StaxXmlParserUtils.readDataFully(parser)) - + inferTypeFromString(c.getData) case e: XMLEvent => sys.error(s"Failed to parse data with unexpected event ${e.toString}") } @@ -155,8 +148,8 @@ private[xml] object InferSchema { * Infer the type of a xml document from the parser's token stream */ private def inferObject(parser: XMLEventReader, - options: XmlOptions, - rootAttributes: Array[Attribute] = Array()): DataType = { + options: XmlOptions, + rootAttributes: Array[Attribute] = Array()): DataType = { val builder = Seq.newBuilder[StructField] val nameToDataTypes = collection.mutable.Map.empty[String, ArrayBuffer[DataType]] var shouldStop = false @@ -172,7 +165,6 @@ private[xml] object InferSchema { val attributes = e.getAttributes.map(_.asInstanceOf[Attribute]).toArray val valuesMap = StaxXmlParserUtils.toValuesMap(attributes, options) - val inferredType = inferField(parser, options) match { case st: StructType if valuesMap.nonEmpty => // Merge attributes to the field @@ -183,6 +175,7 @@ private[xml] object InferSchema { nestedBuilder += StructField(f, inferTypeFromString(v), nullable = true) } StructType(nestedBuilder.result().sortBy(_.name)) + case dt: DataType if valuesMap.nonEmpty => // We need to manually add the field for value. val nestedBuilder = Seq.newBuilder[StructField] @@ -192,6 +185,7 @@ private[xml] object InferSchema { nestedBuilder += StructField(f, inferTypeFromString(v), nullable = true) } StructType(nestedBuilder.result().sortBy(_.name)) + case dt: DataType => dt } // Add the field and datatypes so that we can check if this is ArrayType. @@ -199,8 +193,10 @@ private[xml] object InferSchema { val dataTypes = nameToDataTypes.getOrElse(field, ArrayBuffer.empty[DataType]) dataTypes += inferredType nameToDataTypes += (field -> dataTypes) + case _: EndElement => shouldStop = StaxXmlParserUtils.checkEndElement(parser, options) + case _ => shouldStop = shouldStop && parser.hasNext }