Skip to content

Commit

Permalink
More efficient JSON codecs for enums (#767)
Browse files Browse the repository at this point in the history
* More efficient JSON encoders for enums

* More efficient JSON decoders for enums
  • Loading branch information
plokhotnyuk authored Jan 6, 2025
1 parent 00110cc commit 7629e75
Showing 1 changed file with 51 additions and 126 deletions.
177 changes: 51 additions & 126 deletions zio-schema-json/shared/src/main/scala/zio/schema/codec/JsonCodec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import java.util.concurrent.ConcurrentHashMap

import scala.annotation.{ switch, tailrec }
import scala.collection.immutable.ListMap
import scala.util.control.NonFatal

import zio.json.JsonCodec._
import zio.json.JsonDecoder.{ JsonError, UnsafeJson }
Expand Down Expand Up @@ -336,54 +337,17 @@ object JsonCodec {
case Schema.NonEmptyMap(ks: Schema[kt], vs: Schema[vt], _) => mapEncoder(ks, vs, discriminatorTuple, cfg).contramap[NonEmptyMap[kt, vt]](_.toMap.asInstanceOf[Map[kt, vt]]).asInstanceOf[ZJsonEncoder[A]]
case Schema.Set(s, _) =>
ZJsonEncoder.chunk(schemaEncoder(s, cfg, discriminatorTuple)).contramap(m => Chunk.fromIterable(m))
case Schema.Transform(c, _, g, a, _) => transformEncoder(a.foldLeft(c)((s, a) => s.annotate(a)), g, cfg)
case Schema.Tuple2(l, r, _) => ZJsonEncoder.tuple2(schemaEncoder(l, cfg, discriminatorTuple), schemaEncoder(r, cfg, discriminatorTuple))
case Schema.Optional(schema, _) => ZJsonEncoder.option(schemaEncoder(schema, cfg, discriminatorTuple))
case Schema.Fail(_, _) => unitEncoder.contramap(_ => ())
case s @ Schema.GenericRecord(_, _, _) => recordEncoder(s, cfg)
case Schema.Either(left, right, _) => ZJsonEncoder.either(schemaEncoder(left, cfg, discriminatorTuple), schemaEncoder(right, cfg, discriminatorTuple))
case Schema.Fallback(left, right, _, _) => fallbackEncoder(schemaEncoder(left, cfg, discriminatorTuple), schemaEncoder(right, cfg, discriminatorTuple))
case l @ Schema.Lazy(_) => ZJsonEncoder.suspend(schemaEncoder(l.schema, cfg, discriminatorTuple))
case s: Schema.Record[A] => caseClassEncoder(s, cfg, discriminatorTuple)
case e @ Schema.Enum1(_, c, _) => enumEncoder(e, cfg, c)
case e @ Schema.Enum2(_, c1, c2, _) => enumEncoder(e, cfg, c1, c2)
case e @ Schema.Enum3(_, c1, c2, c3, _) => enumEncoder(e, cfg, c1, c2, c3)
case e @ Schema.Enum4(_, c1, c2, c3, c4, _) => enumEncoder(e, cfg, c1, c2, c3, c4)
case e @ Schema.Enum5(_, c1, c2, c3, c4, c5, _) => enumEncoder(e, cfg, c1, c2, c3, c4, c5)
case e @ Schema.Enum6(_, c1, c2, c3, c4, c5, c6, _) => enumEncoder(e, cfg, c1, c2, c3, c4, c5, c6)
case e @ Schema.Enum7(_, c1, c2, c3, c4, c5, c6, c7, _) => enumEncoder(e, cfg, c1, c2, c3, c4, c5, c6, c7)
case e @ Schema.Enum8(_, c1, c2, c3, c4, c5, c6, c7, c8, _) => enumEncoder(e, cfg, c1, c2, c3, c4, c5, c6, c7, c8)
case e @ Schema.Enum9(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, _) => enumEncoder(e, cfg, c1, c2, c3, c4, c5, c6, c7, c8, c9)
case e @ Schema.Enum10(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, _) =>
enumEncoder(e, cfg, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10)
case e @ Schema.Enum11(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, _) =>
enumEncoder(e, cfg, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11)
case e @ Schema.Enum12(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, _) =>
enumEncoder(e, cfg, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12)
case e @ Schema.Enum13(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, _) =>
enumEncoder(e, cfg, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13)
case e @ Schema.Enum14(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, _) =>
enumEncoder(e, cfg, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14)
case e @ Schema.Enum15(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, _) =>
enumEncoder(e, cfg, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15)
case e @ Schema.Enum16(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, _) =>
enumEncoder(e, cfg, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16)
case e @ Schema.Enum17(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17, _) =>
enumEncoder(e, cfg, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17)
case e @ Schema.Enum18(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17, c18, _) =>
enumEncoder(e, cfg, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17, c18)
case e @ Schema.Enum19(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17, c18, c19, _) =>
enumEncoder(e, cfg, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17, c18, c19)
case e @ Schema
.Enum20(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17, c18, c19, c20, _) =>
enumEncoder(e, cfg, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17, c18, c19, c20)
case e @ Schema
.Enum21(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17, c18, c19, c20, c21, _) =>
enumEncoder(e, cfg, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17, c18, c19, c20, c21)
case e @ Schema.Enum22(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17, c18, c19, c20, c21, c22, _) =>
enumEncoder(e, cfg, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17, c18, c19, c20, c21, c22)
case e @ Schema.EnumN(_, cs, _) => enumEncoder(e, cfg, cs.toSeq: _*)
case d @ Schema.Dynamic(_) => dynamicEncoder(d, cfg)
case Schema.Transform(c, _, g, a, _) => transformEncoder(a.foldLeft(c)((s, a) => s.annotate(a)), g, cfg)
case Schema.Tuple2(l, r, _) => ZJsonEncoder.tuple2(schemaEncoder(l, cfg, discriminatorTuple), schemaEncoder(r, cfg, discriminatorTuple))
case Schema.Optional(schema, _) => ZJsonEncoder.option(schemaEncoder(schema, cfg, discriminatorTuple))
case Schema.Fail(_, _) => unitEncoder.contramap(_ => ())
case s @ Schema.GenericRecord(_, _, _) => recordEncoder(s, cfg)
case Schema.Either(left, right, _) => ZJsonEncoder.either(schemaEncoder(left, cfg, discriminatorTuple), schemaEncoder(right, cfg, discriminatorTuple))
case Schema.Fallback(left, right, _, _) => fallbackEncoder(schemaEncoder(left, cfg, discriminatorTuple), schemaEncoder(right, cfg, discriminatorTuple))
case l @ Schema.Lazy(_) => ZJsonEncoder.suspend(schemaEncoder(l.schema, cfg, discriminatorTuple))
case s: Schema.Record[A] => caseClassEncoder(s, cfg, discriminatorTuple)
case s: Schema.Enum[A] => enumEncoder(s, cfg)
case d @ Schema.Dynamic(_) => dynamicEncoder(d, cfg)
case null =>
throw new Exception(s"A captured schema is null, most likely due to wrong field initialization order")
}
Expand Down Expand Up @@ -523,7 +487,7 @@ object JsonCodec {
}
}

