From 6642199bcd7488d3963d3b9995af0739594a438f Mon Sep 17 00:00:00 2001 From: johannes karoff Date: Wed, 9 Aug 2023 16:00:39 +0200 Subject: [PATCH] fix tuple type in scala3 client macro (#214) --- build.sbt | 1 + project/Deps.scala | 2 ++ sloth/src/main/scala-3/internal/Macros.scala | 32 ++++++++++++++------ sloth/src/test/scala-3/Sloth3Spec.scala | 32 ++++++++++++++++++++ 4 files changed, 57 insertions(+), 10 deletions(-) create mode 100644 sloth/src/test/scala-3/Sloth3Spec.scala diff --git a/build.sbt b/build.sbt index 7b57a952..f8291c74 100644 --- a/build.sbt +++ b/build.sbt @@ -67,6 +67,7 @@ lazy val sloth = crossProject(JSPlatform, JVMPlatform) Deps.cats.value :: Deps.chameleon.value :: + Deps.zioJson.value % Test :: Deps.circe.core.value % Test :: Deps.circe.generic.value % Test :: Deps.circe.parser.value % Test :: diff --git a/project/Deps.scala b/project/Deps.scala index 123b07d8..f307b5e6 100644 --- a/project/Deps.scala +++ b/project/Deps.scala @@ -11,6 +11,8 @@ object Deps { val cats = dep("org.typelevel" %%% "cats-core" % "2.9.0") val chameleon = dep("com.github.cornerman" %%% "chameleon" % "0.3.7") + val zioJson = dep("dev.zio" %%% "zio-json" % "0.6.0") + val circe = new { private val version = "0.14.1" val core = dep("io.circe" %%% "circe-core" % version) diff --git a/sloth/src/main/scala-3/internal/Macros.scala b/sloth/src/main/scala-3/internal/Macros.scala index 57506564..08c60b18 100644 --- a/sloth/src/main/scala-3/internal/Macros.scala +++ b/sloth/src/main/scala-3/internal/Macros.scala @@ -60,7 +60,7 @@ private def getInnerTypeOutOfReturnType[Trait: Type, Result[_]: Type](using Quot tpe.typeArgs.indexWhere(_ =:= firstParamType) case _ => -1 } - + val methodReturnType = getMethodType[Trait](method) parameterTypeIndex match { case -1 => methodReturnType.typeArgs.last @@ -68,6 +68,15 @@ private def getInnerTypeOutOfReturnType[Trait: Type, Result[_]: Type](using Quot } } +def createTypeTreeTuple(using Quotes)(tupleTypesList: List[quotes.reflect.TypeRepr]): quotes.reflect.TypeRepr = { + import quotes.reflect.* + + tupleTypesList match { + case Nil => TypeRepr.of[EmptyTuple] + case head :: tail => TypeRepr.of[*:].appliedTo(List(head, createTypeTreeTuple(tail))) + } +} + private def checkMethodErrors[Trait: Type, Result[_]: Type](using q: Quotes)(methods: Seq[quotes.reflect.Symbol]): Unit = { import quotes.reflect.* @@ -167,11 +176,18 @@ object TraitMacro { case arg :: Nil => arg.asExpr case allArgs => Expr.ofTupleFromSeq(allArgs.map(_.asExpr)) } - + + val tupleTypesList = origMethod.paramSymss.flatten.map(_.tree.asInstanceOf[ValDef].tpt.tpe) + val tupleType = tupleTypesList match { + case Nil => TypeRepr.of[Unit] + case head :: Nil => head + case tupleTypesList => createTypeTreeTuple(tupleTypesList) + } + val returnType = getInnerTypeOutOfReturnType[Trait, Result](method) val clientImplType = TypeRepr.of[ClientImpl[PickleType, Result]].typeSymbol - val tupleTypeTree = TypeTree.of(using tupleExpr.asTerm.tpe.asType) + val tupleTypeTree = TypeTree.of(using tupleType.asType) val returnTypeTree = TypeTree.of(using returnType.asType) Apply( @@ -232,14 +248,10 @@ object RouterMacro { val path = traitPathPart :: methodPathPart :: Nil val pathExpr = Expr(path) - val tupleTypesList = method.paramSymss.flatten.map(_.tree.asInstanceOf[ValDef].tpt.tpe) - - def createTypeTreeTuple(tupleTypesList: List[TypeRepr]): TypeRepr = tupleTypesList match { - case Nil => TypeRepr.of[EmptyTuple] - case head :: tail => TypeRepr.of[*:].appliedTo(List(head, createTypeTreeTuple(tail))) - } - + val returnType = getInnerTypeOutOfReturnType[Trait, Result](method) + + val tupleTypesList = method.paramSymss.flatten.map(_.tree.asInstanceOf[ValDef].tpt.tpe) val tupleType = tupleTypesList match { case Nil => TypeRepr.of[Unit] case head :: Nil => head diff --git a/sloth/src/test/scala-3/Sloth3Spec.scala b/sloth/src/test/scala-3/Sloth3Spec.scala new file mode 100644 index 00000000..6b749bef --- /dev/null +++ b/sloth/src/test/scala-3/Sloth3Spec.scala @@ -0,0 +1,32 @@ +package test3 + +import chameleon.* +import chameleon.ext.zioJson.given +import sloth.* +import zio.json.* + +import org.scalatest.freespec.AsyncFreeSpec +import org.scalatest.matchers.must.Matchers + +import scala.concurrent.Future + +trait Api { + def single(page: Int): Future[String] + def list(page: Int, limit: Int): Future[String] +} + +class Sloth3Spec extends AsyncFreeSpec with Matchers { + + "compile client" in { + val client = Client[String, Future](null) + client.wire[Api] + + succeed + } + + "compile router" in { + val router = Router[String, Future].route[Api](null) + + succeed + } +}