Skip to content

Commit

Permalink
Make Shutdown idempotent
Browse files Browse the repository at this point in the history
  • Loading branch information
lucasmrod committed Nov 7, 2023
1 parent dbeefc0 commit 458478a
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 32 deletions.
14 changes: 9 additions & 5 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -315,12 +315,16 @@ func (s *ExtensionManagerServer) Call(ctx context.Context, registry string, item
func (s *ExtensionManagerServer) Shutdown(ctx context.Context) (err error) {
s.mutex.Lock()
defer s.mutex.Unlock()
stat, err := s.serverClient.DeregisterExtension(s.uuid)
err = errors.Wrap(err, "deregistering extension")
if err == nil && stat.Code != 0 {
err = errors.Errorf("status %d deregistering extension: %s", stat.Code, stat.Message)

if s.serverClient != nil {
var stat *osquery.ExtensionStatus
stat, err = s.serverClient.DeregisterExtension(s.uuid)
err = errors.Wrap(err, "deregistering extension")
if err == nil && stat.Code != 0 {
err = errors.Errorf("status %d deregistering extension: %s", stat.Code, stat.Message)
}
}
s.serverClient.Close()

if s.server != nil {
server := s.server
s.server = nil
Expand Down
79 changes: 52 additions & 27 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import (
// Verify that an error in server.Start will return an error instead of deadlock.
func TestNoDeadlockOnError(t *testing.T) {
registry := make(map[string](map[string]OsqueryPlugin))
for reg, _ := range validRegistryNames {
for reg := range validRegistryNames {
registry[reg] = make(map[string]OsqueryPlugin)
}
mut := sync.Mutex{}
Expand All @@ -42,8 +42,9 @@ func TestNoDeadlockOnError(t *testing.T) {
CloseFunc: func() {},
}
server := &ExtensionManagerServer{
serverClient: mock,
registry: registry,
serverClient: mock,
registry: registry,
serverClientShouldShutdown: true,
}

log := func(ctx context.Context, typ logger.LogType, logText string) error {
Expand All @@ -62,8 +63,12 @@ func TestNoDeadlockOnError(t *testing.T) {
// Ensure that the extension server will shutdown and return if the osquery
// instance it is talking to stops responding to pings.
func TestShutdownWhenPingFails(t *testing.T) {
tempPath, err := ioutil.TempFile("", "")
require.Nil(t, err)
defer os.Remove(tempPath.Name())

registry := make(map[string](map[string]OsqueryPlugin))
for reg, _ := range validRegistryNames {
for reg := range validRegistryNames {
registry[reg] = make(map[string]OsqueryPlugin)
}
mock := &MockExtensionManager{
Expand All @@ -80,11 +85,14 @@ func TestShutdownWhenPingFails(t *testing.T) {
CloseFunc: func() {},
}
server := &ExtensionManagerServer{
serverClient: mock,
registry: registry,
serverClient: mock,
registry: registry,
serverClientShouldShutdown: true,
pingInterval: 1 * time.Second,
sockPath: tempPath.Name(),
}

err := server.Run()
err = server.Run()
assert.Error(t, err)
assert.Contains(t, err.Error(), "broken pipe")
assert.True(t, mock.DeRegisterExtensionFuncInvoked)
Expand All @@ -104,6 +112,7 @@ func TestShutdownDeadlock(t *testing.T) {
})
}
}

func testShutdownDeadlock(t *testing.T) {
tempPath, err := ioutil.TempFile("", "")
require.Nil(t, err)
Expand All @@ -119,7 +128,7 @@ func testShutdownDeadlock(t *testing.T) {
},
CloseFunc: func() {},
}
server := ExtensionManagerServer{serverClient: mock, sockPath: tempPath.Name()}
server := ExtensionManagerServer{serverClient: mock, sockPath: tempPath.Name(), serverClientShouldShutdown: true}

wait := sync.WaitGroup{}

Expand Down Expand Up @@ -177,9 +186,13 @@ func testShutdownDeadlock(t *testing.T) {
}

func TestShutdownBasic(t *testing.T) {
tempPath, err := ioutil.TempFile("", "")
require.Nil(t, err)
defer os.Remove(tempPath.Name())
dir := t.TempDir()

tempPath := func() string {
tmp, err := os.CreateTemp(dir, "")
require.NoError(t, err)
return tmp.Name()
}

retUUID := osquery.ExtensionRouteUUID(0)
mock := &MockExtensionManager{
Expand All @@ -191,26 +204,38 @@ func TestShutdownBasic(t *testing.T) {
},
CloseFunc: func() {},
}
server := ExtensionManagerServer{serverClient: mock, sockPath: tempPath.Name()}

completed := make(chan struct{})
go func() {
err := server.Start()
for _, server := range []*ExtensionManagerServer{
// Create the extension manager without using NewExtensionManagerServer.
{serverClient: mock, sockPath: tempPath()},
// Create the extension manager using ExtensionManagerServer.
{serverClient: mock, sockPath: tempPath(), serverClientShouldShutdown: true},
} {
completed := make(chan struct{})
go func() {
err := server.Start()
require.NoError(t, err)
close(completed)
}()

server.waitStarted()

err := server.Shutdown(context.Background())
require.NoError(t, err)
close(completed)
}()

server.waitStarted()
err = server.Shutdown(context.Background())
require.NoError(t, err)
// Test that server.Shutdown is idempotent.
err = server.Shutdown(context.Background())
require.NoError(t, err)

// Either indicate successful shutdown, or fatal the test because it
// hung
select {
case <-completed:
// Success. Do nothing.
case <-time.After(5 * time.Second):
t.Fatal("hung on shutdown")
}

// Either indicate successful shutdown, or fatal the test because it
// hung
select {
case <-completed:
// Success. Do nothing.
case <-time.After(5 * time.Second):
t.Fatal("hung on shutdown")
}
}

Expand Down

0 comments on commit 458478a

Please sign in to comment.