Skip to content

Commit

Permalink
Fix cleanup in Netty handler after a request timeout (#4156)
Browse files Browse the repository at this point in the history
  • Loading branch information
adamw authored Nov 12, 2024
1 parent d7f84bd commit 99b8861
Show file tree
Hide file tree
Showing 11 changed files with 265 additions and 36 deletions.
3 changes: 2 additions & 1 deletion .scalafix.conf
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
OrganizeImports.groupedImports = AggressiveMerge
OrganizeImports.targetDialect = Scala3
OrganizeImports.targetDialect = Scala3
OrganizeImports.removeUnused = false
2 changes: 2 additions & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ val commonSettings = commonSmlBuildSettings ++ ossPublishSettings ++ Seq(
case _ => Seq("-Xmax-inlines", "64")
}
},
Test / scalacOptions += "-Wconf:msg=unused value of type org.scalatest.Assertion:s",
Test / scalacOptions += "-Wconf:msg=unused value of type org.scalatest.compatible.Assertion:s",
evictionErrorLevel := Level.Info
)

Expand Down
2 changes: 1 addition & 1 deletion project/Versions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ object Versions {
val helidon = "4.0.10"
val sttp = "3.10.1"
val sttpModel = "1.7.11"
val sttpShared = "1.3.22"
val sttpShared = "1.4.0"
val sttpApispec = "0.11.3"
val akkaHttp = "10.2.10"
val akkaStreams = "2.6.20"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ private[metrics] class MetricsEndpointInterceptor[F[_]](
)(implicit monad: MonadError[F], bodyListener: BodyListener[F, B]): F[ServerResponse[B]] = {
if (ignoreEndpoints.contains(ctx.endpoint)) endpointHandler.onDecodeSuccess(ctx)
else {
val responseWithMetrics: F[ServerResponse[B]] = for {
def responseWithMetrics: F[ServerResponse[B]] = for {
_ <- collectRequestMetrics(ctx.endpoint)
response <- endpointHandler.onDecodeSuccess(ctx)
_ <- collectResponseHeadersMetrics(ctx.endpoint, response)
Expand All @@ -64,7 +64,7 @@ private[metrics] class MetricsEndpointInterceptor[F[_]](
)(implicit monad: MonadError[F], bodyListener: BodyListener[F, B]): F[ServerResponse[B]] = {
if (ignoreEndpoints.contains(ctx.endpoint)) endpointHandler.onSecurityFailure(ctx)
else {
val responseWithMetrics: F[ServerResponse[B]] = for {
def responseWithMetrics: F[ServerResponse[B]] = for {
_ <- collectRequestMetrics(ctx.endpoint)
response <- endpointHandler.onSecurityFailure(ctx)
_ <- collectResponseHeadersMetrics(ctx.endpoint, response)
Expand All @@ -83,7 +83,7 @@ private[metrics] class MetricsEndpointInterceptor[F[_]](
)(implicit monad: MonadError[F], bodyListener: BodyListener[F, B]): F[Option[ServerResponse[B]]] = {
if (ignoreEndpoints.contains(ctx.endpoint)) endpointHandler.onDecodeFailure(ctx)
else {
val responseWithMetrics: F[Option[ServerResponse[B]]] = for {
def responseWithMetrics: F[Option[ServerResponse[B]]] = for {
response <- endpointHandler.onDecodeFailure(ctx)
withMetrics <- response match {
case Some(response) =>
Expand Down Expand Up @@ -129,7 +129,7 @@ private[metrics] class MetricsEndpointInterceptor[F[_]](
}
}

private def handleResponseExceptions[T](r: F[T], e: AnyEndpoint)(implicit monad: MonadError[F]): F[T] =
private def handleResponseExceptions[T](r: => F[T], e: AnyEndpoint)(implicit monad: MonadError[F]): F[T] =
r.handleError { case ex: Exception => collectExceptionMetrics(e, ex) }

private def collectExceptionMetrics[T](e: AnyEndpoint, ex: Throwable)(implicit monad: MonadError[F]): F[T] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ case class NettyConfig(

def initPipeline(f: NettyConfig => (ChannelPipeline, ChannelHandler) => Unit): NettyConfig = copy(initPipeline = f)

def withGracefulShutdownTimeout(t: FiniteDuration) = copy(gracefulShutdownTimeout = Some(t))
def noGracefulShutdown = copy(gracefulShutdownTimeout = None)
def withGracefulShutdownTimeout(t: FiniteDuration): NettyConfig = copy(gracefulShutdownTimeout = Some(t))
def noGracefulShutdown: NettyConfig = copy(gracefulShutdownTimeout = None)

def serverHeader(h: String): NettyConfig = copy(serverHeader = Some(h))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ import scala.collection.mutable.{Queue => MutableQueue}
import scala.concurrent.{ExecutionContext, Future}
import scala.util.control.NonFatal
import scala.util.{Failure, Success}
import java.util.concurrent.TimeoutException
import org.reactivestreams.Subscriber
import org.reactivestreams.Subscription

/** @param unsafeRunAsync
* Function which dispatches given effect to run asynchronously, returning its result as a Future, and function of type `() =>
Expand Down Expand Up @@ -109,11 +112,13 @@ class NettyServerHandler[F[_]](
evt match {
case e: IdleStateEvent =>
if (e.state() == IdleState.WRITER_IDLE) {
logger.error(s"Closing connection due to exceeded response timeout of ${config.requestTimeout}")
logger.error(
s"Closing connection due to exceeded response timeout of ${config.requestTimeout.map(_.toString).getOrElse("(not set)")}"
)
writeError503ThenClose(ctx)
}
if (e.state() == IdleState.ALL_IDLE) {
logger.debug(s"Closing connection due to exceeded idle timeout of ${config.idleTimeout}")
logger.debug(s"Closing connection due to exceeded idle timeout of ${config.idleTimeout.map(_.toString).getOrElse("(not set)")}")
val _ = ctx.close()
}
case other =>
Expand Down Expand Up @@ -147,30 +152,42 @@ class NettyServerHandler[F[_]](
pendingResponses.enqueue(cancellationSwitch)
lastResponseSent = lastResponseSent.flatMap { _ =>
runningFuture
.andThen { case _ =>
requestTimeoutHandler.foreach(ctx.pipeline().remove)
}(eventLoopContext)
.transform {
case Success(serverResponse) =>
pendingResponses.dequeue()
try {
handleResponse(ctx, req, serverResponse)
Success(())
} catch {
case NonFatal(ex) =>
writeError500(req, ex)
Failure(ex)
} finally {
val _ = releaseReq()
}
case Failure(NonFatal(ex)) =>
try {
writeError500(req, ex)
Failure(ex)
} finally {
val _ = releaseReq()
.transform { result =>
try {
// #4131: the channel might be closed if the request timed out
// both timeout & response-ready events (i.e., comleting this future) are handled on the event loop's executor,
// so they won't be handled concurrently
if (ctx.channel().isOpen()) {
requestTimeoutHandler.foreach(ctx.pipeline().remove)
result match {
case Success(serverResponse) =>
pendingResponses.dequeue()
try {
handleResponse(ctx, req, serverResponse)
Success(())
} catch {
case NonFatal(ex) =>
writeError500(req, ex)
Failure(ex)
}
case Failure(NonFatal(ex)) =>
writeError500(req, ex)
Failure(ex)
case Failure(fatalException) => Failure(fatalException)
}
} else {
// pendingResponses is already dequeued because the channel is closed
result match {
case Success(serverResponse) =>
val e = new TimeoutException("Request timed out")
handleResponseAfterTimeout(ctx, serverResponse, e)
Failure(e)
case Failure(e) => Failure(e)
}
}
case Failure(fatalException) => Failure(fatalException)
} finally {
val _ = releaseReq()
}
}(eventLoopContext)
}(eventLoopContext)
}
Expand Down Expand Up @@ -270,6 +287,39 @@ class NettyServerHandler[F[_]](
}
)

private def handleResponseAfterTimeout(
ctx: ChannelHandlerContext,
serverResponse: ServerResponse[NettyResponse],
timeoutException: Exception
): Unit =
serverResponse.handle(
ctx = ctx,
byteBufHandler = (channelPromise, byteBuf) => { val _ = channelPromise.setFailure(timeoutException) },
chunkedStreamHandler = (channelPromise, chunkedStream) => {
chunkedStream.close()
val _ = channelPromise.setFailure(timeoutException)
},
chunkedFileHandler = (channelPromise, chunkedFile) => {
chunkedFile.close()
val _ = channelPromise.setFailure(timeoutException)
},
reactiveStreamHandler = (channelPromise, publisher) => {
publisher.subscribe(new Subscriber[HttpContent] {
override def onSubscribe(s: Subscription): Unit = {
s.cancel()
val _ = channelPromise.setFailure(timeoutException)
}
override def onNext(t: HttpContent): Unit = ()
override def onError(t: Throwable): Unit = ()
override def onComplete(): Unit = ()
})
},
wsHandler = (responseContent) => {
val _ = responseContent.channelPromise.setFailure(timeoutException)
},
noBodyHandler = () => ()
)

private def initWsPipeline(
ctx: ChannelHandlerContext,
r: ReactiveWebSocketProcessorNettyResponseContent,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package sttp.tapir.server.netty

import sttp.tapir._
import sttp.tapir.tests.Test
import scala.concurrent.Future
import java.util.concurrent.atomic.AtomicInteger
import scala.concurrent.duration.DurationInt
import sttp.tapir.server.interceptor.metrics.MetricsRequestInterceptor
import sttp.tapir.server.metrics.Metric
import sttp.tapir.server.metrics.EndpointMetric
import io.netty.channel.EventLoopGroup
import cats.effect.IO
import cats.effect.kernel.Resource
import scala.concurrent.ExecutionContext
import sttp.client3._
import sttp.capabilities.fs2.Fs2Streams
import sttp.capabilities.WebSockets
import org.scalatest.matchers.should.Matchers._
import cats.effect.unsafe.implicits.global
import sttp.model.StatusCode

class NettyFutureRequestTimeoutTests(eventLoopGroup: EventLoopGroup, backend: SttpBackend[IO, Fs2Streams[IO] with WebSockets])(implicit
ec: ExecutionContext
) {
def tests(): List[Test] = List(
Test("properly update metrics when a request times out") {
val e = endpoint.post
.in(stringBody)
.out(stringBody)
.serverLogicSuccess[Future] { body =>
Thread.sleep(2000); Future.successful(body)
}

val activeRequests = new AtomicInteger()
val totalRequests = new AtomicInteger()
val customMetrics: List[Metric[Future, AtomicInteger]] = List(
Metric(
metric = activeRequests,
onRequest = (_, metric, me) =>
me.eval {
EndpointMetric()
.onEndpointRequest { _ => me.eval { val _ = metric.incrementAndGet(); } }
.onResponseBody { (_, _) => me.eval { val _ = metric.decrementAndGet(); } }
.onException { (_, _) => me.eval { val _ = metric.decrementAndGet(); } }
}
),
Metric(
metric = totalRequests,
onRequest = (_, metric, me) => me.eval(EndpointMetric().onEndpointRequest { _ => me.eval { val _ = metric.incrementAndGet(); } })
)
)

val config =
NettyConfig.default
.eventLoopGroup(eventLoopGroup)
.randomPort
.withDontShutdownEventLoopGroupOnClose
.noGracefulShutdown
.requestTimeout(1.second)
val options = NettyFutureServerOptions.customiseInterceptors
.metricsInterceptor(new MetricsRequestInterceptor[Future](customMetrics, Seq.empty))
.options
val bind = IO.fromFuture(IO.delay(NettyFutureServer(options, config).addEndpoints(List(e)).start()))

Resource
.make(bind)(server => IO.fromFuture(IO.delay(server.stop())))
.map(_.port)
.use { port =>
basicRequest.post(uri"http://localhost:$port").body("test").send(backend).map { response =>
response.body should matchPattern { case Left(_) => }
response.code shouldBe StatusCode.ServiceUnavailable
// the metrics will only be updated when the endpoint's logic completes, which is 1 second after receiving the timeout response
Thread.sleep(1100)
activeRequests.get() shouldBe 0
totalRequests.get() shouldBe 1
}
}
.unsafeToFuture()
}
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ class NettyFutureServerTest extends TestSuite with EitherValues {

val tests =
new AllServerTests(createServerTest, interpreter, backend, multipart = false).tests() ++
new ServerGracefulShutdownTests(createServerTest, Sleeper.futureSleeper).tests()
new ServerGracefulShutdownTests(createServerTest, Sleeper.futureSleeper).tests() ++
new NettyFutureRequestTimeoutTests(eventLoopGroup, backend).tests()

(tests, eventLoopGroup)
}) { case (_, eventLoopGroup) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,12 @@ object NettySyncServerOptions:
doLogWhenReceived = debugLog(_, None),
doLogWhenHandled = debugLog,
doLogAllDecodeFailures = debugLog,
doLogExceptions = (msg: String, ex: Throwable) => log.error(msg, ex),
doLogExceptions = (msg: String, e: Throwable) =>
e match
// if server logic is interrupted (e.g. due to timeout), this isn't an error, but might still be useful for debugging,
// to know how far processing got
case _: InterruptedException => log.debug(msg, e)
case _ => log.error(msg, e),
noLog = ()
)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package sttp.tapir.server.netty

import cats.effect.IO
import cats.effect.unsafe.implicits.global
import io.netty.channel.EventLoopGroup
import org.scalatest.matchers.should.Matchers.*
import ox.*
import sttp.capabilities.WebSockets
import sttp.capabilities.fs2.Fs2Streams
import sttp.client3.*
import sttp.model.StatusCode
import sttp.tapir.*
import sttp.tapir.server.interceptor.metrics.MetricsRequestInterceptor
import sttp.tapir.server.metrics.{EndpointMetric, Metric}
import sttp.tapir.server.netty.sync.{NettySyncServer, NettySyncServerOptions}
import sttp.tapir.tests.Test

import java.util.concurrent.atomic.AtomicInteger
import scala.concurrent.Future
import scala.concurrent.duration.DurationInt
import org.slf4j.LoggerFactory

class NettySyncRequestTimeoutTests(eventLoopGroup: EventLoopGroup, backend: SttpBackend[IO, Fs2Streams[IO] with WebSockets]):
val logger = LoggerFactory.getLogger(getClass.getName)

def tests(): List[Test] = List(
Test("properly update metrics when a request times out") {
val e = endpoint.post
.in(stringBody)
.out(stringBody)
.serverLogicSuccess[Identity]: body =>
Thread.sleep(2000)
body

val activeRequests = new AtomicInteger()
val totalRequests = new AtomicInteger()
val customMetrics: List[Metric[Identity, AtomicInteger]] = List(
Metric(
metric = activeRequests,
onRequest = (_, metric, me) =>
me.eval:
EndpointMetric()
.onEndpointRequest: _ =>
val _ = metric.incrementAndGet();
(): Identity[Unit]
.onResponseBody: (_, _) =>
val _ = metric.decrementAndGet();
.onException: (_, _) =>
val _ = metric.decrementAndGet();
),
Metric(
metric = totalRequests,
onRequest = (_, metric, me) =>
me.eval(EndpointMetric().onEndpointRequest: _ =>
val _ = metric.incrementAndGet();
)
)
)

val config =
NettyConfig.default
.eventLoopGroup(eventLoopGroup)
.randomPort
.withDontShutdownEventLoopGroupOnClose
.noGracefulShutdown
.requestTimeout(1.second)
val options = NettySyncServerOptions.customiseInterceptors
.metricsInterceptor(new MetricsRequestInterceptor[Identity](customMetrics, Seq.empty))
.options

Future.successful:
supervised:
val port = useInScope(NettySyncServer(options, config).addEndpoint(e).start())(_.stop()).port
basicRequest
.post(uri"http://localhost:$port")
.body("test")
.send(backend)
.map: response =>
response.body should matchPattern { case Left(_) => }
response.code shouldBe StatusCode.ServiceUnavailable
// unlike in NettyFutureRequestTimeoutTest, here interruption works properly, and the metrics should be updated quickly
Thread.sleep(100)
activeRequests.get() shouldBe 0
totalRequests.get() shouldBe 1
.unsafeRunSync()
}
)
Loading

0 comments on commit 99b8861

Please sign in to comment.