Skip to content

Commit

Permalink
all: add cancellation causes to timeouts
Browse files Browse the repository at this point in the history
The new WithTimeoutCause and WithDeadlineCause functions
allow us to decorate contexts with metadata surrounding
a specific timeout or deadline. Combined with the automatic
discovery of the context cause in the errors and event
packages, we should get much more information about
context cancellations.
  • Loading branch information
johanbrandhorst committed Sep 20, 2024
1 parent 55a1d77 commit 4c9e1e8
Show file tree
Hide file tree
Showing 20 changed files with 146 additions and 32 deletions.
10 changes: 9 additions & 1 deletion api/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -736,6 +736,7 @@ func (c *Client) NewRequest(ctx context.Context, method, requestPath string, bod
// Do takes a properly configured request and applies client configuration to
// it, returning the response.
func (c *Client) Do(r *retryablehttp.Request, opt ...Option) (*Response, error) {
const op = "api.(Client).Do"
opts := getOpts(opt...)
c.modifyLock.RLock()
limiter := c.config.Limiter
Expand Down Expand Up @@ -772,7 +773,11 @@ func (c *Client) Do(r *retryablehttp.Request, opt ...Option) (*Response, error)

if timeout != 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, timeout)
ctx, cancel = context.WithTimeoutCause(
ctx,
timeout,
fmt.Errorf("%s: client configured timeout exceeded", op),
)
// This dance is just to ignore vet warnings; we don't want to cancel
// this as it will make reading the response body impossible
_ = cancel
Expand Down Expand Up @@ -841,6 +846,9 @@ func (c *Client) Do(r *retryablehttp.Request, opt ...Option) (*Response, error)
}

