Skip to content

Commit

Permalink
XML factory per partition instead of each row and reduce peek() call
Browse files Browse the repository at this point in the history
#84

Firstly, the XML factory was being created for each row. Although this factory cannot be reused across the tasks because it is not serializable, it still can create the factory for each partition.

Secondly, It looks `peek()` is called multiple times for reading a complete data within an element. However, `IS_COALESCING` option for XML parser allows this to read only once.

Also, While looking through this library, I found some stylistic corrections and useless logics. I removed them.
1. I used `isIgnorableWhiteSpace()` so that it can differentiate some spaces between elements but it looks this function does not detect anything due to DTD is not accessible during parsing each row.
2. Corrected the indentations of parameters in function across multiple lines.
3. Removed duplicated logics.
4. Corrected `StaxXmlGenerator` to write attributes first always for safety (to prevent to try to write some data within an element first and then write attributes, which emits an exception)

Author: hyukjinkwon <gurwls223@gmail.com>

Closes #85 from HyukjinKwon/ISSUE-84-performance.
  • Loading branch information
HyukjinKwon committed Feb 15, 2016
1 parent 8422ab5 commit 771a8cc
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 115 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ private[xml] object StaxXmlGenerator {
* @param row The row to convert
*/
def apply(schema: StructType,
writer: IndentingXMLStreamWriter,
options: XmlOptions)(row: Row): Unit = {
writer: IndentingXMLStreamWriter,
options: XmlOptions)(row: Row): Unit = {
def writeChildElement: (String, DataType, Any) => Unit = {
// If this is meant to be value but in no child, write only a value
case (_, _, null) |(_, NullType, _) if options.nullValue == null =>
Expand All @@ -59,9 +59,12 @@ private[xml] object StaxXmlGenerator {
}
case _ if name.startsWith(options.attributePrefix) =>
writer.writeAttribute(name.substring(options.attributePrefix.length), v.toString)

// For ArrayType, we just need to write each as XML element.
case (ArrayType(ty, _), v: Seq[_]) =>
v.foreach(e => writeChildElement(name, ty, e))
v.foreach { e =>
writeChildElement(name, ty, e)
}
// For other datatypes, we just write normal elements.
case _ =>
writeChildElement(name, dt, v)
Expand Down Expand Up @@ -94,13 +97,21 @@ private[xml] object StaxXmlGenerator {
}

case (MapType(kv, vt, _), mv: Map[_, _]) =>
mv.foreach {
val (attributes, elements) = mv.toSeq.partition {
case (f, _) => f.toString.startsWith(options.attributePrefix)
}
// We need to write attributes first before the value.
(attributes ++ elements).foreach {
case (k, v) =>
writeChild(k.toString, vt, v)
}

case (StructType(ty), r: Row) =>
ty.zip(r.toSeq).foreach {
val (attributes, elements) = ty.zip(r.toSeq).partition {
case (f, _) => f.name.startsWith(options.attributePrefix)
}
// We need to write attributes first before the value.
(attributes ++ elements).foreach {
case (field, v) =>
writeChild(field.name, field.dataType, v)
}
Expand Down
88 changes: 43 additions & 45 deletions src/main/scala/com/databricks/spark/xml/parsers/StaxXmlParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,23 +39,23 @@ private[xml] object StaxXmlParser {
private val logger = LoggerFactory.getLogger(StaxXmlParser.getClass)

def parse(xml: RDD[String],
schema: StructType,
options: XmlOptions): RDD[Row] = {
schema: StructType,
options: XmlOptions): RDD[Row] = {
val failFast = options.failFastFlag
xml.mapPartitions { iter =>
val factory = XMLInputFactory.newInstance()
factory.setProperty(XMLInputFactory.IS_NAMESPACE_AWARE, false)
factory.setProperty(XMLInputFactory.IS_COALESCING, true)
iter.flatMap { xml =>
// It does not have to skip for white space, since `XmlInputFormat`
// always finds the root tag without a heading space.
val factory = XMLInputFactory.newInstance()
factory.setProperty(XMLInputFactory.IS_NAMESPACE_AWARE, false)
val reader = new ByteArrayInputStream(xml.getBytes)
val parser = factory.createXMLEventReader(reader)
try {
val rootAttributes = {
val rootEvent = StaxXmlParserUtils.skipUntil(parser, XMLStreamConstants.START_ELEMENT)
rootEvent.asStartElement.getAttributes
.map(_.asInstanceOf[Attribute]).toArray
}
StaxXmlParserUtils.skipUntil(parser, XMLStreamConstants.START_ELEMENT)
val rootEvent = parser.nextEvent()
val rootAttributes =
rootEvent.asStartElement.getAttributes.map(_.asInstanceOf[Attribute]).toArray
Some(convertObject(parser, schema, options, rootAttributes))
} catch {
case _: java.lang.NumberFormatException if !failFast =>
Expand All @@ -80,47 +80,40 @@ private[xml] object StaxXmlParser {
* Parse the current token (and related children) according to a desired schema
*/
private[xml] def convertField(parser: XMLEventReader,
dataType: DataType,
options: XmlOptions): Any = {
dataType: DataType,
options: XmlOptions): Any = {
def convertComplicatedType: DataType => Any = {
case dt: StructType => convertObject(parser, dt, options)
case MapType(StringType, vt, _) => convertMap(parser, vt, options)
case ArrayType(st, _) => convertField(parser, st, options)
case udt: UserDefinedType[_] => convertField(parser, udt.sqlType, options)
}

val current = parser.peek
(current, dataType) match {
(parser.peek, dataType) match {
case (_: StartElement, dt: DataType) => convertComplicatedType(dt)
case (_: EndElement, _: DataType) => null
case (c: Characters, dt: DataType) if !c.isIgnorableWhiteSpace && c.isWhiteSpace =>
case (c: Characters, dt: DataType) if c.isWhiteSpace =>
// When `Characters` is found, we need to look further to decide
// if this is really data or space between other elements.
val next = {
parser.nextEvent
parser.peek
}
val data = c.asCharacters().getData
(next, dataType) match {
case (_: EndElement, _) => if (options.treatEmptyValuesAsNulls) null else data
val data = c.getData
parser.nextEvent()
(parser.peek, dataType) match {
case (_: StartElement, dt: DataType) => convertComplicatedType(dt)
case (_: Characters, st: StructType) =>
// This case can be happen when current data type is inferred as `StructType`
// due to `valueTag` for elements having attributes but no child.
val dt = st.filter(_.name == options.valueTag).head.dataType
convertStringTo(StaxXmlParserUtils.readDataFully(parser), dt)
case (_: Characters, dt: DataType) =>
convertStringTo(StaxXmlParserUtils.readDataFully(parser), dt)
case (_: EndElement, _) if data.isEmpty => null
case (_: EndElement, _) if options.treatEmptyValuesAsNulls => null
case (_: EndElement, _: DataType) => data
}
case (c: Characters, ArrayType(st, _)) if !c.isIgnorableWhiteSpace && !c.isWhiteSpace =>
convertStringTo(StaxXmlParserUtils.readDataFully(parser), st)
case (c: Characters, st: StructType) if !c.isIgnorableWhiteSpace && !c.isWhiteSpace =>

case (c: Characters, ArrayType(st, _)) =>
// For `ArrayType`, it needs to return the type of element. The values are merged later.
convertStringTo(c.getData, st)
case (c: Characters, st: StructType) =>
// This case can be happen when current data type is inferred as `StructType`
// due to `valueTag` for elements having attributes but no child.
val dt = st.filter(_.name == options.valueTag).head.dataType
convertStringTo(StaxXmlParserUtils.readDataFully(parser), dt)
case (c: Characters, dt: DataType) if !c.isIgnorableWhiteSpace && !c.isWhiteSpace =>
convertStringTo(StaxXmlParserUtils.readDataFully(parser), dt)
convertStringTo(c.getData, dt)
case (c: Characters, dt: DataType) =>
convertStringTo(c.getData, dt)
case (e: XMLEvent, dt: DataType) =>
sys.error(s"Failed to parse a value for data type $dt with event ${e.toString}")
}
Expand All @@ -147,8 +140,8 @@ private[xml] object StaxXmlParser {
* Parse an object as map.
*/
private def convertMap(parser: XMLEventReader,
valueType: DataType,
options: XmlOptions): Map[String, Any] = {
valueType: DataType,
options: XmlOptions): Map[String, Any] = {
val keys = ArrayBuffer.empty[String]
val values = ArrayBuffer.empty[Any]
var shouldStop = false
Expand All @@ -170,7 +163,7 @@ private[xml] object StaxXmlParser {
* Convert string values to required data type.
*/
private def convertValues(valuesMap: Map[String, String],
schema: StructType): Map[String, Any] = {
schema: StructType): Map[String, Any] = {
val convertedValuesMap = collection.mutable.Map.empty[String, Any]
valuesMap.foreach {
case (f, v) =>
Expand All @@ -188,13 +181,13 @@ private[xml] object StaxXmlParser {
* Fields in the xml that are not defined in the requested schema will be dropped.
*/
private def convertObject(parser: XMLEventReader,
schema: StructType,
options: XmlOptions,
rootAttributes: Array[Attribute] = Array()): Row = {
schema: StructType,
options: XmlOptions,
rootAttributes: Array[Attribute] = Array()): Row = {
val row = new Array[Any](schema.length)
var shouldStop = false
while (!shouldStop) {
parser.nextEvent match {
parser.nextEvent match {
case e: StartElement =>
val nameToIndex = schema.map(_.name).zipWithIndex.toMap
// If there are attributes, then we process them first.
Expand All @@ -214,19 +207,21 @@ private[xml] object StaxXmlParser {
val dataType = schema(index).dataType
row(index) = dataType match {
case st: StructType =>
// The fields are sorted so `TreeMap` is used.
val fields = convertField(parser, st, options) match {
case row: Row =>
TreeMap(st.map(_.name).zip(row.toSeq): _*)
Map(st.map(_.name).zip(row.toSeq): _*)
case v if st.exists(_.name == options.valueTag) =>
// If this is the element having no children, then it wraps attributes
// with a row So, we first need to find the field name that has the real
// value and then push the value.
TreeMap(options.valueTag -> v)
Map(options.valueTag -> v)
case null => Map.empty
}
// The fields are sorted so `TreeMap` is used.
val convertedValuesMap = convertValues(valuesMap, st)
val row = (fields ++ convertedValuesMap).values.toSeq
val row = TreeMap((fields ++ convertedValuesMap).toSeq : _*).values.toSeq
Row.fromSeq(row)

case ArrayType(dt: DataType, _) =>
val values = Option(row(index))
.map(_.asInstanceOf[ArrayBuffer[Any]])
Expand All @@ -243,12 +238,15 @@ private[xml] object StaxXmlParser {
}
}
values :+ newValue

case _ =>
convertField(parser, dataType, options)
}
}

case _: EndElement =>
shouldStop = StaxXmlParserUtils.checkEndElement(parser, options)

case _ =>
shouldStop = shouldStop && parser.hasNext
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,47 +10,27 @@ private[xml] object StaxXmlParserUtils {
* Skips elements until this meets the given type of a element
*/
def skipUntil(parser: XMLEventReader, eventType: Int): XMLEvent = {
var event = parser.nextEvent
while(parser.hasNext && event.getEventType != eventType) {
event = parser.nextEvent
}
event
}

/**
* Reads the data for all continuous character events within an element.
*/
def readDataFully(parser: XMLEventReader): String = {
var event = parser.peek
var data: String = if (event.isCharacters) "" else null
while(event.isCharacters) {
data += event.asCharacters.getData
while(parser.hasNext && event.getEventType != eventType) {
parser.nextEvent
event = parser.peek
}
data
event
}

/**
* Checks if current event points the EndElement.
*/
def checkEndElement(parser: XMLEventReader, options: XmlOptions): Boolean = {
val current = parser.peek
current match {
parser.peek match {
case _: EndElement => true
case _: StartElement => false
case _: Characters =>
// When `Characters` is found here, we need to look further to decide
// if this is really `EndElement` because this can be whitespace between
// `EndElement` and `StartElement`.
val next = {
parser.nextEvent
parser.peek
}
next match {
case _: EndElement => true
case _: XMLEvent => false
}
case _: XMLEvent =>
// When other events are found here rather than `EndElement` or `StartElement`
// , we need to look further to decide if this is the end because this can be
// whitespace between `EndElement` and `StartElement`.
parser.nextEvent
checkEndElement(parser, options)
}
}

Expand Down
Loading

0 comments on commit 771a8cc

Please sign in to comment.