Skip to content

Commit

Permalink
TCP Pipelining: PerparedQuery Bind+Exec prefetch
Browse files Browse the repository at this point in the history
  • Loading branch information
vbergeron-ledger committed Mar 21, 2024
1 parent fa9f282 commit 02f9cbb
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 28 deletions.
6 changes: 3 additions & 3 deletions modules/core/shared/src/main/scala/PreparedQuery.scala
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,15 @@ object PreparedQuery {
new PreparedQuery[F, A, B] {

override def cursor(args: A)(implicit or: Origin): Resource[F, Cursor[F, B]] =
proto.bind(args, or).map { p =>
proto.bind(args, or, None).map { p =>
new Cursor[F, B] {
override def fetch(maxRows: Int): F[(List[B], Boolean)] =
p.execute(maxRows)
}
}

override def stream(args: A, chunkSize: Int)(implicit or: Origin): Stream[F, B] =
Stream.resource(proto.bind(args, or)).flatMap { cursor =>
Stream.resource(proto.bind(args, or, Some(chunkSize))).flatMap { cursor =>
def chunks: Stream[F, B] =
Stream.eval(cursor.execute(chunkSize)).flatMap { case (bs, more) =>
val s = Stream.chunk(Chunk.from(bs))
Expand All @@ -84,7 +84,7 @@ object PreparedQuery {
// We have a few operations that only want the first row. In order to do this AND
// know if there are more we need to ask for 2 rows.
private def fetch2(args: A)(implicit or: Origin): F[(List[B], Boolean)] =
cursor(args).use(_.fetch(2))
proto.bind(args, or, Some(2)).use(_.execute(2))

override def option(args: A)(implicit or: Origin): F[Option[B]] =
fetch2(args).flatMap { case (bs, _) =>
Expand Down
2 changes: 1 addition & 1 deletion modules/core/shared/src/main/scala/net/Protocol.scala
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ object Protocol {
val rowDescription: TypedRowDescription
) extends PreparedStatement[F, A] {
def statement: Statement[A] = query
def bind(args: A, argsOrigin: Origin): Resource[F, QueryPortal[F, A, B]]
def bind(args: A, argsOrigin: Origin, maxRows: Option[Int]): Resource[F, QueryPortal[F, A, B]]
}

/**
Expand Down
75 changes: 63 additions & 12 deletions modules/core/shared/src/main/scala/net/protocol/BindExecute.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ package skunk.net.protocol

import cats.effect.Resource
import cats.syntax.all._
import cats.MonadError
import cats.effect.Concurrent
import skunk.~
import skunk.exception._
import skunk.net.message.{ Bind => BindMessage, Close => _, _ }
import skunk.net.message.{ Bind => BindMessage, Execute => ExecuteMessage, Close => _, _ }
import skunk.net.MessageSocket
import skunk.net.Protocol.PortalId
import skunk.util.{ Origin, Namer }
Expand All @@ -18,6 +19,7 @@ import skunk.RedactionStrategy
import skunk.net.Protocol
import skunk.data.Completion
import skunk.net.protocol.exchange
import cats.effect.kernel.Deferred

trait BindExecute[F[_]] {

Expand All @@ -28,21 +30,28 @@ trait BindExecute[F[_]] {
redactionStrategy: RedactionStrategy
): Resource[F, Protocol.CommandPortal[F, A]]

def query[A, B](
statement: Protocol.PreparedQuery[F, A, B],
args: A,
argsOrigin: Origin,
redactionStrategy: RedactionStrategy,
initialSize: Int
): Resource[F, Protocol.QueryPortal[F, A, B]]
}

object BindExecute {

def apply[F[_]: Exchange: MessageSocket: Namer: Tracer](
implicit ev: MonadError[F, Throwable]
implicit ev: Concurrent[F]
): BindExecute[F] =
new BindExecute[F] {

def command[A](
statement: Protocol.PreparedCommand[F, A],
args: A,
new Unroll[F] with BindExecute[F] {
def bindExchange[A](
statement: Protocol.PreparedStatement[F, A],
args: A,
argsOrigin: Origin,
redactionStrategy: RedactionStrategy
): Resource[F, Protocol.CommandPortal[F, A]] = {
):(Span[F] => F[PortalId], F[Unit]) = {
val ea = statement.statement.encoder.encode(args) // encoded args

def preBind(span: Span[F]): F[PortalId] = for {
Expand Down Expand Up @@ -71,8 +80,17 @@ object BindExecute {
)
} yield a
}
(preBind, postBind)
}

def command[A](
statement: Protocol.PreparedCommand[F, A],
args: A,
argsOrigin: Origin,
redactionStrategy: RedactionStrategy
): Resource[F, Protocol.CommandPortal[F, A]] = {

def preExec(portal: PortalId):F[Unit] = send(Execute(portal.value, 0))
val (preBind, postBind) = bindExchange(statement, args, argsOrigin, redactionStrategy)

val postExec: F[Completion] = flatExpect {
case CommandComplete(c) => send(Sync) *> expect { case ReadyForQuery(_) => c } // https://github.com/tpolecat/skunk/issues/210
Expand Down Expand Up @@ -118,7 +136,7 @@ object BindExecute {
exchange("bind+execute"){ (span: Span[F]) =>
for {
pn <- preBind(span)
_ <- preExec(pn)
_ <- send(ExecuteMessage(pn.value, 0))
_ <- send(Flush)
_ <- postBind
c <- postExec
Expand All @@ -128,8 +146,41 @@ object BindExecute {
}
} { portal => Close[F].apply(portal.id)}

}
}

def query[A, B](
statement: Protocol.PreparedQuery[F, A, B],
args: A,
argsOrigin: Origin,
redactionStrategy: RedactionStrategy,
initialSize: Int
): Resource[F, Protocol.QueryPortal[F, A, B]] = {
val (preBind, postBind) = bindExchange(statement, args, argsOrigin, redactionStrategy)
Resource.eval(Deferred[F, Unit]).flatMap { prefetch =>
Resource.make {
exchange("bind+execute"){ (span: Span[F]) =>
for {
pn <- preBind(span)
_ <- span.addAttributes(
Attribute("max-rows", initialSize.toLong),
Attribute("portal-id", pn.value)
)
_ <- send(ExecuteMessage(pn.value, initialSize))
_ <- send(Flush)
rs <- unroll(statement, args, argsOrigin, redactionStrategy)
_ <- postBind
} yield new Protocol.QueryPortal[F, A, B](pn, statement, args, argsOrigin, redactionStrategy) {
def execute(maxRows: Int): F[List[B] ~ Boolean] =
prefetch.tryGet.flatMap {
case None => rs.pure <* prefetch.complete(())
case Some(()) => Execute[F].apply(this, maxRows)
}
}
}
} { portal => Close[F].apply(portal.id)}
}
}
}


}
5 changes: 2 additions & 3 deletions modules/core/shared/src/main/scala/net/protocol/Execute.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@ import cats.MonadError
import skunk.~
import skunk.net.{ Protocol, MessageSocket }
import skunk.net.message.{ Execute => ExecuteMessage, _ }
import skunk.util.Typer
import org.typelevel.otel4s.Attribute
import org.typelevel.otel4s.trace.Span
import org.typelevel.otel4s.trace.Tracer

trait Execute[F[_]] {
def apply[A, B](portal: Protocol.QueryPortal[F, A, B], maxRows: Int, ty: Typer): F[List[B] ~ Boolean]
def apply[A, B](portal: Protocol.QueryPortal[F, A, B], maxRows: Int): F[List[B] ~ Boolean]
}

object Execute {
Expand All @@ -25,7 +24,7 @@ object Execute {
): Execute[F] =
new Unroll[F] with Execute[F] {

override def apply[A, B](portal: Protocol.QueryPortal[F, A, B], maxRows: Int, ty: Typer): F[List[B] ~ Boolean] =
override def apply[A, B](portal: Protocol.QueryPortal[F, A, B], maxRows: Int): F[List[B] ~ Boolean] =
exchange("execute") { (span: Span[F]) =>
for {
_ <- span.addAttributes(
Expand Down
23 changes: 14 additions & 9 deletions modules/core/shared/src/main/scala/net/protocol/Prepare.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package skunk.net.protocol

import cats.effect.Resource
import cats.effect.Concurrent
import cats.syntax.functor._
import skunk.~
import skunk.RedactionStrategy
Expand All @@ -13,7 +14,6 @@ import skunk.net.Protocol.{ PreparedCommand, PreparedQuery, CommandPortal, Query
import skunk.util.{ Origin, Namer }
import skunk.util.Typer
import org.typelevel.otel4s.trace.Tracer
import cats.effect.kernel.MonadCancel

trait Prepare[F[_]] {
def apply[A](command: skunk.Command[A], ty: Typer): F[PreparedCommand[F, A]]
Expand All @@ -23,7 +23,7 @@ trait Prepare[F[_]] {
object Prepare {

def apply[F[_]: Exchange: MessageSocket: Namer: Tracer](describeCache: Describe.Cache[F], parseCache: Parse.Cache[F], redactionStrategy: RedactionStrategy)(
implicit ev: MonadCancel[F, Throwable]
implicit ev: Concurrent[F]
): Prepare[F] =
new Prepare[F] {

Expand All @@ -38,13 +38,18 @@ object Prepare {
override def apply[A, B](query: skunk.Query[A, B], ty: Typer): F[PreparedQuery[F, A, B]] =
ParseDescribe[F](describeCache, parseCache).apply(query, ty).map { case (id, rd) =>
new PreparedQuery[F, A, B](id, query, rd) { pq =>
def bind(args: A, origin: Origin): Resource[F, QueryPortal[F, A, B]] =
Bind[F].apply(this, args, origin, redactionStrategy).map {
new QueryPortal[F, A, B](_, pq, args, origin, redactionStrategy) {
def execute(maxRows: Int): F[List[B] ~ Boolean] =
Execute[F].apply(this, maxRows, ty)
}
}
def bind(args: A, origin: Origin, maxRows: Option[Int]): Resource[F, QueryPortal[F, A, B]] =
maxRows match {
case None =>
Bind[F].apply(this, args, origin, redactionStrategy).map {
new QueryPortal[F, A, B](_, pq, args, origin, redactionStrategy) {
def execute(maxRows: Int): F[List[B] ~ Boolean] =
Execute[F].apply(this, maxRows)
}
}
case Some(initialSize) =>
BindExecute[F].query(this, args, origin, redactionStrategy, initialSize)
}
}
}

Expand Down
19 changes: 19 additions & 0 deletions modules/core/shared/src/main/scala/net/protocol/Unroll.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import skunk.exception.DecodeException
import skunk.net.message._
import skunk.net.MessageSocket
import skunk.net.Protocol.QueryPortal
import skunk.net.Protocol.PreparedQuery
import skunk.util.Origin
import skunk.data.TypedRowDescription
import org.typelevel.otel4s.Attribute
Expand Down Expand Up @@ -41,6 +42,24 @@ private[protocol] class Unroll[F[_]: MessageSocket: Tracer](
decoder = portal.preparedQuery.query.decoder,
redactionStrategy = portal.redactionStrategy
)

def unroll[A, B](
preparedQuery: PreparedQuery[F, A, B],
arguments: A,
argumentsOrigin: Origin,
redactionStrategy: RedactionStrategy
): F[(List[B], Boolean)] =
unroll(
extended = true,
sql = preparedQuery.query.sql,
sqlOrigin = preparedQuery.query.origin,
args = arguments,
argsOrigin = Some(argumentsOrigin),
encoder = preparedQuery.query.encoder,
rowDescription = preparedQuery.rowDescription,
decoder = preparedQuery.query.decoder,
redactionStrategy = redactionStrategy
)

// When we do a quick query there's no statement to hang onto all the error-reporting context
// so we have to pass everything in manually.
Expand Down

0 comments on commit 02f9cbb

Please sign in to comment.