if err != nil {
if ctxCause := context.Cause(ctx); ctxCause != nil {
return nil, fmt.Errorf("%w (%w)", err, ctxCause)
}
if strings.Contains(err.Error(), "tls: oversized") {
err = fmt.Errorf(
"%w\n\n"+
Expand Down
13 changes: 11 additions & 2 deletions api/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ type ClientProxy struct {
// EXPERIMENTAL: While this API is not expected to change, it is new and
// feedback from users may necessitate changes.
func New(ctx context.Context, authzToken string, opt ...Option) (*ClientProxy, error) {
const op = "proxy.New"
opts, err := getOpts(opt...)
if err != nil {
return nil, fmt.Errorf("could not parse options: %w", err)
Expand Down Expand Up @@ -142,7 +143,7 @@ func New(ctx context.Context, authzToken string, opt ...Option) (*ClientProxy, e
// We don't _rely_ on client-side timeout verification but this prevents us
// seeming to be ready for a connection that will immediately fail when we
// try to actually make it
p.ctx, p.cancel = context.WithDeadline(ctx, p.expiration)
p.ctx, p.cancel = context.WithDeadlineCause(ctx, p.expiration, fmt.Errorf("%s: session expiration exceeded", op))

transport := cleanhttp.DefaultTransport()
transport.DisableKeepAlives = false
Expand Down Expand Up @@ -173,6 +174,7 @@ func New(ctx context.Context, authzToken string, opt ...Option) (*ClientProxy, e
// EXPERIMENTAL: While this API is not expected to change, it is new and
// feedback from users may necessitate changes.
func (p *ClientProxy) Start(opt ...Option) (retErr error) {
const op = "proxy.(ClientProxy).Start"
opts, err := getOpts(opt...)
if err != nil {
return fmt.Errorf("could not parse options: %w", err)
Expand Down Expand Up @@ -350,9 +352,16 @@ func (p *ClientProxy) Start(opt ...Option) (retErr error) {
return nil
}

ctx, cancel := context.WithTimeout(context.Background(), opts.withSessionTeardownTimeout)
ctx, cancel := context.WithTimeoutCause(
context.Background(),
opts.withSessionTeardownTimeout,
fmt.Errorf("%s: session teardown timeout exceeded", op),
)
defer cancel()
if err := p.sendSessionTeardown(ctx); err != nil {
if ctxCause := ctx.Err(); ctxCause != nil {
return fmt.Errorf("error sending session teardown request to worker: %w (%w)", err, ctxCause)
}
return fmt.Errorf("error sending session teardown request to worker: %w", err)
}

Expand Down
7 changes: 6 additions & 1 deletion internal/clientcache/cmd/cache/wrapper_register.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ func silentUi() *cli.BasicUi {
// addTokenToCache runs AddTokenCommand with the token used in, or retrieved by
// the wrapped command.
func addTokenToCache(ctx context.Context, baseCmd *base.Command, token string) bool {
const op = "cache.addTokenToCache"
com := AddTokenCommand{Command: base.NewCommand(baseCmd.UI)}
client, err := baseCmd.Client()
if err != nil {
Expand All @@ -95,7 +96,11 @@ func addTokenToCache(ctx context.Context, baseCmd *base.Command, token string) b

// Since the daemon might have just started, we need to wait until it can
// respond to our requests
waitCtx, cancel := context.WithTimeout(ctx, 3*time.Second)
waitCtx, cancel := context.WithTimeoutCause(
ctx,
3*time.Second,
fmt.Errorf("%s: daemon startup timeout exceeded", op),
)
defer cancel()
if err := waitForDaemon(waitCtx); err != nil {
// TODO: Print the result of this out into a log in the dot directory
Expand Down
6 changes: 5 additions & 1 deletion internal/clientcache/internal/cache/refresh.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,11 @@ func (r *RefreshService) RefreshForSearch(ctx context.Context, authTokenid strin
const op = "cache.(RefreshService).RefreshForSearch"
if r.maxSearchRefreshTimeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, r.maxSearchRefreshTimeout)
ctx, cancel = context.WithTimeoutCause(
ctx,
r.maxSearchRefreshTimeout,
fmt.Errorf("%s: search refresh timeout exceeded", op),
)
defer cancel()
}
at, err := r.repo.LookupToken(ctx, authTokenid)
Expand Down
8 changes: 6 additions & 2 deletions internal/clientcache/internal/daemon/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,11 +139,15 @@ func (s *CacheServer) Shutdown(ctx context.Context) error {
if s.conf.ContextCancel != nil {
s.conf.ContextCancel()
}
srvCtx, srvCancel := context.WithTimeout(context.Background(), 5*time.Second)
srvCtx, srvCancel := context.WithTimeoutCause(
context.Background(),
5*time.Second,
fmt.Errorf("%s: http server shutdown timeout exceeded", op),
)
defer srvCancel()
err := s.httpSrv.Shutdown(srvCtx)
if err != nil {
shutdownErr = fmt.Errorf("error shutting down server: %w", err)
shutdownErr = errors.Wrap(ctx, err, op, errors.WithMsg("error shutting down server"), errors.WithoutEvent())
return
}
s.tickerWg.Wait()
Expand Down
16 changes: 14 additions & 2 deletions internal/cmd/commands/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ func (c *Command) AutocompleteFlags() complete.Flags {
}

func (c *Command) Run(args []string) int {
const op = "server.(Command).Run"
c.CombineLogs = c.flagCombineLogs

defer func() {
Expand Down Expand Up @@ -479,12 +480,23 @@ func (c *Command) Run(args []string) int {
// 1 second is chosen so the shutdown is still responsive and this is a mostly
// non critical step since the lock should be released when the session with the
// database is closed.
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
ctx, cancel := context.WithTimeoutCause(
context.Background(),
1*time.Second,
fmt.Errorf("%s: database lock release timeout exceeded", op),
)
defer cancel()

err := c.schemaManager.Close(ctx)
if err != nil {
c.UI.Error(fmt.Errorf("Unable to release shared lock to the database: %w", err).Error())
// Use errors.E to capture the context cause if there is one
c.UI.Error(errors.Wrap(
ctx,
err,
op,
errors.WithMsg("Unable to release shared lock to the database"),
errors.WithoutEvent(),
).Error())
}
}()

Expand Down
13 changes: 9 additions & 4 deletions internal/cmd/ops/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ package ops

import (
"context"
"errors"
stderrors "errors"
"fmt"
"net"
"net/http"
Expand All @@ -17,6 +17,7 @@ import (
"github.com/hashicorp/boundary/internal/cmd/base"
"github.com/hashicorp/boundary/internal/daemon/controller"
"github.com/hashicorp/boundary/internal/daemon/worker"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-secure-stdlib/listenerutil"
Expand Down Expand Up @@ -94,18 +95,22 @@ func (s *Server) Shutdown() error {
return fmt.Errorf("%s: missing bundle, listener or its fields", op)
}

ctx, cancel := context.WithTimeout(context.Background(), b.ln.Config.MaxRequestDuration)
ctx, cancel := context.WithTimeoutCause(
context.Background(),
b.ln.Config.MaxRequestDuration,
fmt.Errorf("%s: max request duration exceeded", op),
)
defer cancel()

err := b.ln.HTTPServer.Shutdown(ctx)
if err != nil {
errors.Join(closeErrors, fmt.Errorf("%s: failed to shutdown http server: %w", op, err))
closeErrors = stderrors.Join(closeErrors, errors.Wrap(ctx, err, op, errors.WithMsg("failed to shutdown http server")))
}

err = b.ln.OpsListener.Close()
err = listenerCloseErrorCheck(b.ln.Config.Type, err)
if err != nil {
errors.Join(closeErrors, fmt.Errorf("%s: failed to close listener mux: %w", op, err))
closeErrors = stderrors.Join(closeErrors, fmt.Errorf("%s: failed to close listener mux: %w", op, err))
}
}

Expand Down
6 changes: 5 additions & 1 deletion internal/daemon/controller/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,11 @@ func wrapHandlerWithCommonFuncs(h http.Handler, c *Controller, props HandlerProp
w.Header().Set("Cache-Control", "no-store")

// Start with the request context and our timeout
ctx, cancelFunc := context.WithTimeout(r.Context(), maxRequestDuration)
ctx, cancelFunc := context.WithTimeoutCause(
r.Context(),
maxRequestDuration,
fmt.Errorf("%s: max request duration exceeded", op),
)
defer cancelFunc()

// Add a size limiter if desired
Expand Down
8 changes: 6 additions & 2 deletions internal/daemon/controller/handlers/targets/target_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -1064,7 +1064,7 @@ func (s Service) AuthorizeSession(ctx context.Context, req *pbs.AuthorizeSession
if retErr != nil {
// Delete created session in case of errors.
// Use new context for deletion in case error is because of context cancellation.
deleteCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
deleteCtx, cancel := context.WithTimeoutCause(context.Background(), 5*time.Second, stderrors.New("session deletion timeout exceeded"))
defer cancel()
_, err := sessionRepo.DeleteSession(deleteCtx, sess.PublicId)
retErr = stderrors.Join(retErr, err)
Expand Down Expand Up @@ -1095,7 +1095,11 @@ func (s Service) AuthorizeSession(ctx context.Context, req *pbs.AuthorizeSession
if retErr != nil {
// Revoke issued credentials in case of errors.
// Use new context for deletion in case error is because of context cancellation.
deleteCtx, cancel := context.WithTimeout(context.Background(), time.Minute)
deleteCtx, cancel := context.WithTimeoutCause(
context.Background(),
time.Minute,
fmt.Errorf("%s: credential revocation timeout exceeded", op),
)
defer cancel()
err := credRepo.Revoke(deleteCtx, sess.PublicId)
retErr = stderrors.Join(retErr, err)
Expand Down
6 changes: 5 additions & 1 deletion internal/daemon/controller/interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,11 @@ func eventsResponseInterceptor(
func requestMaxDurationInterceptor(_ context.Context, maxRequestDuration time.Duration) grpc.UnaryServerInterceptor {
const op = "controller.requestMaxDurationInterceptor"
return func(interceptorCtx context.Context, req any, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
withTimeout, cancel := context.WithTimeout(interceptorCtx, maxRequestDuration)
withTimeout, cancel := context.WithTimeoutCause(
interceptorCtx,
maxRequestDuration,
fmt.Errorf("%s: max request duration exceeded", op),
)
defer cancel()
return handler(withTimeout, req)
}
Expand Down
9 changes: 7 additions & 2 deletions internal/daemon/controller/listeners.go
Original file line number Diff line number Diff line change
Expand Up @@ -299,15 +299,20 @@ func (c *Controller) stopClusterGrpcServerAndListener() error {
}

func (c *Controller) stopHttpServersAndListeners() error {
const op = "controller.Controller.stopHttpServersAndListeners"
var closeErrors error
for i := range c.apiListeners {
ln := c.apiListeners[i]
if ln.HTTPServer == nil {
continue
}

ctx, cancel := context.WithTimeout(c.baseContext, ln.Config.MaxRequestDuration)
ln.HTTPServer.Shutdown(ctx)
ctx, cancel := context.WithTimeoutCause(
c.baseContext,
ln.Config.MaxRequestDuration,
fmt.Errorf("%s: max request duration exceeded", op),
)
_ = ln.HTTPServer.Shutdown(ctx)
cancel()

err := ln.ApiListener.Close() // The HTTP Shutdown call should close this, but just in case.
Expand Down
6 changes: 5 additions & 1 deletion internal/daemon/controller/testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -888,7 +888,11 @@ func (tc *TestController) WaitForNextWorkerStatusUpdate(workerStatusName string)
ctx := context.TODO()
event.WriteSysEvent(ctx, op, "waiting for next status report from worker", "worker", workerStatusName)
waitStatusStart := time.Now()
ctx, cancel := context.WithTimeout(tc.ctx, time.Duration(tc.c.workerStatusGracePeriod.Load()))
ctx, cancel := context.WithTimeoutCause(
tc.ctx,
time.Duration(tc.c.workerStatusGracePeriod.Load()),
fmt.Errorf("%s: worker status grace period exceeded", op),
)
defer cancel()
var err error
for {
Expand Down
2 changes: 1 addition & 1 deletion internal/daemon/worker/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ func (w *Worker) handleProxy(listenerCfg *listenerutil.ListenerConfig, sessionMa
// Later calls will cause this to noop if they return a different status
defer conn.Close(websocket.StatusNormalClosure, "done")

connCtx, connCancel := context.WithDeadline(ctx, sess.GetExpiration())
connCtx, connCancel := context.WithDeadlineCause(ctx, sess.GetExpiration(), fmt.Errorf("%s: session expiration exceeded", op))
defer connCancel()

var handshake proxy.ClientHandshake
Expand Down
7 changes: 6 additions & 1 deletion internal/daemon/worker/listeners.go
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ func (w *Worker) stopServersAndListeners() error {
}

func (w *Worker) stopHttpServer() error {
const op = "worker.stopHttpServer"
if w.proxyListener == nil {
return nil
}
Expand All @@ -329,7 +330,11 @@ func (w *Worker) stopHttpServer() error {
return nil
}

ctx, cancel := context.WithTimeout(w.baseContext, w.proxyListener.Config.MaxRequestDuration)
ctx, cancel := context.WithTimeoutCause(
w.baseContext,
w.proxyListener.Config.MaxRequestDuration,
fmt.Errorf("%s: max request duration exceeded", op),
)
w.proxyListener.HTTPServer.Shutdown(ctx)
cancel()

Expand Down
6 changes: 5 additions & 1 deletion internal/daemon/worker/session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,11 @@ func closeConnections(ctx context.Context, sessClient pbs.SessionServiceClient,
// bit of formalization in terms of how we handle timeouts. For now, this
// just ensures consistency with the same status call in that it times out
// within an adequate period of time.
closeConnCtx, closeConnCancel := context.WithTimeout(ctx, time.Duration(CloseCallTimeout.Load()))
closeConnCtx, closeConnCancel := context.WithTimeoutCause(
ctx,
time.Duration(CloseCallTimeout.Load()),
fmt.Errorf("%s: close call timeout exceeded", op),
)
defer closeConnCancel()
response, err := closeConnection(closeConnCtx, sessClient, makeCloseConnectionRequest(closeInfo))
if err != nil {
Expand Down
18 changes: 15 additions & 3 deletions internal/daemon/worker/status.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,11 @@ func (w *Worker) LastStatusSuccess() *LastStatusInformation {
func (w *Worker) WaitForNextSuccessfulStatusUpdate() error {
const op = "worker.(Worker).WaitForNextSuccessfulStatusUpdate"
waitStatusStart := time.Now()
ctx, cancel := context.WithTimeout(w.baseContext, time.Duration(w.successfulStatusGracePeriod.Load()))
ctx, cancel := context.WithTimeoutCause(
w.baseContext,
time.Duration(w.successfulStatusGracePeriod.Load()),
fmt.Errorf("%s: status grace period exceeded", op),
)
defer cancel()
event.WriteSysEvent(ctx, op, "waiting for next status report to controller")
for {
Expand Down Expand Up @@ -189,7 +193,11 @@ func (w *Worker) sendWorkerStatus(cancelCtx context.Context, sessionManager sess
if w.updateTags.Load() {
tags = w.tags.Load().([]*pb.TagPair)
}
statusCtx, statusCancel := context.WithTimeout(cancelCtx, time.Duration(w.statusCallTimeoutDuration.Load()))
statusCtx, statusCancel := context.WithTimeoutCause(
cancelCtx,
time.Duration(w.statusCallTimeoutDuration.Load()),
fmt.Errorf("%s: status call timeout exceeded", op),
)
defer statusCancel()

keyId := w.WorkerAuthCurrentKeyId.Load()
Expand Down Expand Up @@ -312,7 +320,11 @@ func (w *Worker) sendWorkerStatus(cancelCtx context.Context, sessionManager sess
}
} else if checkHCPBUpstreams != nil && checkHCPBUpstreams(w) {
// This is a worker that is one hop away from managed workers, so attempt to get that list
hcpbWorkersCtx, hcpbWorkersCancel := context.WithTimeout(cancelCtx, time.Duration(w.statusCallTimeoutDuration.Load()))
hcpbWorkersCtx, hcpbWorkersCancel := context.WithTimeoutCause(
cancelCtx,
time.Duration(w.statusCallTimeoutDuration.Load()),
fmt.Errorf("%s: status call timeout exceeded", op),
)
defer hcpbWorkersCancel()
workersResp, err := client.ListHcpbWorkers(hcpbWorkersCtx, &pbs.ListHcpbWorkersRequest{})
if err != nil {
Expand Down
12 changes: 10 additions & 2 deletions internal/daemon/worker/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -725,7 +725,11 @@ func (w *Worker) Shutdown() error {
// at our default liveness value, which is also our default status grace
// period timeout
waitStatusStart := time.Now()
nextStatusCtx, nextStatusCancel := context.WithTimeout(w.baseContext, server.DefaultLiveness)
nextStatusCtx, nextStatusCancel := context.WithTimeoutCause(
w.baseContext,
server.DefaultLiveness,
fmt.Errorf("%s: liveness timeout exceeded", op),
)
defer nextStatusCancel()
for {
if err := nextStatusCtx.Err(); err != nil {
Expand Down Expand Up @@ -813,7 +817,11 @@ func (w *Worker) getSessionTls(sessionManager session.Manager) func(hello *tls.C
return nil, fmt.Errorf("no last status information found at session acceptance time")
}

timeoutContext, cancel := context.WithTimeout(w.baseContext, session.ValidateSessionTimeout)
timeoutContext, cancel := context.WithTimeoutCause(
w.baseContext,
session.ValidateSessionTimeout,
fmt.Errorf("%s: session validation timeout exceeded", op),
)
defer cancel()
sess, err := sessionManager.LoadLocalSession(timeoutContext, sessionId, lastSuccess.GetWorkerId())
if err != nil {
Expand Down
Loading

0 comments on commit 4c9e1e8

Please sign in to comment.