diff --git a/Makefile b/Makefile index 4bdd35028..5cf18447b 100644 --- a/Makefile +++ b/Makefile @@ -102,7 +102,7 @@ build-all: GOOS=windows go build -o receptor.exe ./cmd/receptor-cl && \ GOOS=darwin go build -o receptor.app ./cmd/receptor-cl && \ go build example/*.go && \ - go build -o receptor --tags no_backends,no_services,no_tls_config,no_workceptor,no_cert_auth ./cmd/receptor-cl && \ + go build -o receptor --tags no_backends,no_services,no_tls_config ./cmd/receptor-cl && \ go build -o receptor ./cmd/receptor-cl DIST := receptor_$(shell echo '$(VERSION)' | sed 's/^v//')_$(GOOS)_$(GOARCH) diff --git a/pkg/workceptor/command.go b/pkg/workceptor/command.go index 664a87004..fa9394490 100644 --- a/pkg/workceptor/command.go +++ b/pkg/workceptor/command.go @@ -4,6 +4,7 @@ package workceptor import ( + "context" "flag" "fmt" "os" @@ -11,6 +12,7 @@ import ( "os/signal" "path" "strings" + "sync" "syscall" "time" @@ -18,17 +20,44 @@ import ( "github.com/google/shlex" ) +type BaseWorkUnitForWorkUnit interface { + CancelContext() + ID() string + Init(w *Workceptor, unitID string, workType string, fs FileSystemer, watcher WatcherWrapper) + LastUpdateError() error + Load() error + MonitorLocalStatus() + Release(force bool) error + Save() error + SetFromParams(_ map[string]string) error + Status() *StatusFileData + StatusFileName() string + StdoutFileName() string + UnitDir() string + UnredactedStatus() *StatusFileData + UpdateBasicStatus(state int, detail string, stdoutSize int64) + UpdateFullStatus(statusFunc func(*StatusFileData)) + GetStatusCopy() StatusFileData + GetStatusWithoutExtraData() *StatusFileData + SetStatusExtraData(interface{}) + GetStatusLock() *sync.RWMutex + GetWorkceptor() *Workceptor + SetWorkceptor(*Workceptor) + GetContext() context.Context + GetCancel() context.CancelFunc +} + // commandUnit implements the WorkUnit interface for the Receptor command worker plugin. type commandUnit struct { - BaseWorkUnit + BaseWorkUnitForWorkUnit command string baseParams string allowRuntimeParams bool done bool } -// commandExtraData is the content of the ExtraData JSON field for a command worker. -type commandExtraData struct { +// CommandExtraData is the content of the ExtraData JSON field for a command worker. +type CommandExtraData struct { Pid int Params string } @@ -60,7 +89,7 @@ func cmdWaiter(cmd *exec.Cmd, doneChan chan bool) { // commandRunner is run in a separate process, to monitor the subprocess and report back metadata. func commandRunner(command string, params string, unitdir string) error { status := StatusFileData{} - status.ExtraData = &commandExtraData{} + status.ExtraData = &CommandExtraData{} statusFilename := path.Join(unitdir, "status") err := status.UpdateBasicStatus(statusFilename, WorkStatePending, "Not started yet", 0) if err != nil { @@ -169,7 +198,7 @@ func (cw *commandUnit) SetFromParams(params map[string]string) error { if cmdParams != "" && !cw.allowRuntimeParams { return fmt.Errorf("extra params provided but not allowed") } - cw.status.ExtraData.(*commandExtraData).Params = combineParams(cw.baseParams, cmdParams) + cw.GetStatusCopy().ExtraData.(*CommandExtraData).Params = combineParams(cw.baseParams, cmdParams) return nil } @@ -181,10 +210,10 @@ func (cw *commandUnit) Status() *StatusFileData { // UnredactedStatus returns a copy of the status currently loaded in memory, including secrets. func (cw *commandUnit) UnredactedStatus() *StatusFileData { - cw.statusLock.RLock() - defer cw.statusLock.RUnlock() - status := cw.getStatus() - ed, ok := cw.status.ExtraData.(*commandExtraData) + cw.GetStatusLock().RLock() + defer cw.GetStatusLock().RUnlock() + status := cw.GetStatusWithoutExtraData() + ed, ok := cw.GetStatusCopy().ExtraData.(*CommandExtraData) if ok { edCopy := *ed status.ExtraData = &edCopy @@ -206,9 +235,9 @@ func (cw *commandUnit) runCommand(cmd *exec.Cmd) error { } cw.UpdateFullStatus(func(status *StatusFileData) { if status.ExtraData == nil { - status.ExtraData = &commandExtraData{} + status.ExtraData = &CommandExtraData{} } - status.ExtraData.(*commandExtraData).Pid = cmd.Process.Pid + status.ExtraData.(*CommandExtraData).Pid = cmd.Process.Pid }) doneChan := make(chan bool) go func() { @@ -226,8 +255,8 @@ func (cw *commandUnit) runCommand(cmd *exec.Cmd) error { // Start launches a job with given parameters. func (cw *commandUnit) Start() error { - level := cw.w.nc.GetLogger().GetLogLevel() - levelName, _ := cw.w.nc.GetLogger().LogLevelToName(level) + level := cw.GetWorkceptor().nc.GetLogger().GetLogLevel() + levelName, _ := cw.GetWorkceptor().nc.GetLogger().LogLevelToName(level) cw.UpdateBasicStatus(WorkStatePending, "Launching command runner", 0) // TODO: This is another place where we rely on a pre-built binary for testing. @@ -243,7 +272,7 @@ func (cw *commandUnit) Start() error { "--log-level", levelName, "--command-runner", fmt.Sprintf("command=%s", cw.command), - fmt.Sprintf("params=%s", cw.Status().ExtraData.(*commandExtraData).Params), + fmt.Sprintf("params=%s", cw.Status().ExtraData.(*CommandExtraData).Params), fmt.Sprintf("unitdir=%s", cw.UnitDir())) return cw.runCommand(cmd) @@ -270,9 +299,9 @@ func (cw *commandUnit) Restart() error { // Cancel stops a running job. func (cw *commandUnit) Cancel() error { - cw.cancel() + cw.CancelContext() status := cw.Status() - ced, ok := status.ExtraData.(*commandExtraData) + ced, ok := status.ExtraData.(*CommandExtraData) if !ok || ced.Pid <= 0 { return nil } @@ -304,7 +333,7 @@ func (cw *commandUnit) Release(force bool) error { return err } - return cw.BaseWorkUnit.Release(force) + return cw.BaseWorkUnitForWorkUnit.Release(force) } // ************************************************************************** @@ -320,18 +349,22 @@ type CommandWorkerCfg struct { VerifySignature bool `description:"Verify a signed work submission" default:"false"` } -func (cfg CommandWorkerCfg) NewWorker(w *Workceptor, unitID string, workType string) WorkUnit { - cw := &commandUnit{ - BaseWorkUnit: BaseWorkUnit{ +func (cfg CommandWorkerCfg) NewWorker(bwu BaseWorkUnitForWorkUnit, w *Workceptor, unitID string, workType string) WorkUnit { + if bwu == nil { + bwu = &BaseWorkUnit{ status: StatusFileData{ - ExtraData: &commandExtraData{}, + ExtraData: &CommandExtraData{}, }, - }, - command: cfg.Command, - baseParams: cfg.Params, - allowRuntimeParams: cfg.AllowRuntimeParams, + } + } + + cw := &commandUnit{ + BaseWorkUnitForWorkUnit: bwu, + command: cfg.Command, + baseParams: cfg.Params, + allowRuntimeParams: cfg.AllowRuntimeParams, } - cw.BaseWorkUnit.Init(w, unitID, workType, FileSystem{}, nil) + cw.BaseWorkUnitForWorkUnit.Init(w, unitID, workType, FileSystem{}, nil) return cw } diff --git a/pkg/workceptor/command_test.go b/pkg/workceptor/command_test.go new file mode 100644 index 000000000..813dfaf27 --- /dev/null +++ b/pkg/workceptor/command_test.go @@ -0,0 +1,472 @@ +package workceptor_test + +import ( + "context" + "errors" + "fmt" + "os/exec" + "sync" + "testing" + "time" + + "github.com/ansible/receptor/pkg/workceptor" + "github.com/ansible/receptor/pkg/workceptor/mock_workceptor" + "github.com/golang/mock/gomock" +) + +func statusExpectCalls(mockBaseWorkUnit *mock_workceptor.MockBaseWorkUnitForWorkUnit) { + statusLock := &sync.RWMutex{} + mockBaseWorkUnit.EXPECT().GetStatusLock().Return(statusLock).Times(2) + mockBaseWorkUnit.EXPECT().GetStatusWithoutExtraData().Return(&workceptor.StatusFileData{}) + mockBaseWorkUnit.EXPECT().GetStatusCopy().Return(workceptor.StatusFileData{ + ExtraData: &workceptor.CommandExtraData{}, + }) +} + +func createCommandTestSetup(t *testing.T) (workceptor.WorkUnit, *mock_workceptor.MockBaseWorkUnitForWorkUnit, *mock_workceptor.MockNetceptorForWorkceptor, *workceptor.Workceptor) { + ctrl := gomock.NewController(t) + ctx := context.Background() + + mockBaseWorkUnit := mock_workceptor.NewMockBaseWorkUnitForWorkUnit(ctrl) + mockNetceptor := mock_workceptor.NewMockNetceptorForWorkceptor(ctrl) + mockNetceptor.EXPECT().NodeID().Return("NodeID") + + w, err := workceptor.New(ctx, mockNetceptor, "/tmp") + if err != nil { + t.Errorf("Error while creating Workceptor: %v", err) + } + + cwc := &workceptor.CommandWorkerCfg{} + mockBaseWorkUnit.EXPECT().Init(w, "", "", workceptor.FileSystem{}, nil) + workUnit := cwc.NewWorker(mockBaseWorkUnit, w, "", "") + + return workUnit, mockBaseWorkUnit, mockNetceptor, w +} + +func TestCommandSetFromParams(t *testing.T) { + wu, mockBaseWorkUnit, _, _ := createCommandTestSetup(t) + + paramsTestCases := []struct { + name string + params map[string]string + expectedCalls func() + errorCatch func(error, *testing.T) + }{ + { + name: "no params with no error", + params: map[string]string{"": ""}, + expectedCalls: func() { + mockBaseWorkUnit.EXPECT().GetStatusCopy().Return(workceptor.StatusFileData{ + ExtraData: &workceptor.CommandExtraData{}, + }) + }, + errorCatch: func(err error, t *testing.T) { + if err != nil { + t.Error(err) + } + }, + }, + { + name: "params with error", + params: map[string]string{"params": "param"}, + expectedCalls: func() { + }, + errorCatch: func(err error, t *testing.T) { + if err == nil { + t.Error(err) + } + }, + }, + } + + for _, testCase := range paramsTestCases { + t.Run(testCase.name, func(t *testing.T) { + testCase.expectedCalls() + err := wu.SetFromParams(testCase.params) + testCase.errorCatch(err, t) + }) + } +} + +func TestUnredactedStatus(t *testing.T) { + t.Parallel() + wu, mockBaseWorkUnit, _, _ := createCommandTestSetup(t) + restartTestCases := []struct { + name string + }{ + {name: "test1"}, + {name: "test2"}, + } + + statusLock := &sync.RWMutex{} + for _, testCase := range restartTestCases { + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + mockBaseWorkUnit.EXPECT().GetStatusLock().Return(statusLock).Times(2) + mockBaseWorkUnit.EXPECT().GetStatusWithoutExtraData().Return(&workceptor.StatusFileData{}) + mockBaseWorkUnit.EXPECT().GetStatusCopy().Return(workceptor.StatusFileData{ + ExtraData: &workceptor.CommandExtraData{}, + }) + wu.UnredactedStatus() + }) + } +} + +func TestStart(t *testing.T) { + wu, mockBaseWorkUnit, mockNetceptor, w := createCommandTestSetup(t) + + mockBaseWorkUnit.EXPECT().GetWorkceptor().Return(w).Times(2) + mockNetceptor.EXPECT().GetLogger().Times(2) + mockBaseWorkUnit.EXPECT().UpdateBasicStatus(gomock.Any(), gomock.Any(), gomock.Any()) + statusExpectCalls(mockBaseWorkUnit) + + mockBaseWorkUnit.EXPECT().UnitDir() + mockBaseWorkUnit.EXPECT().UpdateFullStatus(gomock.Any()) + mockBaseWorkUnit.EXPECT().MonitorLocalStatus().AnyTimes() + mockBaseWorkUnit.EXPECT().UpdateFullStatus(gomock.Any()).AnyTimes() + wu.Start() +} + +func TestRestart(t *testing.T) { + wu, mockBaseWorkUnit, _, _ := createCommandTestSetup(t) + + restartTestCases := []struct { + name string + expectedCalls func() + errorCatch func(error, *testing.T) + }{ + { + name: "load error", + expectedCalls: func() { + mockBaseWorkUnit.EXPECT().Load().Return(errors.New("terminated")) + }, + errorCatch: func(err error, t *testing.T) { + if err.Error() != "terminated" { + t.Error(err) + } + }, + }, + { + name: "job complete with no error", + expectedCalls: func() { + statusFile := &workceptor.StatusFileData{State: 2} + mockBaseWorkUnit.EXPECT().Load().Return(nil) + statusLock := &sync.RWMutex{} + mockBaseWorkUnit.EXPECT().GetStatusLock().Return(statusLock).Times(2) + mockBaseWorkUnit.EXPECT().GetStatusWithoutExtraData().Return(statusFile) + mockBaseWorkUnit.EXPECT().GetStatusCopy().Return(workceptor.StatusFileData{ + ExtraData: &workceptor.CommandExtraData{}, + }) + }, + errorCatch: func(err error, t *testing.T) { + if err != nil { + t.Error(err) + } + }, + }, + { + name: "restart successful", + expectedCalls: func() { + statusFile := &workceptor.StatusFileData{State: 0} + mockBaseWorkUnit.EXPECT().Load().Return(nil) + statusLock := &sync.RWMutex{} + mockBaseWorkUnit.EXPECT().GetStatusLock().Return(statusLock).Times(2) + mockBaseWorkUnit.EXPECT().GetStatusWithoutExtraData().Return(statusFile) + mockBaseWorkUnit.EXPECT().GetStatusCopy().Return(workceptor.StatusFileData{ + ExtraData: &workceptor.CommandExtraData{}, + }) + mockBaseWorkUnit.EXPECT().UpdateBasicStatus(gomock.Any(), gomock.Any(), gomock.Any()) + mockBaseWorkUnit.EXPECT().UnitDir() + }, + errorCatch: func(err error, t *testing.T) { + if err != nil { + t.Error(err) + } + }, + }, + } + + for _, testCase := range restartTestCases { + t.Run(testCase.name, func(t *testing.T) { + testCase.expectedCalls() + mockBaseWorkUnit.EXPECT().MonitorLocalStatus().AnyTimes() + err := wu.Restart() + testCase.errorCatch(err, t) + }) + } +} + +func TestCancel(t *testing.T) { + wu, mockBaseWorkUnit, _, _ := createCommandTestSetup(t) + + paramsTestCases := []struct { + name string + expectedCalls func() + errorCatch func(error, *testing.T) + }{ + { + name: "not a valid pid no error", + expectedCalls: func() { + mockBaseWorkUnit.EXPECT().CancelContext() + statusExpectCalls(mockBaseWorkUnit) + }, + errorCatch: func(err error, t *testing.T) { + if err != nil { + t.Error(err) + } + }, + }, + { + name: "process interrupt error", + expectedCalls: func() { + mockBaseWorkUnit.EXPECT().CancelContext() + mockBaseWorkUnit.EXPECT().GetStatusLock().Return(&sync.RWMutex{}).Times(2) + mockBaseWorkUnit.EXPECT().GetStatusWithoutExtraData().Return(&workceptor.StatusFileData{}) + mockBaseWorkUnit.EXPECT().GetStatusCopy().Return(workceptor.StatusFileData{ + ExtraData: &workceptor.CommandExtraData{ + Pid: 1, + }, + }) + }, + errorCatch: func(err error, t *testing.T) { + if err == nil { + t.Error(err) + } + }, + }, + { + name: "process already finished", + expectedCalls: func() { + mockBaseWorkUnit.EXPECT().CancelContext() + mockBaseWorkUnit.EXPECT().GetStatusLock().Return(&sync.RWMutex{}).Times(2) + mockBaseWorkUnit.EXPECT().GetStatusWithoutExtraData().Return(&workceptor.StatusFileData{}) + + c := exec.Command("ls", "/tmp") + processPid := make(chan int) + + go func(c *exec.Cmd, processPid chan int) { + c.Run() + processPid <- c.Process.Pid + }(c, processPid) + + time.Sleep(200 * time.Millisecond) + + mockBaseWorkUnit.EXPECT().GetStatusCopy().Return(workceptor.StatusFileData{ + ExtraData: &workceptor.CommandExtraData{ + Pid: <-processPid, + }, + }) + }, + errorCatch: func(err error, t *testing.T) { + if err != nil { + t.Error(err) + } + }, + }, + { + name: "cancelled process successfully", + expectedCalls: func() { + mockBaseWorkUnit.EXPECT().CancelContext() + mockBaseWorkUnit.EXPECT().GetStatusLock().Return(&sync.RWMutex{}).Times(2) + mockBaseWorkUnit.EXPECT().GetStatusWithoutExtraData().Return(&workceptor.StatusFileData{}) + mockBaseWorkUnit.EXPECT().UpdateBasicStatus(gomock.Any(), gomock.Any(), gomock.Any()) + + c := exec.Command("sleep", "30") + processPid := make(chan int) + + go func(c *exec.Cmd, processPid chan int) { + err := c.Start() + if err != nil { + fmt.Println(err) + } + processPid <- c.Process.Pid + }(c, processPid) + time.Sleep(200 * time.Millisecond) + + mockBaseWorkUnit.EXPECT().GetStatusCopy().Return(workceptor.StatusFileData{ + ExtraData: &workceptor.CommandExtraData{ + Pid: <-processPid, + }, + }) + }, + errorCatch: func(err error, t *testing.T) { + if err != nil { + t.Error(err) + } + }, + }, + } + + for _, testCase := range paramsTestCases { + t.Run(testCase.name, func(t *testing.T) { + testCase.expectedCalls() + err := wu.Cancel() + testCase.errorCatch(err, t) + }) + } +} + +func TestRelease(t *testing.T) { + wu, mockBaseWorkUnit, _, _ := createCommandTestSetup(t) + + releaseTestCases := []struct { + name string + expectedCalls func() + errorCatch func(error, *testing.T) + force bool + }{ + { + name: "cancel error", + expectedCalls: func() {}, + errorCatch: func(err error, t *testing.T) { + if err == nil { + t.Error(err) + } + }, + force: false, + }, + { + name: "released successfully", + expectedCalls: func() { + mockBaseWorkUnit.EXPECT().Release(gomock.Any()) + }, + errorCatch: func(err error, t *testing.T) { + if err != nil { + t.Error(err) + } + }, + force: true, + }, + } + for _, testCase := range releaseTestCases { + t.Run(testCase.name, func(t *testing.T) { + mockBaseWorkUnit.EXPECT().CancelContext() + mockBaseWorkUnit.EXPECT().GetStatusLock().Return(&sync.RWMutex{}).Times(2) + mockBaseWorkUnit.EXPECT().GetStatusWithoutExtraData().Return(&workceptor.StatusFileData{}) + mockBaseWorkUnit.EXPECT().GetStatusCopy().Return(workceptor.StatusFileData{ + ExtraData: &workceptor.CommandExtraData{ + Pid: 1, + }, + }) + testCase.expectedCalls() + err := wu.Release(testCase.force) + testCase.errorCatch(err, t) + }) + } +} + +func TestSigningKeyPrepare(t *testing.T) { + privateKey := workceptor.SigningKeyPrivateCfg{} + err := privateKey.Prepare() + + if err == nil { + t.Error(err) + } +} + +func TestPrepareSigningKeyPrivateCfg(t *testing.T) { + signingKeyTestCases := []struct { + name string + errorCatch func(error, *testing.T) + privateKey string + tokenExpiration string + }{ + { + name: "file does not exist error", + privateKey: "does_not_exist.txt", + tokenExpiration: "", + errorCatch: func(err error, t *testing.T) { + if err == nil { + t.Error(err) + } + }, + }, + { + name: "failed to parse token expiration", + privateKey: "/etc/hosts", + tokenExpiration: "random_input", + errorCatch: func(err error, t *testing.T) { + if err == nil { + t.Error(err) + } + }, + }, + { + name: "duration no error", + privateKey: "/etc/hosts", + tokenExpiration: "3h", + errorCatch: func(err error, t *testing.T) { + if err != nil { + t.Error(err) + } + }, + }, + { + name: "no duration no error", + privateKey: "/etc/hosts", + tokenExpiration: "", + errorCatch: func(err error, t *testing.T) { + if err != nil { + t.Error(err) + } + }, + }, + } + + for _, testCase := range signingKeyTestCases { + t.Run(testCase.name, func(t *testing.T) { + privateKey := workceptor.SigningKeyPrivateCfg{ + PrivateKey: testCase.privateKey, + TokenExpiration: testCase.tokenExpiration, + } + _, err := privateKey.PrepareSigningKeyPrivateCfg() + testCase.errorCatch(err, t) + }) + } +} + +func TestVerifyingKeyPrepare(t *testing.T) { + publicKey := workceptor.VerifyingKeyPublicCfg{} + err := publicKey.Prepare() + + if err == nil { + t.Error(err) + } +} + +func TestPrepareVerifyingKeyPrivateCfg(t *testing.T) { + verifyingKeyTestCases := []struct { + name string + errorCatch func(error, *testing.T) + publicKey string + }{ + { + name: "file does not exist", + publicKey: "does_not_exist.txt", + errorCatch: func(err error, t *testing.T) { + if err == nil { + t.Error(err) + } + }, + }, + { + name: "prepared successfully", + publicKey: "/etc/hosts", + errorCatch: func(err error, t *testing.T) { + if err != nil { + t.Error(err) + } + }, + }, + } + + for _, testCase := range verifyingKeyTestCases { + t.Run(testCase.name, func(t *testing.T) { + publicKey := workceptor.VerifyingKeyPublicCfg{ + PublicKey: testCase.publicKey, + } + err := publicKey.PrepareVerifyingKeyPublicCfg() + testCase.errorCatch(err, t) + }) + } +} diff --git a/pkg/workceptor/controlsvc.go b/pkg/workceptor/controlsvc.go index 452d4804f..0d032bcd0 100644 --- a/pkg/workceptor/controlsvc.go +++ b/pkg/workceptor/controlsvc.go @@ -211,7 +211,7 @@ func (c *workceptorCommand) processSignature(workType, signature string, connIsU } func getSignWorkFromStatus(status *StatusFileData) bool { - red, ok := status.ExtraData.(*remoteExtraData) + red, ok := status.ExtraData.(*RemoteExtraData) if ok { return red.SignWork } diff --git a/pkg/workceptor/interfaces.go b/pkg/workceptor/interfaces.go index ca7e14ad6..3b16bc251 100644 --- a/pkg/workceptor/interfaces.go +++ b/pkg/workceptor/interfaces.go @@ -23,11 +23,11 @@ type WorkUnit interface { type WorkerConfig interface { GetWorkType() string GetVerifySignature() bool - NewWorker(w *Workceptor, unitID string, workType string) WorkUnit + NewWorker(bwu BaseWorkUnitForWorkUnit, w *Workceptor, unitID string, workType string) WorkUnit } // NewWorkerFunc represents a factory of WorkUnit instances. -type NewWorkerFunc func(w *Workceptor, unitID string, workType string) WorkUnit +type NewWorkerFunc func(bwu BaseWorkUnitForWorkUnit, w *Workceptor, unitID string, workType string) WorkUnit // StatusFileData is the structure of the JSON data saved to a status file. // This struct should only contain value types, except for ExtraData. diff --git a/pkg/workceptor/json_test.go b/pkg/workceptor/json_test.go index de4d1f4da..bf6a720de 100644 --- a/pkg/workceptor/json_test.go +++ b/pkg/workceptor/json_test.go @@ -11,18 +11,18 @@ import ( "github.com/ansible/receptor/pkg/netceptor" ) -func newCommandWorker(w *Workceptor, unitID string, workType string) WorkUnit { +func newCommandWorker(_ BaseWorkUnitForWorkUnit, w *Workceptor, unitID string, workType string) WorkUnit { cw := &commandUnit{ - BaseWorkUnit: BaseWorkUnit{ + BaseWorkUnitForWorkUnit: &BaseWorkUnit{ status: StatusFileData{ - ExtraData: &commandExtraData{}, + ExtraData: &CommandExtraData{}, }, }, command: "echo", baseParams: "foo", allowRuntimeParams: true, } - cw.BaseWorkUnit.Init(w, unitID, workType, FileSystem{}, nil) + cw.BaseWorkUnitForWorkUnit.Init(w, unitID, workType, FileSystem{}, nil) return cw } @@ -47,7 +47,7 @@ func TestWorkceptorJson(t *testing.T) { t.Fatal(err) } cw.UpdateFullStatus(func(status *StatusFileData) { - ed, ok := status.ExtraData.(*commandExtraData) + ed, ok := status.ExtraData.(*CommandExtraData) if !ok { t.Fatal("ExtraData type assertion failed") } @@ -57,12 +57,12 @@ func TestWorkceptorJson(t *testing.T) { if err != nil { t.Fatal(err) } - cw2 := newCommandWorker(w, cw.ID(), "command") + cw2 := newCommandWorker(nil, w, cw.ID(), "command") err = cw2.Load() if err != nil { t.Fatal(err) } - ed2, ok := cw2.Status().ExtraData.(*commandExtraData) + ed2, ok := cw2.Status().ExtraData.(*CommandExtraData) if !ok { t.Fatal("ExtraData type assertion failed") } diff --git a/pkg/workceptor/kubernetes.go b/pkg/workceptor/kubernetes.go index 14c279139..d173f2870 100644 --- a/pkg/workceptor/kubernetes.go +++ b/pkg/workceptor/kubernetes.go @@ -38,7 +38,7 @@ import ( // kubeUnit implements the WorkUnit interface. type kubeUnit struct { - BaseWorkUnit + BaseWorkUnitForWorkUnit authMethod string streamMethod string baseParams string @@ -141,11 +141,11 @@ func (kw *kubeUnit) kubeLoggingConnectionHandler(timestamps bool) (io.ReadCloser ) // get logstream, with retry for retries := 5; retries > 0; retries-- { - logStream, err = logReq.Stream(kw.ctx) + logStream, err = logReq.Stream(kw.GetContext()) if err == nil { break } - kw.Warning( + kw.GetWorkceptor().nc.GetLogger().Warning( "Error opening log stream for pod %s/%s. Will retry %d more times. Error: %s", podNamespace, podName, @@ -156,7 +156,7 @@ func (kw *kubeUnit) kubeLoggingConnectionHandler(timestamps bool) (io.ReadCloser } if err != nil { errMsg := fmt.Sprintf("Error opening log stream for pod %s/%s. Error: %s", podNamespace, podName, err) - kw.Error(errMsg) + kw.GetWorkceptor().nc.GetLogger().Error(errMsg) kw.UpdateBasicStatus(WorkStateFailed, errMsg, 0) return nil, err @@ -180,7 +180,7 @@ func (kw *kubeUnit) kubeLoggingNoReconnect(streamWait *sync.WaitGroup, stdout *S _, *stdoutErr = io.Copy(stdout, logStream) if *stdoutErr != nil { - kw.Error( + kw.GetWorkceptor().nc.GetLogger().Error( "Error streaming pod logs to stdout for pod %s/%s. Error: %s", podNamespace, podName, @@ -208,11 +208,11 @@ func (kw *kubeUnit) kubeLoggingWithReconnect(streamWait *sync.WaitGroup, stdout // get pod, with retry for retries := 5; retries > 0; retries-- { - kw.pod, err = kw.clientset.CoreV1().Pods(podNamespace).Get(kw.ctx, podName, metav1.GetOptions{}) + kw.pod, err = kw.clientset.CoreV1().Pods(podNamespace).Get(kw.GetContext(), podName, metav1.GetOptions{}) if err == nil { break } - kw.Warning( + kw.GetWorkceptor().nc.GetLogger().Warning( "Error getting pod %s/%s. Will retry %d more times. Error: %s", podNamespace, podName, @@ -223,7 +223,7 @@ func (kw *kubeUnit) kubeLoggingWithReconnect(streamWait *sync.WaitGroup, stdout } if err != nil { errMsg := fmt.Sprintf("Error getting pod %s/%s. Error: %s", podNamespace, podName, err) - kw.Error(errMsg) + kw.GetWorkceptor().nc.GetLogger().Error(errMsg) kw.UpdateBasicStatus(WorkStateFailed, errMsg, 0) break @@ -239,15 +239,15 @@ func (kw *kubeUnit) kubeLoggingWithReconnect(streamWait *sync.WaitGroup, stdout for *stdinErr == nil { // check between every line read to see if we need to stop reading line, err := streamReader.ReadString('\n') if err != nil { - if kw.ctx.Err() == context.Canceled { - kw.Info( + if kw.GetContext().Err() == context.Canceled { + kw.GetWorkceptor().nc.GetLogger().Info( "Context was canceled while reading logs for pod %s/%s. Assuming pod has finished", podNamespace, podName) return } - kw.Info( + kw.GetWorkceptor().nc.GetLogger().Info( "Detected Error: %s for pod %s/%s. Will retry %d more times.", err, podNamespace, @@ -263,7 +263,7 @@ func (kw *kubeUnit) kubeLoggingWithReconnect(streamWait *sync.WaitGroup, stdout break } *stdoutErr = err - kw.Error("Error reading from pod %s/%s: %s", podNamespace, podName, err) + kw.GetWorkceptor().nc.GetLogger().Error("Error reading from pod %s/%s: %s", podNamespace, podName, err) return } @@ -278,7 +278,7 @@ func (kw *kubeUnit) kubeLoggingWithReconnect(streamWait *sync.WaitGroup, stdout _, err = stdout.Write([]byte(msg)) if err != nil { *stdoutErr = fmt.Errorf("writing to stdout: %s", err) - kw.Error("Error writing to stdout: %s", err) + kw.GetWorkceptor().nc.GetLogger().Error("Error writing to stdout: %s", err) return } @@ -374,13 +374,13 @@ func (kw *kubeUnit) createPod(env map[string]string) error { } // get pod and store to kw.pod - kw.pod, err = kw.clientset.CoreV1().Pods(ked.KubeNamespace).Create(kw.ctx, pod, metav1.CreateOptions{}) + kw.pod, err = kw.clientset.CoreV1().Pods(ked.KubeNamespace).Create(kw.GetContext(), pod, metav1.CreateOptions{}) if err != nil { return err } select { - case <-kw.ctx.Done(): + case <-kw.GetContext().Done(): return fmt.Errorf("cancelled") default: } @@ -398,18 +398,18 @@ func (kw *kubeUnit) createPod(env map[string]string) error { ListFunc: func(options metav1.ListOptions) (runtime.Object, error) { options.FieldSelector = fieldSelector - return kw.clientset.CoreV1().Pods(ked.KubeNamespace).List(kw.ctx, options) + return kw.clientset.CoreV1().Pods(ked.KubeNamespace).List(kw.GetContext(), options) }, WatchFunc: func(options metav1.ListOptions) (watch.Interface, error) { options.FieldSelector = fieldSelector - return kw.clientset.CoreV1().Pods(ked.KubeNamespace).Watch(kw.ctx, options) + return kw.clientset.CoreV1().Pods(ked.KubeNamespace).Watch(kw.GetContext(), options) }, } - ctxPodReady := kw.ctx + ctxPodReady := kw.GetContext() if kw.podPendingTimeout != time.Duration(0) { - ctxPodReady, _ = context.WithTimeout(kw.ctx, kw.podPendingTimeout) + ctxPodReady, _ = context.WithTimeout(kw.GetContext(), kw.podPendingTimeout) } time.Sleep(2 * time.Second) @@ -441,7 +441,7 @@ func (kw *kubeUnit) createPod(env map[string]string) error { stdout, err2 := NewStdoutWriter(FileSystem{}, kw.UnitDir()) if err2 != nil { errMsg := fmt.Sprintf("Error opening stdout file: %s", err2) - kw.Error(errMsg) + kw.GetWorkceptor().nc.GetLogger().Error(errMsg) kw.UpdateBasicStatus(WorkStateFailed, errMsg, 0) return fmt.Errorf(errMsg) @@ -492,7 +492,7 @@ func (kw *kubeUnit) runWorkUsingLogger() { if err := kw.createPod(nil); err != nil { if err != ErrPodCompleted { errMsg := fmt.Sprintf("Error creating pod: %s", err) - kw.Error(errMsg) + kw.GetWorkceptor().nc.GetLogger().Error(errMsg) kw.UpdateBasicStatus(WorkStateFailed, errMsg, 0) return @@ -509,7 +509,7 @@ func (kw *kubeUnit) runWorkUsingLogger() { errMsg := fmt.Sprintf("Error creating pod: pod namespace is empty for pod %s", podName, ) - kw.Error(errMsg) + kw.GetWorkceptor().nc.GetLogger().Error(errMsg) kw.UpdateBasicStatus(WorkStateFailed, errMsg, 0) return @@ -520,19 +520,19 @@ func (kw *kubeUnit) runWorkUsingLogger() { for retries := 5; retries > 0; retries-- { // check if the kw.ctx is already cancel select { - case <-kw.ctx.Done(): - errMsg := fmt.Sprintf("Context Done while getting pod %s/%s. Error: %s", podNamespace, podName, kw.ctx.Err()) - kw.Warning(errMsg) + case <-kw.GetContext().Done(): + errMsg := fmt.Sprintf("Context Done while getting pod %s/%s. Error: %s", podNamespace, podName, kw.GetContext().Err()) + kw.GetWorkceptor().nc.GetLogger().Warning(errMsg) return default: } - kw.pod, err = kw.clientset.CoreV1().Pods(podNamespace).Get(kw.ctx, podName, metav1.GetOptions{}) + kw.pod, err = kw.clientset.CoreV1().Pods(podNamespace).Get(kw.GetContext(), podName, metav1.GetOptions{}) if err == nil { break } - kw.Warning( + kw.GetWorkceptor().nc.GetLogger().Warning( "Error getting pod %s/%s. Will retry %d more times. Retrying: %s", podNamespace, podName, @@ -543,7 +543,7 @@ func (kw *kubeUnit) runWorkUsingLogger() { } if err != nil { errMsg := fmt.Sprintf("Error getting pod %s/%s. Error: %s", podNamespace, podName, err) - kw.Error(errMsg) + kw.GetWorkceptor().nc.GetLogger().Error(errMsg) kw.UpdateBasicStatus(WorkStateFailed, errMsg, 0) return @@ -601,7 +601,7 @@ func (kw *kubeUnit) runWorkUsingLogger() { skipStdin = true } else { errMsg := fmt.Sprintf("Error opening stdin file: %s", err) - kw.Error(errMsg) + kw.GetWorkceptor().nc.GetLogger().Error(errMsg) kw.UpdateBasicStatus(WorkStateFailed, errMsg, 0) return @@ -610,7 +610,7 @@ func (kw *kubeUnit) runWorkUsingLogger() { // goroutine to cancel stdin reader go func() { select { - case <-kw.ctx.Done(): + case <-kw.GetContext().Done(): stdin.reader.Close() return @@ -626,7 +626,7 @@ func (kw *kubeUnit) runWorkUsingLogger() { stdout, err := NewStdoutWriter(FileSystem{}, kw.UnitDir()) if err != nil { errMsg := fmt.Sprintf("Error opening stdout file: %s", err) - kw.Error(errMsg) + kw.GetWorkceptor().nc.GetLogger().Error(errMsg) kw.UpdateBasicStatus(WorkStateFailed, errMsg, 0) return @@ -635,7 +635,7 @@ func (kw *kubeUnit) runWorkUsingLogger() { // goroutine to cancel stdout stream go func() { select { - case <-kw.ctx.Done(): + case <-kw.GetContext().Done(): stdout.writer.Close() return @@ -665,13 +665,13 @@ func (kw *kubeUnit) runWorkUsingLogger() { var err error for retries := 5; retries > 0; retries-- { - err = exec.StreamWithContext(kw.ctx, remotecommand.StreamOptions{ + err = exec.StreamWithContext(kw.GetContext(), remotecommand.StreamOptions{ Stdin: stdin, Tty: false, }) if err != nil { // NOTE: io.EOF for stdin is handled by remotecommand and will not trigger this - kw.Warning( + kw.GetWorkceptor().nc.GetLogger().Warning( "Error streaming stdin to pod %s/%s. Will retry %d more times. Error: %s", podNamespace, podName, @@ -692,7 +692,7 @@ func (kw *kubeUnit) runWorkUsingLogger() { podName, err, ) - kw.Error(errMsg) + kw.GetWorkceptor().nc.GetLogger().Error(errMsg) kw.UpdateBasicStatus(WorkStateFailed, errMsg, stdout.Size()) close(stdinErrChan) // signal STDOUT goroutine to stop @@ -702,7 +702,7 @@ func (kw *kubeUnit) runWorkUsingLogger() { } else { // this is probably not possible... errMsg := fmt.Sprintf("Error reading stdin: %s", stdin.Error()) - kw.Error(errMsg) + kw.GetWorkceptor().nc.GetLogger().Error(errMsg) kw.UpdateBasicStatus(WorkStateFailed, errMsg, stdout.Size()) close(stdinErrChan) // signal STDOUT goroutine to stop @@ -713,10 +713,10 @@ func (kw *kubeUnit) runWorkUsingLogger() { stdoutWithReconnect := shouldUseReconnect(kw) if stdoutWithReconnect && stdoutErr == nil { - kw.Debug("streaming stdout with reconnect support") + kw.GetWorkceptor().nc.GetLogger().Debug("streaming stdout with reconnect support") go kw.kubeLoggingWithReconnect(&streamWait, stdout, &stdinErr, &stdoutErr) } else { - kw.Debug("streaming stdout with no reconnect support") + kw.GetWorkceptor().nc.GetLogger().Debug("streaming stdout with no reconnect support") go kw.kubeLoggingNoReconnect(&streamWait, stdout, &stdoutErr) } @@ -734,14 +734,14 @@ func (kw *kubeUnit) runWorkUsingLogger() { errDetail = fmt.Sprintf("Error running pod. stdin: %s, stdout: %s", stdinErr, stdoutErr) } - if kw.ctx.Err() != context.Canceled { + if kw.GetContext().Err() != context.Canceled { kw.UpdateBasicStatus(WorkStateFailed, errDetail, stdout.Size()) } return } - if kw.ctx.Err() != context.Canceled { + if kw.GetContext().Err() != context.Canceled { kw.UpdateBasicStatus(WorkStateSucceeded, "Finished", stdout.Size()) } } @@ -749,7 +749,7 @@ func (kw *kubeUnit) runWorkUsingLogger() { func isCompatibleK8S(kw *kubeUnit, versionStr string) bool { semver, err := version.ParseSemantic(versionStr) if err != nil { - kw.w.nc.GetLogger().Warning("could parse Kubernetes server version %s, will not use reconnect support", versionStr) + kw.GetWorkceptor().nc.GetLogger().Warning("could parse Kubernetes server version %s, will not use reconnect support", versionStr) return false } @@ -773,11 +773,11 @@ func isCompatibleK8S(kw *kubeUnit, versionStr string) bool { } if semver.AtLeast(version.MustParseSemantic(compatibleVer)) { - kw.w.nc.GetLogger().Debug("Kubernetes version %s is at least %s, using reconnect support", semver, compatibleVer) + kw.GetWorkceptor().nc.GetLogger().Debug("Kubernetes version %s is at least %s, using reconnect support", semver, compatibleVer) return true } - kw.w.nc.GetLogger().Debug("Kubernetes version %s not at least %s, not using reconnect support", semver, compatibleVer) + kw.GetWorkceptor().nc.GetLogger().Debug("Kubernetes version %s not at least %s, not using reconnect support", semver, compatibleVer) return false } @@ -814,7 +814,7 @@ func shouldUseReconnect(kw *kubeUnit) bool { serverVerInfo, err := kw.clientset.ServerVersion() if err != nil { - kw.w.nc.GetLogger().Warning("could not detect Kubernetes server version, will not use reconnect support") + kw.GetWorkceptor().nc.GetLogger().Warning("could not detect Kubernetes server version, will not use reconnect support") return false } @@ -864,7 +864,7 @@ func getDefaultInterface() (string, error) { func (kw *kubeUnit) runWorkUsingTCP() { // Create local cancellable context - ctx, cancel := kw.ctx, kw.cancel + ctx, cancel := kw.GetContext(), kw.GetCancel() defer cancel() // Create the TCP listener @@ -884,7 +884,7 @@ func (kw *kubeUnit) runWorkUsingTCP() { if err != nil { errMsg := fmt.Sprintf("Error listening: %s", err) kw.UpdateBasicStatus(WorkStateFailed, errMsg, 0) - kw.w.nc.GetLogger().Error(errMsg) + kw.GetWorkceptor().nc.GetLogger().Error(errMsg) return } @@ -908,7 +908,7 @@ func (kw *kubeUnit) runWorkUsingTCP() { if err != nil { errMsg := fmt.Sprintf("Error accepting: %s", err) kw.UpdateBasicStatus(WorkStateFailed, errMsg, 0) - kw.w.nc.GetLogger().Error(errMsg) + kw.GetWorkceptor().nc.GetLogger().Error(errMsg) cancel() return @@ -921,7 +921,7 @@ func (kw *kubeUnit) runWorkUsingTCP() { if err != nil { errMsg := fmt.Sprintf("Error creating pod: %s", err) kw.UpdateBasicStatus(WorkStateFailed, errMsg, 0) - kw.w.nc.GetLogger().Error(errMsg) + kw.GetWorkceptor().nc.GetLogger().Error(errMsg) cancel() return @@ -940,7 +940,7 @@ func (kw *kubeUnit) runWorkUsingTCP() { stdin, err = NewStdinReader(FileSystem{}, kw.UnitDir()) if err != nil { errMsg := fmt.Sprintf("Error opening stdin file: %s", err) - kw.w.nc.GetLogger().Error(errMsg) + kw.GetWorkceptor().nc.GetLogger().Error(errMsg) kw.UpdateBasicStatus(WorkStateFailed, errMsg, 0) cancel() @@ -951,7 +951,7 @@ func (kw *kubeUnit) runWorkUsingTCP() { stdout, err := NewStdoutWriter(FileSystem{}, kw.UnitDir()) if err != nil { errMsg := fmt.Sprintf("Error opening stdout file: %s", err) - kw.w.nc.GetLogger().Error(errMsg) + kw.GetWorkceptor().nc.GetLogger().Error(errMsg) kw.UpdateBasicStatus(WorkStateFailed, errMsg, 0) cancel() @@ -969,7 +969,7 @@ func (kw *kubeUnit) runWorkUsingTCP() { _ = conn.CloseWrite() if err != nil { errMsg := fmt.Sprintf("Error sending stdin to pod: %s", err) - kw.w.nc.GetLogger().Error(errMsg) + kw.GetWorkceptor().nc.GetLogger().Error(errMsg) kw.UpdateBasicStatus(WorkStateFailed, errMsg, 0) cancel() @@ -1001,7 +1001,7 @@ func (kw *kubeUnit) runWorkUsingTCP() { } if err != nil { errMsg := fmt.Sprintf("Error reading stdout from pod: %s", err) - kw.w.nc.GetLogger().Error(errMsg) + kw.GetWorkceptor().nc.GetLogger().Error(errMsg) kw.UpdateBasicStatus(WorkStateFailed, errMsg, 0) cancel() @@ -1093,14 +1093,14 @@ func (kw *kubeUnit) connectToKube() error { qps, err := strconv.Atoi(envQPS) if err != nil { // ignore error, use default - kw.Warning("Invalid value for RECEPTOR_KUBE_CLIENTSET_QPS: %s. Ignoring", envQPS) + kw.GetWorkceptor().nc.GetLogger().Warning("Invalid value for RECEPTOR_KUBE_CLIENTSET_QPS: %s. Ignoring", envQPS) } else { kw.config.QPS = float32(qps) kw.config.Burst = qps * 10 } } - kw.Debug("RECEPTOR_KUBE_CLIENTSET_QPS: %s", envQPS) + kw.GetWorkceptor().nc.GetLogger().Debug("RECEPTOR_KUBE_CLIENTSET_QPS: %s", envQPS) // RECEPTOR_KUBE_CLIENTSET_BURST // default: 10 x QPS @@ -1108,15 +1108,15 @@ func (kw *kubeUnit) connectToKube() error { if ok { burst, err := strconv.Atoi(envBurst) if err != nil { - kw.Warning("Invalid value for RECEPTOR_KUBE_CLIENTSET_BURST: %s. Ignoring", envQPS) + kw.GetWorkceptor().nc.GetLogger().Warning("Invalid value for RECEPTOR_KUBE_CLIENTSET_BURST: %s. Ignoring", envQPS) } else { kw.config.Burst = burst } } - kw.Debug("RECEPTOR_KUBE_CLIENTSET_BURST: %s", envBurst) + kw.GetWorkceptor().nc.GetLogger().Debug("RECEPTOR_KUBE_CLIENTSET_BURST: %s", envBurst) - kw.Debug("Initializing Kubernetes clientset") + kw.GetWorkceptor().nc.GetLogger().Debug("Initializing Kubernetes clientset") // RECEPTOR_KUBE_CLIENTSET_RATE_LIMITER // default: tokenbucket // options: never, always, tokenbucket @@ -1129,10 +1129,10 @@ func (kw *kubeUnit) connectToKube() error { kw.config.RateLimiter = flowcontrol.NewFakeAlwaysRateLimiter() default: } - kw.Debug("RateLimiter: %s", envRateLimiter) + kw.GetWorkceptor().nc.GetLogger().Debug("RateLimiter: %s", envRateLimiter) } - kw.Debug("QPS: %f, Burst: %d", kw.config.QPS, kw.config.Burst) + kw.GetWorkceptor().nc.GetLogger().Debug("QPS: %f, Burst: %d", kw.config.QPS, kw.config.Burst) kw.clientset, err = kubernetes.NewForConfig(kw.config) if err != nil { return err @@ -1156,7 +1156,7 @@ func readFileToString(filename string) (string, error) { // SetFromParams sets the in-memory state from parameters. func (kw *kubeUnit) SetFromParams(params map[string]string) error { - ked := kw.status.ExtraData.(*kubeExtraData) + ked := kw.GetStatusCopy().ExtraData.(*kubeExtraData) type value struct { name string permission bool @@ -1217,7 +1217,7 @@ func (kw *kubeUnit) SetFromParams(params map[string]string) error { if podPendingTimeoutString != "" { podPendingTimeout, err := time.ParseDuration(podPendingTimeoutString) if err != nil { - kw.w.nc.GetLogger().Error("Failed to parse pod_pending_timeout -- valid examples include '1.5h', '30m', '30m10s'") + kw.GetWorkceptor().nc.GetLogger().Error("Failed to parse pod_pending_timeout -- valid examples include '1.5h', '30m', '30m10s'") return err } @@ -1256,10 +1256,10 @@ func (kw *kubeUnit) Status() *StatusFileData { // Status returns a copy of the status currently loaded in memory. func (kw *kubeUnit) UnredactedStatus() *StatusFileData { - kw.statusLock.RLock() - defer kw.statusLock.RUnlock() - status := kw.getStatus() - ked, ok := kw.status.ExtraData.(*kubeExtraData) + kw.GetStatusLock().RLock() + defer kw.GetStatusLock().RUnlock() + status := kw.GetStatusWithoutExtraData() + ked, ok := kw.GetStatusCopy().ExtraData.(*kubeExtraData) if ok { kedCopy := *ked status.ExtraData = &kedCopy @@ -1300,11 +1300,11 @@ func (kw *kubeUnit) Restart() error { if kw.deletePodOnRestart { err := kw.connectToKube() if err != nil { - kw.w.nc.GetLogger().Warning("Pod %s could not be deleted: %s", ked.PodName, err.Error()) + kw.GetWorkceptor().nc.GetLogger().Warning("Pod %s could not be deleted: %s", ked.PodName, err.Error()) } else { err := kw.clientset.CoreV1().Pods(ked.KubeNamespace).Delete(context.Background(), ked.PodName, metav1.DeleteOptions{}) if err != nil { - kw.w.nc.GetLogger().Warning("Pod %s could not be deleted: %s", ked.PodName, err.Error()) + kw.GetWorkceptor().nc.GetLogger().Warning("Pod %s could not be deleted: %s", ked.PodName, err.Error()) } } } @@ -1324,16 +1324,16 @@ func (kw *kubeUnit) Start() error { // Cancel releases resources associated with a job, including cancelling it if running. func (kw *kubeUnit) Cancel() error { - kw.cancel() + kw.CancelContext() kw.UpdateBasicStatus(WorkStateCanceled, "Canceled", -1) if kw.pod != nil { err := kw.clientset.CoreV1().Pods(kw.pod.Namespace).Delete(context.Background(), kw.pod.Name, metav1.DeleteOptions{}) if err != nil { - kw.w.nc.GetLogger().Error("Error deleting pod %s: %s", kw.pod.Name, err) + kw.GetWorkceptor().nc.GetLogger().Error("Error deleting pod %s: %s", kw.pod.Name, err) } } - if kw.cancel != nil { - kw.cancel() + if kw.GetCancel() != nil { + kw.CancelContext() } return nil @@ -1346,7 +1346,7 @@ func (kw *kubeUnit) Release(force bool) error { return err } - return kw.BaseWorkUnit.Release(force) + return kw.BaseWorkUnitForWorkUnit.Release(force) } // ************************************************************************** @@ -1373,9 +1373,9 @@ type KubeWorkerCfg struct { } // NewWorker is a factory to produce worker instances. -func (cfg KubeWorkerCfg) NewWorker(w *Workceptor, unitID string, workType string) WorkUnit { - ku := &kubeUnit{ - BaseWorkUnit: BaseWorkUnit{ +func (cfg KubeWorkerCfg) NewWorker(bwu BaseWorkUnitForWorkUnit, w *Workceptor, unitID string, workType string) WorkUnit { + if bwu == nil { + bwu = &BaseWorkUnit{ status: StatusFileData{ ExtraData: &kubeExtraData{ Image: cfg.Image, @@ -1385,18 +1385,22 @@ func (cfg KubeWorkerCfg) NewWorker(w *Workceptor, unitID string, workType string KubeConfig: cfg.KubeConfig, }, }, - }, - authMethod: strings.ToLower(cfg.AuthMethod), - streamMethod: strings.ToLower(cfg.StreamMethod), - baseParams: cfg.Params, - allowRuntimeAuth: cfg.AllowRuntimeAuth, - allowRuntimeCommand: cfg.AllowRuntimeCommand, - allowRuntimeParams: cfg.AllowRuntimeParams, - allowRuntimePod: cfg.AllowRuntimePod, - deletePodOnRestart: cfg.DeletePodOnRestart, - namePrefix: fmt.Sprintf("%s-", strings.ToLower(cfg.WorkType)), - } - ku.BaseWorkUnit.Init(w, unitID, workType, FileSystem{}, nil) + } + } + + ku := &kubeUnit{ + BaseWorkUnitForWorkUnit: bwu, + authMethod: strings.ToLower(cfg.AuthMethod), + streamMethod: strings.ToLower(cfg.StreamMethod), + baseParams: cfg.Params, + allowRuntimeAuth: cfg.AllowRuntimeAuth, + allowRuntimeCommand: cfg.AllowRuntimeCommand, + allowRuntimeParams: cfg.AllowRuntimeParams, + allowRuntimePod: cfg.AllowRuntimePod, + deletePodOnRestart: cfg.DeletePodOnRestart, + namePrefix: fmt.Sprintf("%s-", strings.ToLower(cfg.WorkType)), + } + ku.BaseWorkUnitForWorkUnit.Init(w, unitID, workType, FileSystem{}, nil) return ku } diff --git a/pkg/workceptor/kubernetes_test.go b/pkg/workceptor/kubernetes_test.go index 47a04ff77..f5818bdd1 100644 --- a/pkg/workceptor/kubernetes_test.go +++ b/pkg/workceptor/kubernetes_test.go @@ -13,7 +13,9 @@ func Test_isCompatibleK8S(t *testing.T) { isCompatible bool } - kw := &kubeUnit{} + kw := &kubeUnit{ + BaseWorkUnitForWorkUnit: &BaseWorkUnit{}, + } // Create Netceptor node using external backends n1 := netceptor.New(context.Background(), "node1") @@ -29,7 +31,7 @@ func Test_isCompatibleK8S(t *testing.T) { if err != nil { t.Fatal(err) } - kw.w = w + kw.SetWorkceptor(w) tests := []args{ // K8S compatible versions diff --git a/pkg/workceptor/mock_workceptor/baseworkunit.go b/pkg/workceptor/mock_workceptor/baseworkunit.go new file mode 100644 index 000000000..da19ba9ae --- /dev/null +++ b/pkg/workceptor/mock_workceptor/baseworkunit.go @@ -0,0 +1,427 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/ansible/receptor/pkg/workceptor (interfaces: BaseWorkUnitForWorkUnit) + +// Package mock_workceptor is a generated GoMock package. +package mock_workceptor + +import ( + context "context" + reflect "reflect" + sync "sync" + + workceptor "github.com/ansible/receptor/pkg/workceptor" + gomock "github.com/golang/mock/gomock" +) + +// MockBaseWorkUnitForWorkUnit is a mock of BaseWorkUnitForWorkUnit interface. +type MockBaseWorkUnitForWorkUnit struct { + ctrl *gomock.Controller + recorder *MockBaseWorkUnitForWorkUnitMockRecorder +} + +// MockBaseWorkUnitForWorkUnitMockRecorder is the mock recorder for MockBaseWorkUnitForWorkUnit. +type MockBaseWorkUnitForWorkUnitMockRecorder struct { + mock *MockBaseWorkUnitForWorkUnit +} + +// NewMockBaseWorkUnitForWorkUnit creates a new mock instance. +func NewMockBaseWorkUnitForWorkUnit(ctrl *gomock.Controller) *MockBaseWorkUnitForWorkUnit { + mock := &MockBaseWorkUnitForWorkUnit{ctrl: ctrl} + mock.recorder = &MockBaseWorkUnitForWorkUnitMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockBaseWorkUnitForWorkUnit) EXPECT() *MockBaseWorkUnitForWorkUnitMockRecorder { + return m.recorder +} + +// CancelContext mocks base method. +func (m *MockBaseWorkUnitForWorkUnit) CancelContext() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "CancelContext") +} + +// CancelContext indicates an expected call of CancelContext. +func (mr *MockBaseWorkUnitForWorkUnitMockRecorder) CancelContext() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelContext", reflect.TypeOf((*MockBaseWorkUnitForWorkUnit)(nil).CancelContext)) +} + +// Debug mocks base method. +func (m *MockBaseWorkUnitForWorkUnit) Debug(arg0 string, arg1 ...interface{}) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + m.ctrl.Call(m, "Debug", varargs...) +} + +// Debug indicates an expected call of Debug. +func (mr *MockBaseWorkUnitForWorkUnitMockRecorder) Debug(arg0 interface{}, arg1 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockBaseWorkUnitForWorkUnit)(nil).Debug), varargs...) +} + +// Error mocks base method. +func (m *MockBaseWorkUnitForWorkUnit) Error(arg0 string, arg1 ...interface{}) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + m.ctrl.Call(m, "Error", varargs...) +} + +// Error indicates an expected call of Error. +func (mr *MockBaseWorkUnitForWorkUnitMockRecorder) Error(arg0 interface{}, arg1 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Error", reflect.TypeOf((*MockBaseWorkUnitForWorkUnit)(nil).Error), varargs...) +} + +// GetCancel mocks base method. +func (m *MockBaseWorkUnitForWorkUnit) GetCancel() context.CancelFunc { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetCancel") + ret0, _ := ret[0].(context.CancelFunc) + return ret0 +} + +// GetCancel indicates an expected call of GetCancel. +func (mr *MockBaseWorkUnitForWorkUnitMockRecorder) GetCancel() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCancel", reflect.TypeOf((*MockBaseWorkUnitForWorkUnit)(nil).GetCancel)) +} + +// GetContext mocks base method. +func (m *MockBaseWorkUnitForWorkUnit) GetContext() context.Context { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetContext") + ret0, _ := ret[0].(context.Context) + return ret0 +} + +// GetContext indicates an expected call of GetContext. +func (mr *MockBaseWorkUnitForWorkUnitMockRecorder) GetContext() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetContext", reflect.TypeOf((*MockBaseWorkUnitForWorkUnit)(nil).GetContext)) +} + +// GetStatusCopy mocks base method. +func (m *MockBaseWorkUnitForWorkUnit) GetStatusCopy() workceptor.StatusFileData { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetStatusCopy") + ret0, _ := ret[0].(workceptor.StatusFileData) + return ret0 +} + +// GetStatusCopy indicates an expected call of GetStatusCopy. +func (mr *MockBaseWorkUnitForWorkUnitMockRecorder) GetStatusCopy() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStatusCopy", reflect.TypeOf((*MockBaseWorkUnitForWorkUnit)(nil).GetStatusCopy)) +} + +// GetStatusLock mocks base method. +func (m *MockBaseWorkUnitForWorkUnit) GetStatusLock() *sync.RWMutex { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetStatusLock") + ret0, _ := ret[0].(*sync.RWMutex) + return ret0 +} + +// GetStatusLock indicates an expected call of GetStatusLock. +func (mr *MockBaseWorkUnitForWorkUnitMockRecorder) GetStatusLock() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStatusLock", reflect.TypeOf((*MockBaseWorkUnitForWorkUnit)(nil).GetStatusLock)) +} + +// GetStatusWithoutExtraData mocks base method. +func (m *MockBaseWorkUnitForWorkUnit) GetStatusWithoutExtraData() *workceptor.StatusFileData { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetStatusWithoutExtraData") + ret0, _ := ret[0].(*workceptor.StatusFileData) + return ret0 +} + +// GetStatusWithoutExtraData indicates an expected call of GetStatusWithoutExtraData. +func (mr *MockBaseWorkUnitForWorkUnitMockRecorder) GetStatusWithoutExtraData() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStatusWithoutExtraData", reflect.TypeOf((*MockBaseWorkUnitForWorkUnit)(nil).GetStatusWithoutExtraData)) +} + +// GetWorkceptor mocks base method. +func (m *MockBaseWorkUnitForWorkUnit) GetWorkceptor() *workceptor.Workceptor { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetWorkceptor") + ret0, _ := ret[0].(*workceptor.Workceptor) + return ret0 +} + +// GetWorkceptor indicates an expected call of GetWorkceptor. +func (mr *MockBaseWorkUnitForWorkUnitMockRecorder) GetWorkceptor() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWorkceptor", reflect.TypeOf((*MockBaseWorkUnitForWorkUnit)(nil).GetWorkceptor)) +} + +// ID mocks base method. +func (m *MockBaseWorkUnitForWorkUnit) ID() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ID") + ret0, _ := ret[0].(string) + return ret0 +} + +// ID indicates an expected call of ID. +func (mr *MockBaseWorkUnitForWorkUnitMockRecorder) ID() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ID", reflect.TypeOf((*MockBaseWorkUnitForWorkUnit)(nil).ID)) +} + +// Info mocks base method. +func (m *MockBaseWorkUnitForWorkUnit) Info(arg0 string, arg1 ...interface{}) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + m.ctrl.Call(m, "Info", varargs...) +} + +// Info indicates an expected call of Info. +func (mr *MockBaseWorkUnitForWorkUnitMockRecorder) Info(arg0 interface{}, arg1 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Info", reflect.TypeOf((*MockBaseWorkUnitForWorkUnit)(nil).Info), varargs...) +} + +// Init mocks base method. +func (m *MockBaseWorkUnitForWorkUnit) Init(arg0 *workceptor.Workceptor, arg1, arg2 string, arg3 workceptor.FileSystemer, arg4 workceptor.WatcherWrapper) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Init", arg0, arg1, arg2, arg3, arg4) +} + +// Init indicates an expected call of Init. +func (mr *MockBaseWorkUnitForWorkUnitMockRecorder) Init(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Init", reflect.TypeOf((*MockBaseWorkUnitForWorkUnit)(nil).Init), arg0, arg1, arg2, arg3, arg4) +} + +// LastUpdateError mocks base method. +func (m *MockBaseWorkUnitForWorkUnit) LastUpdateError() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LastUpdateError") + ret0, _ := ret[0].(error) + return ret0 +} + +// LastUpdateError indicates an expected call of LastUpdateError. +func (mr *MockBaseWorkUnitForWorkUnitMockRecorder) LastUpdateError() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LastUpdateError", reflect.TypeOf((*MockBaseWorkUnitForWorkUnit)(nil).LastUpdateError)) +} + +// Load mocks base method. +func (m *MockBaseWorkUnitForWorkUnit) Load() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Load") + ret0, _ := ret[0].(error) + return ret0 +} + +// Load indicates an expected call of Load. +func (mr *MockBaseWorkUnitForWorkUnitMockRecorder) Load() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Load", reflect.TypeOf((*MockBaseWorkUnitForWorkUnit)(nil).Load)) +} + +// MonitorLocalStatus mocks base method. +func (m *MockBaseWorkUnitForWorkUnit) MonitorLocalStatus() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "MonitorLocalStatus") +} + +// MonitorLocalStatus indicates an expected call of MonitorLocalStatus. +func (mr *MockBaseWorkUnitForWorkUnitMockRecorder) MonitorLocalStatus() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MonitorLocalStatus", reflect.TypeOf((*MockBaseWorkUnitForWorkUnit)(nil).MonitorLocalStatus)) +} + +// Release mocks base method. +func (m *MockBaseWorkUnitForWorkUnit) Release(arg0 bool) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Release", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Release indicates an expected call of Release. +func (mr *MockBaseWorkUnitForWorkUnitMockRecorder) Release(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Release", reflect.TypeOf((*MockBaseWorkUnitForWorkUnit)(nil).Release), arg0) +} + +// Save mocks base method. +func (m *MockBaseWorkUnitForWorkUnit) Save() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Save") + ret0, _ := ret[0].(error) + return ret0 +} + +// Save indicates an expected call of Save. +func (mr *MockBaseWorkUnitForWorkUnitMockRecorder) Save() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Save", reflect.TypeOf((*MockBaseWorkUnitForWorkUnit)(nil).Save)) +} + +// SetFromParams mocks base method. +func (m *MockBaseWorkUnitForWorkUnit) SetFromParams(arg0 map[string]string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetFromParams", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetFromParams indicates an expected call of SetFromParams. +func (mr *MockBaseWorkUnitForWorkUnitMockRecorder) SetFromParams(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetFromParams", reflect.TypeOf((*MockBaseWorkUnitForWorkUnit)(nil).SetFromParams), arg0) +} + +// SetStatusExtraData mocks base method. +func (m *MockBaseWorkUnitForWorkUnit) SetStatusExtraData(arg0 interface{}) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetStatusExtraData", arg0) +} + +// SetStatusExtraData indicates an expected call of SetStatusExtraData. +func (mr *MockBaseWorkUnitForWorkUnitMockRecorder) SetStatusExtraData(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetStatusExtraData", reflect.TypeOf((*MockBaseWorkUnitForWorkUnit)(nil).SetStatusExtraData), arg0) +} + +// SetWorkceptor mocks base method. +func (m *MockBaseWorkUnitForWorkUnit) SetWorkceptor(arg0 *workceptor.Workceptor) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetWorkceptor", arg0) +} + +// SetWorkceptor indicates an expected call of SetWorkceptor. +func (mr *MockBaseWorkUnitForWorkUnitMockRecorder) SetWorkceptor(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWorkceptor", reflect.TypeOf((*MockBaseWorkUnitForWorkUnit)(nil).SetWorkceptor), arg0) +} + +// Status mocks base method. +func (m *MockBaseWorkUnitForWorkUnit) Status() *workceptor.StatusFileData { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Status") + ret0, _ := ret[0].(*workceptor.StatusFileData) + return ret0 +} + +// Status indicates an expected call of Status. +func (mr *MockBaseWorkUnitForWorkUnitMockRecorder) Status() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Status", reflect.TypeOf((*MockBaseWorkUnitForWorkUnit)(nil).Status)) +} + +// StatusFileName mocks base method. +func (m *MockBaseWorkUnitForWorkUnit) StatusFileName() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "StatusFileName") + ret0, _ := ret[0].(string) + return ret0 +} + +// StatusFileName indicates an expected call of StatusFileName. +func (mr *MockBaseWorkUnitForWorkUnitMockRecorder) StatusFileName() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StatusFileName", reflect.TypeOf((*MockBaseWorkUnitForWorkUnit)(nil).StatusFileName)) +} + +// StdoutFileName mocks base method. +func (m *MockBaseWorkUnitForWorkUnit) StdoutFileName() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "StdoutFileName") + ret0, _ := ret[0].(string) + return ret0 +} + +// StdoutFileName indicates an expected call of StdoutFileName. +func (mr *MockBaseWorkUnitForWorkUnitMockRecorder) StdoutFileName() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StdoutFileName", reflect.TypeOf((*MockBaseWorkUnitForWorkUnit)(nil).StdoutFileName)) +} + +// UnitDir mocks base method. +func (m *MockBaseWorkUnitForWorkUnit) UnitDir() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UnitDir") + ret0, _ := ret[0].(string) + return ret0 +} + +// UnitDir indicates an expected call of UnitDir. +func (mr *MockBaseWorkUnitForWorkUnitMockRecorder) UnitDir() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnitDir", reflect.TypeOf((*MockBaseWorkUnitForWorkUnit)(nil).UnitDir)) +} + +// UnredactedStatus mocks base method. +func (m *MockBaseWorkUnitForWorkUnit) UnredactedStatus() *workceptor.StatusFileData { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UnredactedStatus") + ret0, _ := ret[0].(*workceptor.StatusFileData) + return ret0 +} + +// UnredactedStatus indicates an expected call of UnredactedStatus. +func (mr *MockBaseWorkUnitForWorkUnitMockRecorder) UnredactedStatus() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnredactedStatus", reflect.TypeOf((*MockBaseWorkUnitForWorkUnit)(nil).UnredactedStatus)) +} + +// UpdateBasicStatus mocks base method. +func (m *MockBaseWorkUnitForWorkUnit) UpdateBasicStatus(arg0 int, arg1 string, arg2 int64) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UpdateBasicStatus", arg0, arg1, arg2) +} + +// UpdateBasicStatus indicates an expected call of UpdateBasicStatus. +func (mr *MockBaseWorkUnitForWorkUnitMockRecorder) UpdateBasicStatus(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateBasicStatus", reflect.TypeOf((*MockBaseWorkUnitForWorkUnit)(nil).UpdateBasicStatus), arg0, arg1, arg2) +} + +// UpdateFullStatus mocks base method. +func (m *MockBaseWorkUnitForWorkUnit) UpdateFullStatus(arg0 func(*workceptor.StatusFileData)) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UpdateFullStatus", arg0) +} + +// UpdateFullStatus indicates an expected call of UpdateFullStatus. +func (mr *MockBaseWorkUnitForWorkUnitMockRecorder) UpdateFullStatus(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateFullStatus", reflect.TypeOf((*MockBaseWorkUnitForWorkUnit)(nil).UpdateFullStatus), arg0) +} + +// Warning mocks base method. +func (m *MockBaseWorkUnitForWorkUnit) Warning(arg0 string, arg1 ...interface{}) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + m.ctrl.Call(m, "Warning", varargs...) +} + +// Warning indicates an expected call of Warning. +func (mr *MockBaseWorkUnitForWorkUnitMockRecorder) Warning(arg0 interface{}, arg1 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Warning", reflect.TypeOf((*MockBaseWorkUnitForWorkUnit)(nil).Warning), varargs...) +} diff --git a/pkg/workceptor/python.go b/pkg/workceptor/python.go index 98a363fdf..bc4124a2d 100644 --- a/pkg/workceptor/python.go +++ b/pkg/workceptor/python.go @@ -26,7 +26,7 @@ func (pw *pythonUnit) Start() error { for k, v := range pw.config { config[k] = v } - config["params"] = pw.Status().ExtraData.(*commandExtraData).Params + config["params"] = pw.Status().ExtraData.(*CommandExtraData).Params configJSON, err := json.Marshal(config) if err != nil { return err @@ -50,12 +50,12 @@ type workPythonCfg struct { } // NewWorker is a factory to produce worker instances. -func (cfg workPythonCfg) NewWorker(w *Workceptor, unitID string, workType string) WorkUnit { +func (cfg workPythonCfg) NewWorker(_ BaseWorkUnitForWorkUnit, w *Workceptor, unitID string, workType string) WorkUnit { cw := &pythonUnit{ commandUnit: commandUnit{ - BaseWorkUnit: BaseWorkUnit{ + BaseWorkUnitForWorkUnit: &BaseWorkUnit{ status: StatusFileData{ - ExtraData: &commandExtraData{}, + ExtraData: &CommandExtraData{}, }, }, }, @@ -63,7 +63,7 @@ func (cfg workPythonCfg) NewWorker(w *Workceptor, unitID string, workType string function: cfg.Function, config: cfg.Config, } - cw.BaseWorkUnit.Init(w, unitID, workType, FileSystem{}, nil) + cw.BaseWorkUnitForWorkUnit.Init(w, unitID, workType, FileSystem{}, nil) return cw } diff --git a/pkg/workceptor/remote_work.go b/pkg/workceptor/remote_work.go index 69acfd3a1..a48c95c77 100644 --- a/pkg/workceptor/remote_work.go +++ b/pkg/workceptor/remote_work.go @@ -23,13 +23,13 @@ import ( // remoteUnit implements the WorkUnit interface for the Receptor remote worker plugin. type remoteUnit struct { - BaseWorkUnit + BaseWorkUnitForWorkUnit topJC *utils.JobContext logger *logger.ReceptorLogger } -// remoteExtraData is the content of the ExtraData JSON field for a remote work unit. -type remoteExtraData struct { +// RemoteExtraData is the content of the ExtraData JSON field for a remote work unit. +type RemoteExtraData struct { RemoteNode string RemoteWorkType string RemoteParams map[string]string @@ -47,15 +47,15 @@ type actionFunc func(context.Context, net.Conn, *bufio.Reader) error // connectToRemote establishes a control socket connection to a remote node. func (rw *remoteUnit) connectToRemote(ctx context.Context) (net.Conn, *bufio.Reader, error) { status := rw.Status() - red, ok := status.ExtraData.(*remoteExtraData) + red, ok := status.ExtraData.(*RemoteExtraData) if !ok { return nil, nil, fmt.Errorf("remote ExtraData missing") } - tlsConfig, err := rw.w.nc.GetClientTLSConfig(red.TLSClient, red.RemoteNode, netceptor.ExpectedHostnameTypeReceptor) + tlsConfig, err := rw.GetWorkceptor().nc.GetClientTLSConfig(red.TLSClient, red.RemoteNode, netceptor.ExpectedHostnameTypeReceptor) if err != nil { return nil, nil, err } - conn, err := rw.w.nc.DialContext(ctx, red.RemoteNode, "control", tlsConfig) + conn, err := rw.GetWorkceptor().nc.DialContext(ctx, red.RemoteNode, "control", tlsConfig) if err != nil { return nil, nil, err } @@ -85,14 +85,14 @@ func (rw *remoteUnit) getConnection(ctx context.Context) (net.Conn, *bufio.Reade if err == nil { return conn, reader } - rw.w.nc.GetLogger().Debug("Connection to %s failed with error: %s", - rw.Status().ExtraData.(*remoteExtraData).RemoteNode, err) + rw.GetWorkceptor().nc.GetLogger().Debug("Connection to %s failed with error: %s", + rw.Status().ExtraData.(*RemoteExtraData).RemoteNode, err) errStr := err.Error() if strings.Contains(errStr, "CRYPTO_ERROR") { shouldExit := false rw.UpdateFullStatus(func(status *StatusFileData) { status.Detail = fmt.Sprintf("TLS error connecting to remote service: %s", errStr) - if !status.ExtraData.(*remoteExtraData).RemoteStarted { + if !status.ExtraData.(*RemoteExtraData).RemoteStarted { shouldExit = true status.State = WorkStateFailed } @@ -146,7 +146,7 @@ func (rw *remoteUnit) getConnectionAndRun(ctx context.Context, firstTimeSync boo // startRemoteUnit makes a single attempt to start a remote unit. func (rw *remoteUnit) startRemoteUnit(ctx context.Context, conn net.Conn, reader *bufio.Reader) error { defer conn.(interface{ CloseConnection() error }).CloseConnection() - red := rw.UnredactedStatus().ExtraData.(*remoteExtraData) + red := rw.UnredactedStatus().ExtraData.(*RemoteExtraData) workSubmitCmd := make(map[string]interface{}) for k, v := range red.RemoteParams { workSubmitCmd[k] = v @@ -157,7 +157,7 @@ func (rw *remoteUnit) startRemoteUnit(ctx context.Context, conn net.Conn, reader workSubmitCmd["worktype"] = red.RemoteWorkType workSubmitCmd["tlsclient"] = red.TLSClient if red.SignWork { - signature, err := rw.w.createSignature(red.RemoteNode) + signature, err := rw.GetWorkceptor().createSignature(red.RemoteNode) if err != nil { return err } @@ -185,7 +185,7 @@ func (rw *remoteUnit) startRemoteUnit(ctx context.Context, conn net.Conn, reader } red.RemoteUnitID = string(match[1]) rw.UpdateFullStatus(func(status *StatusFileData) { - ed := status.ExtraData.(*remoteExtraData) + ed := status.ExtraData.(*RemoteExtraData) ed.RemoteUnitID = red.RemoteUnitID }) stdin, err := os.Open(path.Join(rw.UnitDir(), "stdin")) @@ -210,7 +210,7 @@ func (rw *remoteUnit) startRemoteUnit(ctx context.Context, conn net.Conn, reader return fmt.Errorf("error from remote: %s", match[1]) } rw.UpdateFullStatus(func(status *StatusFileData) { - ed := status.ExtraData.(*remoteExtraData) + ed := status.ExtraData.(*RemoteExtraData) ed.RemoteStarted = true }) @@ -222,7 +222,7 @@ func (rw *remoteUnit) cancelOrReleaseRemoteUnit(ctx context.Context, conn net.Co release bool, ) error { defer conn.(interface{ CloseConnection() error }).CloseConnection() - red := rw.Status().ExtraData.(*remoteExtraData) + red := rw.Status().ExtraData.(*RemoteExtraData) var workCmd string if release { workCmd = "release" @@ -234,7 +234,7 @@ func (rw *remoteUnit) cancelOrReleaseRemoteUnit(ctx context.Context, conn net.Co workSubmitCmd["subcommand"] = workCmd workSubmitCmd["unitid"] = red.RemoteUnitID if red.SignWork { - signature, err := rw.w.createSignature(red.RemoteNode) + signature, err := rw.GetWorkceptor().createSignature(red.RemoteNode) if err != nil { return err } @@ -267,9 +267,9 @@ func (rw *remoteUnit) monitorRemoteStatus(mw *utils.JobContext, forRelease bool) mw.WorkerDone() }() status := rw.Status() - red, ok := status.ExtraData.(*remoteExtraData) + red, ok := status.ExtraData.(*RemoteExtraData) if !ok { - rw.w.nc.GetLogger().Error("remote ExtraData missing") + rw.GetWorkceptor().nc.GetLogger().Error("remote ExtraData missing") return } @@ -294,7 +294,7 @@ func (rw *remoteUnit) monitorRemoteStatus(mw *utils.JobContext, forRelease bool) } _, err := conn.Write([]byte(fmt.Sprintf("work status %s\n", remoteUnitID))) if err != nil { - rw.w.nc.GetLogger().Debug("Write error sending to %s: %s\n", remoteUnitID, err) + rw.GetWorkceptor().nc.GetLogger().Debug("Write error sending to %s: %s\n", remoteUnitID, err) _ = conn.(interface{ CloseConnection() error }).CloseConnection() conn = nil @@ -302,7 +302,7 @@ func (rw *remoteUnit) monitorRemoteStatus(mw *utils.JobContext, forRelease bool) } status, err := utils.ReadStringContext(mw, reader, '\n') if err != nil { - rw.w.nc.GetLogger().Debug("Read error reading from %s: %s\n", remoteNode, err) + rw.GetWorkceptor().nc.GetLogger().Debug("Read error reading from %s: %s\n", remoteNode, err) _ = conn.(interface{ CloseConnection() error }).CloseConnection() conn = nil @@ -311,7 +311,7 @@ func (rw *remoteUnit) monitorRemoteStatus(mw *utils.JobContext, forRelease bool) if status[:5] == "ERROR" { if strings.Contains(status, "unknown work unit") { if !forRelease { - rw.w.nc.GetLogger().Debug("Work unit %s on node %s is gone.\n", remoteUnitID, remoteNode) + rw.GetWorkceptor().nc.GetLogger().Debug("Work unit %s on node %s is gone.\n", remoteUnitID, remoteNode) rw.UpdateFullStatus(func(status *StatusFileData) { status.State = WorkStateFailed status.Detail = "Remote work unit is gone" @@ -320,14 +320,14 @@ func (rw *remoteUnit) monitorRemoteStatus(mw *utils.JobContext, forRelease bool) return } - rw.w.nc.GetLogger().Error("Remote error: %s\n", strings.TrimRight(status[6:], "\n")) + rw.GetWorkceptor().nc.GetLogger().Error("Remote error: %s\n", strings.TrimRight(status[6:], "\n")) return } si := StatusFileData{} err = json.Unmarshal([]byte(status), &si) if err != nil { - rw.w.nc.GetLogger().Error("Error unmarshalling JSON: %s\n", status) + rw.GetWorkceptor().nc.GetLogger().Error("Error unmarshalling JSON: %s\n", status) return } @@ -335,7 +335,7 @@ func (rw *remoteUnit) monitorRemoteStatus(mw *utils.JobContext, forRelease bool) if rw.LastUpdateError() != nil { writeStatusFailures++ if writeStatusFailures > 3 { - rw.w.nc.GetLogger().Error("Exceeded retries for updating status file for work unit %s", rw.unitID) + rw.GetWorkceptor().nc.GetLogger().Error("Exceeded retries for updating status file for work unit %s", rw.ID()) return } @@ -343,7 +343,7 @@ func (rw *remoteUnit) monitorRemoteStatus(mw *utils.JobContext, forRelease bool) writeStatusFailures = 0 } if err != nil { - rw.w.nc.GetLogger().Error("Error saving local status file: %s\n", err) + rw.GetWorkceptor().nc.GetLogger().Error("Error saving local status file: %s\n", err) return } @@ -361,20 +361,20 @@ func (rw *remoteUnit) monitorRemoteStdout(mw *utils.JobContext) { }() firstTime := true status := rw.Status() - red, ok := status.ExtraData.(*remoteExtraData) + red, ok := status.ExtraData.(*RemoteExtraData) if !ok { - rw.w.nc.GetLogger().Error("remote ExtraData missing") + rw.GetWorkceptor().nc.GetLogger().Error("remote ExtraData missing") return } remoteNode := red.RemoteNode remoteUnitID := red.RemoteUnitID - stdout, err := os.OpenFile(rw.stdoutFileName, os.O_CREATE+os.O_APPEND+os.O_WRONLY, 0o600) + stdout, err := os.OpenFile(rw.StdoutFileName(), os.O_CREATE+os.O_APPEND+os.O_WRONLY, 0o600) if err == nil { err = stdout.Close() } if err != nil { - rw.w.nc.GetLogger().Error("Could not open stdout file %s: %s\n", rw.stdoutFileName, err) + rw.GetWorkceptor().nc.GetLogger().Error("Could not open stdout file %s: %s\n", rw.StdoutFileName(), err) return } @@ -389,7 +389,7 @@ func (rw *remoteUnit) monitorRemoteStdout(mw *utils.JobContext) { } err := rw.Load() if err != nil { - rw.w.nc.GetLogger().Error("Could not read status file %s: %s\n", rw.statusFileName, err) + rw.GetWorkceptor().nc.GetLogger().Error("Could not read status file %s: %s\n", rw.StatusFileName(), err) return } @@ -414,9 +414,9 @@ func (rw *remoteUnit) monitorRemoteStdout(mw *utils.JobContext) { workSubmitCmd["unitid"] = remoteUnitID workSubmitCmd["startpos"] = diskStdoutSize if red.SignWork { - signature, err := rw.w.createSignature(red.RemoteNode) + signature, err := rw.GetWorkceptor().createSignature(red.RemoteNode) if err != nil { - rw.w.nc.GetLogger().Error("could not create signature to get results") + rw.GetWorkceptor().nc.GetLogger().Error("could not create signature to get results") return } @@ -424,31 +424,31 @@ func (rw *remoteUnit) monitorRemoteStdout(mw *utils.JobContext) { } wscBytes, err := json.Marshal(workSubmitCmd) if err != nil { - rw.w.nc.GetLogger().Error("error constructing work results command: %s", err) + rw.GetWorkceptor().nc.GetLogger().Error("error constructing work results command: %s", err) return } wscBytes = append(wscBytes, '\n') _, err = conn.Write(wscBytes) if err != nil { - rw.w.nc.GetLogger().Warning("Write error sending to %s: %s\n", remoteNode, err) + rw.GetWorkceptor().nc.GetLogger().Warning("Write error sending to %s: %s\n", remoteNode, err) continue } status, err := utils.ReadStringContext(mw, reader, '\n') if err != nil { - rw.w.nc.GetLogger().Warning("Read error reading from %s: %s\n", remoteNode, err) + rw.GetWorkceptor().nc.GetLogger().Warning("Read error reading from %s: %s\n", remoteNode, err) continue } if !strings.Contains(status, "Streaming results") { - rw.w.nc.GetLogger().Warning("Remote node %s did not stream results\n", remoteNode) + rw.GetWorkceptor().nc.GetLogger().Warning("Remote node %s did not stream results\n", remoteNode) continue } - stdout, err := os.OpenFile(rw.stdoutFileName, os.O_CREATE+os.O_APPEND+os.O_WRONLY, 0o600) + stdout, err := os.OpenFile(rw.StdoutFileName(), os.O_CREATE+os.O_APPEND+os.O_WRONLY, 0o600) if err != nil { - rw.w.nc.GetLogger().Error("Could not open stdout file %s: %s\n", rw.stdoutFileName, err) + rw.GetWorkceptor().nc.GetLogger().Error("Could not open stdout file %s: %s\n", rw.StdoutFileName(), err) return } @@ -476,7 +476,7 @@ func (rw *remoteUnit) monitorRemoteStdout(mw *utils.JobContext) { } else { errmsg = err.Error() } - rw.w.nc.GetLogger().Warning("Could not copy to stdout file %s: %s\n", rw.stdoutFileName, errmsg) + rw.GetWorkceptor().nc.GetLogger().Warning("Could not copy to stdout file %s: %s\n", rw.StdoutFileName(), errmsg) continue } @@ -501,7 +501,7 @@ func (rw *remoteUnit) monitorRemoteUnit(ctx context.Context, forRelease bool) { // SetFromParams sets the in-memory state from parameters. func (rw *remoteUnit) SetFromParams(params map[string]string) error { for k, v := range params { - rw.status.ExtraData.(*remoteExtraData).RemoteParams[k] = v + rw.GetStatusCopy().ExtraData.(*RemoteExtraData).RemoteParams[k] = v } return nil @@ -510,7 +510,7 @@ func (rw *remoteUnit) SetFromParams(params map[string]string) error { // Status returns a copy of the status currently loaded in memory. func (rw *remoteUnit) Status() *StatusFileData { status := rw.UnredactedStatus() - ed, ok := status.ExtraData.(*remoteExtraData) + ed, ok := status.ExtraData.(*RemoteExtraData) if ok { keysToDelete := make([]string, 0) for k := range ed.RemoteParams { @@ -528,10 +528,10 @@ func (rw *remoteUnit) Status() *StatusFileData { // UnredactedStatus returns a copy of the status currently loaded in memory, including secrets. func (rw *remoteUnit) UnredactedStatus() *StatusFileData { - rw.statusLock.RLock() - defer rw.statusLock.RUnlock() - status := rw.getStatus() - ed, ok := rw.status.ExtraData.(*remoteExtraData) + rw.GetStatusLock().RLock() + defer rw.GetStatusLock().RUnlock() + status := rw.GetStatusWithoutExtraData() + ed, ok := rw.GetStatusCopy().ExtraData.(*RemoteExtraData) if ok { edCopy := *ed edCopy.RemoteParams = make(map[string]string) @@ -556,9 +556,9 @@ func (rw *remoteUnit) runAndMonitor(mw *utils.JobContext, forRelease bool, actio go func() { rw.monitorRemoteUnit(ctx, forRelease) if forRelease { - err := rw.BaseWorkUnit.Release(false) + err := rw.BaseWorkUnitForWorkUnit.Release(false) if err != nil { - rw.w.nc.GetLogger().Error("Error releasing unit %s: %s", rw.UnitDir(), err) + rw.GetWorkceptor().nc.GetLogger().Error("Error releasing unit %s: %s", rw.UnitDir(), err) } } mw.WorkerDone() @@ -571,12 +571,12 @@ func (rw *remoteUnit) runAndMonitor(mw *utils.JobContext, forRelease bool, actio } func (rw *remoteUnit) setExpiration(mw *utils.JobContext) { - red := rw.Status().ExtraData.(*remoteExtraData) + red := rw.Status().ExtraData.(*RemoteExtraData) dur := time.Until(red.Expiration) select { case <-mw.Done(): case <-time.After(dur): - red := rw.Status().ExtraData.(*remoteExtraData) + red := rw.Status().ExtraData.(*RemoteExtraData) if !red.RemoteStarted { rw.UpdateFullStatus(func(status *StatusFileData) { status.Detail = fmt.Sprintf("Work unit expired on %s", red.Expiration.Format("Mon Jan 2 15:04:05")) @@ -589,11 +589,11 @@ func (rw *remoteUnit) setExpiration(mw *utils.JobContext) { // startOrRestart is a shared implementation of Start() and Restart(). func (rw *remoteUnit) startOrRestart(start bool) error { - red := rw.Status().ExtraData.(*remoteExtraData) + red := rw.Status().ExtraData.(*RemoteExtraData) if start && red.RemoteStarted { return fmt.Errorf("unit was already started") } - newJobStarted := rw.topJC.NewJob(rw.w.ctx, 1, true) + newJobStarted := rw.topJC.NewJob(rw.GetWorkceptor().ctx, 1, true) if !newJobStarted { return fmt.Errorf("start or monitor process already running") } @@ -624,7 +624,7 @@ func (rw *remoteUnit) Start() error { // Restart resumes monitoring a job after a Receptor restart. func (rw *remoteUnit) Restart() error { - red := rw.Status().ExtraData.(*remoteExtraData) + red := rw.Status().ExtraData.(*RemoteExtraData) if red.RemoteStarted { return rw.startOrRestart(false) } @@ -637,31 +637,31 @@ func (rw *remoteUnit) cancelOrRelease(release bool, force bool) error { // Update the status file that the unit is locally cancelled/released var remoteStarted bool rw.UpdateFullStatus(func(status *StatusFileData) { - status.ExtraData.(*remoteExtraData).LocalCancelled = true + status.ExtraData.(*RemoteExtraData).LocalCancelled = true if release { - status.ExtraData.(*remoteExtraData).LocalReleased = true + status.ExtraData.(*RemoteExtraData).LocalReleased = true } - remoteStarted = status.ExtraData.(*remoteExtraData).RemoteStarted + remoteStarted = status.ExtraData.(*RemoteExtraData).RemoteStarted }) // if remote work has not started, don't attempt to connect to remote if !remoteStarted { rw.topJC.Cancel() rw.topJC.Wait() if release { - return rw.BaseWorkUnit.Release(true) + return rw.BaseWorkUnitForWorkUnit.Release(true) } rw.UpdateBasicStatus(WorkStateFailed, "Locally Cancelled", 0) return nil } if release && force { - _ = rw.connectAndRun(rw.w.ctx, func(ctx context.Context, conn net.Conn, reader *bufio.Reader) error { + _ = rw.connectAndRun(rw.GetWorkceptor().ctx, func(ctx context.Context, conn net.Conn, reader *bufio.Reader) error { return rw.cancelOrReleaseRemoteUnit(ctx, conn, reader, true) }) - return rw.BaseWorkUnit.Release(true) + return rw.BaseWorkUnitForWorkUnit.Release(true) } - rw.topJC.NewJob(rw.w.ctx, 1, false) + rw.topJC.NewJob(rw.GetWorkceptor().ctx, 1, false) return rw.runAndMonitor(rw.topJC, release, func(ctx context.Context, conn net.Conn, reader *bufio.Reader) error { return rw.cancelOrReleaseRemoteUnit(ctx, conn, reader, release) @@ -678,12 +678,22 @@ func (rw *remoteUnit) Release(force bool) error { return rw.cancelOrRelease(true, force) } -func newRemoteWorker(w *Workceptor, unitID, workType string) WorkUnit { - rw := &remoteUnit{logger: w.nc.GetLogger()} - rw.BaseWorkUnit.Init(w, unitID, workType, FileSystem{}, nil) - red := &remoteExtraData{} +func NewRemoteWorker(bwu BaseWorkUnitForWorkUnit, w *Workceptor, unitID, workType string) WorkUnit { + return newRemoteWorker(bwu, w, unitID, workType) +} + +func newRemoteWorker(bwu BaseWorkUnitForWorkUnit, w *Workceptor, unitID, workType string) WorkUnit { + if bwu == nil { + bwu = &BaseWorkUnit{} + } + rw := &remoteUnit{ + BaseWorkUnitForWorkUnit: bwu, + logger: w.nc.GetLogger(), + } + rw.BaseWorkUnitForWorkUnit.Init(w, unitID, workType, FileSystem{}, nil) + red := &RemoteExtraData{} red.RemoteParams = make(map[string]string) - rw.status.ExtraData = red + rw.SetStatusExtraData(red) rw.topJC = &utils.JobContext{} return rw diff --git a/pkg/workceptor/remote_work_test.go b/pkg/workceptor/remote_work_test.go new file mode 100644 index 000000000..1f1e1fa01 --- /dev/null +++ b/pkg/workceptor/remote_work_test.go @@ -0,0 +1,56 @@ +package workceptor_test + +import ( + "context" + "sync" + "testing" + + "github.com/ansible/receptor/pkg/workceptor" + "github.com/ansible/receptor/pkg/workceptor/mock_workceptor" + "github.com/golang/mock/gomock" +) + +func createRemoteWorkTestSetup(t *testing.T) (workceptor.WorkUnit, *mock_workceptor.MockBaseWorkUnitForWorkUnit, *mock_workceptor.MockNetceptorForWorkceptor, *workceptor.Workceptor) { + ctrl := gomock.NewController(t) + ctx := context.Background() + + mockBaseWorkUnit := mock_workceptor.NewMockBaseWorkUnitForWorkUnit(ctrl) + mockNetceptor := mock_workceptor.NewMockNetceptorForWorkceptor(ctrl) + mockNetceptor.EXPECT().NodeID().Return("NodeID") + mockNetceptor.EXPECT().GetLogger() + + w, err := workceptor.New(ctx, mockNetceptor, "/tmp") + if err != nil { + t.Errorf("Error while creating Workceptor: %v", err) + } + + mockBaseWorkUnit.EXPECT().Init(w, "", "", workceptor.FileSystem{}, nil) + mockBaseWorkUnit.EXPECT().SetStatusExtraData(gomock.Any()) + workUnit := workceptor.NewRemoteWorker(mockBaseWorkUnit, w, "", "") + + return workUnit, mockBaseWorkUnit, mockNetceptor, w +} + +func TestRemoteWorkUnredactedStatus(t *testing.T) { + t.Parallel() + wu, mockBaseWorkUnit, _, _ := createRemoteWorkTestSetup(t) + restartTestCases := []struct { + name string + }{ + {name: "test1"}, + {name: "test2"}, + } + + statusLock := &sync.RWMutex{} + for _, testCase := range restartTestCases { + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + mockBaseWorkUnit.EXPECT().GetStatusLock().Return(statusLock).Times(2) + mockBaseWorkUnit.EXPECT().GetStatusWithoutExtraData().Return(&workceptor.StatusFileData{}) + mockBaseWorkUnit.EXPECT().GetStatusCopy().Return(workceptor.StatusFileData{ + ExtraData: &workceptor.RemoteExtraData{}, + }) + wu.UnredactedStatus() + }) + } +} diff --git a/pkg/workceptor/workceptor.go b/pkg/workceptor/workceptor.go index 0057f9a77..2292822a9 100644 --- a/pkg/workceptor/workceptor.go +++ b/pkg/workceptor/workceptor.go @@ -243,7 +243,7 @@ func (w *Workceptor) AllocateUnit(workTypeName string, params map[string]string) if err != nil { return nil, err } - worker := wt.newWorkerFunc(w, ident, workTypeName) + worker := wt.newWorkerFunc(nil, w, ident, workTypeName) err = worker.SetFromParams(params) if err == nil { err = worker.Save() @@ -295,7 +295,7 @@ func (w *Workceptor) AllocateRemoteUnit(remoteNode, remoteWorkType, tlsClient, t expiration = time.Time{} } rw.UpdateFullStatus(func(status *StatusFileData) { - ed := status.ExtraData.(*remoteExtraData) + ed := status.ExtraData.(*RemoteExtraData) ed.RemoteNode = remoteNode ed.RemoteWorkType = remoteWorkType ed.TLSClient = tlsClient @@ -330,7 +330,7 @@ func (w *Workceptor) scanForUnit(unitID string) { w.workTypesLock.RUnlock() var worker WorkUnit if ok { - worker = wt.newWorkerFunc(w, ident, sfd.WorkType) + worker = wt.newWorkerFunc(nil, w, ident, sfd.WorkType) } else { worker = newUnknownWorker(w, ident, sfd.WorkType) } diff --git a/pkg/workceptor/workceptor_test.go b/pkg/workceptor/workceptor_test.go index 44b1d7a24..f74e3a36c 100644 --- a/pkg/workceptor/workceptor_test.go +++ b/pkg/workceptor/workceptor_test.go @@ -23,7 +23,7 @@ func TestAllocateUnit(t *testing.T) { logger := logger.NewReceptorLogger("") mockNetceptor.EXPECT().GetLogger().AnyTimes().Return(logger) - workFunc := func(w *workceptor.Workceptor, unitID string, workType string) workceptor.WorkUnit { + workFunc := func(bwu workceptor.BaseWorkUnitForWorkUnit, w *workceptor.Workceptor, unitID string, workType string) workceptor.WorkUnit { return mockWorkUnit } diff --git a/pkg/workceptor/workunitbase.go b/pkg/workceptor/workunitbase.go index 291e95978..97413548b 100644 --- a/pkg/workceptor/workunitbase.go +++ b/pkg/workceptor/workunitbase.go @@ -485,6 +485,38 @@ func (bwu *BaseWorkUnit) CancelContext() { bwu.cancel() } +func (bwu *BaseWorkUnit) GetStatusCopy() StatusFileData { + return bwu.status +} + +func (bwu *BaseWorkUnit) GetStatusWithoutExtraData() *StatusFileData { + return bwu.getStatus() +} + +func (bwu *BaseWorkUnit) SetStatusExtraData(ed interface{}) { + bwu.status.ExtraData = ed +} + +func (bwu *BaseWorkUnit) GetStatusLock() *sync.RWMutex { + return bwu.statusLock +} + +func (bwu *BaseWorkUnit) GetWorkceptor() *Workceptor { + return bwu.w +} + +func (bwu *BaseWorkUnit) SetWorkceptor(w *Workceptor) { + bwu.w = w +} + +func (bwu *BaseWorkUnit) GetContext() context.Context { + return bwu.ctx +} + +func (bwu *BaseWorkUnit) GetCancel() context.CancelFunc { + return bwu.cancel +} + // =============================================================================================== // func newUnknownWorker(w *Workceptor, unitID string, workType string) WorkUnit {