private def enumEncoder[Z](schema: Schema.Enum[Z], cfg: Config, cases: Schema.Case[Z, _]*): ZJsonEncoder[Z] =
private def enumEncoder[Z](schema: Schema.Enum[Z], cfg: Config): ZJsonEncoder[Z] =
// if all cases are CaseClass0, encode as a String
if (schema.annotations.exists(_.isInstanceOf[simpleEnum])) {
val caseMap: Map[Z, String] =
Expand All @@ -538,13 +502,9 @@ object JsonCodec {
} else { (value: Z, indent: Option[Int], out: Write) =>
{
val nonTransientCase =
try cases.collectFirst {
case c @ Schema.Case(_, _, _, _, _, annotations) if annotations.collectFirst {
case _: transientCase => ()
}.isEmpty && c.deconstructOption(value).isDefined =>
c
} catch {
case ex: Throwable => throw new RuntimeException(s"Failed to encode enum type $schema", ex)
try schema.nonTransientCases.find(_.deconstructOption(value).isDefined)
catch {
case ex if NonFatal(ex) => throw new RuntimeException(s"Failed to encode enum type $schema", ex)
}

nonTransientCase match {
Expand All @@ -554,13 +514,14 @@ object JsonCodec {
val discriminatorChunk = schema.annotations.collect {
case d: discriminatorName => (d, caseName)
}
val noDiscriminators = schema.noDiscriminator
val noDiscriminators = schema.noDiscriminator
val doJsonObjectWrapping = discriminatorChunk.isEmpty && !noDiscriminators

if (discriminatorChunk.isEmpty && !noDiscriminators) out.write('{')
if (doJsonObjectWrapping) out.write('{')
val indent_ = bump(indent)
pad(indent_, out)

if (discriminatorChunk.isEmpty && !noDiscriminators) {
if (doJsonObjectWrapping) {
string.encoder.unsafeEncode(caseName, indent_, out)
if (indent.isEmpty) out.write(':')
else out.write(" : ")
Expand All @@ -572,7 +533,7 @@ object JsonCodec {
discriminatorTuple = if (noDiscriminators) Chunk.empty else discriminatorChunk
).unsafeEncode(case_.deconstruct(value), indent, out)

if (discriminatorChunk.isEmpty && !noDiscriminators) out.write('}')
if (doJsonObjectWrapping) out.write('}')
case None =>
out.write("{}")
}
Expand Down Expand Up @@ -772,45 +733,9 @@ object JsonCodec {
caseClass21Decoder(discriminator, s)
case s @ Schema.CaseClass22(_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _) =>
caseClass22Decoder(discriminator, s)
case e @ Schema.Enum1(_, c, _) => enumDecoder(e, c)
case e @ Schema.Enum2(_, c1, c2, _) => enumDecoder(e, c1, c2)
case e @ Schema.Enum3(_, c1, c2, c3, _) => enumDecoder(e, c1, c2, c3)
case e @ Schema.Enum4(_, c1, c2, c3, c4, _) => enumDecoder(e, c1, c2, c3, c4)
case e @ Schema.Enum5(_, c1, c2, c3, c4, c5, _) => enumDecoder(e, c1, c2, c3, c4, c5)
case e @ Schema.Enum6(_, c1, c2, c3, c4, c5, c6, _) => enumDecoder(e, c1, c2, c3, c4, c5, c6)
case e @ Schema.Enum7(_, c1, c2, c3, c4, c5, c6, c7, _) => enumDecoder(e, c1, c2, c3, c4, c5, c6, c7)
case e @ Schema.Enum8(_, c1, c2, c3, c4, c5, c6, c7, c8, _) => enumDecoder(e, c1, c2, c3, c4, c5, c6, c7, c8)
case e @ Schema.Enum9(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, _) => enumDecoder(e, c1, c2, c3, c4, c5, c6, c7, c8, c9)
case e @ Schema.Enum10(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, _) =>
enumDecoder(e, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10)
case e @ Schema.Enum11(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, _) =>
enumDecoder(e, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11)
case e @ Schema.Enum12(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, _) =>
enumDecoder(e, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12)
case e @ Schema.Enum13(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, _) =>
enumDecoder(e, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13)
case e @ Schema.Enum14(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, _) =>
enumDecoder(e, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14)
case e @ Schema.Enum15(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, _) =>
enumDecoder(e, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15)
case e @ Schema.Enum16(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, _) =>
enumDecoder(e, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16)
case e @ Schema.Enum17(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17, _) =>
enumDecoder(e, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17)
case e @ Schema.Enum18(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17, c18, _) =>
enumDecoder(e, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17, c18)
case e @ Schema.Enum19(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17, c18, c19, _) =>
enumDecoder(e, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17, c18, c19)
case e @ Schema
.Enum20(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17, c18, c19, c20, _) =>
enumDecoder(e, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17, c18, c19, c20)
case e @ Schema.Enum21(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17, c18, c19, c20, c21, _) =>
enumDecoder(e, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17, c18, c19, c20, c21)
case e @ Schema.Enum22(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17, c18, c19, c20, c21, c22, _) =>
enumDecoder(e, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17, c18, c19, c20, c21, c22)
case e @ Schema.EnumN(_, cs, _) => enumDecoder(e, cs.toSeq: _*)
case d @ Schema.Dynamic(_) => dynamicDecoder(d)
case _ => throw new Exception(s"Missing a handler for decoding of schema ${schema.toString()}.")
case s: Schema.Enum[A] => enumDecoder(s)
case d @ Schema.Dynamic(_) => dynamicDecoder(d)
case _ => throw new Exception(s"Missing a handler for decoding of schema ${schema.toString()}.")
}
//scalafmt: { maxColumn = 120, optIn.configStyleArguments = true }

Expand Down Expand Up @@ -867,11 +792,12 @@ object JsonCodec {
case Json.Null => DynamicValue.NoneValue
}

private def enumDecoder[Z](parentSchema: Schema.Enum[Z], cases: Schema.Case[Z, _]*): ZJsonDecoder[Z] = {
private def enumDecoder[Z](parentSchema: Schema.Enum[Z]): ZJsonDecoder[Z] = {
val cases = parentSchema.cases
val caseNameAliases = cases.flatMap {
case Schema.Case(name, _, _, _, _, annotations) =>
annotations.flatMap {
case a: caseNameAliases => a.aliases.toList.map(_ -> name)
case a: caseNameAliases => a.aliases.map(_ -> name)
case cn: caseName => List(cn.name -> name)
case _ => Nil
}
Expand All @@ -885,18 +811,15 @@ object JsonCodec {
s =>
caseMap.get(caseNameAliases.getOrElse(s, s)) match {
case Some(z) => Right(z)
case None => Left("unrecognized string")
case _ => Left("unrecognized string")
}
)
} else {
if (parentSchema.annotations.exists(_.isInstanceOf[noDiscriminator])) {
def error(trace: List[JsonError]): Nothing =
throw UnsafeJson(JsonError.Message("none of the subtypes could decode the data") :: trace)

val noDiscriminators = parentSchema.annotations.exists {
case noDiscriminator() => true
case _ => false
}

(trace: List[JsonError], in: RetractReader) => {
if (noDiscriminators) {
(trace: List[JsonError], in: RetractReader) => {
var rr = RecordingReader(in)
val it = cases.iterator
var result: Option[Z] = None
Expand All @@ -915,16 +838,15 @@ object JsonCodec {

result match {
case Some(value) => value
case None => throw UnsafeJson(JsonError.Message("none of the subtypes could decode the data") :: trace)
case _ => error(trace)
}

} else {
}
} else {
val maybeDiscriminatorName = parentSchema.annotations.collectFirst { case d: discriminatorName => d.tag }
(trace: List[JsonError], in: RetractReader) => {
Lexer.char(trace, in, '{')

if (Lexer.firstField(trace, in)) {
val maybeDiscriminatorName = parentSchema.annotations.collectFirst {
case d: discriminatorName => d.tag
}
maybeDiscriminatorName match {
case None =>
val subtype = deAliasCaseName(Lexer.string(trace, in).toString, caseNameAliases)
Expand All @@ -935,7 +857,7 @@ object JsonCodec {
val decoded = schemaDecoder(c.schema, -1).unsafeDecode(trace_, in).asInstanceOf[Z]
Lexer.nextField(trace_, in)
decoded
case None =>
case _ =>
throw UnsafeJson(JsonError.Message("unrecognized subtype") :: trace_)
}
case Some(discriminatorName) =>
Expand All @@ -948,12 +870,12 @@ object JsonCodec {
val discriminator = Lexer.string(trace, rr).toString
// Perform a second de-aliasing because the first one would resolve the discriminator key instead.
val innerSubtype = deAliasCaseName(discriminator, caseNameAliases)
if (index > 0) {
rr.rewind()
(innerSubtype, JsonError.ObjectAccess(innerSubtype) :: trace_, index)
} else {
(innerSubtype, JsonError.ObjectAccess(innerSubtype) :: trace_, -2)
}
(innerSubtype, JsonError.ObjectAccess(innerSubtype) :: trace_, {
if (index > 0) {
rr.rewind()
index
} else -2
})
} else {
Lexer.char(trace_, rr, ':')
Lexer.skipValue(trace_, rr)
Expand All @@ -966,11 +888,14 @@ object JsonCodec {
val (subtype, trace_, discriminator) = findDiscriminator(0, rr)

cases.find(_.id == subtype) match {
case Some(c) if discriminator >= 0 =>
schemaDecoder(c.schema, discriminator).unsafeDecode(trace_, rr).asInstanceOf[Z]
case Some(c) =>
schemaDecoder(c.schema, discriminator).unsafeDecode(trace_, in).asInstanceOf[Z]
case None =>
schemaDecoder(c.schema, discriminator)
.unsafeDecode(trace_, {
if (discriminator >= 0) rr
else in
})
.asInstanceOf[Z]
case _ =>
throw UnsafeJson(JsonError.Message("unrecognized subtype") :: trace_)
}
}
Expand Down

0 comments on commit 7629e75

Please sign in to comment.