Skip to content

Commit e647d77

Browse files
#patch Adding support for SdkBindingData[scala.Option[_]] (#308)
* Adding support for SdkBindingData[scala.Option[_]] Signed-off-by: Jonathan Schuchart <jschuchart@spotify.com> * Fixing handling of product element names in scala 2.12 Signed-off-by: Jonathan Schuchart <jschuchart@spotify.com> --------- Signed-off-by: Jonathan Schuchart <jschuchart@spotify.com>
1 parent 806d894 commit e647d77

File tree

4 files changed

+115
-29
lines changed

4 files changed

+115
-29
lines changed

flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkScalaTypeTest.scala

Lines changed: 78 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,9 @@ class SdkScalaTypeTest {
9595
datetime: SdkBindingData[Instant],
9696
duration: SdkBindingData[Duration],
9797
blob: SdkBindingData[Blob],
98-
generic: SdkBindingData[ScalarNested]
98+
generic: SdkBindingData[ScalarNested],
99+
none: SdkBindingData[Option[String]],
100+
some: SdkBindingData[Option[String]]
99101
)
100102

101103
case class CollectionInput(
@@ -105,7 +107,8 @@ class SdkScalaTypeTest {
105107
booleans: SdkBindingData[List[Boolean]],
106108
datetimes: SdkBindingData[List[Instant]],
107109
durations: SdkBindingData[List[Duration]],
108-
generics: SdkBindingData[List[ScalarNested]]
110+
generics: SdkBindingData[List[ScalarNested]],
111+
options: SdkBindingData[List[Option[String]]]
109112
)
110113

111114
case class MapInput(
@@ -115,7 +118,8 @@ class SdkScalaTypeTest {
115118
booleanMap: SdkBindingData[Map[String, Boolean]],
116119
datetimeMap: SdkBindingData[Map[String, Instant]],
117120
durationMap: SdkBindingData[Map[String, Duration]],
118-
genericMap: SdkBindingData[Map[String, ScalarNested]]
121+
genericMap: SdkBindingData[Map[String, ScalarNested]],
122+
optionMap: SdkBindingData[Map[String, Option[String]]]
119123
)
120124

121125
case class ComplexInput(
@@ -196,7 +200,9 @@ class SdkScalaTypeTest {
196200
.literalType(LiteralType.ofBlobType(BlobType.DEFAULT))
197201
.description("")
198202
.build(),
199-
"generic" -> createVar(SimpleType.STRUCT)
203+
"generic" -> createVar(SimpleType.STRUCT),
204+
"none" -> createVar(SimpleType.STRUCT),
205+
"some" -> createVar(SimpleType.STRUCT)
200206
).asJava
201207

202208
val output = SdkScalaType[ScalarInput].getVariableMap
@@ -274,6 +280,16 @@ class SdkScalaTypeTest {
274280
).asJava
275281
)
276282
)
283+
),
284+
"none" -> Literal.ofScalar(
285+
Scalar.ofGeneric(
286+
Struct.of(Map.empty[String, Struct.Value].asJava)
287+
)
288+
),
289+
"some" -> Literal.ofScalar(
290+
Scalar.ofGeneric(
291+
Struct.of(Map("value" -> Struct.Value.ofStringValue("hello")).asJava)
292+
)
277293
)
278294
).asJava
279295

@@ -295,6 +311,14 @@ class SdkScalaTypeTest {
295311
List(ScalarNestedNested("foo", Some("bar"))),
296312
Map("foo" -> ScalarNestedNested("foo", Some("bar")))
297313
)
314+
),
315+
none = SdkBindingDataFactory.of(
316+
SdkLiteralTypes.generics[Option[String]](),
317+
Option(null)
318+
),
319+
some = SdkBindingDataFactory.of(
320+
SdkLiteralTypes.generics[Option[String]](),
321+
Option("hello")
298322
)
299323
)
300324

@@ -323,7 +347,11 @@ class SdkScalaTypeTest {
323347
List(ScalarNestedNested("foo", Some("bar"))),
324348
Map("foo" -> ScalarNestedNested("foo", Some("bar")))
325349
)
326-
)
350+
),
351+
none =
352+
SdkBindingDataFactory.of(SdkLiteralTypes.generics(), Option(null)),
353+
some =
354+
SdkBindingDataFactory.of(SdkLiteralTypes.generics(), Option("hello"))
327355
)
328356

