@@ -21,6 +21,10 @@ import org.apache.spark.SparkException
2121import org .apache .spark .sql .DataFrame
2222import org .apache .spark .sql .SaveMode
2323import org .apache .spark .sql .SparkSession
24+ import org .apache .spark .sql .types .ArrayType
25+ import org .apache .spark .sql .types .DataType
26+ import org .apache .spark .sql .types .DayTimeIntervalType
27+ import org .apache .spark .sql .types .YearMonthIntervalType
2428import org .junit
2529import org .junit .Assert ._
2630import org .junit .Ignore
@@ -30,6 +34,7 @@ import org.neo4j.driver.Transaction
3034import org .neo4j .driver .TransactionWork
3135import org .neo4j .driver .Value
3236import org .neo4j .driver .exceptions .ClientException
37+ import org .neo4j .driver .exceptions .value .Uncoercible
3338import org .neo4j .driver .internal .InternalPoint2D
3439import org .neo4j .driver .internal .InternalPoint3D
3540import org .neo4j .driver .internal .types .InternalTypeSystem
@@ -74,7 +79,7 @@ case class SimplePerson(name: String, surname: String)
7479case class EmptyRow [T ](data : T )
7580
7681class DataSourceWriterTSE extends SparkConnectorScalaBaseTSE {
77- val sparkSession = SparkSession .builder().getOrCreate()
82+ val sparkSession : SparkSession = SparkSession .builder().getOrCreate()
7883
7984 import sparkSession .implicits ._
8085
@@ -414,7 +419,7 @@ class DataSourceWriterTSE extends SparkConnectorScalaBaseTSE {
414419 }
415420
416421 @ Test
417- def `should write nodes with duration values into Neo4j` (): Unit = {
422+ def `should write nodes with duration values into Neo4j from struct ` (): Unit = {
418423 val total = 10
419424 val ds = (1 to total)
420425 .map(i => i.toLong)
@@ -430,10 +435,10 @@ class DataSourceWriterTSE extends SparkConnectorScalaBaseTSE {
430435
431436 val records = SparkConnectorScalaSuiteIT .session().run(
432437 """ MATCH (p:BeanWithDuration)
433- |RETURN p.data AS data
438+ |RETURN p.data AS duration
434439 |""" .stripMargin
435440 ).list().asScala
436- .map(r => r.get(" data " ).asIsoDuration())
441+ .map(r => r.get(" duration " ).asIsoDuration())
437442 .map(data => (data.months, data.days, data.seconds, data.nanoseconds))
438443 .toSet
439444
@@ -445,7 +450,7 @@ class DataSourceWriterTSE extends SparkConnectorScalaBaseTSE {
445450 }
446451
447452 @ Test
448- def `should write nodes with duration array values into Neo4j` (): Unit = {
453+ def `should write nodes with duration array values into Neo4j from struct ` (): Unit = {
449454 val total = 10
450455 val ds = (1 to total)
451456 .map(i => i.toLong)
@@ -484,6 +489,175 @@ class DataSourceWriterTSE extends SparkConnectorScalaBaseTSE {
484489 assertEquals(expected, records)
485490 }
486491
492+ private def writeAndGetInterval (expectedDt : Class [_ <: DataType ], sql : String ): Value = {
493+ val id = java.util.UUID .randomUUID().toString
494+ val df = sparkSession.sql(s " SELECT ' $id' AS id, $sql AS duration " )
495+
496+ df.write
497+ .format(classOf [DataSource ].getName)
498+ .mode(SaveMode .Append )
499+ .option(" url" , SparkConnectorScalaSuiteIT .server.getBoltUrl)
500+ .option(" labels" , " Dur" )
501+ .save()
502+
503+ val want = expectedDt.getSimpleName
504+ val got = df.schema(" duration" ).dataType
505+ assertTrue(s " expected Spark to pick $want but it was $got" , expectedDt.isInstance(got))
506+
507+ SparkConnectorScalaSuiteIT .session().run(
508+ s """ MATCH (d:Dur {id: ' $id'})
509+ |RETURN d.duration AS duration
510+ | """ .stripMargin
511+ ).single().get(" duration" )
512+ }
513+
514+ @ Test
515+ def `should write nodes with native Neo4j durations when passed DayTimeIntervalType` (): Unit = {
516+ val got = writeAndGetInterval(classOf [DayTimeIntervalType ], " INTERVAL '10 05:30:15.123' DAY TO SECOND" )
517+
518+ assertTrue(
519+ " can convert SQL day to second interval" ,
520+ try {
521+ val _ = got.asIsoDuration()
522+ true
523+ } catch {
524+ case e : Uncoercible => false
525+ case e => throw e // passthrough other exceptions
526+ }
527+ )
528+
529+ assertEquals(0L , got.asIsoDuration().months)
530+ assertEquals(10L , got.asIsoDuration().days)
531+ assertEquals(19815L , got.asIsoDuration().seconds)
532+ assertEquals(123000000 , got.asIsoDuration().nanoseconds)
533+ }
534+
535+ @ Test
536+ def `should write nodes with native Neo4j durations when passed YearMonthIntervalType` (): Unit = {
537+ val got = writeAndGetInterval(classOf [YearMonthIntervalType ], " INTERVAL '4-5' YEAR TO MONTH" )
538+
539+ assertTrue(
540+ " can convert SQL year to month interval" ,
541+ try {
542+ val _ = got.asIsoDuration()
543+ true
544+ } catch {
545+ case e : Uncoercible => false
546+ case e => throw e // passthrough other exceptions
547+ }
548+ )
549+
550+ assertEquals(53L , got.asIsoDuration().months)
551+ assertEquals(0L , got.asIsoDuration().days)
552+ assertEquals(0L , got.asIsoDuration().seconds)
553+ assertEquals(0 , got.asIsoDuration().nanoseconds)
554+ }
555+
556+ @ Test
557+ def `should write nodes with native Neo4j durations when passed timestamp arithmetics` (): Unit = {
558+ val intervalQuery = " timestamp('2025-01-02 18:30:00.454') - timestamp('2024-01-01 00:00:00')"
559+ val got = writeAndGetInterval(classOf [DayTimeIntervalType ], intervalQuery)
560+
561+ assertTrue(
562+ " can convert SQL day to second interval arithmetic" ,
563+ try {
564+ val _ = got.asIsoDuration()
565+ true
566+ } catch {
567+ case e : Uncoercible => false
568+ case e => throw e // passthrough other exceptions
569+ }
570+ )
571+
572+ assertEquals(0L , got.asIsoDuration().months) // DayTimeIntervalType never returns months
573+ assertEquals(367L , got.asIsoDuration().days) // can it capture the leap day?!
574+ assertEquals(66600L , got.asIsoDuration().seconds)
575+ assertEquals(454000000 , got.asIsoDuration().nanoseconds)
576+ }
577+
578+ private def writeAndQueryIfIsDurationArray (
579+ expectedInnerDt : Class [_ <: DataType ],
580+ elemsSql : Seq [String ]
581+ ): Boolean = {
582+ val id = java.util.UUID .randomUUID().toString
583+ val sqlArray = elemsSql.mkString(" array(" , " , " , " )" )
584+ val df = sparkSession.sql(s " SELECT ' $id' AS id, $sqlArray AS durations " )
585+
586+ df.write
587+ .format(classOf [DataSource ].getName)
588+ .mode(SaveMode .Append )
589+ .option(" url" , SparkConnectorScalaSuiteIT .server.getBoltUrl)
590+ .option(" labels" , " DurArr" )
591+ .save()
592+
593+ val got = df.schema(" durations" ).dataType
594+ val isTypeOk = got match {
595+ case ArrayType (et, _) if expectedInnerDt.isInstance(et) => true
596+ case _ => false
597+ }
598+ assertTrue(
599+ s " expected Spark to infer ArrayType( ${expectedInnerDt.getSimpleName}) but it was $got" ,
600+ isTypeOk
601+ )
602+
603+ val result = SparkConnectorScalaSuiteIT .session().run(
604+ s """ MATCH (d:DurArr {id: ' $id'})
605+ |RETURN d.durations AS durations
606+ | """ .stripMargin
607+ ).single().get(" durations" )
608+
609+ try {
610+ val _ = result.asList((v : Value ) => v.asIsoDuration())
611+ } catch {
612+ case _ : Uncoercible => return false
613+ case e => throw e
614+ }
615+
616+ true
617+ }
618+
619+ @ Test
620+ def `should write interval day second arrays as native neo4j durations` (): Unit = {
621+ assertTrue(
622+ " can convert array<DAY TO SECOND> to IsoDuration[]" ,
623+ writeAndQueryIfIsDurationArray(
624+ classOf [DayTimeIntervalType ],
625+ Seq (
626+ " INTERVAL '10 05:30:15.123' DAY TO SECOND" ,
627+ " INTERVAL '0 00:00:01.000' DAY TO SECOND"
628+ )
629+ )
630+ )
631+ }
632+
633+ @ Test
634+ def `should write interval year month arrays as native neo4j durations` (): Unit = {
635+ assertTrue(
636+ " can convert array<YEAR TO MONTH> to IsoDuration[]" ,
637+ writeAndQueryIfIsDurationArray(
638+ classOf [YearMonthIntervalType ],
639+ Seq (
640+ " INTERVAL '1-02' YEAR TO MONTH" ,
641+ " INTERVAL '0-11' YEAR TO MONTH"
642+ )
643+ )
644+ )
645+ }
646+
647+ @ Test
648+ def `should write interval arithmetic arrays as native neo4j durations` (): Unit = {
649+ assertTrue(
650+ " can convert array of interval arithmetic to IsoDuration[]" ,
651+ writeAndQueryIfIsDurationArray(
652+ classOf [DayTimeIntervalType ],
653+ Seq (
654+ " timestamp('2024-01-02 00:00:00') - timestamp('2024-01-01 00:00:00')" ,
655+ " current_timestamp() - timestamp('2024-01-01 00:00:00')"
656+ )
657+ )
658+ )
659+ }
660+
487661 @ Test
488662 def `should write nodes into Neo4j with points` (): Unit = {
489663 val total = 10
0 commit comments