From 458478a2f735447f6fec1d422b8601653cd5d3f2 Mon Sep 17 00:00:00 2001 From: Lucas Rodriguez Date: Tue, 7 Nov 2023 16:48:01 -0300 Subject: [PATCH] Make Shutdown idempotent --- server.go | 14 +++++---- server_test.go | 79 +++++++++++++++++++++++++++++++++----------------- 2 files changed, 61 insertions(+), 32 deletions(-) diff --git a/server.go b/server.go index 117f08f..c3cc790 100644 --- a/server.go +++ b/server.go @@ -315,12 +315,16 @@ func (s *ExtensionManagerServer) Call(ctx context.Context, registry string, item func (s *ExtensionManagerServer) Shutdown(ctx context.Context) (err error) { s.mutex.Lock() defer s.mutex.Unlock() - stat, err := s.serverClient.DeregisterExtension(s.uuid) - err = errors.Wrap(err, "deregistering extension") - if err == nil && stat.Code != 0 { - err = errors.Errorf("status %d deregistering extension: %s", stat.Code, stat.Message) + + if s.serverClient != nil { + var stat *osquery.ExtensionStatus + stat, err = s.serverClient.DeregisterExtension(s.uuid) + err = errors.Wrap(err, "deregistering extension") + if err == nil && stat.Code != 0 { + err = errors.Errorf("status %d deregistering extension: %s", stat.Code, stat.Message) + } } - s.serverClient.Close() + if s.server != nil { server := s.server s.server = nil diff --git a/server_test.go b/server_test.go index 705f278..85dabfc 100644 --- a/server_test.go +++ b/server_test.go @@ -23,7 +23,7 @@ import ( // Verify that an error in server.Start will return an error instead of deadlock. func TestNoDeadlockOnError(t *testing.T) { registry := make(map[string](map[string]OsqueryPlugin)) - for reg, _ := range validRegistryNames { + for reg := range validRegistryNames { registry[reg] = make(map[string]OsqueryPlugin) } mut := sync.Mutex{} @@ -42,8 +42,9 @@ func TestNoDeadlockOnError(t *testing.T) { CloseFunc: func() {}, } server := &ExtensionManagerServer{ - serverClient: mock, - registry: registry, + serverClient: mock, + registry: registry, + serverClientShouldShutdown: true, } log := func(ctx context.Context, typ logger.LogType, logText string) error { @@ -62,8 +63,12 @@ func TestNoDeadlockOnError(t *testing.T) { // Ensure that the extension server will shutdown and return if the osquery // instance it is talking to stops responding to pings. func TestShutdownWhenPingFails(t *testing.T) { + tempPath, err := ioutil.TempFile("", "") + require.Nil(t, err) + defer os.Remove(tempPath.Name()) + registry := make(map[string](map[string]OsqueryPlugin)) - for reg, _ := range validRegistryNames { + for reg := range validRegistryNames { registry[reg] = make(map[string]OsqueryPlugin) } mock := &MockExtensionManager{ @@ -80,11 +85,14 @@ func TestShutdownWhenPingFails(t *testing.T) { CloseFunc: func() {}, } server := &ExtensionManagerServer{ - serverClient: mock, - registry: registry, + serverClient: mock, + registry: registry, + serverClientShouldShutdown: true, + pingInterval: 1 * time.Second, + sockPath: tempPath.Name(), } - err := server.Run() + err = server.Run() assert.Error(t, err) assert.Contains(t, err.Error(), "broken pipe") assert.True(t, mock.DeRegisterExtensionFuncInvoked) @@ -104,6 +112,7 @@ func TestShutdownDeadlock(t *testing.T) { }) } } + func testShutdownDeadlock(t *testing.T) { tempPath, err := ioutil.TempFile("", "") require.Nil(t, err) @@ -119,7 +128,7 @@ func testShutdownDeadlock(t *testing.T) { }, CloseFunc: func() {}, } - server := ExtensionManagerServer{serverClient: mock, sockPath: tempPath.Name()} + server := ExtensionManagerServer{serverClient: mock, sockPath: tempPath.Name(), serverClientShouldShutdown: true} wait := sync.WaitGroup{} @@ -177,9 +186,13 @@ func testShutdownDeadlock(t *testing.T) { } func TestShutdownBasic(t *testing.T) { - tempPath, err := ioutil.TempFile("", "") - require.Nil(t, err) - defer os.Remove(tempPath.Name()) + dir := t.TempDir() + + tempPath := func() string { + tmp, err := os.CreateTemp(dir, "") + require.NoError(t, err) + return tmp.Name() + } retUUID := osquery.ExtensionRouteUUID(0) mock := &MockExtensionManager{ @@ -191,26 +204,38 @@ func TestShutdownBasic(t *testing.T) { }, CloseFunc: func() {}, } - server := ExtensionManagerServer{serverClient: mock, sockPath: tempPath.Name()} - completed := make(chan struct{}) - go func() { - err := server.Start() + for _, server := range []*ExtensionManagerServer{ + // Create the extension manager without using NewExtensionManagerServer. + {serverClient: mock, sockPath: tempPath()}, + // Create the extension manager using ExtensionManagerServer. + {serverClient: mock, sockPath: tempPath(), serverClientShouldShutdown: true}, + } { + completed := make(chan struct{}) + go func() { + err := server.Start() + require.NoError(t, err) + close(completed) + }() + + server.waitStarted() + + err := server.Shutdown(context.Background()) require.NoError(t, err) - close(completed) - }() - server.waitStarted() - err = server.Shutdown(context.Background()) - require.NoError(t, err) + // Test that server.Shutdown is idempotent. + err = server.Shutdown(context.Background()) + require.NoError(t, err) + + // Either indicate successful shutdown, or fatal the test because it + // hung + select { + case <-completed: + // Success. Do nothing. + case <-time.After(5 * time.Second): + t.Fatal("hung on shutdown") + } - // Either indicate successful shutdown, or fatal the test because it - // hung - select { - case <-completed: - // Success. Do nothing. - case <-time.After(5 * time.Second): - t.Fatal("hung on shutdown") } }