diff --git a/lib/inventory/controller.go b/lib/inventory/controller.go index e5649480a0586..f22e82a6a17d9 100644 --- a/lib/inventory/controller.go +++ b/lib/inventory/controller.go @@ -62,6 +62,7 @@ const ( appKeepAliveOk testEvent = "app-keep-alive-ok" appKeepAliveErr testEvent = "app-keep-alive-err" + appKeepAliveDel testEvent = "app-keep-alive-del" appUpsertOk testEvent = "app-upsert-ok" appUpsertErr testEvent = "app-upsert-err" @@ -76,6 +77,8 @@ const ( handlerStart = "handler-start" handlerClose = "handler-close" + + keepAliveTick = "keep-alive-tick" ) // intervalKey is used to uniquely identify the subintervals registered with the interval.MultiInterval @@ -616,6 +619,10 @@ func (c *Controller) handleAgentMetadata(handle *upstreamHandle, m proto.Upstrea } func (c *Controller) keepAliveServer(handle *upstreamHandle, now time.Time) error { + // always fire off 'tick' event after keepalive processing to ensure + // that waiting for N ticks maps intuitively to waiting for N keepalive + // processing steps. + defer c.testEvent(keepAliveTick) if err := c.keepAliveSSHServer(handle, now); err != nil { return trace.Wrap(err) } @@ -641,14 +648,15 @@ func (c *Controller) keepAliveAppServer(handle *upstreamHandle, now time.Time) e srv.keepAliveErrs++ handle.appServers[name] = srv - shouldClose := srv.keepAliveErrs > c.maxKeepAliveErrs - - log.Warnf("Failed to keep alive app server %q: %v (count=%d, closing=%v).", handle.Hello().ServerID, err, srv.keepAliveErrs, shouldClose) + shouldRemove := srv.keepAliveErrs > c.maxKeepAliveErrs + log.Warnf("Failed to keep alive app server %q: %v (count=%d, removing=%v).", handle.Hello().ServerID, err, srv.keepAliveErrs, shouldRemove) - if shouldClose { - return trace.Errorf("failed to keep alive app server: %v", err) + if shouldRemove { + c.testEvent(appKeepAliveDel) + delete(handle.appServers, name) } } else { + srv.keepAliveErrs = 0 c.testEvent(appKeepAliveOk) } } else if srv.retryUpsert { @@ -691,6 +699,7 @@ func (c *Controller) keepAliveSSHServer(handle *upstreamHandle, now time.Time) e return trace.Errorf("failed to keep alive ssh server: %v", err) } } else { + handle.sshServer.keepAliveErrs = 0 c.testEvent(sshKeepAliveOk) } } else if handle.sshServer.retryUpsert { diff --git a/lib/inventory/controller_test.go b/lib/inventory/controller_test.go index 94c58aa4bc41d..8115703b919f5 100644 --- a/lib/inventory/controller_test.go +++ b/lib/inventory/controller_test.go @@ -289,6 +289,7 @@ func TestSSHServerBasics(t *testing.T) { // an app service. func TestAppServerBasics(t *testing.T) { const serverID = "test-server" + const appCount = 3 t.Parallel() @@ -324,7 +325,7 @@ func TestAppServerBasics(t *testing.T) { require.Equal(t, int64(1), controller.instanceHBVariableDuration.Count()) // send a fake app server heartbeat - for i := 0; i < 3; i++ { + for i := 0; i < appCount; i++ { err := downstream.Send(ctx, proto.InventoryHeartbeat{ AppServer: &types.AppServerV3{ Metadata: types.Metadata{ @@ -366,7 +367,7 @@ func TestAppServerBasics(t *testing.T) { deny(appUpsertErr, handlerClose), ) - for i := 0; i < 3; i++ { + for i := 0; i < appCount; i++ { err := downstream.Send(ctx, proto.InventoryHeartbeat{ AppServer: &types.AppServerV3{ Metadata: types.Metadata{ @@ -415,6 +416,38 @@ func TestAppServerBasics(t *testing.T) { _, err := handle.Ping(pingCtx, 1) require.NoError(t, err) + // ensure that local app keepalive states have reset to healthy by waiting + // on a full cycle+ worth of keepalives without errors. + awaitEvents(t, events, + expect(keepAliveTick, keepAliveTick), + deny(appKeepAliveErr, handlerClose), + ) + + // set up to induce enough consecutive keepalive errors to cause removal + // of server-side keepalive state. + auth.mu.Lock() + auth.failKeepAlives = 3 * appCount + auth.mu.Unlock() + + // expect that all app keepalives fail, then the app is removed. + var expectedEvents []testEvent + for i := 0; i < appCount; i++ { + expectedEvents = append(expectedEvents, []testEvent{appKeepAliveErr, appKeepAliveErr, appKeepAliveErr, appKeepAliveDel}...) + } + + // wait for failed keepalives to trigger removal + awaitEvents(t, events, + expect(expectedEvents...), + deny(handlerClose), + ) + + // verify that further keepalive ticks to not result in attempts to keepalive + // apps (successful or not). + awaitEvents(t, events, + expect(keepAliveTick, keepAliveTick, keepAliveTick), + deny(appKeepAliveOk, appKeepAliveErr, handlerClose), + ) + // set up to induce enough consecutive errors to cause stream closure auth.mu.Lock() auth.failUpserts = 5 @@ -764,7 +797,7 @@ func awaitEvents(t *testing.T, ch <-chan testEvent, opts ...eventOption) { opt(&options) } - timeout := time.After(time.Second * 5) + timeout := time.After(time.Second * 30) for { if len(options.expect) == 0 { return