Skip to content

Commit b8a5787

Browse files
committed
feat: native support for intervals <-> Neo4j duration
this commit introduces a new way to write from Spark SQL interval types to Neo4j duration type. commit is additive, i.e. the previous method to write Neo4j duration type via custom struct is still possible, and should therefore be backwards compatible. Fixes CONN-341
1 parent 07f37d0 commit b8a5787

File tree

3 files changed

+212
-5
lines changed

3 files changed

+212
-5
lines changed

common/src/main/scala/org/neo4j/spark/converter/DataConverter.scala

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,36 @@ trait DataConverter[T] {
5555

5656
object SparkToNeo4jDataConverter {
5757
def apply(): SparkToNeo4jDataConverter = new SparkToNeo4jDataConverter()
58+
59+
def dayTimeIntervalToNeo4j(micros: Long): Value = {
60+
val oneSecondInMicros = 1000000L
61+
val oneDayInMicros = 24 * 3600 * oneSecondInMicros
62+
63+
val numberDays = Math.floorDiv(micros, oneDayInMicros)
64+
val remainderMicros = Math.floorMod(micros, oneDayInMicros)
65+
val numberSeconds = Math.floorDiv(remainderMicros, oneSecondInMicros)
66+
val numberNanos = Math.floorMod(remainderMicros, oneSecondInMicros) * 1000
67+
68+
Values.isoDuration(0L, numberDays, numberSeconds, numberNanos.toInt)
69+
}
70+
71+
// while Neo4j supports years, this driver version's API does not expose it.
72+
def yearMonthIntervalToNeo4j(months: Int): Value = {
73+
Values.isoDuration(months.toLong, 0L, 0L, 0)
74+
}
5875
}
5976

6077
class SparkToNeo4jDataConverter extends DataConverter[Value] {
6178

6279
override def convert(value: Any, dataType: DataType): Value = {
80+
dataType match {
81+
case _: DayTimeIntervalType if value != null =>
82+
return SparkToNeo4jDataConverter.dayTimeIntervalToNeo4j(value.asInstanceOf[Long])
83+
case _: YearMonthIntervalType if value != null =>
84+
return SparkToNeo4jDataConverter.yearMonthIntervalToNeo4j(value.asInstanceOf[Int])
85+
case _ => // do nothing
86+
}
87+
6388
value match {
6489
case date: java.sql.Date => convert(date.toLocalDate, dataType)
6590
case timestamp: java.sql.Timestamp => convert(timestamp.toLocalDateTime, dataType)

common/src/main/scala/org/neo4j/spark/converter/TypeConverter.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ package org.neo4j.spark.converter
1818

1919
import org.apache.spark.sql.types.DataType
2020
import org.apache.spark.sql.types.DataTypes
21+
import org.apache.spark.sql.types.DayTimeIntervalType
22+
import org.apache.spark.sql.types.YearMonthIntervalType
2123
import org.neo4j.driver.types.Entity
2224
import org.neo4j.spark.converter.CypherToSparkTypeConverter.cleanTerms
2325
import org.neo4j.spark.converter.CypherToSparkTypeConverter.durationType
@@ -129,6 +131,8 @@ object SparkToCypherTypeConverter {
129131
DataTypes.DoubleType -> "FLOAT",
130132
DataTypes.DateType -> "DATE",
131133
DataTypes.TimestampType -> "LOCAL DATETIME",
134+
DayTimeIntervalType() -> "DURATION",
135+
YearMonthIntervalType() -> "DURATION",
132136
durationType -> "DURATION",
133137
pointType -> "POINT",
134138
// Cypher graph entities do not allow null values in arrays
@@ -141,6 +145,10 @@ object SparkToCypherTypeConverter {
141145
DataTypes.createArrayType(DataTypes.DateType, false) -> "LIST<DATE NOT NULL>",
142146
DataTypes.createArrayType(DataTypes.TimestampType, false) -> "LIST<LOCAL DATETIME NOT NULL>",
143147
DataTypes.createArrayType(DataTypes.TimestampType, true) -> "LIST<LOCAL DATETIME NOT NULL>",
148+
DataTypes.createArrayType(DayTimeIntervalType(), false) -> "LIST<DURATION NOT NULL>",
149+
DataTypes.createArrayType(DayTimeIntervalType(), true) -> "LIST<DURATION NOT NULL>",
150+
DataTypes.createArrayType(YearMonthIntervalType(), false) -> "LIST<DURATION NOT NULL>",
151+
DataTypes.createArrayType(YearMonthIntervalType(), true) -> "LIST<DURATION NOT NULL>",
144152
DataTypes.createArrayType(durationType, false) -> "LIST<DURATION NOT NULL>",
145153
DataTypes.createArrayType(pointType, false) -> "LIST<POINT NOT NULL>"
146154
)

spark-3/src/test/scala/org/neo4j/spark/DataSourceWriterTSE.scala

Lines changed: 179 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ import org.apache.spark.SparkException
2121
import org.apache.spark.sql.DataFrame
2222
import org.apache.spark.sql.SaveMode
2323
import org.apache.spark.sql.SparkSession
24+
import org.apache.spark.sql.types.ArrayType
25+
import org.apache.spark.sql.types.DataType
26+
import org.apache.spark.sql.types.DayTimeIntervalType
27+
import org.apache.spark.sql.types.YearMonthIntervalType
2428
import org.junit
2529
import org.junit.Assert._
2630
import org.junit.Ignore
@@ -30,6 +34,7 @@ import org.neo4j.driver.Transaction
3034
import org.neo4j.driver.TransactionWork
3135
import org.neo4j.driver.Value
3236
import org.neo4j.driver.exceptions.ClientException
37+
import org.neo4j.driver.exceptions.value.Uncoercible
3338
import org.neo4j.driver.internal.InternalPoint2D
3439
import org.neo4j.driver.internal.InternalPoint3D
3540
import org.neo4j.driver.internal.types.InternalTypeSystem
@@ -74,7 +79,7 @@ case class SimplePerson(name: String, surname: String)
7479
case class EmptyRow[T](data: T)
7580

7681
class DataSourceWriterTSE extends SparkConnectorScalaBaseTSE {
77-
val sparkSession = SparkSession.builder().getOrCreate()
82+
val sparkSession: SparkSession = SparkSession.builder().getOrCreate()
7883

7984
import sparkSession.implicits._
8085

@@ -414,7 +419,7 @@ class DataSourceWriterTSE extends SparkConnectorScalaBaseTSE {
414419
}
415420

416421
@Test
417-
def `should write nodes with duration values into Neo4j`(): Unit = {
422+
def `should write nodes with duration values into Neo4j from struct`(): Unit = {
418423
val total = 10
419424
val ds = (1 to total)
420425
.map(i => i.toLong)
@@ -430,10 +435,10 @@ class DataSourceWriterTSE extends SparkConnectorScalaBaseTSE {
430435

431436
val records = SparkConnectorScalaSuiteIT.session().run(
432437
"""MATCH (p:BeanWithDuration)
433-
|RETURN p.data AS data
438+
|RETURN p.data AS duration
434439
|""".stripMargin
435440
).list().asScala
436-
.map(r => r.get("data").asIsoDuration())
441+
.map(r => r.get("duration").asIsoDuration())
437442
.map(data => (data.months, data.days, data.seconds, data.nanoseconds))
438443
.toSet
439444

@@ -445,7 +450,7 @@ class DataSourceWriterTSE extends SparkConnectorScalaBaseTSE {
445450
}
446451

447452
@Test
448-
def `should write nodes with duration array values into Neo4j`(): Unit = {
453+
def `should write nodes with duration array values into Neo4j from struct`(): Unit = {
449454
val total = 10
450455
val ds = (1 to total)
451456
.map(i => i.toLong)
@@ -484,6 +489,175 @@ class DataSourceWriterTSE extends SparkConnectorScalaBaseTSE {
484489
assertEquals(expected, records)
485490
}
486491

492+
private def writeAndGetInterval(expectedDt: Class[_ <: DataType], sql: String): Value = {
493+
val id = java.util.UUID.randomUUID().toString
494+
val df = sparkSession.sql(s"SELECT '$id' AS id, $sql AS duration")
495+
496+
df.write
497+
.format(classOf[DataSource].getName)
498+
.mode(SaveMode.Append)
499+
.option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl)
500+
.option("labels", "Dur")
501+
.save()
502+
503+
val want = expectedDt.getSimpleName
504+
val got = df.schema("duration").dataType
505+
assertTrue(s"expected Spark to pick $want but it was $got", expectedDt.isInstance(got))
506+
507+
SparkConnectorScalaSuiteIT.session().run(
508+
s"""MATCH (d:Dur {id: '$id'})
509+
|RETURN d.duration AS duration
510+
|""".stripMargin
511+
).single().get("duration")
512+
}
513+
514+
@Test
515+
def `should write nodes with native Neo4j durations when passed DayTimeIntervalType`(): Unit = {
516+
val got = writeAndGetInterval(classOf[DayTimeIntervalType], "INTERVAL '10 05:30:15.123' DAY TO SECOND")
517+
518+
assertTrue(
519+
"can convert SQL day to second interval",
520+
try {
521+
val _ = got.asIsoDuration()
522+
true
523+
} catch {
524+
case e: Uncoercible => false
525+
case e => throw e // passthrough other exceptions
526+
}
527+
)
528+
529+
assertEquals(0L, got.asIsoDuration().months)
530+
assertEquals(10L, got.asIsoDuration().days)
531+
assertEquals(19815L, got.asIsoDuration().seconds)
532+
assertEquals(123000000, got.asIsoDuration().nanoseconds)
533+
}
534+
535+
@Test
536+
def `should write nodes with native Neo4j durations when passed YearMonthIntervalType`(): Unit = {
537+
val got = writeAndGetInterval(classOf[YearMonthIntervalType], "INTERVAL '4-5' YEAR TO MONTH")
538+
539+
assertTrue(
540+
"can convert SQL year to month interval",
541+
try {
542+
val _ = got.asIsoDuration()
543+
true
544+
} catch {
545+
case e: Uncoercible => false
546+
case e => throw e // passthrough other exceptions
547+
}
548+
)
549+
550+
assertEquals(53L, got.asIsoDuration().months)
551+
assertEquals(0L, got.asIsoDuration().days)
552+
assertEquals(0L, got.asIsoDuration().seconds)
553+
assertEquals(0, got.asIsoDuration().nanoseconds)
554+
}
555+
556+
@Test
557+
def `should write nodes with native Neo4j durations when passed timestamp arithmetics`(): Unit = {
558+
val intervalQuery = "timestamp('2025-01-02 18:30:00.454') - timestamp('2024-01-01 00:00:00')"
559+
val got = writeAndGetInterval(classOf[DayTimeIntervalType], intervalQuery)
560+
561+
assertTrue(
562+
"can convert SQL day to second interval arithmetic",
563+
try {
564+
val _ = got.asIsoDuration()
565+
true
566+
} catch {
567+
case e: Uncoercible => false
568+
case e => throw e // passthrough other exceptions
569+
}
570+
)
571+
572+
assertEquals(0L, got.asIsoDuration().months) // DayTimeIntervalType never returns months
573+
assertEquals(367L, got.asIsoDuration().days) // can it capture the leap day?!
574+
assertEquals(66600L, got.asIsoDuration().seconds)
575+
assertEquals(454000000, got.asIsoDuration().nanoseconds)
576+
}
577+
578+
private def writeAndQueryIfIsDurationArray(
579+
expectedInnerDt: Class[_ <: DataType],
580+
elemsSql: Seq[String]
581+
): Boolean = {
582+
val id = java.util.UUID.randomUUID().toString
583+
val sqlArray = elemsSql.mkString("array(", ", ", ")")
584+
val df = sparkSession.sql(s"SELECT '$id' AS id, $sqlArray AS durations")
585+
586+
df.write
587+
.format(classOf[DataSource].getName)
588+
.mode(SaveMode.Append)
589+
.option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl)
590+
.option("labels", "DurArr")
591+
.save()
592+
593+
val got = df.schema("durations").dataType
594+
val isTypeOk = got match {
595+
case ArrayType(et, _) if expectedInnerDt.isInstance(et) => true
596+
case _ => false
597+
}
598+
assertTrue(
599+
s"expected Spark to infer ArrayType(${expectedInnerDt.getSimpleName}) but it was $got",
600+
isTypeOk
601+
)
602+
603+
val result = SparkConnectorScalaSuiteIT.session().run(
604+
s"""MATCH (d:DurArr {id: '$id'})
605+
|RETURN d.durations AS durations
606+
|""".stripMargin
607+
).single().get("durations")
608+
609+
try {
610+
val _ = result.asList((v: Value) => v.asIsoDuration())
611+
} catch {
612+
case _: Uncoercible => return false
613+
case e => throw e
614+
}
615+
616+
true
617+
}
618+
619+
@Test
620+
def `should write interval day second arrays as native neo4j durations`(): Unit = {
621+
assertTrue(
622+
"can convert array<DAY TO SECOND> to IsoDuration[]",
623+
writeAndQueryIfIsDurationArray(
624+
classOf[DayTimeIntervalType],
625+
Seq(
626+
"INTERVAL '10 05:30:15.123' DAY TO SECOND",
627+
"INTERVAL '0 00:00:01.000' DAY TO SECOND"
628+
)
629+
)
630+
)
631+
}
632+
633+
@Test
634+
def `should write interval year month arrays as native neo4j durations`(): Unit = {
635+
assertTrue(
636+
"can convert array<YEAR TO MONTH> to IsoDuration[]",
637+
writeAndQueryIfIsDurationArray(
638+
classOf[YearMonthIntervalType],
639+
Seq(
640+
"INTERVAL '1-02' YEAR TO MONTH",
641+
"INTERVAL '0-11' YEAR TO MONTH"
642+
)
643+
)
644+
)
645+
}
646+
647+
@Test
648+
def `should write interval arithmetic arrays as native neo4j durations`(): Unit = {
649+
assertTrue(
650+
"can convert array of interval arithmetic to IsoDuration[]",
651+
writeAndQueryIfIsDurationArray(
652+
classOf[DayTimeIntervalType],
653+
Seq(
654+
"timestamp('2024-01-02 00:00:00') - timestamp('2024-01-01 00:00:00')",
655+
"current_timestamp() - timestamp('2024-01-01 00:00:00')"
656+
)
657+
)
658+
)
659+
}
660+
487661
@Test
488662
def `should write nodes into Neo4j with points`(): Unit = {
489663
val total = 10

0 commit comments

Comments
 (0)