diff --git a/http4sServer/src/main/scala/HttpRpcRoutes.scala b/http4sServer/src/main/scala/HttpRpcRoutes.scala index 1f361773..405bf657 100644 --- a/http4sServer/src/main/scala/HttpRpcRoutes.scala +++ b/http4sServer/src/main/scala/HttpRpcRoutes.scala @@ -13,6 +13,11 @@ object HttpRpcRoutes { def apply[PickleType: EntityDecoder[F, *]: EntityEncoder[F, *], F[_]: Concurrent]( router: Router[PickleType, F], onError: PartialFunction[Throwable, F[Response[F]]] = PartialFunction.empty + ): HttpRoutes[F] = withRequest[PickleType, F](_ => router, onError) + + def withRequest[PickleType: EntityDecoder[F, *]: EntityEncoder[F, *], F[_]: Concurrent]( + router: Request[F] => Router[PickleType, F], + onError: PartialFunction[Throwable, F[Response[F]]] = PartialFunction.empty ): HttpRoutes[F] = { val dsl = Http4sDsl[F] import dsl._ @@ -21,7 +26,7 @@ object HttpRpcRoutes { request.pathInfo.segments match { case Vector(apiName, methodName) => val path = List(apiName.decoded(), methodName.decoded()) - val result = router.getFunction(path).traverse { f => + val result = router(request).getFunction(path).traverse { f => request.as[PickleType].flatMap { payload => f(payload) match { case Left(error) => serverFailureToResponse[F](dsl, onError)(error) @@ -39,6 +44,11 @@ object HttpRpcRoutes { def eventStream[F[_]: Concurrent]( router: Router[String, Stream[F, *]], onError: PartialFunction[Throwable, F[Response[F]]] = PartialFunction.empty + ): HttpRoutes[F] = eventStreamWithRequest[F](_ => router, onError) + + def eventStreamWithRequest[F[_]: Concurrent]( + router: Request[F] => Router[String, Stream[F, *]], + onError: PartialFunction[Throwable, F[Response[F]]] = PartialFunction.empty ): HttpRoutes[F] = { val dsl = Http4sDsl[F] import dsl._ @@ -47,7 +57,7 @@ object HttpRpcRoutes { request.pathInfo.segments match { case Vector(apiName, methodName) => val path = List(apiName.decoded(), methodName.decoded()) - val result = router.getFunction(path).traverse { f => + val result = router(request).getFunction(path).traverse { f => request.as[String].flatMap { payload => f(payload) match { case Left(error) => serverFailureToResponse[F](dsl, onError)(error)