From 00110cca6300d506ee50b277c77cfa99d308a000 Mon Sep 17 00:00:00 2001 From: Andriy Plokhotnyuk Date: Mon, 6 Jan 2025 15:07:08 +0100 Subject: [PATCH] Fix #758 by proper decoding of backticked names of fields and classes (#769) --- .../scala-2/zio/schema/DeriveSchema.scala | 35 +++++++++------ .../scala/zio/schema/DeriveSchemaSpec.scala | 34 +++++++------- .../zio/schema/codec/JsonCodecSpec.scala | 44 ++++++++++++++----- 3 files changed, 73 insertions(+), 40 deletions(-) diff --git a/zio-schema-derivation/shared/src/main/scala-2/zio/schema/DeriveSchema.scala b/zio-schema-derivation/shared/src/main/scala-2/zio/schema/DeriveSchema.scala index 6a9262e65..3dfb4f7f5 100644 --- a/zio-schema-derivation/shared/src/main/scala-2/zio/schema/DeriveSchema.scala +++ b/zio-schema-derivation/shared/src/main/scala-2/zio/schema/DeriveSchema.scala @@ -1,5 +1,6 @@ package zio.schema +import scala.reflect.NameTransformer import scala.reflect.macros.whitebox import zio.Chunk @@ -40,6 +41,11 @@ object DeriveSchema { def isMap(tpe: Type): Boolean = tpe.typeSymbol.fullName == "scala.collection.immutable.Map" + def decodeName(s: Symbol): String = NameTransformer.decode(s.name.toString) + + def decodeFieldName(s: Symbol): String = + NameTransformer.decode(s.name.toString.trim) // why is there a space at the end of field name?! + def collectTypeAnnotations(tpe: Type): List[Tree] = tpe.typeSymbol.annotations.collect { case annotation if !(annotation.tree.tpe <:< JavaAnnotationTpe) => @@ -235,7 +241,7 @@ object DeriveSchema { val genericAnnotations: List[Tree] = if (tpe.typeArgs.isEmpty) Nil else { - val typeMembers = tpe.typeSymbol.asClass.typeParams.map(_.name.toString) + val typeMembers = tpe.typeSymbol.asClass.typeParams.map(decodeName) val typeArgs = tpe.typeArgs .map(_.typeSymbol.fullName) .map(t => q"_root_.zio.schema.TypeId.parse(${t}).asInstanceOf[_root_.zio.schema.TypeId.Nominal]") @@ -334,7 +340,7 @@ object DeriveSchema { concreteType(tpe, termSymbol.typeSignature), currentFrame +: stack ) - val fieldLabel = termSymbol.name.toString.trim + val fieldLabel = decodeFieldName(termSymbol) val getFunc = q" (z: _root_.scala.collection.immutable.ListMap[String, _]) => z.apply($fieldLabel).asInstanceOf[${termSymbol.typeSignature}]" @@ -353,12 +359,13 @@ object DeriveSchema { val fromMap = { val casts = fieldTypes.zip(fieldAnnotations).map { case (termSymbol, annotations) => - val newName = getFieldName(annotations).getOrElse(termSymbol.name.toString.trim) + val fieldLabel = decodeFieldName(termSymbol) + val newName = getFieldName(annotations).getOrElse(fieldLabel) q""" try m.apply(${newName}).asInstanceOf[${termSymbol.typeSignature}] catch { - case _: ClassCastException => throw new RuntimeException("Field " + ${termSymbol.name.toString.trim} + " has invalid type") - case _: Throwable => throw new RuntimeException("Field " + ${termSymbol.name.toString.trim} + " is missing") + case _: ClassCastException => throw new RuntimeException("Field " + $fieldLabel + " has invalid type") + case _: Throwable => throw new RuntimeException("Field " + $fieldLabel + " is missing") } """ } @@ -414,8 +421,8 @@ object DeriveSchema { currentFrame +: stack ) val fieldArg = if (fieldTypes.size > 1) TermName(s"field0${idx + 1}") else TermName("field0") - val fieldLabel = termSymbol.name.toString.trim - val getArg = TermName(fieldLabel) + val fieldLabel = decodeFieldName(termSymbol) + val getArg = TermName(termSymbol.name.toString.trim) val getFunc = q" (z: $tpe) => z.$getArg" val setFunc = q" (z: $tpe, v: $fieldType) => z.copy[..${tpe.typeArgs}]($getArg = v)" @@ -459,9 +466,9 @@ object DeriveSchema { val typeArgsWithFields = fieldTypes.zip(fieldAnnotations).map { case (termSymbol, annotations) if annotations.nonEmpty => - tq"${getFieldName(annotations).getOrElse(termSymbol.name.toString.trim)}.type" + tq"${getFieldName(annotations).getOrElse(decodeFieldName(termSymbol))}.type" case (termSymbol, _) => - tq"${termSymbol.name.toString.trim}.type" + tq"${decodeFieldName(termSymbol)}.type" } ++ typeArgs fieldTypes.size match { @@ -524,12 +531,12 @@ object DeriveSchema { val typeId = q"_root_.zio.schema.TypeId.parse(${tpe.toString()})" val appliedTypeArgs: Map[String, Type] = - tpe.typeConstructor.typeParams.map(_.name.toString).zip(tpe.typeArgs).toMap + tpe.typeConstructor.typeParams.map(decodeName).zip(tpe.typeArgs).toMap def appliedSubtype(subtype: Type): Type = if (subtype.typeArgs.size == 0) subtype else { - val appliedTypes = subtype.typeConstructor.typeParams.map(_.name.toString).map { typeParam => + val appliedTypes = subtype.typeConstructor.typeParams.map(decodeName).map { typeParam => appliedTypeArgs.get(typeParam) match { case None => c.abort( @@ -583,7 +590,7 @@ object DeriveSchema { val genericAnnotations: List[Tree] = if (tpe.typeArgs.isEmpty) Nil else { - val typeMembers = tpe.typeSymbol.asClass.typeParams.map(_.name.toString) + val typeMembers = tpe.typeSymbol.asClass.typeParams.map(decodeName) val typeArgs = tpe.typeArgs .map(_.typeSymbol.fullName) .map(t => q"_root_.zio.schema.TypeId.parse(${t}).asInstanceOf[_root_.zio.schema.TypeId.Nominal]") @@ -641,7 +648,7 @@ object DeriveSchema { val genericAnnotations: List[Tree] = if (subtype.typeArgs.isEmpty) Nil else { - val typeMembers = subtype.typeSymbol.asClass.typeParams.map(_.name.toString) + val typeMembers = subtype.typeSymbol.asClass.typeParams.map(decodeName) val typeArgs = subtype.typeArgs .map(_.typeSymbol.fullName) .map(t => q"_root_.zio.schema.TypeId.parse(${t}).asInstanceOf[_root_.zio.schema.TypeId.Nominal]") @@ -668,7 +675,7 @@ object DeriveSchema { EmptyTree }.filter(_ != EmptyTree) ++ genericAnnotations - val caseLabel = subtype.typeSymbol.name.toString.trim + val caseLabel = decodeName(subtype.typeSymbol) val caseSchema = directInferSchema(tpe, concreteType(tpe, subtype), currentFrame +: stack) val deconstructFn = q"(z: $tpe) => z.asInstanceOf[$subtype]" val constructFn = q"(z: $subtype) => z.asInstanceOf[$tpe]" diff --git a/zio-schema-derivation/shared/src/test/scala/zio/schema/DeriveSchemaSpec.scala b/zio-schema-derivation/shared/src/test/scala/zio/schema/DeriveSchemaSpec.scala index 04cae53f6..9ce237b50 100644 --- a/zio-schema-derivation/shared/src/test/scala/zio/schema/DeriveSchemaSpec.scala +++ b/zio-schema-derivation/shared/src/test/scala/zio/schema/DeriveSchemaSpec.scala @@ -111,7 +111,7 @@ object DeriveSchemaSpec extends ZIOSpecDefault with VersionSpecificDeriveSchemaS f21: Int = 21, f22: Int = 22, f23: Int = 23, - f24: Int = 24 + `f-24`: Int = 24 ) object Arity24 { @@ -140,7 +140,7 @@ object DeriveSchemaSpec extends ZIOSpecDefault with VersionSpecificDeriveSchemaS arity19: (User, User, User, User, User, User, User, User, User, User, User, User, User, User, User, User, User, User, User), arity20: (User, User, User, User, User, User, User, User, User, User, User, User, User, User, User, User, User, User, User, User), arity21: (User, User, User, User, User, User, User, User, User, User, User, User, User, User, User, User, User, User, User, User, User), - arity22: (User, User, User, User, User, User, User, User, User, User, User, User, User, User, User, User, User, User, User, User, User, User) + `arity-22`: (User, User, User, User, User, User, User, User, User, User, User, User, User, User, User, User, User, User, User, User, User, User) ) //scalafmt: { maxColumn = 120, optIn.configStyleArguments = true } @@ -258,11 +258,11 @@ object DeriveSchemaSpec extends ZIOSpecDefault with VersionSpecificDeriveSchemaS case class SimpleClass1() extends SimpleEnum1 sealed trait SimpleEnum2 - case class SimpleClass2() extends SimpleEnum2 + case class `Simple-Class-2`() extends SimpleEnum2 sealed abstract class AbstractBaseClass(val x: Int) - final case class ConcreteClass1(override val x: Int, y: Int) extends AbstractBaseClass(x) - final case class ConcreteClass2(override val x: Int, s: String) extends AbstractBaseClass(x) + final case class ConcreteClass1(override val x: Int, y: Int) extends AbstractBaseClass(x) + final case class `Concrete-Class-2`(override val x: Int, s: String) extends AbstractBaseClass(x) sealed abstract class AbstractBaseClass2(val x: Int) sealed abstract class MiddleClass(override val x: Int, val y: Int) extends AbstractBaseClass2(x) @@ -288,13 +288,17 @@ object DeriveSchemaSpec extends ZIOSpecDefault with VersionSpecificDeriveSchemaS assert(Schema[User].toString)(not(containsString("null")) && not(equalTo("$Lazy$"))) }, test("correctly derives case class with arity > 22") { - assert(Schema[Arity24].toString)(not(containsString("null")) && not(equalTo("$Lazy$"))) + assert(Schema[Arity24].toString)( + not(containsString("null")) && not(equalTo("$Lazy$") && containsString("f-24")) + ) }, test("correctly derives recursive data structure") { assert(Schema[Recursive].toString)(not(containsString("null")) && not(equalTo("$Lazy$"))) }, test("correctly derives tuple arities from 2 to 22") { - assert(Schema[TupleArities].toString)(not(containsString("null")) && not(equalTo("$Lazy$"))) + assert(Schema[TupleArities].toString)( + not(containsString("null")) && not(equalTo("$Lazy$") && containsString("arity-22")) + ) }, test("correctly derive mutually recursive data structure") { val c = Cyclic(1, CyclicChild1(2, CyclicChild2("3", None))) @@ -526,26 +530,26 @@ object DeriveSchemaSpec extends ZIOSpecDefault with VersionSpecificDeriveSchemaS (a: AbstractBaseClass) => a.isInstanceOf[ConcreteClass1] ), Schema.Case( - "ConcreteClass2", + "Concrete-Class-2", Schema.CaseClass2( - TypeId.parse("zio.schema.DeriveSchemaSpec.ConcreteClass2"), - field01 = Schema.Field[ConcreteClass2, Int]( + TypeId.parse("zio.schema.DeriveSchemaSpec.Concrete-Class-2"), + field01 = Schema.Field[`Concrete-Class-2`, Int]( "x", Schema.Primitive(StandardType.IntType), get0 = _.x, set0 = (a, b: Int) => a.copy(x = b) ), - field02 = Schema.Field[ConcreteClass2, String]( + field02 = Schema.Field[`Concrete-Class-2`, String]( "s", Schema.Primitive(StandardType.StringType), get0 = _.s, set0 = (a, b: String) => a.copy(s = b) ), - ConcreteClass2.apply + `Concrete-Class-2`.apply ), - (a: AbstractBaseClass) => a.asInstanceOf[ConcreteClass2], - (a: ConcreteClass2) => a.asInstanceOf[AbstractBaseClass], - (a: AbstractBaseClass) => a.isInstanceOf[ConcreteClass2] + (a: AbstractBaseClass) => a.asInstanceOf[`Concrete-Class-2`], + (a: `Concrete-Class-2`) => a.asInstanceOf[AbstractBaseClass], + (a: AbstractBaseClass) => a.isInstanceOf[`Concrete-Class-2`] ), Chunk.empty ) diff --git a/zio-schema-json/shared/src/test/scala/zio/schema/codec/JsonCodecSpec.scala b/zio-schema-json/shared/src/test/scala/zio/schema/codec/JsonCodecSpec.scala index c20db4ba3..c39a20796 100644 --- a/zio-schema-json/shared/src/test/scala/zio/schema/codec/JsonCodecSpec.scala +++ b/zio-schema-json/shared/src/test/scala/zio/schema/codec/JsonCodecSpec.scala @@ -92,6 +92,15 @@ object JsonCodecSpec extends ZIOSpecDefault { ) } ), + suite("case class")( + test("backticked field name") { + assertEncodesJson( + Schema[BacktickedFieldName], + BacktickedFieldName("test"), + """{"x-api-key":"test"}""" + ) + } + ), suite("optional field annotation")( test("list empty") { assertEncodesJson( @@ -868,6 +877,13 @@ object JsonCodecSpec extends ZIOSpecDefault { charSequenceToByteChunk("""{"query":"test","pageNumber":0,"resultPerPage":10}""") ) }, + test("backticked field name") { + assertDecodes( + BacktickedFieldName.schema, + BacktickedFieldName("test"), + charSequenceToByteChunk("""{"x-api-key":"test"}""") + ) + }, test("field name with alias - id") { assertDecodes( Order.schema, @@ -1599,7 +1615,7 @@ object JsonCodecSpec extends ZIOSpecDefault { Enumeration2(StringValue2("foo")) ) &> assertEncodesThenDecodes( Schema[Enumeration2], - Enumeration2(StringValue2Multi("foo", "bar")) + Enumeration2(`StringValue2-Backticked`("foo", "bar")) ) &> assertEncodesThenDecodes(Schema[Enumeration2], Enumeration2(IntValue2(-1))) &> assertEncodesThenDecodes( Schema[Enumeration2], Enumeration2(BooleanValue2(false)) @@ -1611,7 +1627,7 @@ object JsonCodecSpec extends ZIOSpecDefault { Enumeration3(StringValue3("foo")) ) &> assertEncodesThenDecodes( Schema[Enumeration3], - Enumeration3(StringValue3Multi("foo", "bar")) + Enumeration3(`StringValue3-Backticked`("foo", "bar")) ) }, test("of case classes with discriminator") { @@ -2147,10 +2163,10 @@ object JsonCodecSpec extends ZIOSpecDefault { @discriminatorName("_type") sealed trait OneOf2 - case class StringValue2(value: String) extends OneOf2 - case class IntValue2(value: Int) extends OneOf2 - case class BooleanValue2(value: Boolean) extends OneOf2 - case class StringValue2Multi(value1: String, value2: String) extends OneOf2 + case class StringValue2(value: String) extends OneOf2 + case class IntValue2(value: Int) extends OneOf2 + case class BooleanValue2(value: Boolean) extends OneOf2 + case class `StringValue2-Backticked`(value1: String, value2: String) extends OneOf2 case class Enumeration2(oneOf: OneOf2) @@ -2160,11 +2176,11 @@ object JsonCodecSpec extends ZIOSpecDefault { @noDiscriminator sealed trait OneOf3 - case class StringValue3(value: String) extends OneOf3 - case class IntValue3(value: Int) extends OneOf3 - case class BooleanValue3(value: Boolean) extends OneOf3 - case class StringValue3Multi(value1: String, value2: String) extends OneOf3 - case class Nested(oneOf: OneOf3) extends OneOf3 + case class StringValue3(value: String) extends OneOf3 + case class IntValue3(value: Int) extends OneOf3 + case class BooleanValue3(value: Boolean) extends OneOf3 + case class `StringValue3-Backticked`(value1: String, value2: String) extends OneOf3 + case class Nested(oneOf: OneOf3) extends OneOf3 case class Enumeration3(oneOf: OneOf3) @@ -2449,4 +2465,10 @@ object JsonCodecSpec extends ZIOSpecDefault { object Recursive { implicit val schema: Schema[Recursive] = DeriveSchema.gen } + + case class BacktickedFieldName(`x-api-key`: String) + + object BacktickedFieldName { + implicit val schema: Schema[BacktickedFieldName] = DeriveSchema.gen + } }