From 99b88616036672fc645a7d2dfd06002d269d4cfd Mon Sep 17 00:00:00 2001 From: Adam Warski Date: Tue, 12 Nov 2024 16:44:39 +0100 Subject: [PATCH] Fix cleanup in Netty handler after a request timeout (#4156) --- .scalafix.conf | 3 +- build.sbt | 2 + project/Versions.scala | 2 +- .../metrics/MetricsEndpointInterceptor.scala | 8 +- .../sttp/tapir/server/netty/NettyConfig.scala | 4 +- .../netty/internal/NettyServerHandler.scala | 100 +++++++++++++----- .../NettyFutureRequestTimeoutTests.scala | 81 ++++++++++++++ .../server/netty/NettyFutureServerTest.scala | 3 +- .../netty/sync/NettySyncServerOptions.scala | 7 +- .../sync/NettySyncRequestTimeoutTests.scala | 87 +++++++++++++++ .../netty/sync/NettySyncServerTest.scala | 4 +- 11 files changed, 265 insertions(+), 36 deletions(-) create mode 100644 server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyFutureRequestTimeoutTests.scala create mode 100644 server/netty-server/sync/src/test/scala/sttp/tapir/server/netty/sync/NettySyncRequestTimeoutTests.scala diff --git a/.scalafix.conf b/.scalafix.conf index 25e991dbe3..40c578c6f6 100644 --- a/.scalafix.conf +++ b/.scalafix.conf @@ -1,2 +1,3 @@ OrganizeImports.groupedImports = AggressiveMerge -OrganizeImports.targetDialect = Scala3 \ No newline at end of file +OrganizeImports.targetDialect = Scala3 +OrganizeImports.removeUnused = false \ No newline at end of file diff --git a/build.sbt b/build.sbt index 30215bfcb6..6c5c0eeb3f 100644 --- a/build.sbt +++ b/build.sbt @@ -85,6 +85,8 @@ val commonSettings = commonSmlBuildSettings ++ ossPublishSettings ++ Seq( case _ => Seq("-Xmax-inlines", "64") } }, + Test / scalacOptions += "-Wconf:msg=unused value of type org.scalatest.Assertion:s", + Test / scalacOptions += "-Wconf:msg=unused value of type org.scalatest.compatible.Assertion:s", evictionErrorLevel := Level.Info ) diff --git a/project/Versions.scala b/project/Versions.scala index 9f6ebd0ee7..d1a8e9e852 100644 --- a/project/Versions.scala +++ b/project/Versions.scala @@ -10,7 +10,7 @@ object Versions { val helidon = "4.0.10" val sttp = "3.10.1" val sttpModel = "1.7.11" - val sttpShared = "1.3.22" + val sttpShared = "1.4.0" val sttpApispec = "0.11.3" val akkaHttp = "10.2.10" val akkaStreams = "2.6.20" diff --git a/server/core/src/main/scala/sttp/tapir/server/interceptor/metrics/MetricsEndpointInterceptor.scala b/server/core/src/main/scala/sttp/tapir/server/interceptor/metrics/MetricsEndpointInterceptor.scala index 57f8fc047d..069232e153 100644 --- a/server/core/src/main/scala/sttp/tapir/server/interceptor/metrics/MetricsEndpointInterceptor.scala +++ b/server/core/src/main/scala/sttp/tapir/server/interceptor/metrics/MetricsEndpointInterceptor.scala @@ -47,7 +47,7 @@ private[metrics] class MetricsEndpointInterceptor[F[_]]( )(implicit monad: MonadError[F], bodyListener: BodyListener[F, B]): F[ServerResponse[B]] = { if (ignoreEndpoints.contains(ctx.endpoint)) endpointHandler.onDecodeSuccess(ctx) else { - val responseWithMetrics: F[ServerResponse[B]] = for { + def responseWithMetrics: F[ServerResponse[B]] = for { _ <- collectRequestMetrics(ctx.endpoint) response <- endpointHandler.onDecodeSuccess(ctx) _ <- collectResponseHeadersMetrics(ctx.endpoint, response) @@ -64,7 +64,7 @@ private[metrics] class MetricsEndpointInterceptor[F[_]]( )(implicit monad: MonadError[F], bodyListener: BodyListener[F, B]): F[ServerResponse[B]] = { if (ignoreEndpoints.contains(ctx.endpoint)) endpointHandler.onSecurityFailure(ctx) else { - val responseWithMetrics: F[ServerResponse[B]] = for { + def responseWithMetrics: F[ServerResponse[B]] = for { _ <- collectRequestMetrics(ctx.endpoint) response <- endpointHandler.onSecurityFailure(ctx) _ <- collectResponseHeadersMetrics(ctx.endpoint, response) @@ -83,7 +83,7 @@ private[metrics] class MetricsEndpointInterceptor[F[_]]( )(implicit monad: MonadError[F], bodyListener: BodyListener[F, B]): F[Option[ServerResponse[B]]] = { if (ignoreEndpoints.contains(ctx.endpoint)) endpointHandler.onDecodeFailure(ctx) else { - val responseWithMetrics: F[Option[ServerResponse[B]]] = for { + def responseWithMetrics: F[Option[ServerResponse[B]]] = for { response <- endpointHandler.onDecodeFailure(ctx) withMetrics <- response match { case Some(response) => @@ -129,7 +129,7 @@ private[metrics] class MetricsEndpointInterceptor[F[_]]( } } - private def handleResponseExceptions[T](r: F[T], e: AnyEndpoint)(implicit monad: MonadError[F]): F[T] = + private def handleResponseExceptions[T](r: => F[T], e: AnyEndpoint)(implicit monad: MonadError[F]): F[T] = r.handleError { case ex: Exception => collectExceptionMetrics(e, ex) } private def collectExceptionMetrics[T](e: AnyEndpoint, ex: Throwable)(implicit monad: MonadError[F]): F[T] = diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala index dc799c362a..c8eee8395b 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala @@ -104,8 +104,8 @@ case class NettyConfig( def initPipeline(f: NettyConfig => (ChannelPipeline, ChannelHandler) => Unit): NettyConfig = copy(initPipeline = f) - def withGracefulShutdownTimeout(t: FiniteDuration) = copy(gracefulShutdownTimeout = Some(t)) - def noGracefulShutdown = copy(gracefulShutdownTimeout = None) + def withGracefulShutdownTimeout(t: FiniteDuration): NettyConfig = copy(gracefulShutdownTimeout = Some(t)) + def noGracefulShutdown: NettyConfig = copy(gracefulShutdownTimeout = None) def serverHeader(h: String): NettyConfig = copy(serverHeader = Some(h)) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala index 6183072e0f..40ffe282dc 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala @@ -30,6 +30,9 @@ import scala.collection.mutable.{Queue => MutableQueue} import scala.concurrent.{ExecutionContext, Future} import scala.util.control.NonFatal import scala.util.{Failure, Success} +import java.util.concurrent.TimeoutException +import org.reactivestreams.Subscriber +import org.reactivestreams.Subscription /** @param unsafeRunAsync * Function which dispatches given effect to run asynchronously, returning its result as a Future, and function of type `() => @@ -109,11 +112,13 @@ class NettyServerHandler[F[_]]( evt match { case e: IdleStateEvent => if (e.state() == IdleState.WRITER_IDLE) { - logger.error(s"Closing connection due to exceeded response timeout of ${config.requestTimeout}") + logger.error( + s"Closing connection due to exceeded response timeout of ${config.requestTimeout.map(_.toString).getOrElse("(not set)")}" + ) writeError503ThenClose(ctx) } if (e.state() == IdleState.ALL_IDLE) { - logger.debug(s"Closing connection due to exceeded idle timeout of ${config.idleTimeout}") + logger.debug(s"Closing connection due to exceeded idle timeout of ${config.idleTimeout.map(_.toString).getOrElse("(not set)")}") val _ = ctx.close() } case other => @@ -147,30 +152,42 @@ class NettyServerHandler[F[_]]( pendingResponses.enqueue(cancellationSwitch) lastResponseSent = lastResponseSent.flatMap { _ => runningFuture - .andThen { case _ => - requestTimeoutHandler.foreach(ctx.pipeline().remove) - }(eventLoopContext) - .transform { - case Success(serverResponse) => - pendingResponses.dequeue() - try { - handleResponse(ctx, req, serverResponse) - Success(()) - } catch { - case NonFatal(ex) => - writeError500(req, ex) - Failure(ex) - } finally { - val _ = releaseReq() - } - case Failure(NonFatal(ex)) => - try { - writeError500(req, ex) - Failure(ex) - } finally { - val _ = releaseReq() + .transform { result => + try { + // #4131: the channel might be closed if the request timed out + // both timeout & response-ready events (i.e., comleting this future) are handled on the event loop's executor, + // so they won't be handled concurrently + if (ctx.channel().isOpen()) { + requestTimeoutHandler.foreach(ctx.pipeline().remove) + result match { + case Success(serverResponse) => + pendingResponses.dequeue() + try { + handleResponse(ctx, req, serverResponse) + Success(()) + } catch { + case NonFatal(ex) => + writeError500(req, ex) + Failure(ex) + } + case Failure(NonFatal(ex)) => + writeError500(req, ex) + Failure(ex) + case Failure(fatalException) => Failure(fatalException) + } + } else { + // pendingResponses is already dequeued because the channel is closed + result match { + case Success(serverResponse) => + val e = new TimeoutException("Request timed out") + handleResponseAfterTimeout(ctx, serverResponse, e) + Failure(e) + case Failure(e) => Failure(e) + } } - case Failure(fatalException) => Failure(fatalException) + } finally { + val _ = releaseReq() + } }(eventLoopContext) }(eventLoopContext) } @@ -270,6 +287,39 @@ class NettyServerHandler[F[_]]( } ) + private def handleResponseAfterTimeout( + ctx: ChannelHandlerContext, + serverResponse: ServerResponse[NettyResponse], + timeoutException: Exception + ): Unit = + serverResponse.handle( + ctx = ctx, + byteBufHandler = (channelPromise, byteBuf) => { val _ = channelPromise.setFailure(timeoutException) }, + chunkedStreamHandler = (channelPromise, chunkedStream) => { + chunkedStream.close() + val _ = channelPromise.setFailure(timeoutException) + }, + chunkedFileHandler = (channelPromise, chunkedFile) => { + chunkedFile.close() + val _ = channelPromise.setFailure(timeoutException) + }, + reactiveStreamHandler = (channelPromise, publisher) => { + publisher.subscribe(new Subscriber[HttpContent] { + override def onSubscribe(s: Subscription): Unit = { + s.cancel() + val _ = channelPromise.setFailure(timeoutException) + } + override def onNext(t: HttpContent): Unit = () + override def onError(t: Throwable): Unit = () + override def onComplete(): Unit = () + }) + }, + wsHandler = (responseContent) => { + val _ = responseContent.channelPromise.setFailure(timeoutException) + }, + noBodyHandler = () => () + ) + private def initWsPipeline( ctx: ChannelHandlerContext, r: ReactiveWebSocketProcessorNettyResponseContent, diff --git a/server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyFutureRequestTimeoutTests.scala b/server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyFutureRequestTimeoutTests.scala new file mode 100644 index 0000000000..d65308176c --- /dev/null +++ b/server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyFutureRequestTimeoutTests.scala @@ -0,0 +1,81 @@ +package sttp.tapir.server.netty + +import sttp.tapir._ +import sttp.tapir.tests.Test +import scala.concurrent.Future +import java.util.concurrent.atomic.AtomicInteger +import scala.concurrent.duration.DurationInt +import sttp.tapir.server.interceptor.metrics.MetricsRequestInterceptor +import sttp.tapir.server.metrics.Metric +import sttp.tapir.server.metrics.EndpointMetric +import io.netty.channel.EventLoopGroup +import cats.effect.IO +import cats.effect.kernel.Resource +import scala.concurrent.ExecutionContext +import sttp.client3._ +import sttp.capabilities.fs2.Fs2Streams +import sttp.capabilities.WebSockets +import org.scalatest.matchers.should.Matchers._ +import cats.effect.unsafe.implicits.global +import sttp.model.StatusCode + +class NettyFutureRequestTimeoutTests(eventLoopGroup: EventLoopGroup, backend: SttpBackend[IO, Fs2Streams[IO] with WebSockets])(implicit + ec: ExecutionContext +) { + def tests(): List[Test] = List( + Test("properly update metrics when a request times out") { + val e = endpoint.post + .in(stringBody) + .out(stringBody) + .serverLogicSuccess[Future] { body => + Thread.sleep(2000); Future.successful(body) + } + + val activeRequests = new AtomicInteger() + val totalRequests = new AtomicInteger() + val customMetrics: List[Metric[Future, AtomicInteger]] = List( + Metric( + metric = activeRequests, + onRequest = (_, metric, me) => + me.eval { + EndpointMetric() + .onEndpointRequest { _ => me.eval { val _ = metric.incrementAndGet(); } } + .onResponseBody { (_, _) => me.eval { val _ = metric.decrementAndGet(); } } + .onException { (_, _) => me.eval { val _ = metric.decrementAndGet(); } } + } + ), + Metric( + metric = totalRequests, + onRequest = (_, metric, me) => me.eval(EndpointMetric().onEndpointRequest { _ => me.eval { val _ = metric.incrementAndGet(); } }) + ) + ) + + val config = + NettyConfig.default + .eventLoopGroup(eventLoopGroup) + .randomPort + .withDontShutdownEventLoopGroupOnClose + .noGracefulShutdown + .requestTimeout(1.second) + val options = NettyFutureServerOptions.customiseInterceptors + .metricsInterceptor(new MetricsRequestInterceptor[Future](customMetrics, Seq.empty)) + .options + val bind = IO.fromFuture(IO.delay(NettyFutureServer(options, config).addEndpoints(List(e)).start())) + + Resource + .make(bind)(server => IO.fromFuture(IO.delay(server.stop()))) + .map(_.port) + .use { port => + basicRequest.post(uri"http://localhost:$port").body("test").send(backend).map { response => + response.body should matchPattern { case Left(_) => } + response.code shouldBe StatusCode.ServiceUnavailable + // the metrics will only be updated when the endpoint's logic completes, which is 1 second after receiving the timeout response + Thread.sleep(1100) + activeRequests.get() shouldBe 0 + totalRequests.get() shouldBe 1 + } + } + .unsafeToFuture() + } + ) +} diff --git a/server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyFutureServerTest.scala b/server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyFutureServerTest.scala index f8f263706e..401940ff69 100644 --- a/server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyFutureServerTest.scala +++ b/server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyFutureServerTest.scala @@ -23,7 +23,8 @@ class NettyFutureServerTest extends TestSuite with EitherValues { val tests = new AllServerTests(createServerTest, interpreter, backend, multipart = false).tests() ++ - new ServerGracefulShutdownTests(createServerTest, Sleeper.futureSleeper).tests() + new ServerGracefulShutdownTests(createServerTest, Sleeper.futureSleeper).tests() ++ + new NettyFutureRequestTimeoutTests(eventLoopGroup, backend).tests() (tests, eventLoopGroup) }) { case (_, eventLoopGroup) => diff --git a/server/netty-server/sync/src/main/scala/sttp/tapir/server/netty/sync/NettySyncServerOptions.scala b/server/netty-server/sync/src/main/scala/sttp/tapir/server/netty/sync/NettySyncServerOptions.scala index 097c20b24e..1d69fa3e11 100644 --- a/server/netty-server/sync/src/main/scala/sttp/tapir/server/netty/sync/NettySyncServerOptions.scala +++ b/server/netty-server/sync/src/main/scala/sttp/tapir/server/netty/sync/NettySyncServerOptions.scala @@ -47,7 +47,12 @@ object NettySyncServerOptions: doLogWhenReceived = debugLog(_, None), doLogWhenHandled = debugLog, doLogAllDecodeFailures = debugLog, - doLogExceptions = (msg: String, ex: Throwable) => log.error(msg, ex), + doLogExceptions = (msg: String, e: Throwable) => + e match + // if server logic is interrupted (e.g. due to timeout), this isn't an error, but might still be useful for debugging, + // to know how far processing got + case _: InterruptedException => log.debug(msg, e) + case _ => log.error(msg, e), noLog = () ) diff --git a/server/netty-server/sync/src/test/scala/sttp/tapir/server/netty/sync/NettySyncRequestTimeoutTests.scala b/server/netty-server/sync/src/test/scala/sttp/tapir/server/netty/sync/NettySyncRequestTimeoutTests.scala new file mode 100644 index 0000000000..65769d9ba6 --- /dev/null +++ b/server/netty-server/sync/src/test/scala/sttp/tapir/server/netty/sync/NettySyncRequestTimeoutTests.scala @@ -0,0 +1,87 @@ +package sttp.tapir.server.netty + +import cats.effect.IO +import cats.effect.unsafe.implicits.global +import io.netty.channel.EventLoopGroup +import org.scalatest.matchers.should.Matchers.* +import ox.* +import sttp.capabilities.WebSockets +import sttp.capabilities.fs2.Fs2Streams +import sttp.client3.* +import sttp.model.StatusCode +import sttp.tapir.* +import sttp.tapir.server.interceptor.metrics.MetricsRequestInterceptor +import sttp.tapir.server.metrics.{EndpointMetric, Metric} +import sttp.tapir.server.netty.sync.{NettySyncServer, NettySyncServerOptions} +import sttp.tapir.tests.Test + +import java.util.concurrent.atomic.AtomicInteger +import scala.concurrent.Future +import scala.concurrent.duration.DurationInt +import org.slf4j.LoggerFactory + +class NettySyncRequestTimeoutTests(eventLoopGroup: EventLoopGroup, backend: SttpBackend[IO, Fs2Streams[IO] with WebSockets]): + val logger = LoggerFactory.getLogger(getClass.getName) + + def tests(): List[Test] = List( + Test("properly update metrics when a request times out") { + val e = endpoint.post + .in(stringBody) + .out(stringBody) + .serverLogicSuccess[Identity]: body => + Thread.sleep(2000) + body + + val activeRequests = new AtomicInteger() + val totalRequests = new AtomicInteger() + val customMetrics: List[Metric[Identity, AtomicInteger]] = List( + Metric( + metric = activeRequests, + onRequest = (_, metric, me) => + me.eval: + EndpointMetric() + .onEndpointRequest: _ => + val _ = metric.incrementAndGet(); + (): Identity[Unit] + .onResponseBody: (_, _) => + val _ = metric.decrementAndGet(); + .onException: (_, _) => + val _ = metric.decrementAndGet(); + ), + Metric( + metric = totalRequests, + onRequest = (_, metric, me) => + me.eval(EndpointMetric().onEndpointRequest: _ => + val _ = metric.incrementAndGet(); + ) + ) + ) + + val config = + NettyConfig.default + .eventLoopGroup(eventLoopGroup) + .randomPort + .withDontShutdownEventLoopGroupOnClose + .noGracefulShutdown + .requestTimeout(1.second) + val options = NettySyncServerOptions.customiseInterceptors + .metricsInterceptor(new MetricsRequestInterceptor[Identity](customMetrics, Seq.empty)) + .options + + Future.successful: + supervised: + val port = useInScope(NettySyncServer(options, config).addEndpoint(e).start())(_.stop()).port + basicRequest + .post(uri"http://localhost:$port") + .body("test") + .send(backend) + .map: response => + response.body should matchPattern { case Left(_) => } + response.code shouldBe StatusCode.ServiceUnavailable + // unlike in NettyFutureRequestTimeoutTest, here interruption works properly, and the metrics should be updated quickly + Thread.sleep(100) + activeRequests.get() shouldBe 0 + totalRequests.get() shouldBe 1 + .unsafeRunSync() + } + ) diff --git a/server/netty-server/sync/src/test/scala/sttp/tapir/server/netty/sync/NettySyncServerTest.scala b/server/netty-server/sync/src/test/scala/sttp/tapir/server/netty/sync/NettySyncServerTest.scala index ef3ff0295c..19ded1bbe7 100644 --- a/server/netty-server/sync/src/test/scala/sttp/tapir/server/netty/sync/NettySyncServerTest.scala +++ b/server/netty-server/sync/src/test/scala/sttp/tapir/server/netty/sync/NettySyncServerTest.scala @@ -25,6 +25,7 @@ import scala.concurrent.Future import scala.concurrent.duration.FiniteDuration import ox.flow.Flow import scala.annotation.nowarn +import sttp.tapir.server.netty.NettySyncRequestTimeoutTests class NettySyncServerTest extends AsyncFunSuite with BeforeAndAfterAll { @@ -44,7 +45,8 @@ class NettySyncServerTest extends AsyncFunSuite with BeforeAndAfterAll { new ServerWebSocketTests(createServerTest, OxStreams, autoPing = true, failingPipe = true, handlePong = true) { override def functionToPipe[A, B](f: A => B): OxStreams.Pipe[A, B] = _.map(f) override def emptyPipe[A, B]: OxStreams.Pipe[A, B] = _ => Flow.empty - }.tests() + }.tests() ++ + NettySyncRequestTimeoutTests(eventLoopGroup, backend).tests() tests.foreach { t => if (testNameFilter.forall(filter => t.name.contains(filter))) {