Skip to content

Commit 8e1c5fa

Browse files
authored
Add missing support of discriminator fields for generic records (#772)
* Add missing support of discriminator fields for generic records * Fix pretty printing of non-simple enums
1 parent 5694f0d commit 8e1c5fa

File tree

2 files changed

+129
-28
lines changed

2 files changed

+129
-28
lines changed

zio-schema-json/shared/src/main/scala/zio/schema/codec/JsonCodec.scala

Lines changed: 55 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ object JsonCodec {
4242
val default: Config = Config(ignoreEmptyCollections = false)
4343
}
4444

45-
type DiscriminatorTuple = Option[(discriminatorName, String)]
45+
type DiscriminatorTuple = Option[(String, String)]
4646

4747
implicit def zioJsonBinaryCodec[A](implicit jsonCodec: ZJsonCodec[A]): BinaryCodec[A] =
4848
new BinaryCodec[A] {
@@ -338,11 +338,11 @@ object JsonCodec {
338338
case Schema.NonEmptyMap(ks: Schema[kt], vs: Schema[vt], _) => mapEncoder(ks, vs, discriminatorTuple, cfg).contramap[NonEmptyMap[kt, vt]](_.toMap.asInstanceOf[Map[kt, vt]]).asInstanceOf[ZJsonEncoder[A]]
339339
case Schema.Set(s, _) =>
340340
ZJsonEncoder.chunk(schemaEncoder(s, cfg, discriminatorTuple)).contramap(m => Chunk.fromIterable(m))
341-
case Schema.Transform(c, _, g, a, _) => transformEncoder(a.foldLeft(c)((s, a) => s.annotate(a)), g, cfg)
341+
case Schema.Transform(c, _, g, a, _) => transformEncoder(a.foldLeft(c)((s, a) => s.annotate(a)), g, cfg, discriminatorTuple)
342342
case Schema.Tuple2(l, r, _) => ZJsonEncoder.tuple2(schemaEncoder(l, cfg, discriminatorTuple), schemaEncoder(r, cfg, discriminatorTuple))
343343
case Schema.Optional(schema, _) => ZJsonEncoder.option(schemaEncoder(schema, cfg, discriminatorTuple))
344344
case Schema.Fail(_, _) => unitEncoder.contramap(_ => ())
345-
case s @ Schema.GenericRecord(_, _, _) => recordEncoder(s, cfg)
345+
case s: Schema.GenericRecord => recordEncoder(s, cfg, discriminatorTuple)
346346
case Schema.Either(left, right, _) => ZJsonEncoder.either(schemaEncoder(left, cfg, discriminatorTuple), schemaEncoder(right, cfg, discriminatorTuple))
347347
case Schema.Fallback(left, right, _, _) => fallbackEncoder(schemaEncoder(left, cfg, discriminatorTuple), schemaEncoder(right, cfg, discriminatorTuple))
348348
case l @ Schema.Lazy(_) => ZJsonEncoder.suspend(schemaEncoder(l.schema, cfg, discriminatorTuple))
@@ -471,9 +471,14 @@ object JsonCodec {
471471
}
472472
}
473473

474-
private def transformEncoder[A, B](schema: Schema[A], g: B => Either[String, A], cfg: Config): ZJsonEncoder[B] =
474+
private def transformEncoder[A, B](
475+
schema: Schema[A],
476+
g: B => Either[String, A],
477+
cfg: Config,
478+
discriminatorTuple: DiscriminatorTuple
479+
): ZJsonEncoder[B] =
475480
new ZJsonEncoder[B] {
476-
private lazy val innerEncoder = schemaEncoder(schema, cfg)
481+
private lazy val innerEncoder = schemaEncoder(schema, cfg, discriminatorTuple)
477482

478483
override def unsafeEncode(b: B, indent: Option[Int], out: Write): Unit =
479484
g(b) match {
@@ -508,12 +513,13 @@ object JsonCodec {
508513
val noDiscriminators = schema.noDiscriminator
509514
val discriminatorTuple =
510515
if (noDiscriminators) None
511-
else schema.annotations.collectFirst { case d: discriminatorName => (d, caseName) }
516+
else schema.annotations.collectFirst { case d: discriminatorName => (d.tag, caseName) }
512517
val doJsonObjectWrapping = discriminatorTuple.isEmpty && !noDiscriminators
513-
if (doJsonObjectWrapping) out.write('{')
514-
val indent_ = bump(indent)
515-
pad(indent_, out)
518+
var indent_ = indent
516519
if (doJsonObjectWrapping) {
520+
out.write('{')
521+
indent_ = bump(indent)
522+
pad(indent_, out)
517523
string.encoder.unsafeEncode(caseName, indent_, out)
518524
if (indent.isEmpty) out.write(':')
519525
else out.write(" : ")
@@ -524,8 +530,11 @@ object JsonCodec {
524530
catch {
525531
case ex if NonFatal(ex) => throw new RuntimeException(s"Failed to encode enum type $schema", ex)
526532
}
527-
}, indent, out)
528-
if (doJsonObjectWrapping) out.write('}')
533+
}, indent_, out)
534+
if (doJsonObjectWrapping) {
535+
pad(indent, out)
536+
out.write('}')
537+
}
529538
case _ =>
530539
out.write("{}") // for transient cases
531540
}
@@ -551,7 +560,11 @@ object JsonCodec {
551560
}
552561
}
553562

554-
private def recordEncoder(schema: Schema.GenericRecord, cfg: Config): ZJsonEncoder[ListMap[String, _]] = {
563+
private def recordEncoder(
564+
schema: Schema.GenericRecord,
565+
cfg: Config,
566+
discriminatorTuple: DiscriminatorTuple
567+
): ZJsonEncoder[ListMap[String, _]] = {
555568
val nonTransientFields = schema.nonTransientFields.toArray
556569
val encoders = nonTransientFields.map(field => schemaEncoder(field.schema.asInstanceOf[Schema[Any]], cfg))
557570
if (nonTransientFields.isEmpty) { (_: ListMap[String, _], _: Option[Int], out: Write) =>
@@ -567,9 +580,17 @@ object JsonCodec {
567580
}
568581
val strEnc = string.encoder
569582
var first = true
570-
var i = 0
571-
while (i < nonTransientFields.length) {
572-
val field = nonTransientFields(i)
583+
if (discriminatorTuple ne None) {
584+
val tuple = discriminatorTuple.get
585+
first = false
586+
strEnc.unsafeEncode(tuple._1, indent_, out)
587+
if (doPrettyPrint) out.write(" : ")
588+
else out.write(':')
589+
strEnc.unsafeEncode(tuple._2, indent_, out)
590+
}
591+
var idx = 0
592+
while (idx < nonTransientFields.length) {
593+
val field = nonTransientFields(idx)
573594
val fieldName = field.fieldName
574595
val fieldValue = value(fieldName)
575596
if (!isEmptyOptionalValue(field, fieldValue, cfg)) {
@@ -581,9 +602,9 @@ object JsonCodec {
581602
strEnc.unsafeEncode(fieldName, indent_, out)
582603
if (doPrettyPrint) out.write(" : ")
583604
else out.write(':')
584-
encoders(i).unsafeEncode(fieldValue, indent_, out)
605+
encoders(idx).unsafeEncode(fieldValue, indent_, out)
585606
}
586-
i += 1
607+
idx += 1
587608
}
588609
if (doPrettyPrint) pad(indent, out)
589610
out.write('}')
@@ -683,7 +704,7 @@ object JsonCodec {
683704
case Schema.NonEmptyMap(ks, vs, _) => mapDecoder(ks, vs).mapOrFail(m => NonEmptyMap.fromMapOption(m).toRight("NonEmptyMap expected"))
684705
case Schema.Set(s, _) => ZJsonDecoder.chunk(schemaDecoder(s)).map(entries => entries.toSet)
685706
case Schema.Fail(message, _) => failDecoder(message)
686-
case s @ Schema.GenericRecord(_, _, _) => recordDecoder(s)
707+
case s: Schema.GenericRecord => recordDecoder(s, discriminator)
687708
case Schema.Either(left, right, _) => ZJsonDecoder.either(schemaDecoder(left), schemaDecoder(right))
688709
case s @ Schema.Fallback(_, _, _, _) => fallbackDecoder(s)
689710
case l @ Schema.Lazy(_) => ZJsonDecoder.suspend(schemaDecoder(l.schema, discriminator))
@@ -949,7 +970,10 @@ object JsonCodec {
949970
}
950971
}
951972

952-
private def recordDecoder(schema: GenericRecord): ZJsonDecoder[ListMap[String, Any]] = {
973+
private def recordDecoder(
974+
schema: GenericRecord,
975+
discriminator: Option[String]
976+
): ZJsonDecoder[ListMap[String, Any]] = {
953977
val capacity = schema.fields.size * 2
954978
val spansWithDecoders =
955979
new util.HashMap[String, (JsonError.ObjectAccess, ZJsonDecoder[Any])](capacity)
@@ -963,12 +987,15 @@ object JsonCodec {
963987
defaults.put(fieldName, field.defaultValue.get)
964988
}
965989
}
966-
val rejectAdditionalFields = schema.annotations.exists(_.isInstanceOf[rejectExtraFields])
990+
val skipExtraFields = !schema.annotations.exists(_.isInstanceOf[rejectExtraFields])
967991
(trace: List[JsonError], in: RetractReader) => {
968-
val lexer = Lexer
969-
lexer.char(trace, in, '{')
970-
val map = new util.HashMap[String, Any]
971-
var continue = lexer.firstField(trace, in)
992+
val lexer = Lexer
993+
var continue = true
994+
if (discriminator eq None) {
995+
lexer.char(trace, in, '{')
996+
continue = lexer.firstField(trace, in)
997+
}
998+
val map = new util.HashMap[String, Any]
972999
while (continue) {
9731000
val fieldNameOrAlias = lexer.string(trace, in).toString
9741001
val spanWithDecoder = spansWithDecoders.get(fieldNameOrAlias)
@@ -982,11 +1009,11 @@ object JsonCodec {
9821009
if (prev != null) {
9831010
throw UnsafeJson(JsonError.Message("duplicate") :: trace_)
9841011
}
985-
} else if (rejectAdditionalFields) {
986-
throw UnsafeJson(JsonError.Message(s"unexpected field: $fieldNameOrAlias") :: trace)
987-
} else {
1012+
} else if (skipExtraFields || discriminator.contains(fieldNameOrAlias)) {
9881013
lexer.char(trace, in, ':')
9891014
lexer.skipValue(trace, in)
1015+
} else {
1016+
throw UnsafeJson(JsonError.Message(s"unexpected field: $fieldNameOrAlias") :: trace)
9901017
}
9911018
continue = lexer.nextField(trace, in)
9921019
}
@@ -1108,7 +1135,7 @@ object JsonCodec {
11081135
if (discriminatorTuple ne None) {
11091136
val tuple = discriminatorTuple.get
11101137
first = false
1111-
strEnc.unsafeEncode(tuple._1.tag, indent_, out)
1138+
strEnc.unsafeEncode(tuple._1, indent_, out)
11121139
if (doPrettyPrint) out.write(" : ")
11131140
else out.write(':')
11141141
strEnc.unsafeEncode(tuple._2, indent_, out)

zio-schema-json/shared/src/test/scala/zio/schema/codec/JsonCodecSpec.scala

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,41 @@ object JsonCodecSpec extends ZIOSpecDefault {
347347
charSequenceToByteChunk("""{"type":"unlimited"}""")
348348
)
349349
},
350+
test("pretty printing with discriminator field") {
351+
val json = JsonCodec
352+
.jsonCodec(Schema[OneOf4])
353+
.encoder
354+
.encodeJson(
355+
RecordExampleWithDiscriminator(f1 = Some("test"), f2 = None),
356+
Some(0)
357+
)
358+
assert(json)(
359+
equalTo(
360+
"""{
361+
| "type" : "RecordExampleWithDiscriminator",
362+
| "$f1" : "test"
363+
|}""".stripMargin
364+
)
365+
)
366+
},
367+
test("pretty printing with discriminator key") {
368+
val json = JsonCodec
369+
.jsonCodec(Schema[OneOf])
370+
.encoder
371+
.encodeJson(
372+
StringValue("test"),
373+
Some(0)
374+
)
375+
assert(json)(
376+
equalTo(
377+
"""{
378+
| "StringValue" : {
379+
| "value" : "test"
380+
| }
381+
|}""".stripMargin
382+
)
383+
)
384+
},
350385
suite("with no discriminator")(
351386
test("example 1") {
352387
assertEncodes(
@@ -1637,6 +1672,12 @@ object JsonCodecSpec extends ZIOSpecDefault {
16371672
Enumeration3(`StringValue3-Backticked`("foo", "bar"))
16381673
)
16391674
},
1675+
test("ADT with generic records and discriminator field") {
1676+
assertEncodesThenDecodes(
1677+
Schema[OneOf4],
1678+
RecordExampleWithDiscriminator(f1 = Some("test"), f2 = None)
1679+
)
1680+
},
16401681
test("of case classes with discriminator") {
16411682
assertEncodesThenDecodes(Schema[Command], Command.Cash) &>
16421683
assertEncodesThenDecodes(Schema[Command], Command.Buy(100))
@@ -2354,6 +2395,39 @@ object JsonCodecSpec extends ZIOSpecDefault {
23542395
implicit lazy val schema: Schema[AllOptionalFields] = DeriveSchema.gen[AllOptionalFields]
23552396
}
23562397

2398+
@discriminatorName("type")
2399+
sealed trait OneOf4
2400+
2401+
object OneOf4 {
2402+
implicit lazy val schema: Schema[OneOf4] = DeriveSchema.gen
2403+
}
2404+
2405+
@rejectExtraFields case class RecordExampleWithDiscriminator(
2406+
@fieldName("$f1") f1: Option[String], // the only field that does not have a default value
2407+
@fieldNameAliases("field2") f2: Option[String] = None,
2408+
@transientField f3: Option[String] = None,
2409+
f4: Option[String] = None,
2410+
f5: Option[String] = None,
2411+
f6: Option[String] = None,
2412+
f7: Option[String] = None,
2413+
f8: Option[String] = None,
2414+
f9: Option[String] = None,
2415+
f10: Option[String] = None,
2416+
f11: Option[String] = None,
2417+
f12: Option[String] = None,
2418+
f13: Option[String] = None,
2419+
f14: Option[String] = None,
2420+
f15: Option[String] = None,
2421+
f16: Option[String] = None,
2422+
f17: Option[String] = None,
2423+
f18: Option[String] = None,
2424+
f19: Option[String] = None,
2425+
f20: Option[String] = None,
2426+
f21: Option[String] = None,
2427+
f22: Option[String] = None,
2428+
@fieldName("$f23") f23: Option[String] = None
2429+
) extends OneOf4
2430+
23572431
case class RecordExample(
23582432
@fieldName("$f1") f1: Option[String], // the only field that does not have a default value
23592433
@fieldNameAliases("field2") f2: Option[String] = None,

0 commit comments

Comments
 (0)