Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add context timeout/deadline causes #5043

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion api/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -736,6 +736,7 @@ func (c *Client) NewRequest(ctx context.Context, method, requestPath string, bod
// Do takes a properly configured request and applies client configuration to
// it, returning the response.
func (c *Client) Do(r *retryablehttp.Request, opt ...Option) (*Response, error) {
const op = "api.(Client).Do"
opts := getOpts(opt...)
c.modifyLock.RLock()
limiter := c.config.Limiter
Expand Down Expand Up @@ -772,7 +773,11 @@ func (c *Client) Do(r *retryablehttp.Request, opt ...Option) (*Response, error)

if timeout != 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, timeout)
ctx, cancel = context.WithTimeoutCause(
ctx,
timeout,
fmt.Errorf("%s: client configured timeout exceeded", op),
)
// This dance is just to ignore vet warnings; we don't want to cancel
// this as it will make reading the response body impossible
_ = cancel
Expand Down Expand Up @@ -841,6 +846,9 @@ func (c *Client) Do(r *retryablehttp.Request, opt ...Option) (*Response, error)
}

if err != nil {
if ctxCause := context.Cause(ctx); ctxCause != nil {
return nil, fmt.Errorf("%w (%w)", err, ctxCause)
}
if strings.Contains(err.Error(), "tls: oversized") {
err = fmt.Errorf(
"%w\n\n"+
Expand Down
13 changes: 11 additions & 2 deletions api/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ type ClientProxy struct {
// EXPERIMENTAL: While this API is not expected to change, it is new and
// feedback from users may necessitate changes.
func New(ctx context.Context, authzToken string, opt ...Option) (*ClientProxy, error) {
const op = "proxy.New"
opts, err := getOpts(opt...)
if err != nil {
return nil, fmt.Errorf("could not parse options: %w", err)
Expand Down Expand Up @@ -142,7 +143,7 @@ func New(ctx context.Context, authzToken string, opt ...Option) (*ClientProxy, e
// We don't _rely_ on client-side timeout verification but this prevents us
// seeming to be ready for a connection that will immediately fail when we
// try to actually make it
p.ctx, p.cancel = context.WithDeadline(ctx, p.expiration)
p.ctx, p.cancel = context.WithDeadlineCause(ctx, p.expiration, fmt.Errorf("%s: session expiration exceeded", op))

transport := cleanhttp.DefaultTransport()
transport.DisableKeepAlives = false
Expand Down Expand Up @@ -173,6 +174,7 @@ func New(ctx context.Context, authzToken string, opt ...Option) (*ClientProxy, e
// EXPERIMENTAL: While this API is not expected to change, it is new and
// feedback from users may necessitate changes.
func (p *ClientProxy) Start(opt ...Option) (retErr error) {
const op = "proxy.(ClientProxy).Start"
opts, err := getOpts(opt...)
if err != nil {
return fmt.Errorf("could not parse options: %w", err)
Expand Down Expand Up @@ -350,9 +352,16 @@ func (p *ClientProxy) Start(opt ...Option) (retErr error) {
return nil
}

ctx, cancel := context.WithTimeout(context.Background(), opts.withSessionTeardownTimeout)
ctx, cancel := context.WithTimeoutCause(
context.Background(),
opts.withSessionTeardownTimeout,
fmt.Errorf("%s: session teardown timeout exceeded", op),
)
defer cancel()
if err := p.sendSessionTeardown(ctx); err != nil {
if ctxCause := ctx.Err(); ctxCause != nil {
return fmt.Errorf("error sending session teardown request to worker: %w (%w)", err, ctxCause)
}
return fmt.Errorf("error sending session teardown request to worker: %w", err)
}

Expand Down
7 changes: 6 additions & 1 deletion internal/clientcache/cmd/cache/wrapper_register.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ func silentUi() *cli.BasicUi {
// addTokenToCache runs AddTokenCommand with the token used in, or retrieved by
// the wrapped command.
func addTokenToCache(ctx context.Context, baseCmd *base.Command, token string) bool {
const op = "cache.addTokenToCache"
com := AddTokenCommand{Command: base.NewCommand(baseCmd.UI)}
client, err := baseCmd.Client()
if err != nil {
Expand All @@ -95,7 +96,11 @@ func addTokenToCache(ctx context.Context, baseCmd *base.Command, token string) b

// Since the daemon might have just started, we need to wait until it can
// respond to our requests
waitCtx, cancel := context.WithTimeout(ctx, 3*time.Second)
waitCtx, cancel := context.WithTimeoutCause(
ctx,
3*time.Second,
fmt.Errorf("%s: daemon startup timeout exceeded", op),
)
defer cancel()
if err := waitForDaemon(waitCtx); err != nil {
// TODO: Print the result of this out into a log in the dot directory
Expand Down
6 changes: 5 additions & 1 deletion internal/clientcache/internal/cache/refresh.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,11 @@ func (r *RefreshService) RefreshForSearch(ctx context.Context, authTokenid strin
const op = "cache.(RefreshService).RefreshForSearch"
if r.maxSearchRefreshTimeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, r.maxSearchRefreshTimeout)
ctx, cancel = context.WithTimeoutCause(
ctx,
r.maxSearchRefreshTimeout,
fmt.Errorf("%s: search refresh timeout exceeded", op),
)
defer cancel()
}
at, err := r.repo.LookupToken(ctx, authTokenid)
Expand Down
8 changes: 6 additions & 2 deletions internal/clientcache/internal/daemon/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,11 +139,15 @@ func (s *CacheServer) Shutdown(ctx context.Context) error {
if s.conf.ContextCancel != nil {
s.conf.ContextCancel()
}
srvCtx, srvCancel := context.WithTimeout(context.Background(), 5*time.Second)
srvCtx, srvCancel := context.WithTimeoutCause(
context.Background(),
5*time.Second,
fmt.Errorf("%s: http server shutdown timeout exceeded", op),
)
defer srvCancel()
err := s.httpSrv.Shutdown(srvCtx)
if err != nil {
shutdownErr = fmt.Errorf("error shutting down server: %w", err)
shutdownErr = errors.Wrap(ctx, err, op, errors.WithMsg("error shutting down server"), errors.WithoutEvent())
return
}
s.tickerWg.Wait()
Expand Down
16 changes: 14 additions & 2 deletions internal/cmd/commands/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ func (c *Command) AutocompleteFlags() complete.Flags {
}

func (c *Command) Run(args []string) int {
const op = "server.(Command).Run"
c.CombineLogs = c.flagCombineLogs

defer func() {
Expand Down Expand Up @@ -479,12 +480,23 @@ func (c *Command) Run(args []string) int {
// 1 second is chosen so the shutdown is still responsive and this is a mostly
// non critical step since the lock should be released when the session with the
// database is closed.
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
ctx, cancel := context.WithTimeoutCause(
context.Background(),
1*time.Second,
fmt.Errorf("%s: database lock release timeout exceeded", op),
)
defer cancel()

err := c.schemaManager.Close(ctx)
if err != nil {
c.UI.Error(fmt.Errorf("Unable to release shared lock to the database: %w", err).Error())
// Use errors.E to capture the context cause if there is one
c.UI.Error(errors.Wrap(
ctx,
err,
op,
errors.WithMsg("Unable to release shared lock to the database"),
errors.WithoutEvent(),
).Error())
}
}()

Expand Down
13 changes: 9 additions & 4 deletions internal/cmd/ops/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ package ops

import (
"context"
"errors"
stderrors "errors"
"fmt"
"net"
"net/http"
Expand All @@ -17,6 +17,7 @@ import (
"github.com/hashicorp/boundary/internal/cmd/base"
"github.com/hashicorp/boundary/internal/daemon/controller"
"github.com/hashicorp/boundary/internal/daemon/worker"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-secure-stdlib/listenerutil"
Expand Down Expand Up @@ -94,18 +95,22 @@ func (s *Server) Shutdown() error {
return fmt.Errorf("%s: missing bundle, listener or its fields", op)
}

ctx, cancel := context.WithTimeout(context.Background(), b.ln.Config.MaxRequestDuration)
ctx, cancel := context.WithTimeoutCause(
context.Background(),
b.ln.Config.MaxRequestDuration,
fmt.Errorf("%s: max request duration exceeded", op),
)
defer cancel()

err := b.ln.HTTPServer.Shutdown(ctx)
if err != nil {
errors.Join(closeErrors, fmt.Errorf("%s: failed to shutdown http server: %w", op, err))
closeErrors = stderrors.Join(closeErrors, errors.Wrap(ctx, err, op, errors.WithMsg("failed to shutdown http server")))
}

err = b.ln.OpsListener.Close()
err = listenerCloseErrorCheck(b.ln.Config.Type, err)
if err != nil {
errors.Join(closeErrors, fmt.Errorf("%s: failed to close listener mux: %w", op, err))
closeErrors = stderrors.Join(closeErrors, fmt.Errorf("%s: failed to close listener mux: %w", op, err))
}
}

Expand Down
6 changes: 5 additions & 1 deletion internal/daemon/controller/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,11 @@ func wrapHandlerWithCommonFuncs(h http.Handler, c *Controller, props HandlerProp
w.Header().Set("Cache-Control", "no-store")

// Start with the request context and our timeout
ctx, cancelFunc := context.WithTimeout(r.Context(), maxRequestDuration)
ctx, cancelFunc := context.WithTimeoutCause(
r.Context(),
maxRequestDuration,
fmt.Errorf("%s: max request duration exceeded", op),
)
defer cancelFunc()

// Add a size limiter if desired
Expand Down
8 changes: 6 additions & 2 deletions internal/daemon/controller/handlers/targets/target_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -1064,7 +1064,7 @@ func (s Service) AuthorizeSession(ctx context.Context, req *pbs.AuthorizeSession
if retErr != nil {
// Delete created session in case of errors.
// Use new context for deletion in case error is because of context cancellation.
deleteCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
deleteCtx, cancel := context.WithTimeoutCause(context.Background(), 5*time.Second, stderrors.New("session deletion timeout exceeded"))
defer cancel()
_, err := sessionRepo.DeleteSession(deleteCtx, sess.PublicId)
retErr = stderrors.Join(retErr, err)
Expand Down Expand Up @@ -1095,7 +1095,11 @@ func (s Service) AuthorizeSession(ctx context.Context, req *pbs.AuthorizeSession
if retErr != nil {
// Revoke issued credentials in case of errors.
// Use new context for deletion in case error is because of context cancellation.
deleteCtx, cancel := context.WithTimeout(context.Background(), time.Minute)
deleteCtx, cancel := context.WithTimeoutCause(
context.Background(),
time.Minute,
fmt.Errorf("%s: credential revocation timeout exceeded", op),
)
defer cancel()
err := credRepo.Revoke(deleteCtx, sess.PublicId)
retErr = stderrors.Join(retErr, err)
Expand Down
6 changes: 5 additions & 1 deletion internal/daemon/controller/interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,11 @@ func eventsResponseInterceptor(
func requestMaxDurationInterceptor(_ context.Context, maxRequestDuration time.Duration) grpc.UnaryServerInterceptor {
const op = "controller.requestMaxDurationInterceptor"
return func(interceptorCtx context.Context, req any, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
withTimeout, cancel := context.WithTimeout(interceptorCtx, maxRequestDuration)
withTimeout, cancel := context.WithTimeoutCause(
interceptorCtx,
maxRequestDuration,
fmt.Errorf("%s: max request duration exceeded", op),
)
defer cancel()
return handler(withTimeout, req)
}
Expand Down
9 changes: 7 additions & 2 deletions internal/daemon/controller/listeners.go
Original file line number Diff line number Diff line change
Expand Up @@ -299,15 +299,20 @@ func (c *Controller) stopClusterGrpcServerAndListener() error {
}

func (c *Controller) stopHttpServersAndListeners() error {
const op = "controller.Controller.stopHttpServersAndListeners"
var closeErrors error
for i := range c.apiListeners {
ln := c.apiListeners[i]
if ln.HTTPServer == nil {
continue
}

ctx, cancel := context.WithTimeout(c.baseContext, ln.Config.MaxRequestDuration)
ln.HTTPServer.Shutdown(ctx)
ctx, cancel := context.WithTimeoutCause(
c.baseContext,
ln.Config.MaxRequestDuration,
fmt.Errorf("%s: max request duration exceeded", op),
)
_ = ln.HTTPServer.Shutdown(ctx)
cancel()

err := ln.ApiListener.Close() // The HTTP Shutdown call should close this, but just in case.
Expand Down
6 changes: 5 additions & 1 deletion internal/daemon/controller/testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -888,7 +888,11 @@ func (tc *TestController) WaitForNextWorkerStatusUpdate(workerStatusName string)
ctx := context.TODO()
event.WriteSysEvent(ctx, op, "waiting for next status report from worker", "worker", workerStatusName)
waitStatusStart := time.Now()
ctx, cancel := context.WithTimeout(tc.ctx, time.Duration(tc.c.workerStatusGracePeriod.Load()))
ctx, cancel := context.WithTimeoutCause(
tc.ctx,
time.Duration(tc.c.workerStatusGracePeriod.Load()),
fmt.Errorf("%s: worker status grace period exceeded", op),
)
defer cancel()
var err error
for {
Expand Down
2 changes: 1 addition & 1 deletion internal/daemon/worker/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ func (w *Worker) handleProxy(listenerCfg *listenerutil.ListenerConfig, sessionMa
// Later calls will cause this to noop if they return a different status
defer conn.Close(websocket.StatusNormalClosure, "done")

connCtx, connCancel := context.WithDeadline(ctx, sess.GetExpiration())
connCtx, connCancel := context.WithDeadlineCause(ctx, sess.GetExpiration(), fmt.Errorf("%s: session expiration exceeded", op))
defer connCancel()

var handshake proxy.ClientHandshake
Expand Down
7 changes: 6 additions & 1 deletion internal/daemon/worker/listeners.go
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ func (w *Worker) stopServersAndListeners() error {
}

func (w *Worker) stopHttpServer() error {
const op = "worker.stopHttpServer"
if w.proxyListener == nil {
return nil
}
Expand All @@ -329,7 +330,11 @@ func (w *Worker) stopHttpServer() error {
return nil
}

ctx, cancel := context.WithTimeout(w.baseContext, w.proxyListener.Config.MaxRequestDuration)
ctx, cancel := context.WithTimeoutCause(
w.baseContext,
w.proxyListener.Config.MaxRequestDuration,
fmt.Errorf("%s: max request duration exceeded", op),
)
w.proxyListener.HTTPServer.Shutdown(ctx)
cancel()

Expand Down
6 changes: 5 additions & 1 deletion internal/daemon/worker/session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,11 @@ func closeConnections(ctx context.Context, sessClient pbs.SessionServiceClient,
// bit of formalization in terms of how we handle timeouts. For now, this
// just ensures consistency with the same status call in that it times out
// within an adequate period of time.
closeConnCtx, closeConnCancel := context.WithTimeout(ctx, time.Duration(CloseCallTimeout.Load()))
closeConnCtx, closeConnCancel := context.WithTimeoutCause(
ctx,
time.Duration(CloseCallTimeout.Load()),
fmt.Errorf("%s: close call timeout exceeded", op),
)
defer closeConnCancel()
response, err := closeConnection(closeConnCtx, sessClient, makeCloseConnectionRequest(closeInfo))
if err != nil {
Expand Down
18 changes: 15 additions & 3 deletions internal/daemon/worker/status.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,11 @@ func (w *Worker) LastStatusSuccess() *LastStatusInformation {
func (w *Worker) WaitForNextSuccessfulStatusUpdate() error {
const op = "worker.(Worker).WaitForNextSuccessfulStatusUpdate"
waitStatusStart := time.Now()
ctx, cancel := context.WithTimeout(w.baseContext, time.Duration(w.successfulStatusGracePeriod.Load()))
ctx, cancel := context.WithTimeoutCause(
w.baseContext,
time.Duration(w.successfulStatusGracePeriod.Load()),
fmt.Errorf("%s: status grace period exceeded", op),
)
defer cancel()
event.WriteSysEvent(ctx, op, "waiting for next status report to controller")
for {
Expand Down Expand Up @@ -189,7 +193,11 @@ func (w *Worker) sendWorkerStatus(cancelCtx context.Context, sessionManager sess
if w.updateTags.Load() {
tags = w.tags.Load().([]*pb.TagPair)
}
statusCtx, statusCancel := context.WithTimeout(cancelCtx, time.Duration(w.statusCallTimeoutDuration.Load()))
statusCtx, statusCancel := context.WithTimeoutCause(
cancelCtx,
time.Duration(w.statusCallTimeoutDuration.Load()),
fmt.Errorf("%s: status call timeout exceeded", op),
)
defer statusCancel()

keyId := w.WorkerAuthCurrentKeyId.Load()
Expand Down Expand Up @@ -312,7 +320,11 @@ func (w *Worker) sendWorkerStatus(cancelCtx context.Context, sessionManager sess
}
} else if checkHCPBUpstreams != nil && checkHCPBUpstreams(w) {
// This is a worker that is one hop away from managed workers, so attempt to get that list
hcpbWorkersCtx, hcpbWorkersCancel := context.WithTimeout(cancelCtx, time.Duration(w.statusCallTimeoutDuration.Load()))
hcpbWorkersCtx, hcpbWorkersCancel := context.WithTimeoutCause(
cancelCtx,
time.Duration(w.statusCallTimeoutDuration.Load()),
fmt.Errorf("%s: status call timeout exceeded", op),
)
defer hcpbWorkersCancel()
workersResp, err := client.ListHcpbWorkers(hcpbWorkersCtx, &pbs.ListHcpbWorkersRequest{})
if err != nil {
Expand Down
12 changes: 10 additions & 2 deletions internal/daemon/worker/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -725,7 +725,11 @@ func (w *Worker) Shutdown() error {
// at our default liveness value, which is also our default status grace
// period timeout
waitStatusStart := time.Now()
nextStatusCtx, nextStatusCancel := context.WithTimeout(w.baseContext, server.DefaultLiveness)
nextStatusCtx, nextStatusCancel := context.WithTimeoutCause(
w.baseContext,
server.DefaultLiveness,
fmt.Errorf("%s: liveness timeout exceeded", op),
)
defer nextStatusCancel()
for {
if err := nextStatusCtx.Err(); err != nil {
Expand Down Expand Up @@ -813,7 +817,11 @@ func (w *Worker) getSessionTls(sessionManager session.Manager) func(hello *tls.C
return nil, fmt.Errorf("no last status information found at session acceptance time")
}

timeoutContext, cancel := context.WithTimeout(w.baseContext, session.ValidateSessionTimeout)
timeoutContext, cancel := context.WithTimeoutCause(
w.baseContext,
session.ValidateSessionTimeout,
fmt.Errorf("%s: session validation timeout exceeded", op),
)
defer cancel()
sess, err := sessionManager.LoadLocalSession(timeoutContext, sessionId, lastSuccess.GetWorkerId())
if err != nil {
Expand Down
Loading
Loading