diff --git a/server.go b/server.go index 2f15830..117f08f 100644 --- a/server.go +++ b/server.go @@ -52,18 +52,19 @@ const defaultPingInterval = 5 * time.Second // API. Plugins can register with an extension manager, which handles the // communication with the osquery process. type ExtensionManagerServer struct { - name string - version string - sockPath string - serverClient ExtensionManager - registry map[string](map[string]OsqueryPlugin) - server thrift.TServer - transport thrift.TServerTransport - timeout time.Duration - pingInterval time.Duration // How often to ping osquery server - mutex sync.Mutex - uuid osquery.ExtensionRouteUUID - started bool // Used to ensure tests wait until the server is actually started + name string + version string + sockPath string + serverClient ExtensionManager + serverClientShouldShutdown bool // Whether to shutdown the client during server shutdown + registry map[string](map[string]OsqueryPlugin) + server thrift.TServer + transport thrift.TServerTransport + timeout time.Duration + pingInterval time.Duration // How often to ping osquery server + mutex sync.Mutex + uuid osquery.ExtensionRouteUUID + started bool // Used to ensure tests wait until the server is actually started } // validRegistryNames contains the allowable RegistryName() values. If a plugin @@ -157,6 +158,7 @@ func NewExtensionManagerServer(name string, sockPath string, opts ...ServerOptio return nil, err } manager.serverClient = serverClient + manager.serverClientShouldShutdown = true } return manager, nil @@ -253,6 +255,11 @@ func (s *ExtensionManagerServer) Run() error { for { time.Sleep(s.pingInterval) + // can't ping if s.Shutdown has already happened + if s.serverClient == nil { + break + } + status, err := s.serverClient.Ping() if err != nil { errc <- errors.Wrap(err, "extension ping failed") @@ -326,6 +333,12 @@ func (s *ExtensionManagerServer) Shutdown(ctx context.Context) (err error) { }() } + // Shutdown the client, if appropriate + if s.serverClientShouldShutdown && s.serverClient != nil { + s.serverClient.Close() + s.serverClient = nil + } + return }