Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shortcut common type inference cases to fail fast, speed up inference #660

Merged
merged 3 commits into from
Sep 7, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 54 additions & 31 deletions src/main/scala/com/databricks/spark/xml/util/TypeCast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ package com.databricks.spark.xml.util

import java.math.BigDecimal
import java.sql.{Date, Timestamp}
import java.text.{NumberFormat, ParsePosition}
import java.text.NumberFormat
import java.time.{Instant, LocalDate, ZoneId}
import java.time.format.{DateTimeFormatter, DateTimeFormatterBuilder}
import java.util.Locale
Expand All @@ -26,8 +26,6 @@ import scala.util.control.Exception._
import org.apache.spark.sql.types._
import com.databricks.spark.xml.XmlOptions

import java.time.temporal.TemporalQueries

/**
* Utility functions for type casting
*/
Expand Down Expand Up @@ -63,8 +61,14 @@ private[xml] object TypeCast {
case _: BooleanType => parseXmlBoolean(datum)
case dt: DecimalType =>
Decimal(new BigDecimal(datum.replaceAll(",", "")), dt.precision, dt.scale)
case _: TimestampType => parseXmlTimestamp(datum, options)
case _: DateType => parseXmlDate(datum, options)
case _: TimestampType =>
parseXmlTimestamp(datum, options).getOrElse {
throw new IllegalArgumentException(s"cannot convert value $datum to Timestamp")
}
case _: DateType =>
parseXmlDate(datum, options).getOrElse {
throw new IllegalArgumentException(s"cannot convert value $datum to Date")
}
case _: StringType => datum
case _ => throw new IllegalArgumentException(s"Unsupported type: ${castType.typeName}")
}
Expand All @@ -85,17 +89,26 @@ private[xml] object TypeCast {
DateTimeFormatter.ISO_DATE
)

private def parseXmlDate(value: String, options: XmlOptions): Date = {
val formatters = options.dateFormat.map(DateTimeFormatter.ofPattern).
map(supportedXmlDateFormatters :+ _).getOrElse(supportedXmlDateFormatters)
formatters.foreach { format =>
private def parseXmlDate(value: String, options: XmlOptions): Option[Date] = {
// A little shortcut to avoid trying many formatters in the common case that
// the input isn't a date. All built-in formats will start with a digit.
if (value.nonEmpty && Character.isDigit(value.head)) {
supportedXmlDateFormatters.foreach { format =>
try {
return Some(Date.valueOf(LocalDate.parse(value, format)))
} catch {
case _: Exception => // continue
}
}
}
options.dateFormat.map(DateTimeFormatter.ofPattern).foreach { format =>
try {
return Date.valueOf(LocalDate.parse(value, format))
return Some(Date.valueOf(LocalDate.parse(value, format)))
} catch {
case _: Exception => // continue
}
}
throw new IllegalArgumentException(s"cannot convert value $value to Date")
None
}

private val supportedXmlTimestampFormatters = Seq(
Expand All @@ -115,12 +128,16 @@ private[xml] object TypeCast {
DateTimeFormatter.ISO_INSTANT
)

private def parseXmlTimestamp(value: String, options: XmlOptions): Timestamp = {
supportedXmlTimestampFormatters.foreach { format =>
try {
return Timestamp.from(Instant.from(format.parse(value)))
} catch {
case _: Exception => // continue
private def parseXmlTimestamp(value: String, options: XmlOptions): Option[Timestamp] = {
// A little shortcut to avoid trying many formatters in the common case that
// the input isn't a timestamp. All built-in formats will start with a digit.
if (value.nonEmpty && Character.isDigit(value.head)) {
supportedXmlTimestampFormatters.foreach { format =>
try {
return Some(Timestamp.from(Instant.from(format.parse(value))))
} catch {
case _: Exception => // continue
}
}
}
options.timestampFormat.foreach { formatString =>
Expand All @@ -138,12 +155,12 @@ private[xml] object TypeCast {
DateTimeFormatter.ofPattern(formatString).withZone(options.timezone.map(ZoneId.of).orNull)
}
try {
return Timestamp.from(Instant.from(format.parse(value)))
return Some(Timestamp.from(Instant.from(format.parse(value))))
} catch {
case _: Exception => // continue
}
}
throw new IllegalArgumentException(s"cannot convert value $value to Timestamp")
None
}


Expand Down Expand Up @@ -196,6 +213,12 @@ private[xml] object TypeCast {
} else {
value
}
// A little shortcut to avoid trying many formatters in the common case that
// the input isn't a double. All built-in formats will start with a digit or period.
if (signSafeValue.isEmpty ||
!(Character.isDigit(signSafeValue.head) || signSafeValue.head == '.')) {
return false
}
// Rule out strings ending in D or F, as they will parse as double but should be disallowed
if (value.nonEmpty && (value.last match {
case 'd' | 'D' | 'f' | 'F' => true
Expand All @@ -212,6 +235,11 @@ private[xml] object TypeCast {
} else {
value
}
// A little shortcut to avoid trying many formatters in the common case that
// the input isn't a number. All built-in formats will start with a digit.
if (signSafeValue.isEmpty || !Character.isDigit(signSafeValue.head)) {
return false
}
(allCatch opt signSafeValue.toInt).isDefined
}

Expand All @@ -221,25 +249,20 @@ private[xml] object TypeCast {
} else {
value
}
// A little shortcut to avoid trying many formatters in the common case that
// the input isn't a number. All built-in formats will start with a digit.
if (signSafeValue.isEmpty || !Character.isDigit(signSafeValue.head)) {
return false
}
(allCatch opt signSafeValue.toLong).isDefined
}

private[xml] def isTimestamp(value: String, options: XmlOptions): Boolean = {
try {
parseXmlTimestamp(value, options)
true
} catch {
case _: IllegalArgumentException => false
}
parseXmlTimestamp(value, options).nonEmpty
}

private[xml] def isDate(value: String, options: XmlOptions): Boolean = {
try {
parseXmlDate(value, options)
true
} catch {
case _: IllegalArgumentException => false
}
parseXmlDate(value, options).nonEmpty
}

private[xml] def signSafeToLong(value: String, options: XmlOptions): Long = {
Expand Down
Loading