Skip to content

Commit

Permalink
Fix #758 by proper decoding of backticked names of fields and classes (
Browse files Browse the repository at this point in the history
  • Loading branch information
plokhotnyuk authored Jan 6, 2025
1 parent 2902d21 commit 00110cc
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 40 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package zio.schema

import scala.reflect.NameTransformer
import scala.reflect.macros.whitebox

import zio.Chunk
Expand Down Expand Up @@ -40,6 +41,11 @@ object DeriveSchema {

def isMap(tpe: Type): Boolean = tpe.typeSymbol.fullName == "scala.collection.immutable.Map"

def decodeName(s: Symbol): String = NameTransformer.decode(s.name.toString)

def decodeFieldName(s: Symbol): String =
NameTransformer.decode(s.name.toString.trim) // why is there a space at the end of field name?!

def collectTypeAnnotations(tpe: Type): List[Tree] =
tpe.typeSymbol.annotations.collect {
case annotation if !(annotation.tree.tpe <:< JavaAnnotationTpe) =>
Expand Down Expand Up @@ -235,7 +241,7 @@ object DeriveSchema {
val genericAnnotations: List[Tree] =
if (tpe.typeArgs.isEmpty) Nil
else {
val typeMembers = tpe.typeSymbol.asClass.typeParams.map(_.name.toString)
val typeMembers = tpe.typeSymbol.asClass.typeParams.map(decodeName)
val typeArgs = tpe.typeArgs
.map(_.typeSymbol.fullName)
.map(t => q"_root_.zio.schema.TypeId.parse(${t}).asInstanceOf[_root_.zio.schema.TypeId.Nominal]")
Expand Down Expand Up @@ -334,7 +340,7 @@ object DeriveSchema {
concreteType(tpe, termSymbol.typeSignature),
currentFrame +: stack
)
val fieldLabel = termSymbol.name.toString.trim
val fieldLabel = decodeFieldName(termSymbol)
val getFunc =
q" (z: _root_.scala.collection.immutable.ListMap[String, _]) => z.apply($fieldLabel).asInstanceOf[${termSymbol.typeSignature}]"

Expand All @@ -353,12 +359,13 @@ object DeriveSchema {
val fromMap = {
val casts = fieldTypes.zip(fieldAnnotations).map {
case (termSymbol, annotations) =>
val newName = getFieldName(annotations).getOrElse(termSymbol.name.toString.trim)
val fieldLabel = decodeFieldName(termSymbol)
val newName = getFieldName(annotations).getOrElse(fieldLabel)
q"""
try m.apply(${newName}).asInstanceOf[${termSymbol.typeSignature}]
catch {
case _: ClassCastException => throw new RuntimeException("Field " + ${termSymbol.name.toString.trim} + " has invalid type")
case _: Throwable => throw new RuntimeException("Field " + ${termSymbol.name.toString.trim} + " is missing")
case _: ClassCastException => throw new RuntimeException("Field " + $fieldLabel + " has invalid type")
case _: Throwable => throw new RuntimeException("Field " + $fieldLabel + " is missing")
}
"""
}
Expand Down Expand Up @@ -414,8 +421,8 @@ object DeriveSchema {
currentFrame +: stack
)
val fieldArg = if (fieldTypes.size > 1) TermName(s"field0${idx + 1}") else TermName("field0")
val fieldLabel = termSymbol.name.toString.trim
val getArg = TermName(fieldLabel)
val fieldLabel = decodeFieldName(termSymbol)
val getArg = TermName(termSymbol.name.toString.trim)

val getFunc = q" (z: $tpe) => z.$getArg"
val setFunc = q" (z: $tpe, v: $fieldType) => z.copy[..${tpe.typeArgs}]($getArg = v)"
Expand Down Expand Up @@ -459,9 +466,9 @@ object DeriveSchema {

val typeArgsWithFields = fieldTypes.zip(fieldAnnotations).map {
case (termSymbol, annotations) if annotations.nonEmpty =>
tq"${getFieldName(annotations).getOrElse(termSymbol.name.toString.trim)}.type"
tq"${getFieldName(annotations).getOrElse(decodeFieldName(termSymbol))}.type"
case (termSymbol, _) =>
tq"${termSymbol.name.toString.trim}.type"
tq"${decodeFieldName(termSymbol)}.type"
} ++ typeArgs

fieldTypes.size match {
Expand Down Expand Up @@ -524,12 +531,12 @@ object DeriveSchema {
val typeId = q"_root_.zio.schema.TypeId.parse(${tpe.toString()})"

val appliedTypeArgs: Map[String, Type] =
tpe.typeConstructor.typeParams.map(_.name.toString).zip(tpe.typeArgs).toMap
tpe.typeConstructor.typeParams.map(decodeName).zip(tpe.typeArgs).toMap

def appliedSubtype(subtype: Type): Type =
if (subtype.typeArgs.size == 0) subtype
else {
val appliedTypes = subtype.typeConstructor.typeParams.map(_.name.toString).map { typeParam =>
val appliedTypes = subtype.typeConstructor.typeParams.map(decodeName).map { typeParam =>
appliedTypeArgs.get(typeParam) match {
case None =>
c.abort(
Expand Down Expand Up @@ -583,7 +590,7 @@ object DeriveSchema {
val genericAnnotations: List[Tree] =
if (tpe.typeArgs.isEmpty) Nil
else {
val typeMembers = tpe.typeSymbol.asClass.typeParams.map(_.name.toString)
val typeMembers = tpe.typeSymbol.asClass.typeParams.map(decodeName)
val typeArgs = tpe.typeArgs
.map(_.typeSymbol.fullName)
.map(t => q"_root_.zio.schema.TypeId.parse(${t}).asInstanceOf[_root_.zio.schema.TypeId.Nominal]")
Expand Down Expand Up @@ -641,7 +648,7 @@ object DeriveSchema {
val genericAnnotations: List[Tree] =
if (subtype.typeArgs.isEmpty) Nil
else {
val typeMembers = subtype.typeSymbol.asClass.typeParams.map(_.name.toString)
val typeMembers = subtype.typeSymbol.asClass.typeParams.map(decodeName)
val typeArgs = subtype.typeArgs
.map(_.typeSymbol.fullName)
.map(t => q"_root_.zio.schema.TypeId.parse(${t}).asInstanceOf[_root_.zio.schema.TypeId.Nominal]")
Expand All @@ -668,7 +675,7 @@ object DeriveSchema {
EmptyTree
}.filter(_ != EmptyTree) ++ genericAnnotations

val caseLabel = subtype.typeSymbol.name.toString.trim
val caseLabel = decodeName(subtype.typeSymbol)
val caseSchema = directInferSchema(tpe, concreteType(tpe, subtype), currentFrame +: stack)
val deconstructFn = q"(z: $tpe) => z.asInstanceOf[$subtype]"
val constructFn = q"(z: $subtype) => z.asInstanceOf[$tpe]"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ object DeriveSchemaSpec extends ZIOSpecDefault with VersionSpecificDeriveSchemaS
f21: Int = 21,
f22: Int = 22,
f23: Int = 23,
f24: Int = 24
`f-24`: Int = 24
)

object Arity24 {
Expand Down Expand Up @@ -140,7 +140,7 @@ object DeriveSchemaSpec extends ZIOSpecDefault with VersionSpecificDeriveSchemaS
arity19: (User, User, User, User, User, User, User, User, User, User, User, User, User, User, User, User, User, User, User),
arity20: (User, User, User, User, User, User, User, User, User, User, User, User, User, User, User, User, User, User, User, User),
arity21: (User, User, User, User, User, User, User, User, User, User, User, User, User, User, User, User, User, User, User, User, User),
arity22: (User, User, User, User, User, User, User, User, User, User, User, User, User, User, User, User, User, User, User, User, User, User)
`arity-22`: (User, User, User, User, User, User, User, User, User, User, User, User, User, User, User, User, User, User, User, User, User, User)
)
//scalafmt: { maxColumn = 120, optIn.configStyleArguments = true }

Expand Down Expand Up @@ -258,11 +258,11 @@ object DeriveSchemaSpec extends ZIOSpecDefault with VersionSpecificDeriveSchemaS
case class SimpleClass1() extends SimpleEnum1

sealed trait SimpleEnum2
case class SimpleClass2() extends SimpleEnum2
case class `Simple-Class-2`() extends SimpleEnum2

sealed abstract class AbstractBaseClass(val x: Int)
final case class ConcreteClass1(override val x: Int, y: Int) extends AbstractBaseClass(x)
final case class ConcreteClass2(override val x: Int, s: String) extends AbstractBaseClass(x)
final case class ConcreteClass1(override val x: Int, y: Int) extends AbstractBaseClass(x)
final case class `Concrete-Class-2`(override val x: Int, s: String) extends AbstractBaseClass(x)

sealed abstract class AbstractBaseClass2(val x: Int)
sealed abstract class MiddleClass(override val x: Int, val y: Int) extends AbstractBaseClass2(x)
Expand All @@ -288,13 +288,17 @@ object DeriveSchemaSpec extends ZIOSpecDefault with VersionSpecificDeriveSchemaS
assert(Schema[User].toString)(not(containsString("null")) && not(equalTo("$Lazy$")))
},
test("correctly derives case class with arity > 22") {
assert(Schema[Arity24].toString)(not(containsString("null")) && not(equalTo("$Lazy$")))
assert(Schema[Arity24].toString)(
not(containsString("null")) && not(equalTo("$Lazy$") && containsString("f-24"))
)
},
test("correctly derives recursive data structure") {
assert(Schema[Recursive].toString)(not(containsString("null")) && not(equalTo("$Lazy$")))
},
test("correctly derives tuple arities from 2 to 22") {
assert(Schema[TupleArities].toString)(not(containsString("null")) && not(equalTo("$Lazy$")))
assert(Schema[TupleArities].toString)(
not(containsString("null")) && not(equalTo("$Lazy$") && containsString("arity-22"))
)
},
test("correctly derive mutually recursive data structure") {
val c = Cyclic(1, CyclicChild1(2, CyclicChild2("3", None)))
Expand Down Expand Up @@ -526,26 +530,26 @@ object DeriveSchemaSpec extends ZIOSpecDefault with VersionSpecificDeriveSchemaS
(a: AbstractBaseClass) => a.isInstanceOf[ConcreteClass1]
),
Schema.Case(
"ConcreteClass2",
"Concrete-Class-2",
Schema.CaseClass2(
TypeId.parse("zio.schema.DeriveSchemaSpec.ConcreteClass2"),
field01 = Schema.Field[ConcreteClass2, Int](
TypeId.parse("zio.schema.DeriveSchemaSpec.Concrete-Class-2"),
field01 = Schema.Field[`Concrete-Class-2`, Int](
"x",
Schema.Primitive(StandardType.IntType),
get0 = _.x,
set0 = (a, b: Int) => a.copy(x = b)
),
field02 = Schema.Field[ConcreteClass2, String](
field02 = Schema.Field[`Concrete-Class-2`, String](
"s",
Schema.Primitive(StandardType.StringType),
get0 = _.s,
set0 = (a, b: String) => a.copy(s = b)
),
ConcreteClass2.apply
`Concrete-Class-2`.apply
),
(a: AbstractBaseClass) => a.asInstanceOf[ConcreteClass2],
(a: ConcreteClass2) => a.asInstanceOf[AbstractBaseClass],
(a: AbstractBaseClass) => a.isInstanceOf[ConcreteClass2]
(a: AbstractBaseClass) => a.asInstanceOf[`Concrete-Class-2`],
(a: `Concrete-Class-2`) => a.asInstanceOf[AbstractBaseClass],
(a: AbstractBaseClass) => a.isInstanceOf[`Concrete-Class-2`]
),
Chunk.empty
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,15 @@ object JsonCodecSpec extends ZIOSpecDefault {
)
}
),
suite("case class")(
test("backticked field name") {
assertEncodesJson(
Schema[BacktickedFieldName],
BacktickedFieldName("test"),
"""{"x-api-key":"test"}"""
)
}
),
suite("optional field annotation")(
test("list empty") {
assertEncodesJson(
Expand Down Expand Up @@ -868,6 +877,13 @@ object JsonCodecSpec extends ZIOSpecDefault {
charSequenceToByteChunk("""{"query":"test","pageNumber":0,"resultPerPage":10}""")
)
},
test("backticked field name") {
assertDecodes(
BacktickedFieldName.schema,
BacktickedFieldName("test"),
charSequenceToByteChunk("""{"x-api-key":"test"}""")
)
},
test("field name with alias - id") {
assertDecodes(
Order.schema,
Expand Down Expand Up @@ -1599,7 +1615,7 @@ object JsonCodecSpec extends ZIOSpecDefault {
Enumeration2(StringValue2("foo"))
) &> assertEncodesThenDecodes(
Schema[Enumeration2],
Enumeration2(StringValue2Multi("foo", "bar"))
Enumeration2(`StringValue2-Backticked`("foo", "bar"))
) &> assertEncodesThenDecodes(Schema[Enumeration2], Enumeration2(IntValue2(-1))) &> assertEncodesThenDecodes(
Schema[Enumeration2],
Enumeration2(BooleanValue2(false))
Expand All @@ -1611,7 +1627,7 @@ object JsonCodecSpec extends ZIOSpecDefault {
Enumeration3(StringValue3("foo"))
) &> assertEncodesThenDecodes(
Schema[Enumeration3],
Enumeration3(StringValue3Multi("foo", "bar"))
Enumeration3(`StringValue3-Backticked`("foo", "bar"))
)
},
test("of case classes with discriminator") {
Expand Down Expand Up @@ -2147,10 +2163,10 @@ object JsonCodecSpec extends ZIOSpecDefault {

@discriminatorName("_type")
sealed trait OneOf2
case class StringValue2(value: String) extends OneOf2
case class IntValue2(value: Int) extends OneOf2
case class BooleanValue2(value: Boolean) extends OneOf2
case class StringValue2Multi(value1: String, value2: String) extends OneOf2
case class StringValue2(value: String) extends OneOf2
case class IntValue2(value: Int) extends OneOf2
case class BooleanValue2(value: Boolean) extends OneOf2
case class `StringValue2-Backticked`(value1: String, value2: String) extends OneOf2

case class Enumeration2(oneOf: OneOf2)

Expand All @@ -2160,11 +2176,11 @@ object JsonCodecSpec extends ZIOSpecDefault {

@noDiscriminator
sealed trait OneOf3
case class StringValue3(value: String) extends OneOf3
case class IntValue3(value: Int) extends OneOf3
case class BooleanValue3(value: Boolean) extends OneOf3
case class StringValue3Multi(value1: String, value2: String) extends OneOf3
case class Nested(oneOf: OneOf3) extends OneOf3
case class StringValue3(value: String) extends OneOf3
case class IntValue3(value: Int) extends OneOf3
case class BooleanValue3(value: Boolean) extends OneOf3
case class `StringValue3-Backticked`(value1: String, value2: String) extends OneOf3
case class Nested(oneOf: OneOf3) extends OneOf3

case class Enumeration3(oneOf: OneOf3)

Expand Down Expand Up @@ -2449,4 +2465,10 @@ object JsonCodecSpec extends ZIOSpecDefault {
object Recursive {
implicit val schema: Schema[Recursive] = DeriveSchema.gen
}

case class BacktickedFieldName(`x-api-key`: String)

object BacktickedFieldName {
implicit val schema: Schema[BacktickedFieldName] = DeriveSchema.gen
}
}

0 comments on commit 00110cc

Please sign in to comment.