@@ -42,7 +42,7 @@ object JsonCodec {
42
42
val default : Config = Config (ignoreEmptyCollections = false )
43
43
}
44
44
45
- type DiscriminatorTuple = Option [(discriminatorName , String )]
45
+ type DiscriminatorTuple = Option [(String , String )]
46
46
47
47
implicit def zioJsonBinaryCodec [A ](implicit jsonCodec : ZJsonCodec [A ]): BinaryCodec [A ] =
48
48
new BinaryCodec [A ] {
@@ -338,11 +338,11 @@ object JsonCodec {
338
338
case Schema .NonEmptyMap (ks : Schema [kt], vs : Schema [vt], _) => mapEncoder(ks, vs, discriminatorTuple, cfg).contramap[NonEmptyMap [kt, vt]](_.toMap.asInstanceOf [Map [kt, vt]]).asInstanceOf [ZJsonEncoder [A ]]
339
339
case Schema .Set (s, _) =>
340
340
ZJsonEncoder .chunk(schemaEncoder(s, cfg, discriminatorTuple)).contramap(m => Chunk .fromIterable(m))
341
- case Schema .Transform (c, _, g, a, _) => transformEncoder(a.foldLeft(c)((s, a) => s.annotate(a)), g, cfg)
341
+ case Schema .Transform (c, _, g, a, _) => transformEncoder(a.foldLeft(c)((s, a) => s.annotate(a)), g, cfg, discriminatorTuple )
342
342
case Schema .Tuple2 (l, r, _) => ZJsonEncoder .tuple2(schemaEncoder(l, cfg, discriminatorTuple), schemaEncoder(r, cfg, discriminatorTuple))
343
343
case Schema .Optional (schema, _) => ZJsonEncoder .option(schemaEncoder(schema, cfg, discriminatorTuple))
344
344
case Schema .Fail (_, _) => unitEncoder.contramap(_ => ())
345
- case s @ Schema .GenericRecord (_, _, _) => recordEncoder(s, cfg)
345
+ case s : Schema .GenericRecord => recordEncoder(s, cfg, discriminatorTuple )
346
346
case Schema .Either (left, right, _) => ZJsonEncoder .either(schemaEncoder(left, cfg, discriminatorTuple), schemaEncoder(right, cfg, discriminatorTuple))
347
347
case Schema .Fallback (left, right, _, _) => fallbackEncoder(schemaEncoder(left, cfg, discriminatorTuple), schemaEncoder(right, cfg, discriminatorTuple))
348
348
case l @ Schema .Lazy (_) => ZJsonEncoder .suspend(schemaEncoder(l.schema, cfg, discriminatorTuple))
@@ -471,9 +471,14 @@ object JsonCodec {
471
471
}
472
472
}
473
473
474
- private def transformEncoder [A , B ](schema : Schema [A ], g : B => Either [String , A ], cfg : Config ): ZJsonEncoder [B ] =
474
+ private def transformEncoder [A , B ](
475
+ schema : Schema [A ],
476
+ g : B => Either [String , A ],
477
+ cfg : Config ,
478
+ discriminatorTuple : DiscriminatorTuple
479
+ ): ZJsonEncoder [B ] =
475
480
new ZJsonEncoder [B ] {
476
- private lazy val innerEncoder = schemaEncoder(schema, cfg)
481
+ private lazy val innerEncoder = schemaEncoder(schema, cfg, discriminatorTuple )
477
482
478
483
override def unsafeEncode (b : B , indent : Option [Int ], out : Write ): Unit =
479
484
g(b) match {
@@ -508,12 +513,13 @@ object JsonCodec {
508
513
val noDiscriminators = schema.noDiscriminator
509
514
val discriminatorTuple =
510
515
if (noDiscriminators) None
511
- else schema.annotations.collectFirst { case d : discriminatorName => (d, caseName) }
516
+ else schema.annotations.collectFirst { case d : discriminatorName => (d.tag , caseName) }
512
517
val doJsonObjectWrapping = discriminatorTuple.isEmpty && ! noDiscriminators
513
- if (doJsonObjectWrapping) out.write('{' )
514
- val indent_ = bump(indent)
515
- pad(indent_, out)
518
+ var indent_ = indent
516
519
if (doJsonObjectWrapping) {
520
+ out.write('{' )
521
+ indent_ = bump(indent)
522
+ pad(indent_, out)
517
523
string.encoder.unsafeEncode(caseName, indent_, out)
518
524
if (indent.isEmpty) out.write(':' )
519
525
else out.write(" : " )
@@ -524,8 +530,11 @@ object JsonCodec {
524
530
catch {
525
531
case ex if NonFatal (ex) => throw new RuntimeException (s " Failed to encode enum type $schema" , ex)
526
532
}
527
- }, indent, out)
528
- if (doJsonObjectWrapping) out.write('}' )
533
+ }, indent_, out)
534
+ if (doJsonObjectWrapping) {
535
+ pad(indent, out)
536
+ out.write('}' )
537
+ }
529
538
case _ =>
530
539
out.write(" {}" ) // for transient cases
531
540
}
@@ -551,7 +560,11 @@ object JsonCodec {
551
560
}
552
561
}
553
562
554
- private def recordEncoder (schema : Schema .GenericRecord , cfg : Config ): ZJsonEncoder [ListMap [String , _]] = {
563
+ private def recordEncoder (
564
+ schema : Schema .GenericRecord ,
565
+ cfg : Config ,
566
+ discriminatorTuple : DiscriminatorTuple
567
+ ): ZJsonEncoder [ListMap [String , _]] = {
555
568
val nonTransientFields = schema.nonTransientFields.toArray
556
569
val encoders = nonTransientFields.map(field => schemaEncoder(field.schema.asInstanceOf [Schema [Any ]], cfg))
557
570
if (nonTransientFields.isEmpty) { (_ : ListMap [String , _], _ : Option [Int ], out : Write ) =>
@@ -567,9 +580,17 @@ object JsonCodec {
567
580
}
568
581
val strEnc = string.encoder
569
582
var first = true
570
- var i = 0
571
- while (i < nonTransientFields.length) {
572
- val field = nonTransientFields(i)
583
+ if (discriminatorTuple ne None ) {
584
+ val tuple = discriminatorTuple.get
585
+ first = false
586
+ strEnc.unsafeEncode(tuple._1, indent_, out)
587
+ if (doPrettyPrint) out.write(" : " )
588
+ else out.write(':' )
589
+ strEnc.unsafeEncode(tuple._2, indent_, out)
590
+ }
591
+ var idx = 0
592
+ while (idx < nonTransientFields.length) {
593
+ val field = nonTransientFields(idx)
573
594
val fieldName = field.fieldName
574
595
val fieldValue = value(fieldName)
575
596
if (! isEmptyOptionalValue(field, fieldValue, cfg)) {
@@ -581,9 +602,9 @@ object JsonCodec {
581
602
strEnc.unsafeEncode(fieldName, indent_, out)
582
603
if (doPrettyPrint) out.write(" : " )
583
604
else out.write(':' )
584
- encoders(i ).unsafeEncode(fieldValue, indent_, out)
605
+ encoders(idx ).unsafeEncode(fieldValue, indent_, out)
585
606
}
586
- i += 1
607
+ idx += 1
587
608
}
588
609
if (doPrettyPrint) pad(indent, out)
589
610
out.write('}' )
@@ -683,7 +704,7 @@ object JsonCodec {
683
704
case Schema .NonEmptyMap (ks, vs, _) => mapDecoder(ks, vs).mapOrFail(m => NonEmptyMap .fromMapOption(m).toRight(" NonEmptyMap expected" ))
684
705
case Schema .Set (s, _) => ZJsonDecoder .chunk(schemaDecoder(s)).map(entries => entries.toSet)
685
706
case Schema .Fail (message, _) => failDecoder(message)
686
- case s @ Schema .GenericRecord (_, _, _) => recordDecoder(s)
707
+ case s : Schema .GenericRecord => recordDecoder(s, discriminator )
687
708
case Schema .Either (left, right, _) => ZJsonDecoder .either(schemaDecoder(left), schemaDecoder(right))
688
709
case s @ Schema .Fallback (_, _, _, _) => fallbackDecoder(s)
689
710
case l @ Schema .Lazy (_) => ZJsonDecoder .suspend(schemaDecoder(l.schema, discriminator))
@@ -949,7 +970,10 @@ object JsonCodec {
949
970
}
950
971
}
951
972
952
- private def recordDecoder (schema : GenericRecord ): ZJsonDecoder [ListMap [String , Any ]] = {
973
+ private def recordDecoder (
974
+ schema : GenericRecord ,
975
+ discriminator : Option [String ]
976
+ ): ZJsonDecoder [ListMap [String , Any ]] = {
953
977
val capacity = schema.fields.size * 2
954
978
val spansWithDecoders =
955
979
new util.HashMap [String , (JsonError .ObjectAccess , ZJsonDecoder [Any ])](capacity)
@@ -963,12 +987,15 @@ object JsonCodec {
963
987
defaults.put(fieldName, field.defaultValue.get)
964
988
}
965
989
}
966
- val rejectAdditionalFields = schema.annotations.exists(_.isInstanceOf [rejectExtraFields])
990
+ val skipExtraFields = ! schema.annotations.exists(_.isInstanceOf [rejectExtraFields])
967
991
(trace : List [JsonError ], in : RetractReader ) => {
968
- val lexer = Lexer
969
- lexer.char(trace, in, '{' )
970
- val map = new util.HashMap [String , Any ]
971
- var continue = lexer.firstField(trace, in)
992
+ val lexer = Lexer
993
+ var continue = true
994
+ if (discriminator eq None ) {
995
+ lexer.char(trace, in, '{' )
996
+ continue = lexer.firstField(trace, in)
997
+ }
998
+ val map = new util.HashMap [String , Any ]
972
999
while (continue) {
973
1000
val fieldNameOrAlias = lexer.string(trace, in).toString
974
1001
val spanWithDecoder = spansWithDecoders.get(fieldNameOrAlias)
@@ -982,11 +1009,11 @@ object JsonCodec {
982
1009
if (prev != null ) {
983
1010
throw UnsafeJson (JsonError .Message (" duplicate" ) :: trace_)
984
1011
}
985
- } else if (rejectAdditionalFields) {
986
- throw UnsafeJson (JsonError .Message (s " unexpected field: $fieldNameOrAlias" ) :: trace)
987
- } else {
1012
+ } else if (skipExtraFields || discriminator.contains(fieldNameOrAlias)) {
988
1013
lexer.char(trace, in, ':' )
989
1014
lexer.skipValue(trace, in)
1015
+ } else {
1016
+ throw UnsafeJson (JsonError .Message (s " unexpected field: $fieldNameOrAlias" ) :: trace)
990
1017
}
991
1018
continue = lexer.nextField(trace, in)
992
1019
}
@@ -1108,7 +1135,7 @@ object JsonCodec {
1108
1135
if (discriminatorTuple ne None ) {
1109
1136
val tuple = discriminatorTuple.get
1110
1137
first = false
1111
- strEnc.unsafeEncode(tuple._1.tag , indent_, out)
1138
+ strEnc.unsafeEncode(tuple._1, indent_, out)
1112
1139
if (doPrettyPrint) out.write(" : " )
1113
1140
else out.write(':' )
1114
1141
strEnc.unsafeEncode(tuple._2, indent_, out)
0 commit comments