Skip to content

Commit

Permalink
generate String comparison instead of HashMap for Router and use Meth…
Browse files Browse the repository at this point in the history
…od data structure (#246)

This makes the Router#route macro generate more efficient code.

We now generate a direct string comparison for mapping the request to the right implementation, instead of a HashMap. An example macro expansion looks like this:
```scala
val value = ApiImplFunResponse;
val implRouter = sloth.Router.apply[test.PickleType, test.TypeHelper.ApiResultFun](TypeHelper.functor(SlothSpec.this.executionContext));
val impl = new _root_.sloth.internal.RouterImpl[test.PickleType, test.TypeHelper.ApiResultFun](implRouter);
implRouter.orElse(((endpoint) =>
    if (endpoint.traitName.$eq$eq("Api"))
        endpoint.methodName match {
            case "fun" => Some(((payload: test.PickleType) => impl.execute[scala.Tuple2[Int, String], Int](endpoint, payload)(((args) => value.fun(args._1, args._2)))))
            case "multi" => Some(((payload: test.PickleType) => impl.execute[scala.Tuple2[Int, Int], Int](endpoint, payload)(((args) => value.multi(args._1)(args._2)))))
            case _ => None
        }
    else None
))
```

We are also changing the reoccuring type `path: List[String]` to `Method`, which consists of the traitName and the methodName. We have added deprecated aliases for easier migration, it allows to work `List[String]` (now `Method`), `PathName` (now `Name`), `PathNotFound` (now `MethodNotFound`). These will be removed in the upcoming 1.0 release.
  • Loading branch information
cornerman authored Jun 25, 2024
1 parent 4a03077 commit 005e26e
Show file tree
Hide file tree
Showing 15 changed files with 156 additions and 113 deletions.
35 changes: 20 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ val router = Router[ByteBuffer, Future].route[Api](ApiImpl)

Use it to route requests to your Api implementation:
```scala
val result = router(Request[ByteBuffer]("Api" :: "fun" :: Nil, bytes))
val result = router(Request[ByteBuffer](Method(traitName = "Api", methodName = "fun"), bytes))
// Now result contains the serialized Int result returned by the method ApiImpl.fun
```

Expand Down Expand Up @@ -187,7 +187,7 @@ For logging, you can define a `LogHandler`, which can log each request including
Define it when creating the `Client`:
```scala
object MyLogHandler extends LogHandler[ClientResult[_]] {
def logRequest[T](path: List[String], argumentObject: Any, result: ClientResult[T]): ClientResult[T] = ???
def logRequest[T](method: Method, argumentObject: Any, result: ClientResult[T]): ClientResult[T] = ???
}

val client = Client[PickleType, ClientResult](Transport, MyLogHandler)
Expand All @@ -196,19 +196,19 @@ val client = Client[PickleType, ClientResult](Transport, MyLogHandler)
Define it when creating the `Router`:
```scala
object MyLogHandler extends LogHandler[ServerResult[_]] {
def logRequest[T](path: List[String], argumentObject: Any, result: ServerResult[T]): ServerResult[T] = ???
def logRequest[T](method: Method, argumentObject: Any, result: ServerResult[T]): ServerResult[T] = ???
}

val router = Router[PickleType, ServerResult](MyLogHandler)
```

### Method overloading

When overloading methods with different parameter lists, sloth does not have a unique path (because it is derived from the trait name and the method name). Here you will need to provide your own path name:
When overloading methods with different parameter lists, sloth cannot uniquely identify the method (because it is referenced with the trait name and the method name). Here you will need to provide a custom name:
```scala
trait Api {
def fun(i: Int): F[Int]
@PathName("funWithString")
@Name("funWithString")
def fun(i: Int, s: String): F[Int]
}
```
Expand All @@ -223,27 +223,32 @@ In the above examples, we used the type `ByteBuffer` to select the serialization

Sloth derives all information about an API from a scala trait. For example:
```scala
// @PathName("apiName")
// @Name("traitName")
trait Api {
// @PathName("funName")
// @Name("funName")
def fun(a: Int, b: String)(c: Double): F[Int]
}
```

For each declared method in this trait (in this case `fun`):
* Calculate method path: `List("Api", "fun")` (`PathName` annotations on the trait or method are taken into account).
* Calculate method name: `Method("Api", "fun")` (`Name` annotations on the trait or method are taken into account).
* Serialize the method parameters as a tuple: `(a, b, c)`.

### Server

When calling `router.route[Api](impl)`, a macro generates a function that maps a method path and the pickled arguments to a pickled result. This basically boils down to:
When calling `router.route[Api](impl)`, a macro generates a function that maps a method (trait-name + method-name) and the pickled arguments to a pickled result. This basically boils down to:

```scala
HashMap("Api" -> HashMap("fun" -> { payload =>
// deserialize payload
// call Api implementation impl with arguments
// return serialized response
}))
{ (method: sloth.Method) =>
if (method.traitName = "Api") method.methodName match {
case "fun" => Some({ payload =>
// deserialize payload
// call Api implementation impl with arguments
// return serialized response
})
case _ => None
} else None
}
```

### Client
Expand All @@ -254,7 +259,7 @@ When calling `client.wire[Api]`, a macro generates an instance of `Api` by imple
new Api {
def fun(a: Int, b: String)(c: Double): F[Int] = {
// serialize arguments
// call RequestTransport transport with method path and arguments
// call RequestTransport transport with method and arguments
// return deserialized response
}
}
Expand Down
19 changes: 9 additions & 10 deletions http4sClient/src/main/scala/HttpRpcTransport.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,15 @@ import cats.implicits._
import org.http4s.client.Client
import org.http4s.{EntityBody, EntityDecoder, EntityEncoder, Headers, HttpVersion, Method, Request, ServerSentEvent, Uri}
import fs2.Stream
import sloth.RequestTransport

case class HttpRequestConfig(
baseUri: Uri = Uri(path = Uri.Path.Root),
headers: Headers = Headers.empty,
httpVersion: HttpVersion = HttpVersion.`HTTP/1.1`,
) {
def toRequest[F[_]](requestPath: List[String], entityBody: EntityBody[F]): Request[F] = Request[F](
def toRequest[F[_]](method: sloth.Method, entityBody: EntityBody[F]): Request[F] = Request[F](
method = Method.POST,
uri = requestPath.foldLeft(baseUri)(_ / _),
uri = baseUri / method.traitName / method.methodName,
httpVersion = httpVersion,
headers = headers,
body = entityBody,
Expand All @@ -27,34 +26,34 @@ object HttpRpcTransport {
)(implicit
encoder: EntityEncoder[F, PickleType],
decoder: EntityDecoder[F, PickleType]
): RequestTransport[PickleType, F] = apply(client, HttpRequestConfig().pure[F])
): sloth.RequestTransport[PickleType, F] = apply(client, HttpRequestConfig().pure[F])

def apply[PickleType, F[_]: Concurrent](
client: Client[F],
config: F[HttpRequestConfig]
)(implicit
encoder: EntityEncoder[F, PickleType],
decoder: EntityDecoder[F, PickleType]
): RequestTransport[PickleType, F] = new sloth.RequestTransport[PickleType, F] {
): sloth.RequestTransport[PickleType, F] = new sloth.RequestTransport[PickleType, F] {
override def apply(request: sloth.Request[PickleType]): F[PickleType] = for {
config <- config
responseBody <- client.expect[PickleType](config.toRequest(request.path, encoder.toEntity(request.payload).body))
responseBody <- client.expect[PickleType](config.toRequest(request.method, encoder.toEntity(request.payload).body))
} yield responseBody
}

def eventStream[F[_]: Concurrent](
client: Client[F],
): RequestTransport[String, Stream[F, *]] = eventStream(client, HttpRequestConfig().pure[F])
): sloth.RequestTransport[String, Stream[F, *]] = eventStream(client, HttpRequestConfig().pure[F])

def eventStream[F[_]: Concurrent](
client: Client[F],
config: F[HttpRequestConfig]
): RequestTransport[String, Stream[F, *]] = new sloth.RequestTransport[String, Stream[F, *]] {
): sloth.RequestTransport[String, Stream[F, *]] = new sloth.RequestTransport[String, Stream[F, *]] {
override def apply(request: sloth.Request[String]): Stream[F, String] = for {
config <- Stream.eval(config)
response <- Stream.resource(client.run(config.toRequest(request.path, EntityEncoder[F, String].toEntity(request.payload).body)))
response <- Stream.resource(client.run(config.toRequest(request.method, EntityEncoder[F, String].toEntity(request.payload).body)))
event <- response.body.through(ServerSentEvent.decoder[F])
data <- Stream.fromOption(event.data)
} yield data
}
}
}
29 changes: 14 additions & 15 deletions http4sServer/src/main/scala/HttpRpcRoutes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,26 @@ import cats.effect.Concurrent
import org.http4s._
import org.http4s.dsl.Http4sDsl
import fs2.Stream
import sloth.{Router, ServerFailure}

object HttpRpcRoutes {

def apply[PickleType: EntityDecoder[F, *]: EntityEncoder[F, *], F[_]: Concurrent](
router: Router[PickleType, F],
router: sloth.Router[PickleType, F],
onError: PartialFunction[Throwable, F[Response[F]]] = PartialFunction.empty
): HttpRoutes[F] = withRequest[PickleType, F](_ => router, onError)

def withRequest[PickleType: EntityDecoder[F, *]: EntityEncoder[F, *], F[_]: Concurrent](
router: Request[F] => Router[PickleType, F],
router: Request[F] => sloth.Router[PickleType, F],
onError: PartialFunction[Throwable, F[Response[F]]] = PartialFunction.empty
): HttpRoutes[F] = {
val dsl = Http4sDsl[F]
import dsl._

HttpRoutes[F] { request =>
request.pathInfo.segments match {
case Vector(apiName, methodName) =>
val path = List(apiName.decoded(), methodName.decoded())
val result = router(request).getFunction(path).traverse { f =>
case Vector(traitName, methodName) =>
val method = sloth.Method(traitName.decoded(), methodName.decoded())
val result = router(request).getMethod(method).traverse { f =>
request.as[PickleType].flatMap { payload =>
f(payload) match {
case Left(error) => serverFailureToResponse[F](dsl, onError)(error)
Expand All @@ -42,22 +41,22 @@ object HttpRpcRoutes {
}

def eventStream[F[_]: Concurrent](
router: Router[String, Stream[F, *]],
router: sloth.Router[String, Stream[F, *]],
onError: PartialFunction[Throwable, F[Response[F]]] = PartialFunction.empty
): HttpRoutes[F] = eventStreamWithRequest[F](_ => router, onError)

def eventStreamWithRequest[F[_]: Concurrent](
router: Request[F] => Router[String, Stream[F, *]],
router: Request[F] => sloth.Router[String, Stream[F, *]],
onError: PartialFunction[Throwable, F[Response[F]]] = PartialFunction.empty
): HttpRoutes[F] = {
val dsl = Http4sDsl[F]
import dsl._

HttpRoutes[F] { request =>
request.pathInfo.segments match {
case Vector(apiName, methodName) =>
val path = List(apiName.decoded(), methodName.decoded())
val result = router(request).getFunction(path).traverse { f =>
case Vector(traitName, methodName) =>
val method = sloth.Method(traitName.decoded(), methodName.decoded())
val result = router(request).getMethod(method).traverse { f =>
request.as[String].flatMap { payload =>
f(payload) match {
case Left(error) => serverFailureToResponse[F](dsl, onError)(error)
Expand All @@ -72,12 +71,12 @@ object HttpRpcRoutes {
}
}

private def serverFailureToResponse[F[_]: Concurrent](dsl: Http4sDsl[F], onError: PartialFunction[Throwable, F[Response[F]]])(failure: ServerFailure): F[Response[F]] = {
private def serverFailureToResponse[F[_]: Concurrent](dsl: Http4sDsl[F], onError: PartialFunction[Throwable, F[Response[F]]])(failure: sloth.ServerFailure): F[Response[F]] = {
import dsl._
failure match {
case ServerFailure.PathNotFound(_) => NotFound()
case ServerFailure.HandlerError(err) => onError.lift(err).getOrElse(InternalServerError(err.getMessage))
case ServerFailure.DeserializerError(err) => onError.lift(err).getOrElse(BadRequest(err.getMessage))
case sloth.ServerFailure.MethodNotFound(_) => NotFound()
case sloth.ServerFailure.HandlerError(err) => onError.lift(err).getOrElse(InternalServerError(err.getMessage))
case sloth.ServerFailure.DeserializerError(err) => onError.lift(err).getOrElse(BadRequest(err.getMessage))
}
}
}
2 changes: 1 addition & 1 deletion jsdomClient/src/main/scala/HttpRpcTransport.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ object HttpRpcTransport {
def apply[F[_]: Async](config: F[HttpRequestConfig]): RequestTransport[String, F] = new RequestTransport[String, F] {
override def apply(request: Request[String]): F[String] = for {
config <- config
url = s"${config.baseUri}${request.path.mkString("/")}"
url = s"${config.baseUri}/${request.method.traitName}/${request.method.methodName}"
requestArgs = new dom.RequestInit { headers = config.headers.toJSDictionary; method = dom.HttpMethod.POST; body = request.payload }
result <- Async[F].fromThenable(Async[F].delay[js.Thenable[String]](dom.fetch(url, requestArgs).`then`[String](_.text())))
} yield result
Expand Down
36 changes: 23 additions & 13 deletions sloth/src/main/scala-2/internal/Macros.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ class Translator[C <: Context](val c: C) {
import c.universe._
import Validator._

val slothPkg = q"_root_.sloth"
val internalPkg = q"_root_.sloth.internal"
object implicits {
implicit val liftMethod: Liftable[sloth.Method] =
Liftable[sloth.Method]{ r => q"new _root_.sloth.Method(${r.traitName}, ${r.methodName})" }
}

def abort(msg: String) = c.abort(c.enclosingPosition, msg)

Expand All @@ -36,11 +38,11 @@ class Translator[C <: Context](val c: C) {
private def validateAllMethods(methods: List[(MethodSymbol, Type)]): List[Either[String, (MethodSymbol, Type)]] =
methods.groupBy(m => methodPathPart(m._1)).map {
case (_, x :: Nil) => Right(x)
case (k, _) => Left(s"""method $k is overloaded (rename the method or add a @PathName("other-name"))""")
case (k, _) => Left(s"""Method $k is overloaded, please rename one of the methods or use the @Name("other-name") annotation to disambiguate""")
}.toList

private def findPathName(annotations: Seq[Annotation]) = annotations.reverse.map(_.tree).collectFirst {
case Apply(Select(New(annotation), _), Literal(Constant(name)) :: Nil) if annotation.tpe =:= typeOf[sloth.PathName] => name.toString
private def findCustomName(annotations: Seq[Annotation]) = annotations.reverse.map(_.tree).collectFirst {
case Apply(Select(New(annotation), _), Literal(Constant(name)) :: Nil) if annotation.tpe =:= typeOf[sloth.Name] => name.toString
}

private def eitherSeq[A, B](list: List[Either[A, B]]): Either[List[A], List[B]] = list.partition(_.isLeft) match {
Expand Down Expand Up @@ -72,10 +74,10 @@ class Translator[C <: Context](val c: C) {

//TODO what about fqn for trait to not have overlaps?
def traitPathPart(tpe: Type): String =
findPathName(tpe.typeSymbol.annotations).getOrElse(tpe.typeSymbol.name.toString)
findCustomName(tpe.typeSymbol.annotations).getOrElse(tpe.typeSymbol.name.toString)

def methodPathPart(m: MethodSymbol): String =
findPathName(m.annotations).getOrElse(m.name.toString)
findCustomName(m.annotations).getOrElse(m.name.toString)

def paramAsValDef(p: Symbol): ValDef = q"val ${p.name.toTermName}: ${p.typeSignature}"
def paramsAsValDefs(m: Type): List[List[ValDef]] = m.paramLists.map(_.map(paramAsValDef))
Expand Down Expand Up @@ -116,6 +118,7 @@ class Translator[C <: Context](val c: C) {
object Translator {
def apply[T](c: Context)(f: Translator[c.type] => c.Tree): c.Expr[T] = {
val tree = f(new Translator(c))
// println(tree)
c.Expr(tree)
}
}
Expand All @@ -125,14 +128,15 @@ object TraitMacro {
(c: Context)
(impl: c.Tree)
(implicit traitTag: c.WeakTypeTag[Trait], resultTag: c.WeakTypeTag[Result[_]]): c.Expr[Trait] = Translator(c) { t =>
import t.implicits._
import c.universe._

val validMethods = t.supportedMethodsInType(traitTag.tpe, resultTag.tpe)

val traitPathPart = t.traitPathPart(traitTag.tpe)
val methodImplList = validMethods.collect { case (symbol, method) =>
val methodPathPart = t.methodPathPart(symbol)
val path = traitPathPart :: methodPathPart :: Nil
val path = sloth.Method(traitPathPart, methodPathPart)
val parameters = t.paramsAsValDefs(method)
val paramsType = t.paramsType(method)
val paramListValue = t.wrapAsParamsType(method)
Expand Down Expand Up @@ -184,26 +188,32 @@ object RouterMacro {
val validMethods = t.supportedMethodsInType(traitTag.tpe, resultTag.tpe)

val traitPathPart = t.traitPathPart(traitTag.tpe)
val methodTuples = validMethods.map { case (symbol, method) =>
val methodCases = validMethods.map { case (symbol, method) =>
val methodPathPart = t.methodPathPart(symbol)
val path = traitPathPart :: methodPathPart :: Nil
val paramsType = t.paramsType(method)
val argParams = t.objectToParams(method, TermName("args"))
val innerReturnType = t.getInnerTypeOutOfReturnType(resultTag.tpe, method.finalResultType)
val payloadFunction =
q"""(payload: ${pickleTypeTag.tpe}) => impl.execute[${paramsType}, $innerReturnType]($path, payload) { args =>
q"""(payload: ${pickleTypeTag.tpe}) => impl.execute[${paramsType}, $innerReturnType](method, payload) { args =>
value.${symbol.name.toTermName}(...$argParams)
}"""

q"($methodPathPart, $payloadFunction)"
cq"$methodPathPart => Some($payloadFunction)"
}

q"""
val value = $value
val implRouter = ${c.prefix}
val impl = $impl

implRouter.orElse($traitPathPart, scala.collection.immutable.Map(..$methodTuples))
implRouter.orElse { method =>
if (method.traitName == $traitPathPart) {
method.methodName match {
case ..$methodCases
case _ => None
}
} else None
}
"""
}

Expand Down
Loading

0 comments on commit 005e26e

Please sign in to comment.