Skip to content

Commit

Permalink
Parse elements in array having attributes correctly
Browse files Browse the repository at this point in the history
#88

If elements in array have some attributes, it parses and infers the data incorrectly. For example, the xml below:

```xml
<ROWSET>
    <ROW>
        <one attr="value">i'm one</one>
        <one attr="value">i'm one</one>
    </ROW>
</ROWSET>
```

produces the results below:

```
+----+
| one|
+----+
|null|
|null|
+----+
```

This was because of mistakes to deal with elements in `ArrayType`. This PR fixes this.

Also, I changed some relevant functions name to look relevant and removed a unused test file.

Author: hyukjinkwon <gurwls223@gmail.com>

Closes #89 from HyukjinKwon/ISSUE-88-elements-attributes.
  • Loading branch information
HyukjinKwon committed Feb 16, 2016
1 parent e05cfc6 commit 8f423b6
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 98 deletions.
74 changes: 41 additions & 33 deletions src/main/scala/com/databricks/spark/xml/parsers/StaxXmlParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,8 @@ private[xml] object StaxXmlParser {
val parser = factory.createXMLEventReader(reader)
try {
StaxXmlParserUtils.skipUntil(parser, XMLStreamConstants.START_ELEMENT)
val rootEvent = parser.nextEvent()
val rootAttributes =
rootEvent.asStartElement.getAttributes.map(_.asInstanceOf[Attribute]).toArray
parser.nextEvent().asStartElement.getAttributes.map(_.asInstanceOf[Attribute]).toArray
Some(convertObject(parser, schema, options, rootAttributes))
} catch {
case _: java.lang.NumberFormatException if !failFast =>
Expand Down Expand Up @@ -106,20 +105,20 @@ private[xml] object StaxXmlParser {

case (c: Characters, ArrayType(st, _)) =>
// For `ArrayType`, it needs to return the type of element. The values are merged later.
convertStringTo(c.getData, st)
convertTo(c.getData, st)
case (c: Characters, st: StructType) =>
// This case can be happen when current data type is inferred as `StructType`
// due to `valueTag` for elements having attributes but no child.
val dt = st.filter(_.name == options.valueTag).head.dataType
convertStringTo(c.getData, dt)
convertTo(c.getData, dt)
case (c: Characters, dt: DataType) =>
convertStringTo(c.getData, dt)
convertTo(c.getData, dt)
case (e: XMLEvent, dt: DataType) =>
sys.error(s"Failed to parse a value for data type $dt with event ${e.toString}")
}
}

private def convertStringTo: (String, DataType) => Any = {
private def convertTo: (String, DataType) => Any = {
case (null, _) | (_, NullType) => null
case (v, LongType) => signSafeToLong(v)
case (v, DoubleType) => signSafeToDouble(v)
Expand Down Expand Up @@ -170,20 +169,46 @@ private[xml] object StaxXmlParser {
val nameToIndex = schema.map(_.name).zipWithIndex.toMap
nameToIndex.get(f).foreach {
case i =>
convertedValuesMap(f) = convertStringTo(v, schema(i).dataType)
convertedValuesMap(f) = convertTo(v, schema(i).dataType)
}
}
Map(convertedValuesMap.toSeq: _*)
}

/**
* Parse an object from the token stream into a new Row representing the schema.
* [[convertObject()]] calls this in order to convert the object to a row. [[convertObject()]]
* contains some logic to find out which events are the start and end of a row and this function
* converts the events to a row.
*/
private def convertRow(parser: XMLEventReader,
schema: StructType,
options: XmlOptions,
attributes: Array[Attribute] = Array.empty) = {
val valuesMap = StaxXmlParserUtils.convertAttributesToValuesMap(attributes, options)
val fields = convertField(parser, schema, options) match {
case row: Row =>
Map(schema.map(_.name).zip(row.toSeq): _*)
case v if schema.exists(_.name == options.valueTag) =>
// If this is the element having no children, then it wraps attributes
// with a row So, we first need to find the field name that has the real
// value and then push the value.
Map(options.valueTag -> v)
case _ => Map.empty
}
// The fields are sorted so `TreeMap` is used.
val convertedValuesMap = convertValues(valuesMap, schema)
val row = TreeMap((fields ++ convertedValuesMap).toSeq : _*).values.toSeq
Row.fromSeq(row)
}

/**
* Parse an object from the event stream into a new Row representing the schema.
* Fields in the xml that are not defined in the requested schema will be dropped.
*/
private def convertObject(parser: XMLEventReader,
schema: StructType,
options: XmlOptions,
rootAttributes: Array[Attribute] = Array()): Row = {
rootAttributes: Array[Attribute] = Array.empty): Row = {
val row = new Array[Any](schema.length)
var shouldStop = false
while (!shouldStop) {
Expand All @@ -192,47 +217,30 @@ private[xml] object StaxXmlParser {
val nameToIndex = schema.map(_.name).zipWithIndex.toMap
// If there are attributes, then we process them first.
val rootValuesMap =
convertValues(StaxXmlParserUtils.toValuesMap(rootAttributes, options), schema)
rootValuesMap.toSeq.foreach {
StaxXmlParserUtils.convertAttributesToValuesMap(rootAttributes, options)
convertValues(rootValuesMap, schema).toSeq.foreach {
case (f, v) =>
nameToIndex.get(f).foreach(row.update(_, v))
}
val attributes = e.getAttributes.map(_.asInstanceOf[Attribute]).toArray
val valuesMap = StaxXmlParserUtils.toValuesMap(attributes, options)

// Set elements and other attributes to the row
val attributes = e.getAttributes.map(_.asInstanceOf[Attribute]).toArray
val field = e.asStartElement.getName.getLocalPart

nameToIndex.get(field).foreach {
case index =>
val dataType = schema(index).dataType
row(index) = dataType match {
case st: StructType =>
val fields = convertField(parser, st, options) match {
case row: Row =>
Map(st.map(_.name).zip(row.toSeq): _*)
case v if st.exists(_.name == options.valueTag) =>
// If this is the element having no children, then it wraps attributes
// with a row So, we first need to find the field name that has the real
// value and then push the value.
Map(options.valueTag -> v)
case null => Map.empty
}
// The fields are sorted so `TreeMap` is used.
val convertedValuesMap = convertValues(valuesMap, st)
val row = TreeMap((fields ++ convertedValuesMap).toSeq : _*).values.toSeq
Row.fromSeq(row)
convertRow(parser, st, options, attributes)

case ArrayType(dt: DataType, _) =>
val values = Option(row(index))
.map(_.asInstanceOf[ArrayBuffer[Any]])
.getOrElse(ArrayBuffer.empty[Any])
val newValue = {
dt match {
case st: StructType if valuesMap.nonEmpty =>
// If the given type is array but the element type is StructType,
// we should push and write current attributes as fields in elements
// in this array.
convertObject(parser, st, options, attributes)
case st: StructType =>
convertRow(parser, st, options, attributes)
case _ =>
convertField(parser, dataType, options)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ private[xml] object StaxXmlParserUtils {
/**
* Produces values map from given attributes.
*/
def toValuesMap(attributes: Array[Attribute], options: XmlOptions): Map[String, String] = {
def convertAttributesToValuesMap(attributes: Array[Attribute],
options: XmlOptions): Map[String, String] = {
if (options.excludeAttributeFlag) {
Map.empty[String, String]
} else {
Expand Down
28 changes: 14 additions & 14 deletions src/main/scala/com/databricks/spark/xml/util/InferSchema.scala
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,8 @@ private[xml] object InferSchema {
val parser = factory.createXMLEventReader(reader)
try {
StaxXmlParserUtils.skipUntil(parser, XMLStreamConstants.START_ELEMENT)
val rootEvent = parser.nextEvent()
val rootAttributes =
rootEvent.asStartElement.getAttributes.map(_.asInstanceOf[Attribute]).toArray
parser.nextEvent().asStartElement.getAttributes.map(_.asInstanceOf[Attribute]).toArray
Some(inferObject(parser, options, rootAttributes))
} catch {
case _: XMLStreamException if !failFast =>
Expand All @@ -110,7 +109,7 @@ private[xml] object InferSchema {
}
}

private def inferTypeFromString: String => DataType = {
private def inferFrom: String => DataType = {
case null => NullType
case v if v.isEmpty => NullType
case v if isLong(v) => LongType
Expand Down Expand Up @@ -138,7 +137,7 @@ private[xml] object InferSchema {
}
case c: Characters if !c.isWhiteSpace =>
// This means data exists
inferTypeFromString(c.getData)
inferFrom(c.getData)
case e: XMLEvent =>
sys.error(s"Failed to parse data with unexpected event ${e.toString}")
}
Expand All @@ -149,30 +148,31 @@ private[xml] object InferSchema {
*/
private def inferObject(parser: XMLEventReader,
options: XmlOptions,
rootAttributes: Array[Attribute] = Array()): DataType = {
rootAttributes: Array[Attribute] = Array.empty): DataType = {
val builder = Seq.newBuilder[StructField]
val nameToDataTypes = collection.mutable.Map.empty[String, ArrayBuffer[DataType]]
val nameToDataType = collection.mutable.Map.empty[String, ArrayBuffer[DataType]]
var shouldStop = false
while (!shouldStop) {
parser.nextEvent match {
case e: StartElement =>
// If there are attributes, then we should process them first.
val rootValuesMap = StaxXmlParserUtils.toValuesMap(rootAttributes, options)
val rootValuesMap =
StaxXmlParserUtils.convertAttributesToValuesMap(rootAttributes, options)
rootValuesMap.foreach {
case (f, v) =>
nameToDataTypes += (f -> ArrayBuffer(inferTypeFromString(v)))
nameToDataType += (f -> ArrayBuffer(inferFrom(v)))
}

val attributes = e.getAttributes.map(_.asInstanceOf[Attribute]).toArray
val valuesMap = StaxXmlParserUtils.toValuesMap(attributes, options)
val valuesMap = StaxXmlParserUtils.convertAttributesToValuesMap(attributes, options)
val inferredType = inferField(parser, options) match {
case st: StructType if valuesMap.nonEmpty =>
// Merge attributes to the field
val nestedBuilder = Seq.newBuilder[StructField]
nestedBuilder ++= st.fields
valuesMap.foreach {
case (f, v) =>
nestedBuilder += StructField(f, inferTypeFromString(v), nullable = true)
nestedBuilder += StructField(f, inferFrom(v), nullable = true)
}
StructType(nestedBuilder.result().sortBy(_.name))

Expand All @@ -182,17 +182,17 @@ private[xml] object InferSchema {
nestedBuilder += StructField(options.valueTag, dt, nullable = true)
valuesMap.foreach {
case (f, v) =>
nestedBuilder += StructField(f, inferTypeFromString(v), nullable = true)
nestedBuilder += StructField(f, inferFrom(v), nullable = true)
}
StructType(nestedBuilder.result().sortBy(_.name))

case dt: DataType => dt
}
// Add the field and datatypes so that we can check if this is ArrayType.
val field = e.asStartElement.getName.getLocalPart
val dataTypes = nameToDataTypes.getOrElse(field, ArrayBuffer.empty[DataType])
val dataTypes = nameToDataType.getOrElse(field, ArrayBuffer.empty[DataType])
dataTypes += inferredType
nameToDataTypes += (field -> dataTypes)
nameToDataType += (field -> dataTypes)

case _: EndElement =>
shouldStop = StaxXmlParserUtils.checkEndElement(parser, options)
Expand All @@ -203,7 +203,7 @@ private[xml] object InferSchema {
}
// We need to manually merges the fields having the sames so that
// This can be inferred as ArrayType.
nameToDataTypes.foreach{
nameToDataType.foreach{
case (field, dataTypes) if dataTypes.length > 1 =>
val elementType = dataTypes.reduceLeft(InferSchema.compatibleType(options))
builder += StructField(field, ArrayType(elementType), nullable = true)
Expand Down
27 changes: 0 additions & 27 deletions src/test/resources/ages-attribute.xml

This file was deleted.

36 changes: 14 additions & 22 deletions src/test/resources/ages.xml
Original file line number Diff line number Diff line change
@@ -1,22 +1,14 @@
<?xml version="1.0"?>
<ROWSET>
<ROW><id>1</id>
<name>Johnson, Smith, and Jones Co.</name><age>45</age>
<Remark>Pays on time</Remark>
</ROW>
<ROW>
<id>2</id>
<name>Sam Mad Dog Smith</name><amount>93</amount>
</ROW>
<ROW>
<id>3</id><name>Barney Company</name>
<age>0</age>
<Remark>Great to work with
and always pays with cash.</Remark>
</ROW>
<ROW>
<id>4</id>
<name>Johnsons Automotive</name><age>2344</age>
<Remark>Pays on time</Remark>
</ROW>
</ROWSET>
<people>
<person>
<age born="1990-02-24">25</age>
<name>Hyukjin</name>
</person>
<person>
<age born="1985-01-01">30</age>
<name>Lars</name>
</person>
<person>
<age born="1980-01-01">30</age>
<name>Lion</name>
</person>
</people>
12 changes: 11 additions & 1 deletion src/test/scala/com/databricks/spark/xml/XmlSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import com.databricks.spark.xml.XmlOptions._
class XmlSuite extends FunSuite with BeforeAndAfterAll {
val tempEmptyDir = "target/test/empty/"
val agesFile = "src/test/resources/ages.xml"
val agesAttributeFile = "src/test/resources/ages-attribute.xml"
val booksFile = "src/test/resources/books.xml"
val booksNestedObjectFile = "src/test/resources/books-nested-object.xml"
val booksNestedArrayFile = "src/test/resources/books-nested-array.xml"
Expand All @@ -52,7 +51,9 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll {
val booksTag = "book"
val booksRootTag = "books"
val topicsTag = "Topic"
val agesTag = "person"

val numAges = 3
val numCars = 3
val numBooks = 12
val numBooksComplicated = 3
Expand Down Expand Up @@ -103,6 +104,15 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll {
assert(results.size === numCars)
}

test("DSL test with elements in array having attributes") {
val results = sqlContext.xmlFile(agesFile, rowTag = agesTag).collect()
val attrValOne = results(0).get(0).asInstanceOf[Row](1)
val attrValTwo = results(1).get(0).asInstanceOf[Row](1)
assert(attrValOne == "1990-02-24")
assert(attrValTwo == "1985-01-01")
assert(results.size === numAges)
}

test("DSL test for iso-8859-1 encoded file") {
val dataFrame = new XmlReader()
.withCharset("iso-8859-1")
Expand Down

0 comments on commit 8f423b6

Please sign in to comment.