329357
val expected = Map(
@@ -399,6 +427,23 @@ class SdkScalaTypeTest {
399427
).asJava
400428
)
401429
)
430+
),
431+
"none" -> Literal.ofScalar(
432+
Scalar.ofGeneric(
433+
Struct.of(
434+
Map(__TYPE -> Struct.Value.ofStringValue("scala.None$")).asJava
435+
)
436+
)
437+
),
438+
"some" -> Literal.ofScalar(
439+
Scalar.ofGeneric(
440+
Struct.of(
441+
Map(
442+
"value" -> Struct.Value.ofStringValue("hello"),
443+
__TYPE -> Struct.Value.ofStringValue("scala.Some")
444+
).asJava
445+
)
446+
)
402447
)
403448
).asJava
404449

@@ -416,7 +461,8 @@ class SdkScalaTypeTest {
416461
"booleans" -> createCollectionVar(SimpleType.BOOLEAN),
417462
"datetimes" -> createCollectionVar(SimpleType.DATETIME),
418463
"durations" -> createCollectionVar(SimpleType.DURATION),
419-
"generics" -> createCollectionVar(SimpleType.STRUCT)
464+
"generics" -> createCollectionVar(SimpleType.STRUCT),
465+
"options" -> createCollectionVar(SimpleType.STRUCT)
420466
).asJava
421467

422468
val output = SdkScalaType[CollectionInput].getVariableMap
@@ -443,6 +489,14 @@ class SdkScalaTypeTest {
443489
List(ScalarNestedNested("foo", Some("bar"))),
444490
Map("foo" -> ScalarNestedNested("foo", Some("bar")))
445491
)
492+
),
493+
none = SdkBindingDataFactory.of(
494+
SdkLiteralTypes.generics[Option[String]](),
495+
Option(null)
496+
),
497+
some = SdkBindingDataFactory.of(
498+
SdkLiteralTypes.generics[Option[String]](),
499+
Option("hello")
446500
)
447501
)
448502

@@ -465,6 +519,14 @@ class SdkScalaTypeTest {
465519
List(ScalarNestedNested("foo", Some("bar"))),
466520
Map("foo" -> ScalarNestedNested("foo", Some("bar")))
467521
)
522+
),
523+
"none" -> SdkBindingDataFactory.of(
524+
SdkLiteralTypes.generics[Option[String]](),
525+
Option(null)
526+
),
527+
"some" -> SdkBindingDataFactory.of(
528+
SdkLiteralTypes.generics[Option[String]](),
529+
Option("hello")
468530
)
469531
).asJava
470532

@@ -531,6 +593,10 @@ class SdkScalaTypeTest {
531593
Map("foo2" -> ScalarNestedNested("foo2", Some("bar2")))
532594
)
533595
)
596+
),
597+
options = SdkBindingDataFactory.of(
598+
SdkLiteralTypes.generics[Option[String]](),
599+
List(Option("hello"), Option(null))
534600
)
535601
)
536602

@@ -550,7 +616,8 @@ class SdkScalaTypeTest {
550616
"booleanMap" -> createMapVar(SimpleType.BOOLEAN),
551617
"datetimeMap" -> createMapVar(SimpleType.DATETIME),
552618
"durationMap" -> createMapVar(SimpleType.DURATION),
553-
"genericMap" -> createMapVar(SimpleType.STRUCT)
619+
"genericMap" -> createMapVar(SimpleType.STRUCT),
620+
"optionMap" -> createMapVar(SimpleType.STRUCT)
554621
).asJava
555622

556623
val output = SdkScalaType[MapInput].getVariableMap
@@ -598,6 +665,10 @@ class SdkScalaTypeTest {
598665
Map("foo2" -> ScalarNestedNested("foo2", Some("bar2")))
599666
)
600667
)
668+
),
669+
optionMap = SdkBindingDataFactory.of(
670+
SdkLiteralTypes.generics[Option[String]](),
671+
Map("none" -> Option(null), "some" -> Option("hello"))
601672
)
602673
)
603674

flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkLiteralTypes.scala

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import scala.reflect.api.{Mirror, TypeCreator, Universe}
2828
import scala.reflect.runtime.universe
2929
import scala.reflect.{ClassTag, classTag}
3030
import scala.reflect.runtime.universe.{
31+
ClassSymbol,
3132
NoPrefix,
3233
Symbol,
3334
Type,
@@ -72,7 +73,7 @@ object SdkLiteralTypes {
7273
blobs(BlobType.DEFAULT).asInstanceOf[SdkLiteralType[T]]
7374
case t if t =:= typeOf[Binary] =>
7475
binary().asInstanceOf[SdkLiteralType[T]]
75-
case t if t <:< typeOf[Product] && !(t =:= typeOf[Option[_]]) =>
76+
case t if t <:< typeOf[Product] =>
7677
generics().asInstanceOf[SdkLiteralType[T]]
7778

7879
case t if t =:= typeOf[List[Long]] =>
@@ -391,24 +392,37 @@ object SdkLiteralTypes {
391392
)
392393
}
393394

394-
val clazz = typeOf[S].typeSymbol.asClass
395-
val classMirror = mirror.reflectClass(clazz)
396-
val constructor = typeOf[S].decl(termNames.CONSTRUCTOR).asMethod
397-
val constructorMirror = classMirror.reflectConstructor(constructor)
398-
399-
val constructorArgs =
400-
constructor.paramLists.flatten.map((param: Symbol) => {
401-
val paramName = param.name.toString
402-
val value = map.getOrElse(
403-
paramName,
404-
throw new IllegalArgumentException(
405-
s"Map is missing required parameter named $paramName"
395+
def instantiateViaConstructor(cls: ClassSymbol): S = {
396+
val classMirror = mirror.reflectClass(cls)
397+
val constructor = typeOf[S].decl(termNames.CONSTRUCTOR).asMethod
398+
val constructorMirror = classMirror.reflectConstructor(constructor)
399+
400+
val constructorArgs =
401+
constructor.paramLists.flatten.map((param: Symbol) => {
402+
val paramName = param.name.toString
403+
val value = map.getOrElse(
404+
paramName,
405+
throw new IllegalArgumentException(
406+
s"Map is missing required parameter named $paramName"
407+
)
406408
)
407-
)
408-
valueToParamValue(value, param.typeSignature.dealias)
409-
})
409+
valueToParamValue(value, param.typeSignature.dealias)
410+
})
411+
412+
constructorMirror(constructorArgs: _*).asInstanceOf[S]
413+
}
414+
415+
val clazz = typeOf[S].typeSymbol.asClass
416+
// special handling of scala.Option as it is a Product, but can't be instantiated like common
417+
// case classes
418+
if (clazz.name.toString == "Option")
419+
map
420+
.get("value")
421+
.map(valueToParamValue(_, typeOf[S].typeArgs.head))
422+
.asInstanceOf[S]
423+
else
424+
instantiateViaConstructor(clazz)
410425

411-
constructorMirror(constructorArgs: _*).asInstanceOf[S]
412426
}
413427

414428
def structValueToAny(value: Struct.Value): Any = {

flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkScalaType.scala

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -232,11 +232,8 @@ object SdkScalaType {
232232
implicit def durationLiteralType: SdkScalaLiteralType[Duration] =
233233
DelegateLiteralType(SdkLiteralTypes.durations())
234234

235-
// more specific matching to fail the usage of SdkBindingData[Option[_]]
236-
implicit def optionLiteralType: SdkScalaLiteralType[Option[_]] = ???
237-
238235
// fixme: using Product is just an approximation for case class because Product
239-
// is also super class of, for example, Option and Tuple
236+
// is also super class of, for example, Either or Try
240237
implicit def productLiteralType[T <: Product: TypeTag: ClassTag]
241238
: SdkScalaLiteralType[T] =
242239
DelegateLiteralType(SdkLiteralTypes.generics())

flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/package.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,11 @@ package object flytekitscala {
3030
} catch {
3131
case _: Throwable =>
3232
// fall back to java's way, less reliable and with limitations
33-
product.getClass.getDeclaredFields.map(_.getName).toList
33+
val methodNames = product.getClass.getDeclaredMethods.map(_.getName)
34+
product.getClass.getDeclaredFields
35+
.map(_.getName)
36+
.filter(methodNames.contains)
37+
.toList
3438
}
3539
}
3640
}

0 commit comments

Comments
 (0)