Skip to content

Commit

Permalink
Reduce redundant allocations and CPU overhead when decoding and encoding
Browse files Browse the repository at this point in the history
  • Loading branch information
plokhotnyuk committed Jan 10, 2025
1 parent afa0271 commit b95d200
Showing 1 changed file with 79 additions and 73 deletions.
152 changes: 79 additions & 73 deletions zio-schema-json/shared/src/main/scala/zio/schema/codec/JsonCodec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import scala.collection.immutable.ListMap
import scala.collection.mutable
import scala.util.control.NonFatal

import zio.json.JsonCodec._
import zio.json.JsonDecoder.{ JsonError, UnsafeJson }
import zio.json.ast.Json
import zio.json.internal.{ Lexer, RecordingReader, RetractReader, StringMatrix, WithRecordingReader, Write }
Expand All @@ -24,7 +23,7 @@ import zio.json.{
import zio.prelude.NonEmptyMap
import zio.schema.Schema.GenericRecord
import zio.schema._
import zio.schema.annotation.{ rejectExtraFields, _ }
import zio.schema.annotation.{ discriminatorName, rejectExtraFields, _ }
import zio.schema.codec.DecodeError.ReadError
import zio.schema.codec.JsonCodec.JsonDecoder.schemaDecoder
import zio.stream.{ ZChannel, ZPipeline }
Expand Down Expand Up @@ -420,7 +419,7 @@ object JsonCodec {
out.write(',')
if (indent.isDefined) pad(indent_, out)
}
string.encoder.unsafeEncode(key, indent_, out)
ZJsonEncoder.string.unsafeEncode(key, indent_, out)
if (indent.isEmpty) out.write(':')
else out.write(" : ")
directEncoder.unsafeEncode(value, indent_, out)
Expand Down Expand Up @@ -505,38 +504,44 @@ object JsonCodec {
)
.toMap
ZJsonEncoder.string.contramap(caseMap(_))
} else { (value: Z, indent: Option[Int], out: Write) =>
{
schema.nonTransientCases.find(_.isCase(value)) match {
case Some(case_) =>
val caseName = case_.caseName
val noDiscriminators = schema.noDiscriminator
val discriminatorTuple =
if (noDiscriminators) None
else schema.annotations.collectFirst { case d: discriminatorName => (d.tag, caseName) }
val doJsonObjectWrapping = discriminatorTuple.isEmpty && !noDiscriminators
var indent_ = indent
if (doJsonObjectWrapping) {
out.write('{')
indent_ = bump(indent)
pad(indent_, out)
string.encoder.unsafeEncode(caseName, indent_, out)
if (indent.isEmpty) out.write(':')
else out.write(" : ")
}
schemaEncoder(case_.schema.asInstanceOf[Schema[Any]], cfg, discriminatorTuple)
.unsafeEncode({
try case_.deconstruct(value)
catch {
case ex if NonFatal(ex) => throw new RuntimeException(s"Failed to encode enum type $schema", ex)
}
}, indent_, out)
if (doJsonObjectWrapping) {
pad(indent, out)
out.write('}')
} else {
new ZJsonEncoder[Z] {
private[this] val discriminatorName =
if (schema.noDiscriminator) None
else schema.annotations.collectFirst { case d: discriminatorName => d.tag }
private[this] val cases = schema.nonTransientCases.toArray
private[this] val decoders = cases.map { case_ =>
val discriminatorTuple =
if (discriminatorName eq None) None
else Some((discriminatorName.get, case_.caseName))
schemaEncoder(case_.schema.asInstanceOf[Schema[Any]], cfg, discriminatorTuple)
}
private[this] val doJsonObjectWrapping = discriminatorName.isEmpty && !schema.noDiscriminator

override def unsafeEncode(a: Z, indent: Option[Int], out: Write): Unit = {
var idx = 0
while (idx < cases.length) {
val case_ = cases(idx)
if (case_.isCase(a)) {
var indent_ = indent
if (doJsonObjectWrapping) {
out.write('{')
indent_ = bump(indent)
pad(indent_, out)
ZJsonEncoder.string.unsafeEncode(case_.caseName, indent_, out)
if (indent.isEmpty) out.write(':')
else out.write(" : ")
}
decoders(idx).unsafeEncode(case_.deconstruct(a), indent_, out)
if (doJsonObjectWrapping) {
pad(indent, out)
out.write('}')
}
return
}
case _ =>
out.write("{}") // for transient cases
idx += 1
}
out.write("{}") // for transient cases
}
}
}
Expand All @@ -550,12 +555,17 @@ object JsonCodec {
case Fallback.Right(b) => right.unsafeEncode(b, indent, out)
case Fallback.Both(a, b) =>
out.write('[')
if (indent.isDefined) pad(bump(indent), out)
val doPrettyPrint = indent ne None
var indent_ = indent
if (doPrettyPrint) {
indent_ = bump(indent)
pad(indent_, out)
}
left.unsafeEncode(a, indent, out)
out.write(',')
if (indent.isDefined) pad(bump(indent), out)
if (doPrettyPrint) pad(indent_, out)
right.unsafeEncode(b, indent, out)
if (indent.isDefined) pad(indent, out)
if (doPrettyPrint) pad(indent, out)
out.write(']')
}
}
Expand All @@ -578,7 +588,7 @@ object JsonCodec {
indent_ = bump(indent)
pad(indent_, out)
}
val strEnc = string.encoder
val strEnc = ZJsonEncoder.string
var first = true
if (discriminatorTuple ne None) {
val tuple = discriminatorTuple.get
Expand Down Expand Up @@ -759,11 +769,9 @@ object JsonCodec {
val valueDecoder = JsonDecoder.schemaDecoder(vs)
jsonFieldDecoder(ks) match {
case Some(jsonFieldDecoder) =>
ZJsonDecoder.keyValueChunk(jsonFieldDecoder, valueDecoder).map(a => Chunk.fromIterable(a).toMap)
ZJsonDecoder.keyValueChunk(jsonFieldDecoder, valueDecoder).map(_.toMap)
case None =>
ZJsonDecoder
.chunk(schemaDecoder(ks).zip(schemaDecoder(vs)))
.map(_.toList.toMap)
ZJsonDecoder.chunk(schemaDecoder(ks).zip(schemaDecoder(vs))).map(_.toMap)
}
}

Expand All @@ -778,18 +786,12 @@ object JsonCodec {
case _ => None
}

private def dynamicDecoder(schema: Schema.Dynamic): ZJsonDecoder[DynamicValue] = {
val directMapping = schema.annotations.exists {
case directDynamicMapping() => true
case _ => false
}

if (directMapping) {
private def dynamicDecoder(schema: Schema.Dynamic): ZJsonDecoder[DynamicValue] =
if (schema.annotations.exists(_.isInstanceOf[directDynamicMapping])) {
Json.decoder.map(jsonToDynamicValue)
} else {
schemaDecoder(DynamicValue.schema)
}
}

private def jsonToDynamicValue(json: Json): DynamicValue =
json match {
Expand Down Expand Up @@ -832,7 +834,7 @@ object JsonCodec {
}
} else {
new ZJsonDecoder[Z] {
private[this] val cases = new util.HashMap[String, Z](caseNameAliases.size * 2)
private[this] val cases = new util.HashMap[String, Z](caseNameAliases.size << 1)

caseNameAliases.foreach {
case (name, case_) =>
Expand Down Expand Up @@ -888,7 +890,7 @@ object JsonCodec {
}
} else {
val cases =
new util.HashMap[String, (JsonError.ObjectAccess, ZJsonDecoder[Any])](caseNameAliases.size * 2)
new util.HashMap[String, (JsonError.ObjectAccess, ZJsonDecoder[Any])](caseNameAliases.size << 1)
caseNameAliases.foreach {
case (name, case_) =>
cases.put(name, (JsonError.ObjectAccess(case_.caseName), schemaDecoder(case_.schema)))
Expand Down Expand Up @@ -936,7 +938,7 @@ object JsonCodec {
} else {
val discriminatorSpan = JsonError.ObjectAccess(discriminatorName)
val cases =
new util.HashMap[String, (JsonError.ObjectAccess, ZJsonDecoder[Any])](caseNameAliases.size * 2)
new util.HashMap[String, (JsonError.ObjectAccess, ZJsonDecoder[Any])](caseNameAliases.size << 1)
caseNameAliases.foreach {
case (name, case_) =>
cases.put(name, (JsonError.ObjectAccess(case_.caseName), schemaDecoder(case_.schema, discriminator)))
Expand Down Expand Up @@ -969,23 +971,26 @@ object JsonCodec {
private def recordDecoder(
schema: GenericRecord,
discriminator: Option[String]
): ZJsonDecoder[ListMap[String, Any]] = {
val spansWithDecoders =
new util.HashMap[String, (JsonError.ObjectAccess, ZJsonDecoder[Any])](schema.fields.size * 2)
schema.fields.foreach { field =>
val spanWithDecoder =
(JsonError.ObjectAccess(field.fieldName), schemaDecoder(field.schema).asInstanceOf[ZJsonDecoder[Any]])
field.nameAndAliases.foreach(spansWithDecoders.put(_, spanWithDecoder))
}
val skipExtraFields = !schema.annotations.exists(_.isInstanceOf[rejectExtraFields])
(trace: List[JsonError], in: RetractReader) => {
): ZJsonDecoder[ListMap[String, Any]] = new ZJsonDecoder[ListMap[String, Any]] {
private[this] val fields = schema.fields.toArray
private[this] val spansWithDecoders =
new util.HashMap[String, (JsonError.ObjectAccess, ZJsonDecoder[Any])](fields.length << 1) {
fields.foreach { field =>
val spanWithDecoder =
(JsonError.ObjectAccess(field.fieldName), schemaDecoder(field.schema).asInstanceOf[ZJsonDecoder[Any]])
field.nameAndAliases.foreach(put(_, spanWithDecoder))
}
}
private[this] val skipExtraFields = !schema.annotations.exists(_.isInstanceOf[rejectExtraFields])

override def unsafeDecode(trace: List[JsonError], in: RetractReader): ListMap[String, Any] = {
val lexer = Lexer
var continue = true
if (discriminator eq None) {
lexer.char(trace, in, '{')
continue = lexer.firstField(trace, in)
}
val map = new util.HashMap[String, Any]
val map = new util.HashMap[String, Any](fields.length << 1)
while (continue) {
val fieldNameOrAlias = lexer.string(trace, in).toString
val spanWithDecoder = spansWithDecoders.get(fieldNameOrAlias)
Expand All @@ -1003,7 +1008,10 @@ object JsonCodec {
} else error("extra field", trace)
continue = lexer.nextField(trace, in)
}
schema.fields.foreach { field =>
var idx = 0
while (idx < fields.length) {
val field = fields(idx)
idx += 1
val fieldName = field.fieldName // reuse strings with calculated hashCode
if (map.get(fieldName) == null) {
map.put( // mitigation of a linking error for `map.computeIfAbsent` in Scala.js
Expand Down Expand Up @@ -1039,12 +1047,11 @@ object JsonCodec {
private[this] val leftDecoder = schemaDecoder(schema.left)
private[this] val rightDecoder = schemaDecoder(schema.right)

case class BadEnd() extends Throwable

def unsafeDecode(trace: List[JsonError], in: RetractReader): Fallback[A, B] = {
var left: Option[A] = None
var right: Option[B] = None

case class BadEnd() extends Throwable

try {
// If this doesn't throw exception, it is an array, so it encodes a `Fallback.Both`
val lexer = Lexer
Expand Down Expand Up @@ -1086,17 +1093,16 @@ object JsonCodec {
} catch {
// It's not an array, so it is of type A or B
case BadEnd() => ()
case _: UnsafeJson => {
case _: UnsafeJson =>
in.retract()
val in2 = new zio.json.internal.WithRecordingReader(in, 64)
try {
left = Some(schemaDecoder(schema.left).unsafeDecode(trace, in2))
left = Some(leftDecoder.unsafeDecode(trace, in2))
} catch {
case UnsafeJson(_) =>
in2.rewind()
right = Some(schemaDecoder(schema.right).unsafeDecode(trace, in2))
right = Some(rightDecoder.unsafeDecode(trace, in2))
}
}
}

(left, right) match {
Expand Down Expand Up @@ -1131,7 +1137,7 @@ object JsonCodec {
indent_ = bump(indent)
pad(indent_, out)
}
val strEnc = string.encoder
val strEnc = ZJsonEncoder.string
var first = true
if (discriminatorTuple ne None) {
val tuple = discriminatorTuple.get
Expand Down

0 comments on commit b95d200

Please sign in to comment.