From 02f9cbb5555acaf3ca88ebd7525ef1f62f70a113 Mon Sep 17 00:00:00 2001 From: vbergeron Date: Thu, 21 Mar 2024 22:50:10 +0100 Subject: [PATCH] TCP Pipelining: PerparedQuery Bind+Exec prefetch --- .../shared/src/main/scala/PreparedQuery.scala | 6 +- .../shared/src/main/scala/net/Protocol.scala | 2 +- .../main/scala/net/protocol/BindExecute.scala | 75 ++++++++++++++++--- .../src/main/scala/net/protocol/Execute.scala | 5 +- .../src/main/scala/net/protocol/Prepare.scala | 23 +++--- .../src/main/scala/net/protocol/Unroll.scala | 19 +++++ 6 files changed, 102 insertions(+), 28 deletions(-) diff --git a/modules/core/shared/src/main/scala/PreparedQuery.scala b/modules/core/shared/src/main/scala/PreparedQuery.scala index 82aae713..eeee6483 100644 --- a/modules/core/shared/src/main/scala/PreparedQuery.scala +++ b/modules/core/shared/src/main/scala/PreparedQuery.scala @@ -63,7 +63,7 @@ object PreparedQuery { new PreparedQuery[F, A, B] { override def cursor(args: A)(implicit or: Origin): Resource[F, Cursor[F, B]] = - proto.bind(args, or).map { p => + proto.bind(args, or, None).map { p => new Cursor[F, B] { override def fetch(maxRows: Int): F[(List[B], Boolean)] = p.execute(maxRows) @@ -71,7 +71,7 @@ object PreparedQuery { } override def stream(args: A, chunkSize: Int)(implicit or: Origin): Stream[F, B] = - Stream.resource(proto.bind(args, or)).flatMap { cursor => + Stream.resource(proto.bind(args, or, Some(chunkSize))).flatMap { cursor => def chunks: Stream[F, B] = Stream.eval(cursor.execute(chunkSize)).flatMap { case (bs, more) => val s = Stream.chunk(Chunk.from(bs)) @@ -84,7 +84,7 @@ object PreparedQuery { // We have a few operations that only want the first row. In order to do this AND // know if there are more we need to ask for 2 rows. private def fetch2(args: A)(implicit or: Origin): F[(List[B], Boolean)] = - cursor(args).use(_.fetch(2)) + proto.bind(args, or, Some(2)).use(_.execute(2)) override def option(args: A)(implicit or: Origin): F[Option[B]] = fetch2(args).flatMap { case (bs, _) => diff --git a/modules/core/shared/src/main/scala/net/Protocol.scala b/modules/core/shared/src/main/scala/net/Protocol.scala index afcff04a..d6f4cc7e 100644 --- a/modules/core/shared/src/main/scala/net/Protocol.scala +++ b/modules/core/shared/src/main/scala/net/Protocol.scala @@ -163,7 +163,7 @@ object Protocol { val rowDescription: TypedRowDescription ) extends PreparedStatement[F, A] { def statement: Statement[A] = query - def bind(args: A, argsOrigin: Origin): Resource[F, QueryPortal[F, A, B]] + def bind(args: A, argsOrigin: Origin, maxRows: Option[Int]): Resource[F, QueryPortal[F, A, B]] } /** diff --git a/modules/core/shared/src/main/scala/net/protocol/BindExecute.scala b/modules/core/shared/src/main/scala/net/protocol/BindExecute.scala index 7ed09d11..092b1554 100644 --- a/modules/core/shared/src/main/scala/net/protocol/BindExecute.scala +++ b/modules/core/shared/src/main/scala/net/protocol/BindExecute.scala @@ -6,9 +6,10 @@ package skunk.net.protocol import cats.effect.Resource import cats.syntax.all._ -import cats.MonadError +import cats.effect.Concurrent +import skunk.~ import skunk.exception._ -import skunk.net.message.{ Bind => BindMessage, Close => _, _ } +import skunk.net.message.{ Bind => BindMessage, Execute => ExecuteMessage, Close => _, _ } import skunk.net.MessageSocket import skunk.net.Protocol.PortalId import skunk.util.{ Origin, Namer } @@ -18,6 +19,7 @@ import skunk.RedactionStrategy import skunk.net.Protocol import skunk.data.Completion import skunk.net.protocol.exchange +import cats.effect.kernel.Deferred trait BindExecute[F[_]] { @@ -28,21 +30,28 @@ trait BindExecute[F[_]] { redactionStrategy: RedactionStrategy ): Resource[F, Protocol.CommandPortal[F, A]] + def query[A, B]( + statement: Protocol.PreparedQuery[F, A, B], + args: A, + argsOrigin: Origin, + redactionStrategy: RedactionStrategy, + initialSize: Int + ): Resource[F, Protocol.QueryPortal[F, A, B]] } object BindExecute { def apply[F[_]: Exchange: MessageSocket: Namer: Tracer]( - implicit ev: MonadError[F, Throwable] + implicit ev: Concurrent[F] ): BindExecute[F] = - new BindExecute[F] { - - def command[A]( - statement: Protocol.PreparedCommand[F, A], - args: A, + new Unroll[F] with BindExecute[F] { + + def bindExchange[A]( + statement: Protocol.PreparedStatement[F, A], + args: A, argsOrigin: Origin, redactionStrategy: RedactionStrategy - ): Resource[F, Protocol.CommandPortal[F, A]] = { + ):(Span[F] => F[PortalId], F[Unit]) = { val ea = statement.statement.encoder.encode(args) // encoded args def preBind(span: Span[F]): F[PortalId] = for { @@ -71,8 +80,17 @@ object BindExecute { ) } yield a } + (preBind, postBind) + } + + def command[A]( + statement: Protocol.PreparedCommand[F, A], + args: A, + argsOrigin: Origin, + redactionStrategy: RedactionStrategy + ): Resource[F, Protocol.CommandPortal[F, A]] = { - def preExec(portal: PortalId):F[Unit] = send(Execute(portal.value, 0)) + val (preBind, postBind) = bindExchange(statement, args, argsOrigin, redactionStrategy) val postExec: F[Completion] = flatExpect { case CommandComplete(c) => send(Sync) *> expect { case ReadyForQuery(_) => c } // https://github.com/tpolecat/skunk/issues/210 @@ -118,7 +136,7 @@ object BindExecute { exchange("bind+execute"){ (span: Span[F]) => for { pn <- preBind(span) - _ <- preExec(pn) + _ <- send(ExecuteMessage(pn.value, 0)) _ <- send(Flush) _ <- postBind c <- postExec @@ -128,8 +146,41 @@ object BindExecute { } } { portal => Close[F].apply(portal.id)} - } } + def query[A, B]( + statement: Protocol.PreparedQuery[F, A, B], + args: A, + argsOrigin: Origin, + redactionStrategy: RedactionStrategy, + initialSize: Int + ): Resource[F, Protocol.QueryPortal[F, A, B]] = { + val (preBind, postBind) = bindExchange(statement, args, argsOrigin, redactionStrategy) + Resource.eval(Deferred[F, Unit]).flatMap { prefetch => + Resource.make { + exchange("bind+execute"){ (span: Span[F]) => + for { + pn <- preBind(span) + _ <- span.addAttributes( + Attribute("max-rows", initialSize.toLong), + Attribute("portal-id", pn.value) + ) + _ <- send(ExecuteMessage(pn.value, initialSize)) + _ <- send(Flush) + rs <- unroll(statement, args, argsOrigin, redactionStrategy) + _ <- postBind + } yield new Protocol.QueryPortal[F, A, B](pn, statement, args, argsOrigin, redactionStrategy) { + def execute(maxRows: Int): F[List[B] ~ Boolean] = + prefetch.tryGet.flatMap { + case None => rs.pure <* prefetch.complete(()) + case Some(()) => Execute[F].apply(this, maxRows) + } + } + } + } { portal => Close[F].apply(portal.id)} + } + } + } + } diff --git a/modules/core/shared/src/main/scala/net/protocol/Execute.scala b/modules/core/shared/src/main/scala/net/protocol/Execute.scala index a3613d22..c4c67c76 100644 --- a/modules/core/shared/src/main/scala/net/protocol/Execute.scala +++ b/modules/core/shared/src/main/scala/net/protocol/Execute.scala @@ -9,13 +9,12 @@ import cats.MonadError import skunk.~ import skunk.net.{ Protocol, MessageSocket } import skunk.net.message.{ Execute => ExecuteMessage, _ } -import skunk.util.Typer import org.typelevel.otel4s.Attribute import org.typelevel.otel4s.trace.Span import org.typelevel.otel4s.trace.Tracer trait Execute[F[_]] { - def apply[A, B](portal: Protocol.QueryPortal[F, A, B], maxRows: Int, ty: Typer): F[List[B] ~ Boolean] + def apply[A, B](portal: Protocol.QueryPortal[F, A, B], maxRows: Int): F[List[B] ~ Boolean] } object Execute { @@ -25,7 +24,7 @@ object Execute { ): Execute[F] = new Unroll[F] with Execute[F] { - override def apply[A, B](portal: Protocol.QueryPortal[F, A, B], maxRows: Int, ty: Typer): F[List[B] ~ Boolean] = + override def apply[A, B](portal: Protocol.QueryPortal[F, A, B], maxRows: Int): F[List[B] ~ Boolean] = exchange("execute") { (span: Span[F]) => for { _ <- span.addAttributes( diff --git a/modules/core/shared/src/main/scala/net/protocol/Prepare.scala b/modules/core/shared/src/main/scala/net/protocol/Prepare.scala index 5206008a..3f6fb197 100644 --- a/modules/core/shared/src/main/scala/net/protocol/Prepare.scala +++ b/modules/core/shared/src/main/scala/net/protocol/Prepare.scala @@ -5,6 +5,7 @@ package skunk.net.protocol import cats.effect.Resource +import cats.effect.Concurrent import cats.syntax.functor._ import skunk.~ import skunk.RedactionStrategy @@ -13,7 +14,6 @@ import skunk.net.Protocol.{ PreparedCommand, PreparedQuery, CommandPortal, Query import skunk.util.{ Origin, Namer } import skunk.util.Typer import org.typelevel.otel4s.trace.Tracer -import cats.effect.kernel.MonadCancel trait Prepare[F[_]] { def apply[A](command: skunk.Command[A], ty: Typer): F[PreparedCommand[F, A]] @@ -23,7 +23,7 @@ trait Prepare[F[_]] { object Prepare { def apply[F[_]: Exchange: MessageSocket: Namer: Tracer](describeCache: Describe.Cache[F], parseCache: Parse.Cache[F], redactionStrategy: RedactionStrategy)( - implicit ev: MonadCancel[F, Throwable] + implicit ev: Concurrent[F] ): Prepare[F] = new Prepare[F] { @@ -38,13 +38,18 @@ object Prepare { override def apply[A, B](query: skunk.Query[A, B], ty: Typer): F[PreparedQuery[F, A, B]] = ParseDescribe[F](describeCache, parseCache).apply(query, ty).map { case (id, rd) => new PreparedQuery[F, A, B](id, query, rd) { pq => - def bind(args: A, origin: Origin): Resource[F, QueryPortal[F, A, B]] = - Bind[F].apply(this, args, origin, redactionStrategy).map { - new QueryPortal[F, A, B](_, pq, args, origin, redactionStrategy) { - def execute(maxRows: Int): F[List[B] ~ Boolean] = - Execute[F].apply(this, maxRows, ty) - } - } + def bind(args: A, origin: Origin, maxRows: Option[Int]): Resource[F, QueryPortal[F, A, B]] = + maxRows match { + case None => + Bind[F].apply(this, args, origin, redactionStrategy).map { + new QueryPortal[F, A, B](_, pq, args, origin, redactionStrategy) { + def execute(maxRows: Int): F[List[B] ~ Boolean] = + Execute[F].apply(this, maxRows) + } + } + case Some(initialSize) => + BindExecute[F].query(this, args, origin, redactionStrategy, initialSize) + } } } diff --git a/modules/core/shared/src/main/scala/net/protocol/Unroll.scala b/modules/core/shared/src/main/scala/net/protocol/Unroll.scala index f7da42b1..c8d26766 100644 --- a/modules/core/shared/src/main/scala/net/protocol/Unroll.scala +++ b/modules/core/shared/src/main/scala/net/protocol/Unroll.scala @@ -11,6 +11,7 @@ import skunk.exception.DecodeException import skunk.net.message._ import skunk.net.MessageSocket import skunk.net.Protocol.QueryPortal +import skunk.net.Protocol.PreparedQuery import skunk.util.Origin import skunk.data.TypedRowDescription import org.typelevel.otel4s.Attribute @@ -41,6 +42,24 @@ private[protocol] class Unroll[F[_]: MessageSocket: Tracer]( decoder = portal.preparedQuery.query.decoder, redactionStrategy = portal.redactionStrategy ) + + def unroll[A, B]( + preparedQuery: PreparedQuery[F, A, B], + arguments: A, + argumentsOrigin: Origin, + redactionStrategy: RedactionStrategy + ): F[(List[B], Boolean)] = + unroll( + extended = true, + sql = preparedQuery.query.sql, + sqlOrigin = preparedQuery.query.origin, + args = arguments, + argsOrigin = Some(argumentsOrigin), + encoder = preparedQuery.query.encoder, + rowDescription = preparedQuery.rowDescription, + decoder = preparedQuery.query.decoder, + redactionStrategy = redactionStrategy + ) // When we do a quick query there's no statement to hang onto all the error-reporting context // so we have to pass everything in manually.