Skip to content

Commit

Permalink
Add service enable/disable support and tests (#171)
Browse files Browse the repository at this point in the history
* Add utils for all service permutations
  • Loading branch information
sfc-gh-jchacon authored Oct 14, 2022
1 parent f782a58 commit 6e1e71d
Show file tree
Hide file tree
Showing 17 changed files with 306 additions and 61 deletions.
2 changes: 1 addition & 1 deletion proxy/proxy.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion proxy/testdata/testservice.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion services/ansible/ansible.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion services/exec/exec.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion services/fdb/fdb.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion services/healthcheck/healthcheck.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion services/localfile/localfile.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion services/packages/packages.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion services/process/process.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion services/sansshell/sansshell.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 7 additions & 5 deletions services/service/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ func setup(f *flag.FlagSet) *subcommands.Commander {
c.Register(&actionCmd{action: pb.Action_ACTION_START}, "")
c.Register(&statusCmd{}, "")
c.Register(&actionCmd{action: pb.Action_ACTION_STOP}, "")
c.Register(&actionCmd{action: pb.Action_ACTION_ENABLE}, "")
c.Register(&actionCmd{action: pb.Action_ACTION_DISABLE}, "")
return c
}

Expand Down Expand Up @@ -168,12 +170,12 @@ func (a *actionCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interf
out := state.Out[resp.Index]
output := fmt.Sprintf("[%s] %s %v: OK", systemTypeString(system), serviceName, as)
if resp.Error != nil && err != io.EOF {
lastErr = fmt.Errorf("target %s (%d) returned error %w", resp.Target, resp.Index, resp.Error)
lastErr = fmt.Errorf("target %s (%d) returned error %w\n", resp.Target, resp.Index, resp.Error)
fmt.Fprint(state.Err[resp.Index], lastErr)
continue
}
if _, err := fmt.Fprintln(out, output); err != nil {
lastErr = fmt.Errorf("target %s (%d) output write error %w", resp.Target, resp.Index, err)
lastErr = fmt.Errorf("target %s (%d) output write error %w\n", resp.Target, resp.Index, err)
fmt.Fprint(state.Err[resp.Index], lastErr)
}
}
Expand Down Expand Up @@ -245,12 +247,12 @@ func (s *statusCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interf
system, status := resp.Resp.GetSystemType(), resp.Resp.GetServiceStatus().GetStatus()
output := fmt.Sprintf("[%s] %s : %s", systemTypeString(system), serviceName, statusString(status))
if resp.Error != nil {
lastErr = fmt.Errorf("target %s [%d] error: %w", resp.Target, resp.Index, resp.Error)
lastErr = fmt.Errorf("target %s [%d] error: %w\n", resp.Target, resp.Index, resp.Error)
fmt.Fprint(state.Err[resp.Index], lastErr)
continue
}
if _, err := fmt.Fprintln(out, output); err != nil {
lastErr = fmt.Errorf("target %s [%d] write error: %w", resp.Target, resp.Index, err)
lastErr = fmt.Errorf("target %s [%d] write error: %w\n", resp.Target, resp.Index, err)
fmt.Fprint(state.Err[resp.Index], lastErr)
}
}
Expand Down Expand Up @@ -319,7 +321,7 @@ func (l *listCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interfac
system := systemTypeString(resp.Resp.GetSystemType())
for _, svc := range resp.Resp.Services {
if _, err := fmt.Fprintf(out, "[%s] %s : %s\n", system, svc.GetServiceName(), statusString(svc.GetStatus())); err != nil {
lastErr = fmt.Errorf("target %s [%d] writer error: %w", resp.Target, resp.Index, err)
lastErr = fmt.Errorf("target %s [%d] writer error: %w\n", resp.Target, resp.Index, err)
fmt.Fprintln(state.Err[resp.Index], lastErr)
}
}
Expand Down
119 changes: 116 additions & 3 deletions services/service/client/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,88 @@ import (
pb "github.com/Snowflake-Labs/sansshell/services/service"
)

// RestartService is a helper function for restarting a service on a remote target
// ListRemoteServices is a helper function for listing all services on a remote target
// using a proxy.Conn. If the conn is defined for >1 targets this will return an error.
func RestartService(ctx context.Context, conn *proxy.Conn, system pb.SystemType, service string) error {
func ListRemoteServices(ctx context.Context, conn *proxy.Conn, system pb.SystemType) (*pb.ListReply, error) {
if len(conn.Targets) != 1 {
return errors.New("RestartService only supports single targets")
return nil, errors.New("ListRemoteServices only supports single targets")
}

c := pb.NewServiceClient(conn)
ret, err := c.List(ctx, &pb.ListRequest{
SystemType: system,
})
if err != nil {
return nil, fmt.Errorf("can't list services %v", err)
}
return ret, nil
}

// StatusRemoteService is a helper function for getting the status of a service on a remote target
// using a proxy.Conn. If the conn is defined for >1 targets this will return an error.
func StatusRemoteService(ctx context.Context, conn *proxy.Conn, system pb.SystemType, service string) (*pb.StatusReply, error) {
if len(conn.Targets) != 1 {
return nil, errors.New("StatusRemoteService only supports single targets")
}

c := pb.NewServiceClient(conn)
ret, err := c.Status(ctx, &pb.StatusRequest{
SystemType: system,
ServiceName: service,
})
if err != nil {
return nil, fmt.Errorf("can't get status for service %s - %v", service, err)
}
return ret, nil
}

// StartRemoteService is a helper function for starting a service on a remote target
// using a proxy.Conn. If the conn is defined for >1 targets this will return an error.
func StartRemoteService(ctx context.Context, conn *proxy.Conn, system pb.SystemType, service string) error {
if len(conn.Targets) != 1 {
return errors.New("StartRemoteService only supports single targets")
}

c := pb.NewServiceClient(conn)
if _, err := c.Action(ctx, &pb.ActionRequest{
ServiceName: service,
SystemType: system,
Action: pb.Action_ACTION_START,
}); err != nil {
return fmt.Errorf("can't start service %s - %v", service, err)
}
return nil
}

// StopRemoteService is a helper function for stopping a service on a remote target
// using a proxy.Conn. If the conn is defined for >1 targets this will return an error.
func StopRemoteService(ctx context.Context, conn *proxy.Conn, system pb.SystemType, service string) error {
if len(conn.Targets) != 1 {
return errors.New("StopRemoteService only supports single targets")
}

c := pb.NewServiceClient(conn)
if _, err := c.Action(ctx, &pb.ActionRequest{
ServiceName: service,
SystemType: system,
Action: pb.Action_ACTION_STOP,
}); err != nil {
return fmt.Errorf("can't stop service %s - %v", service, err)
}
return nil
}

// RestartService was the original exported name for RestartRemoteService and now
// exists for backwards compatibility.
//
// Deprecated: Use RestartRemoteService instead.
var RestartService = RestartRemoteService

// RestartRemoteService is a helper function for restarting a service on a remote target
// using a proxy.Conn. If the conn is defined for >1 targets this will return an error.
func RestartRemoteService(ctx context.Context, conn *proxy.Conn, system pb.SystemType, service string) error {
if len(conn.Targets) != 1 {
return errors.New("RestartRemoteService only supports single targets")
}

c := pb.NewServiceClient(conn)
Expand All @@ -26,3 +103,39 @@ func RestartService(ctx context.Context, conn *proxy.Conn, system pb.SystemType,
}
return nil
}

// EnableRemoteService is a helper function for enabling a service on a remote target
// using a proxy.Conn. If the conn is defined for >1 targets this will return an error.
func EnableRemoteService(ctx context.Context, conn *proxy.Conn, system pb.SystemType, service string) error {
if len(conn.Targets) != 1 {
return errors.New("EnableRemoteService only supports single targets")
}

c := pb.NewServiceClient(conn)
if _, err := c.Action(ctx, &pb.ActionRequest{
ServiceName: service,
SystemType: system,
Action: pb.Action_ACTION_ENABLE,
}); err != nil {
return fmt.Errorf("can't enable service %s - %v", service, err)
}
return nil
}

// DisableRemoteService is a helper function for disabling a service on a remote target
// using a proxy.Conn. If the conn is defined for >1 targets this will return an error.
func DisableRemoteService(ctx context.Context, conn *proxy.Conn, system pb.SystemType, service string) error {
if len(conn.Targets) != 1 {
return errors.New("DisableRemoteService only supports single targets")
}

c := pb.NewServiceClient(conn)
if _, err := c.Action(ctx, &pb.ActionRequest{
ServiceName: service,
SystemType: system,
Action: pb.Action_ACTION_DISABLE,
}); err != nil {
return fmt.Errorf("can't disable service %s - %v", service, err)
}
return nil
}
25 changes: 22 additions & 3 deletions services/service/server/server_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ type systemdConnection interface {
StartUnitContext(ctx context.Context, name string, mode string, ch chan<- string) (int, error)
StopUnitContext(ctx context.Context, name string, mode string, ch chan<- string) (int, error)
RestartUnitContext(ctx context.Context, name string, mode string, ch chan<- string) (int, error)
DisableUnitFilesContext(ctx context.Context, files []string, runtime bool) ([]dbus.DisableUnitFileChange, error)
EnableUnitFilesContext(ctx context.Context, files []string, runtime bool, force bool) (bool, []dbus.EnableUnitFileChange, error)
ReloadContext(ctx context.Context) error
Close()
}

Expand Down Expand Up @@ -256,6 +259,10 @@ func (s *server) Action(ctx context.Context, req *pb.ActionRequest) (*pb.ActionR
_, err = conn.RestartUnitContext(ctx, unitName, modeReplace, resultChan)
case pb.Action_ACTION_STOP:
_, err = conn.StopUnitContext(ctx, unitName, modeReplace, resultChan)
case pb.Action_ACTION_ENABLE:
_, _, err = conn.EnableUnitFilesContext(ctx, []string{unitName}, false, true)
case pb.Action_ACTION_DISABLE:
_, err = conn.DisableUnitFilesContext(ctx, []string{unitName}, false)
default:
return nil, status.Errorf(codes.InvalidArgument, "invalid action type %v", req.Action)
}
Expand All @@ -266,10 +273,22 @@ func (s *server) Action(ctx context.Context, req *pb.ActionRequest) (*pb.ActionR
// NB: delivery of a value on resultchan respects context cancellation, and will
// deliver a value of 'cancelled' if the ctx is cancelled by a client disconnect,
// so it's safe to do a simple recv.
result := <-resultChan
if result != operationResultDone {
return nil, status.Errorf(codes.Internal, "error performing action %v: %v", req.Action, result)
// Enable/disable don't use this method so we skip the channel (since it would hang)
// and instead force a reload which is what systemctl does when it enables/disables.
switch req.Action {
case pb.Action_ACTION_START, pb.Action_ACTION_RESTART, pb.Action_ACTION_STOP:
result := <-resultChan
if result != operationResultDone {
return nil, status.Errorf(codes.Internal, "error performing action %v: %v", req.Action, result)
}
case pb.Action_ACTION_ENABLE, pb.Action_ACTION_DISABLE:
if err := conn.ReloadContext(ctx); err != nil {
return nil, status.Errorf(codes.Internal, "error reloading: %v", err)
}
default:
return nil, status.Errorf(codes.InvalidArgument, "invalid action type %v for post actions", req.Action)
}

return &pb.ActionReply{
SystemType: pb.SystemType_SYSTEM_TYPE_SYSTEMD,
ServiceName: req.GetServiceName(),
Expand Down
Loading

0 comments on commit 6e1e71d

Please sign in to comment.