diff --git a/docker-compose.dpu.yml b/docker-compose.dpu.yml index 28c9600..7e9bd16 100644 --- a/docker-compose.dpu.yml +++ b/docker-compose.dpu.yml @@ -43,6 +43,9 @@ services: - dhcp-leases-folder:/var/lib/dhclient/ - /etc/os-release:/etc/os-release - /etc/ssh:/etc/ssh + - /var/lib/sztp:/var/lib/sztp + - /run/sztp:/run/sztp + privileged: true networks: - opi command: ['/opi-sztp-agent', 'daemon', diff --git a/scripts/run_agent.sh b/scripts/run_agent.sh index 7b31a63..2c35965 100755 --- a/scripts/run_agent.sh +++ b/scripts/run_agent.sh @@ -21,6 +21,9 @@ docker run --rm -it --network=host \ --mount type=bind,source=/etc/ssh,target=/etc/ssh,readonly \ --mount type=bind,source=/etc/os-release,target=/etc/os-release,readonly \ --mount type=bind,source=/var/lib/NetworkManager,target=/var/lib/NetworkManager,readonly \ + --mount type=bind,source=/var/lib/sztp,target=/var/lib/sztp \ + --mount type=bind,source=/run/sztp,target=/run/sztp \ + --privileged \ ${DOCKER_SZTP_IMAGE} \ /opi-sztp-agent daemon \ --dhcp-lease-file /var/lib/NetworkManager/dhclient-eth0.lease \ diff --git a/sztp-agent/cmd/daemon.go b/sztp-agent/cmd/daemon.go index 92d49ef..7cc4c22 100644 --- a/sztp-agent/cmd/daemon.go +++ b/sztp-agent/cmd/daemon.go @@ -32,13 +32,16 @@ func Daemon() *cobra.Command { devicePrivateKey string deviceEndEntityCert string bootstrapTrustAnchorCert string + statusFilePath string + resultFilePath string + symLinkDir string ) cmd := &cobra.Command{ Use: "daemon", Short: "Run the daemon command", RunE: func(_ *cobra.Command, _ []string) error { - arrayChecker := []string{devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert} + arrayChecker := []string{devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, statusFilePath, resultFilePath} if bootstrapURL != "" && dhcpLeaseFile != "" { return fmt.Errorf("'--bootstrap-url' and '--dhcp-lease-file' are mutualy exclusive") } @@ -60,7 +63,7 @@ func Daemon() *cobra.Command { } } client := secureagent.NewHTTPClient(bootstrapTrustAnchorCert, deviceEndEntityCert, devicePrivateKey) - a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, &client) + a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, statusFilePath, resultFilePath, symLinkDir, &client) return a.RunCommandDaemon() }, } @@ -75,6 +78,9 @@ func Daemon() *cobra.Command { flags.StringVar(&devicePrivateKey, "device-private-key", "/certs/private_key.pem", "Device's private key") flags.StringVar(&deviceEndEntityCert, "device-end-entity-cert", "/certs/my_cert.pem", "Device's End Entity cert") flags.StringVar(&bootstrapTrustAnchorCert, "bootstrap-trust-anchor-cert", "/certs/opi.pem", "Bootstrap server trust anchor Cert") + flags.StringVar(&statusFilePath, "status-file-path", "/var/lib/sztp/status.json", "Status file path") + flags.StringVar(&resultFilePath, "result-file-path", "/var/lib/sztp/result.json", "Result file path") + flags.StringVar(&symLinkDir, "sym-link-dir", "/run/sztp", "Sym Link Directory") return cmd } diff --git a/sztp-agent/cmd/disable.go b/sztp-agent/cmd/disable.go index 157776d..732fb1c 100644 --- a/sztp-agent/cmd/disable.go +++ b/sztp-agent/cmd/disable.go @@ -28,6 +28,9 @@ func Disable() *cobra.Command { devicePrivateKey string deviceEndEntityCert string bootstrapTrustAnchorCert string + statusFilePath string + resultFilePath string + symLinkDir string ) cmd := &cobra.Command{ @@ -35,7 +38,7 @@ func Disable() *cobra.Command { Short: "Run the disable command", RunE: func(_ *cobra.Command, _ []string) error { client := secureagent.NewHTTPClient(bootstrapTrustAnchorCert, deviceEndEntityCert, devicePrivateKey) - a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, &client) + a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, statusFilePath, resultFilePath, symLinkDir, &client) return a.RunCommandDisable() }, } @@ -50,5 +53,9 @@ func Disable() *cobra.Command { flags.StringVar(&devicePrivateKey, "device-private-key", "", "Device's private key") flags.StringVar(&deviceEndEntityCert, "device-end-entity-cert", "", "Device's End Entity cert") flags.StringVar(&bootstrapTrustAnchorCert, "bootstrap-trust-anchor-cert", "", "Bootstrap server trust anchor Cert") + flags.StringVar(&statusFilePath, "status-file-path", "/var/lib/sztp/status.json", "Status file path") + flags.StringVar(&resultFilePath, "result-file-path", "/var/lib/sztp/result.json", "Result file path") + flags.StringVar(&symLinkDir, "sym-link-dir", "/run/sztp", "Sym Link Directory") + return cmd } diff --git a/sztp-agent/cmd/enable.go b/sztp-agent/cmd/enable.go index 49f53b4..2a75031 100644 --- a/sztp-agent/cmd/enable.go +++ b/sztp-agent/cmd/enable.go @@ -28,6 +28,9 @@ func Enable() *cobra.Command { devicePrivateKey string deviceEndEntityCert string bootstrapTrustAnchorCert string + statusFilePath string + resultFilePath string + symLinkDir string ) cmd := &cobra.Command{ @@ -35,7 +38,7 @@ func Enable() *cobra.Command { Short: "Run the enable command", RunE: func(_ *cobra.Command, _ []string) error { client := secureagent.NewHTTPClient(bootstrapTrustAnchorCert, deviceEndEntityCert, devicePrivateKey) - a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, &client) + a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, statusFilePath, resultFilePath, symLinkDir, &client) return a.RunCommandEnable() }, } @@ -50,6 +53,9 @@ func Enable() *cobra.Command { flags.StringVar(&devicePrivateKey, "device-private-key", "", "Device's private key") flags.StringVar(&deviceEndEntityCert, "device-end-entity-cert", "", "Device's End Entity cert") flags.StringVar(&bootstrapTrustAnchorCert, "bootstrap-trust-anchor-cert", "", "Bootstrap server trust anchor Cert") + flags.StringVar(&statusFilePath, "status-file-path", "/var/lib/sztp/status.json", "Status file path") + flags.StringVar(&resultFilePath, "result-file-path", "/var/lib/sztp/result.json", "Result file path") + flags.StringVar(&symLinkDir, "sym-link-dir", "/run/sztp", "Sym Link Directory") return cmd } diff --git a/sztp-agent/cmd/run.go b/sztp-agent/cmd/run.go index 99a9b1f..0178e62 100644 --- a/sztp-agent/cmd/run.go +++ b/sztp-agent/cmd/run.go @@ -32,13 +32,16 @@ func Run() *cobra.Command { devicePrivateKey string deviceEndEntityCert string bootstrapTrustAnchorCert string + statusFilePath string + resultFilePath string + symLinkDir string ) cmd := &cobra.Command{ Use: "run", Short: "Exec the run command", RunE: func(_ *cobra.Command, _ []string) error { - arrayChecker := []string{devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert} + arrayChecker := []string{devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, statusFilePath, resultFilePath} if bootstrapURL != "" && dhcpLeaseFile != "" { return fmt.Errorf("'--bootstrap-url' and '--dhcp-lease-file' are mutualy exclusive") } @@ -60,7 +63,7 @@ func Run() *cobra.Command { } } client := secureagent.NewHTTPClient(bootstrapTrustAnchorCert, deviceEndEntityCert, devicePrivateKey) - a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, &client) + a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, statusFilePath, resultFilePath, symLinkDir, &client) return a.RunCommand() }, } @@ -75,6 +78,9 @@ func Run() *cobra.Command { flags.StringVar(&devicePrivateKey, "device-private-key", "/certs/private_key.pem", "Device's private key") flags.StringVar(&deviceEndEntityCert, "device-end-entity-cert", "/certs/my_cert.pem", "Device's End Entity cert") flags.StringVar(&bootstrapTrustAnchorCert, "bootstrap-trust-anchor-cert", "/certs/opi.pem", "Bootstrap server trust anchor Cert") + flags.StringVar(&statusFilePath, "status-file-path", "/var/lib/sztp/status.json", "Status file path") + flags.StringVar(&resultFilePath, "result-file-path", "/var/lib/sztp/result.json", "Result file path") + flags.StringVar(&symLinkDir, "sym-link-dir", "/run/sztp", "Sym Link Directory") return cmd } diff --git a/sztp-agent/cmd/status.go b/sztp-agent/cmd/status.go index dbc80df..b8b1e1f 100644 --- a/sztp-agent/cmd/status.go +++ b/sztp-agent/cmd/status.go @@ -28,6 +28,9 @@ func Status() *cobra.Command { devicePrivateKey string deviceEndEntityCert string bootstrapTrustAnchorCert string + statusFilePath string + resultFilePath string + symLinkDir string ) cmd := &cobra.Command{ @@ -35,7 +38,7 @@ func Status() *cobra.Command { Short: "Run the status command", RunE: func(_ *cobra.Command, _ []string) error { client := secureagent.NewHTTPClient(bootstrapTrustAnchorCert, deviceEndEntityCert, devicePrivateKey) - a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, &client) + a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, statusFilePath, resultFilePath, symLinkDir, &client) return a.RunCommandStatus() }, } @@ -50,6 +53,9 @@ func Status() *cobra.Command { flags.StringVar(&devicePrivateKey, "device-private-key", "", "Device's private key") flags.StringVar(&deviceEndEntityCert, "device-end-entity-cert", "", "Device's End Entity cert") flags.StringVar(&bootstrapTrustAnchorCert, "bootstrap-trust-anchor-cert", "", "Bootstrap server trust anchor Cert") + flags.StringVar(&statusFilePath, "status-file-path", "/var/lib/sztp/status.json", "Status file path") + flags.StringVar(&resultFilePath, "result-file-path", "/var/lib/sztp/result.json", "Result file path") + flags.StringVar(&symLinkDir, "sym-link-dir", "/run/sztp", "Sym Link Directory") return cmd } diff --git a/sztp-agent/pkg/secureagent/agent.go b/sztp-agent/pkg/secureagent/agent.go index 6797046..5c43e7d 100644 --- a/sztp-agent/pkg/secureagent/agent.go +++ b/sztp-agent/pkg/secureagent/agent.go @@ -93,9 +93,12 @@ type Agent struct { BootstrapServerOnboardingInfo BootstrapServerOnboardingInfo // BootstrapServerOnboardingInfo structure BootstrapServerRedirectInfo BootstrapServerRedirectInfo // BootstrapServerRedirectInfo structure HttpClient HttpClient + StatusFilePath string // Path to the status file + ResultFilePath string // Path to the result file + SymLinkDir string // Path to the symlink directory for the status file } -func NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert string, httpClient HttpClient) *Agent { +func NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, statusFilePath, resultFilePath, symLinkDir string, httpClient HttpClient) *Agent { return &Agent{ InputBootstrapURL: bootstrapURL, BootstrapURL: "", @@ -111,6 +114,9 @@ func NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, deviceP BootstrapServerRedirectInfo: BootstrapServerRedirectInfo{}, BootstrapServerOnboardingInfo: BootstrapServerOnboardingInfo{}, HttpClient: httpClient, + StatusFilePath: statusFilePath, + ResultFilePath: resultFilePath, + SymLinkDir: symLinkDir, } } @@ -150,6 +156,18 @@ func (a *Agent) GetProgressJSON() ProgressJSON { return a.ProgressJSON } +func (a *Agent) GetStatusFilePath() string { + return a.StatusFilePath +} + +func (a *Agent) GetResultFilePath() string { + return a.ResultFilePath +} + +func (a *Agent) GetSymLinkDir() string { + return a.SymLinkDir +} + func (a *Agent) SetBootstrapURL(url string) { a.BootstrapURL = url } @@ -181,3 +199,15 @@ func (a *Agent) SetContentTypeReq(ct string) { func (a *Agent) SetProgressJSON(p ProgressJSON) { a.ProgressJSON = p } + +func (a *Agent) SetStatusFilePath(path string) { + a.StatusFilePath = path +} + +func (a *Agent) SetResultFilePath(path string) { + a.ResultFilePath = path +} + +func (a *Agent) SetSymLinkDir(path string) { + a.SymLinkDir = path +} diff --git a/sztp-agent/pkg/secureagent/agent_test.go b/sztp-agent/pkg/secureagent/agent_test.go index e8234a8..fec69b9 100644 --- a/sztp-agent/pkg/secureagent/agent_test.go +++ b/sztp-agent/pkg/secureagent/agent_test.go @@ -829,6 +829,9 @@ func TestNewAgent(t *testing.T) { devicePrivateKey string deviceEndEntityCert string bootstrapTrustAnchorCert string + statusFilePath string + resultFilePath string + symLinkDir string } client := http.Client{} tests := []struct { @@ -846,6 +849,9 @@ func TestNewAgent(t *testing.T) { devicePrivateKey: "TestDevicePrivateKey", deviceEndEntityCert: "TestDeviceEndEntityCert", bootstrapTrustAnchorCert: "TestBootstrapTrustCert", + statusFilePath: "TestStatusFilePath", + resultFilePath: "TestResultFilePath", + symLinkDir: "TestSymLinkDir", }, want: &Agent{ InputBootstrapURL: "TestBootstrap", @@ -858,13 +864,16 @@ func TestNewAgent(t *testing.T) { ContentTypeReq: "application/yang-data+json", InputJSONContent: generateInputJSONContent(), DhcpLeaseFile: "TestDhcpLeaseFile", + StatusFilePath: "TestStatusFilePath", + ResultFilePath: "TestResultFilePath", + SymLinkDir: "TestSymLinkDir", HttpClient: &client, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := NewAgent(tt.args.bootstrapURL, tt.args.serialNumber, tt.args.dhcpLeaseFile, tt.args.devicePassword, tt.args.devicePrivateKey, tt.args.deviceEndEntityCert, tt.args.bootstrapTrustAnchorCert, &client); !reflect.DeepEqual(got, tt.want) { + if got := NewAgent(tt.args.bootstrapURL, tt.args.serialNumber, tt.args.dhcpLeaseFile, tt.args.devicePassword, tt.args.devicePrivateKey, tt.args.deviceEndEntityCert, tt.args.bootstrapTrustAnchorCert, tt.args.statusFilePath, tt.args.resultFilePath, tt.args.symLinkDir, &client); !reflect.DeepEqual(got, tt.want) { t.Errorf("NewAgent() = %v, want %v", got, tt.want) } }) diff --git a/sztp-agent/pkg/secureagent/configuration.go b/sztp-agent/pkg/secureagent/configuration.go index 1a6f013..f6594de 100644 --- a/sztp-agent/pkg/secureagent/configuration.go +++ b/sztp-agent/pkg/secureagent/configuration.go @@ -10,6 +10,7 @@ import ( func (a *Agent) copyConfigurationFile() error { log.Println("[INFO] Starting the Copy Configuration.") _ = a.doReportProgress(ProgressTypeConfigInitiated, "Configuration Initiated") + _ = a.updateAndSaveStatus(StageTypeConfig, true, "") // Copy the configuration file to the device file, err := os.Create(ARTIFACTS_PATH + a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.InfoTimestampReference + "-config") if err != nil { @@ -36,6 +37,7 @@ func (a *Agent) copyConfigurationFile() error { } log.Println("[INFO] Configuration file copied successfully") _ = a.doReportProgress(ProgressTypeConfigComplete, "Configuration Complete") + _ = a.updateAndSaveStatus(StageTypeConfig, false, "") return nil } @@ -43,19 +45,24 @@ func (a *Agent) launchScriptsConfiguration(typeOf string) error { var script, scriptName string var reportStart, reportEnd ProgressType switch typeOf { - case "post": + case POST: script = a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.PostConfigurationScript - scriptName = "post" + scriptName = POST reportStart = ProgressTypePostScriptInitiated reportEnd = ProgressTypePostScriptComplete default: // pre or default script = a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.PreConfigurationScript - scriptName = "pre" + scriptName = PRE reportStart = ProgressTypePreScriptInitiated reportEnd = ProgressTypePreScriptComplete } log.Println("[INFO] Starting the " + scriptName + "-configuration.") _ = a.doReportProgress(reportStart, "Report starting") + if scriptName == PRE { + _ = a.updateAndSaveStatus(StageTypePreScript, true, "") + } else if scriptName == POST { + _ = a.updateAndSaveStatus(StageTypePostScript, true, "") + } // nolint:gosec file, err := os.Create(ARTIFACTS_PATH + a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.InfoTimestampReference + scriptName + "configuration.sh") if err != nil { @@ -89,6 +96,11 @@ func (a *Agent) launchScriptsConfiguration(typeOf string) error { } log.Println(string(out)) // remove it _ = a.doReportProgress(reportEnd, "Report end") + if scriptName == PRE { + _ = a.updateAndSaveStatus(StageTypePreScript, false, "") + } else if scriptName == POST { + _ = a.updateAndSaveStatus(StageTypePostScript, false, "") + } log.Println("[INFO] " + scriptName + "-Configuration script executed successfully") return nil } diff --git a/sztp-agent/pkg/secureagent/daemon.go b/sztp-agent/pkg/secureagent/daemon.go index 2064d90..a8abd7b 100644 --- a/sztp-agent/pkg/secureagent/daemon.go +++ b/sztp-agent/pkg/secureagent/daemon.go @@ -33,14 +33,21 @@ const ( // RunCommandDaemon runs the command in the background func (a *Agent) RunCommandDaemon() error { + if err := a.prepareStatus(); err != nil { + log.Println("failed to prepare status: ", err) + return err + } + _ = a.updateAndSaveStatus(StageTypeIsCompleted, true, "") for { err := a.performBootstrapSequence() if err != nil { log.Println("[ERROR] Failed to perform the bootstrap sequence: ", err.Error()) log.Println("[INFO] Retrying in 5 seconds") time.Sleep(5 * time.Second) + _ = a.updateAndSaveStatus(StageTypeIsCompleted, false, err.Error()) continue } + _ = a.updateAndSaveStatus(StageTypeIsCompleted, false, "") return nil } } @@ -49,33 +56,41 @@ func (a *Agent) performBootstrapSequence() error { var err error err = a.discoverBootstrapURLs() if err != nil { + _ = a.updateAndSaveStatus(StageTypeParsing, false, err.Error()) return err } err = a.doRequestBootstrapServerOnboardingInfo() if err != nil { + _ = a.updateAndSaveStatus(StageTypeOnboarding, false, err.Error()) return err } err = a.doHandleBootstrapRedirect() if err != nil { + _ = a.updateAndSaveStatus(StageTypeBootImage, false, err.Error()) return err } err = a.downloadAndValidateImage() if err != nil { + _ = a.updateAndSaveStatus(StageTypeBootImage, false, err.Error()) return err } err = a.copyConfigurationFile() if err != nil { + _ = a.updateAndSaveStatus(StageTypeConfig, false, err.Error()) return err } err = a.launchScriptsConfiguration(PRE) if err != nil { + _ = a.updateAndSaveStatus(StageTypePreScript, false, err.Error()) return err } err = a.launchScriptsConfiguration(POST) if err != nil { + _ = a.updateAndSaveStatus(StageTypePostScript, false, err.Error()) return err } _ = a.doReportProgress(ProgressTypeBootstrapComplete, "Bootstrap Complete") + _ = a.updateAndSaveStatus(StageTypeBootstrap, false, "") return nil } @@ -142,6 +157,7 @@ func (a *Agent) doRequestBootstrapServerOnboardingInfo() error { } log.Println("[INFO] Response retrieved successfully") _ = a.doReportProgress(ProgressTypeBootstrapInitiated, "Bootstrap Initiated") + _ = a.updateAndSaveStatus(StageTypeBootstrap, true, "") crypto := res.IetfSztpBootstrapServerOutput.ConveyedInformation newVal, err := base64.StdEncoding.DecodeString(crypto) if err != nil { diff --git a/sztp-agent/pkg/secureagent/image.go b/sztp-agent/pkg/secureagent/image.go index 9af486c..7362d77 100644 --- a/sztp-agent/pkg/secureagent/image.go +++ b/sztp-agent/pkg/secureagent/image.go @@ -23,6 +23,7 @@ import ( func (a *Agent) downloadAndValidateImage() error { log.Printf("[INFO] Starting the Download Image: %v", a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.BootImage.DownloadURI) _ = a.doReportProgress(ProgressTypeBootImageInitiated, "BootImage Initiated") + _ = a.updateAndSaveStatus(StageTypeBootImage, true, "") // Download the image from DownloadURI and save it to a file a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.InfoTimestampReference = fmt.Sprintf("%8d", time.Now().Unix()) for i, item := range a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.BootImage.DownloadURI { @@ -78,6 +79,7 @@ func (a *Agent) downloadAndValidateImage() error { } log.Println("[INFO] Checksum verified successfully") _ = a.doReportProgress(ProgressTypeBootImageComplete, "BootImage Complete") + _ = a.updateAndSaveStatus(StageTypeBootImage, false, "") return nil default: return errors.New("unsupported hash algorithm") diff --git a/sztp-agent/pkg/secureagent/run.go b/sztp-agent/pkg/secureagent/run.go index b3b2e4c..09dbf8c 100644 --- a/sztp-agent/pkg/secureagent/run.go +++ b/sztp-agent/pkg/secureagent/run.go @@ -13,6 +13,10 @@ import "log" // RunCommand runs the command in the background func (a *Agent) RunCommand() error { log.Println("runCommand started") + if err := a.prepareStatus(); err != nil { + log.Println("failed to prepare status: ", err) + return err + } err := a.performBootstrapSequence() if err != nil { log.Println("Error in performBootstrapSequence inside runCommand: ", err) diff --git a/sztp-agent/pkg/secureagent/status.go b/sztp-agent/pkg/secureagent/status.go index e5341fd..4922a45 100644 --- a/sztp-agent/pkg/secureagent/status.go +++ b/sztp-agent/pkg/secureagent/status.go @@ -4,23 +4,274 @@ Copyright (C) 2022-2023 Intel Corporation Copyright (c) 2022 Dell Inc, or its subsidiaries. Copyright (C) 2022 Red Hat. */ - +// nolint // Package secureagent implements the secure agent package secureagent -import "log" +import ( + "fmt" + "log" + "path/filepath" + "time" +) + +type StageType int64 + +const ( + StageTypeInit StageType = iota + StageTypeDownloadingFile + StageTypePendingReboot + StageTypeParsing + StageTypeOnboarding + StageTypeRedirect + StageTypeBootImage + StageTypePreScript + StageTypeConfig + StageTypePostScript + StageTypeBootstrap + StageTypeIsCompleted +) + +func (s StageType) String() string { + switch s { + case StageTypeInit: + return "init" + case StageTypeDownloadingFile: + return "downloading-file" + case StageTypePendingReboot: + return "pending-reboot" + case StageTypeParsing: + return "parsing" + case StageTypeOnboarding: + return "onboarding" + case StageTypeRedirect: + return "redirect" + case StageTypeBootImage: + return "boot-image" + case StageTypePreScript: + return "pre-script" + case StageTypeConfig: + return "config" + case StageTypePostScript: + return "post-script" + case StageTypeBootstrap: + return "bootstrap" + case StageTypeIsCompleted: + return "is-completed" + default: + return "unknown" + } +} + +// Status represents the status of the provisioning process. +type Status struct { + Init StageStatus `json:"init"` + DownloadingFile StageStatus `json:"downloading-file"` + PendingReboot StageStatus `json:"pending-reboot"` + Parsing StageStatus `json:"parsing"` + Onboarding StageStatus `json:"onboarding"` + Redirect StageStatus `json:"redirect"` + BootImage StageStatus `json:"boot-image"` + PreScript StageStatus `json:"pre-script"` + Config StageStatus `json:"config"` + PostScript StageStatus `json:"post-script"` + Bootstrap StageStatus `json:"bootstrap"` + IsCompleted StageStatus `json:"is-completed"` + Informational string `json:"informational"` + Stage string `json:"stage"` +} + +// Result represents the result of the provisioning process. +type Result struct { + Errors []string `json:"errors"` +} + +// StageStatus represents the status of a specific stage. +type StageStatus struct { + Errors []string `json:"errors"` + Start float64 `json:"start"` + End float64 `json:"end"` +} + +func (a *Agent) getCurrStatus() (*Status, error) { + var status Status + err := loadFile(a.GetStatusFilePath(), &status) + if err != nil { + return nil, err + } + return &status, nil +} + +func (a *Agent) getCurrResult() (*Result, error) { + var result Result + err := loadFile(a.GetResultFilePath(), &result) + if err != nil { + return nil, err + } + return &result, nil +} + +func (a *Agent) createNewStatus() *Status { + return &Status{ + Stage: "", + IsCompleted: StageStatus{}, + } +} + +// updateAndSaveStatus updates the status object for a specific stage and saves it to the status.json file. +func (a *Agent) updateAndSaveStatus(s StageType, isStart bool, errMsg string) error { + status, err := a.getCurrStatus() + if err != nil { + fmt.Println("Creating a new status file.") + status = a.createNewStatus() + } + + err = a.updateStageStatus(status, s, isStart, errMsg) + if err != nil { + return err + } + + return a.saveStatus(status) +} + +// updateStageStatus updates the status object for a specific stage. +func (a *Agent) updateStageStatus(status *Status, stageType StageType, isStart bool, errMsg string) error { + now := float64(time.Now().Unix()) + stage := stageType.String() + + switch stageType { + case StageTypeInit: + a.updateStage(&status.Init, isStart, now, errMsg) + case StageTypeDownloadingFile: + a.updateStage(&status.DownloadingFile, isStart, now, errMsg) + case StageTypePendingReboot: + a.updateStage(&status.PendingReboot, isStart, now, errMsg) + case StageTypeIsCompleted: + a.updateStage(&status.IsCompleted, isStart, now, errMsg) + case StageTypeParsing: + a.updateStage(&status.Parsing, isStart, now, errMsg) + case StageTypeOnboarding: + a.updateStage(&status.Onboarding, isStart, now, errMsg) + case StageTypeRedirect: + a.updateStage(&status.Redirect, isStart, now, errMsg) + case StageTypeBootImage: + a.updateStage(&status.BootImage, isStart, now, errMsg) + case StageTypePreScript: + a.updateStage(&status.PreScript, isStart, now, errMsg) + case StageTypeConfig: + a.updateStage(&status.Config, isStart, now, errMsg) + case StageTypePostScript: + a.updateStage(&status.PostScript, isStart, now, errMsg) + case StageTypeBootstrap: + a.updateStage(&status.Bootstrap, isStart, now, errMsg) + + default: + return fmt.Errorf("unknown stage: %s", stage) + } + + if isStart { + status.Stage = stage + "-in-progress" + } else { + status.Stage = stage + "-completed" + } + + return nil +} + +func (a *Agent) updateStage(stageStatus *StageStatus, isStart bool, now float64, errMsg string) { + if isStart { + stageStatus.Start = now + stageStatus.End = 0 + } else { + stageStatus.End = now + if errMsg != "" { + stageStatus.Errors = append(stageStatus.Errors, errMsg) + err := a.updateAndSaveResult(errMsg) + if err != nil { + fmt.Printf("Failed to update and save result: %v\n", err) + } + } + } +} + +func (a *Agent) saveStatus(status *Status) error { + return saveToFile(status, a.GetStatusFilePath()) +} + +func (a *Agent) saveResult(result *Result) error { + return saveToFile(result, a.GetResultFilePath()) +} + +func (a *Agent) updateAndSaveResult(errMsg string) error { + result, err := a.getCurrResult() + if err != nil { + fmt.Println("Creating a new result file.") + result = &Result{ + Errors: []string{}, + } + } + + if errMsg != "" { + result.Errors = append(result.Errors, errMsg) + } + + return a.saveResult(result) +} // RunCommandStatus runs the command in the background func (a *Agent) RunCommandStatus() error { log.Println("RunCommandStatus") + status, err := a.getCurrStatus() + if err != nil { + log.Println("failed to load status file: ", err) + return err + } + fmt.Printf("Current status: %+v\n", status) return nil } -/* -func (a *Agent) prepareEnvStatus() error { - log.Println("prepareEnvStatus") +func (a *Agent) prepareStatus() error { + log.Println("prepareStatus") + + // Ensure /run/sztp directory exists + if err := ensureDirExists(a.GetSymLinkDir()); err != nil { + fmt.Printf("Failed to create directory %s: %v\n", a.GetSymLinkDir(), err) + return err + } + + fmt.Println("Status File Path", a.GetStatusFilePath()) + fmt.Println("Result File Path", a.GetResultFilePath()) + + if err := ensureFileExists(a.GetStatusFilePath()); err != nil { + return err + } + if err := ensureFileExists(a.GetResultFilePath()); err != nil { + return err + } + + statusSymlinkPath := filepath.Join(a.GetSymLinkDir(), "status.json") + resultSymlinkPath := filepath.Join(a.GetSymLinkDir(), "result.json") + + // Create symlinks for status.json and result.json + if err := createSymlink(a.GetStatusFilePath(), statusSymlinkPath); err != nil { + fmt.Printf("Failed to create symlink for status.json: %v\n", err) + return err + } + if err := createSymlink(a.GetResultFilePath(), resultSymlinkPath); err != nil { + fmt.Printf("Failed to create symlink for result.json: %v\n", err) + return err + } + + fmt.Println("Symlinks created successfully.") + + if err := a.updateAndSaveStatus(StageTypeInit, true, ""); err != nil { + return err + } + return nil } + +/* func (a *Agent) configureStatus() error { log.Println("configureStatus") return nil diff --git a/sztp-agent/pkg/secureagent/status_test.go b/sztp-agent/pkg/secureagent/status_test.go index 2aa1133..75c9470 100644 --- a/sztp-agent/pkg/secureagent/status_test.go +++ b/sztp-agent/pkg/secureagent/status_test.go @@ -4,9 +4,36 @@ // Package secureagent implements the secure agent package secureagent -import "testing" +import ( + "testing" +) + +const StatusTestContent = `{ + "init": {"errors": [], "start": 1729891263, "end": 0}, + "downloading-file": {"errors": [], "start": 0, "end": 0}, + "pending-reboot": {"errors": [], "start": 0, "end": 0}, + "parsing": {"errors": [], "start": 0, "end": 0}, + "onboarding": {"errors": [], "start": 0, "end": 0}, + "redirect": {"errors": [], "start": 0, "end": 0}, + "boot-image": {"errors": [], "start": 1729891263, "end": 1729891263}, + "pre-script": {"errors": [], "start": 1729891264, "end": 1729891264}, + "config": {"errors": [], "start": 1729891264, "end": 1729891264}, + "post-script": {"errors": [], "start": 1729891264, "end": 1729891264}, + "bootstrap": {"errors": [], "start": 1729891263, "end": 1729891264}, + "is-completed": {"errors": [], "start": 1729891263, "end": 1729891264}, + "informational": "", + "stage": "is-completed-completed" +}` + +const ResultTestContent = `{ + "errors": ["error1", "error2"], +}` func TestAgent_RunCommandStatus(t *testing.T) { + testStatusFile := "/tmp/sztp/status.json" + testResultFile := "/tmp/sztp/result.json" + testSymLinkDir := "/tmp/symlink" + type fields struct { BootstrapURL string SerialNumber string @@ -20,6 +47,9 @@ func TestAgent_RunCommandStatus(t *testing.T) { ProgressJSON ProgressJSON BootstrapServerOnboardingInfo BootstrapServerOnboardingInfo BootstrapServerRedirectInfo BootstrapServerRedirectInfo + StatusFilePath string + ResultFilePath string + SymLinkDir string } tests := []struct { name string @@ -41,6 +71,9 @@ func TestAgent_RunCommandStatus(t *testing.T) { ProgressJSON: ProgressJSON{}, BootstrapServerRedirectInfo: BootstrapServerRedirectInfo{}, BootstrapServerOnboardingInfo: BootstrapServerOnboardingInfo{}, + StatusFilePath: testStatusFile, + ResultFilePath: testResultFile, + SymLinkDir: testSymLinkDir, }, }, } @@ -59,6 +92,12 @@ func TestAgent_RunCommandStatus(t *testing.T) { ProgressJSON: tt.fields.ProgressJSON, BootstrapServerOnboardingInfo: tt.fields.BootstrapServerOnboardingInfo, BootstrapServerRedirectInfo: tt.fields.BootstrapServerRedirectInfo, + StatusFilePath: tt.fields.StatusFilePath, + ResultFilePath: tt.fields.ResultFilePath, + SymLinkDir: tt.fields.SymLinkDir, + } + if err := a.prepareStatus(); err != nil { + t.Errorf("prepareStatus() error = %v", err) } if err := a.RunCommandStatus(); (err != nil) != tt.wantErr { t.Errorf("RunCommandStatus() error = %v, wantErr %v", err, tt.wantErr) diff --git a/sztp-agent/pkg/secureagent/utils.go b/sztp-agent/pkg/secureagent/utils.go index e9cc7a6..21aa3c7 100644 --- a/sztp-agent/pkg/secureagent/utils.go +++ b/sztp-agent/pkg/secureagent/utils.go @@ -9,6 +9,7 @@ Copyright (C) 2022 Red Hat. package secureagent import ( + "crypto/rand" "crypto/sha256" "encoding/json" "fmt" @@ -88,3 +89,103 @@ func calculateSHA256File(filePath string) (string, error) { checkSum := fmt.Sprintf("%x", h.Sum(nil)) return checkSum, nil } + +func saveToFile(data interface{}, filePath string) error { + filePath = filepath.Clean(filePath) + random, _ := rand.Prime(rand.Reader, 64) + tempPath := fmt.Sprintf("%s.%d.tmp", filePath, random) // rand number to avoid conflicts when multiple agents are running + tempPath = filepath.Clean(tempPath) + file, err := os.Create(tempPath) + if err != nil { + return err + } + defer func() { + if err := file.Close(); err != nil { + log.Println("[ERROR] Error when closing:", err) + } + }() + + encoder := json.NewEncoder(file) + if err := encoder.Encode(data); err != nil { + return err + } + + // Atomic move of temp file to replace the original. + if err := os.Rename(tempPath, filePath); err != nil { + return fmt.Errorf("failed to rename %s to %s: %v", tempPath, filePath, err) + } + + return nil +} + +func ensureDirExists(dir string) error { + if _, err := os.Stat(dir); os.IsNotExist(err) { + err := os.MkdirAll(dir, 0750) // Create the directory with appropriate permissions + if err != nil { + return fmt.Errorf("failed to create directory %s: %v", dir, err) + } + } + return nil +} + +func ensureFileExists(filePath string) error { + dir := filepath.Dir(filePath) + if err := ensureDirExists(dir); err != nil { + return err + } + + fmt.Printf("Checking if file %s exists...\n", filePath) + + if _, err := os.Stat(filePath); os.IsNotExist(err) { + filePath = filepath.Clean(filePath) + file, err := os.Create(filePath) + if err != nil { + return fmt.Errorf("[ERROR] failed to create file %s: %v", filePath, err) + } + defer func() { + if err := file.Close(); err != nil { + log.Println("[ERROR] Error when closing:", err) + } + }() + fmt.Printf("File %s created successfully.\n", filePath) + } else { + fmt.Printf("File %s already exists.\n", filePath) + } + return nil +} + +func createSymlink(targetFile, linkFile string) error { + targetFile = filepath.Clean(targetFile) + linkFile = filepath.Clean(linkFile) + + linkDir := filepath.Dir(linkFile) + if err := ensureDirExists(linkDir); err != nil { + return err + } + + // Check if linkFile exists and is a symlink to targetFile + if existingTarget, err := os.Readlink(linkFile); err == nil { + if existingTarget == targetFile { + return nil // Symlink already points to the target -> skip creation + } + // Remove the existing file (even if it's a wrong symlink or regular file) + if err := os.Remove(linkFile); err != nil { + return err + } + } + + return os.Symlink(targetFile, linkFile) +} + +func loadFile(filePath string, v interface{}) error { + filePath = filepath.Clean(filePath) + file, err := os.ReadFile(filePath) + if err != nil { + return err + } + err = json.Unmarshal(file, v) + if err != nil { + return err + } + return nil +} diff --git a/sztp-agent/pkg/secureagent/utils_test.go b/sztp-agent/pkg/secureagent/utils_test.go index bf32fec..a3217bf 100644 --- a/sztp-agent/pkg/secureagent/utils_test.go +++ b/sztp-agent/pkg/secureagent/utils_test.go @@ -5,7 +5,9 @@ package secureagent import ( + "encoding/json" "os" + "path/filepath" "testing" ) @@ -76,3 +78,171 @@ func Test_calculateSHA256File(t *testing.T) { t.Errorf("Checksum did not match %s %s", checksum, expected) } } + +func Test_saveToFile(t *testing.T) { + tempDir, err := os.MkdirTemp("", "test_save_to_file") + if err != nil { + t.Fatalf("failed to create temp directory: %v", err) + } + defer func() { + if err := os.RemoveAll(tempDir); err != nil { + t.Fatalf("failed to remove temp directory: %v", err) + } + }() + + filePath := filepath.Join(tempDir, "test.json") + data := map[string]string{"key": "value"} + + err = saveToFile(data, filePath) + if err != nil { + t.Fatalf("saveToFile returned an error: %v", err) + } + + _, err = os.Stat(filePath) + if os.IsNotExist(err) { + t.Fatalf("file %s was not created", filePath) + } + + filePath = filepath.Clean(filePath) + file, err := os.Open(filePath) + if err != nil { + t.Fatalf("failed to open the file: %v", err) + } + defer func() { + if err := file.Close(); err != nil { + t.Fatalf("failed to close the file: %v", err) + } + }() + + var readData map[string]string + decoder := json.NewDecoder(file) + err = decoder.Decode(&readData) + if err != nil { + t.Fatalf("failed to decode JSON data: %v", err) + } + + if readData["key"] != "value" { + t.Errorf("expected 'key' to be 'value', got %s", readData["key"]) + } +} + +func Test_ensureDirExists(t *testing.T) { + tempDir, err := os.MkdirTemp("", "test_ensure_dir_exists") + if err != nil { + t.Fatalf("failed to create temp directory: %v", err) + } + defer func() { + if err := os.RemoveAll(tempDir); err != nil { + t.Fatalf("failed to remove temp directory: %v", err) + } + }() + + newDir := filepath.Join(tempDir, "newdir") + + if _, err := os.Stat(newDir); !os.IsNotExist(err) { + t.Fatalf("expected directory %s to not exist", newDir) + } + + err = ensureDirExists(newDir) + if err != nil { + t.Fatalf("ensureDirExists returned an error: %v", err) + } + + if _, err := os.Stat(newDir); os.IsNotExist(err) { + t.Fatalf("expected directory %s to be created", newDir) + } + + err = ensureDirExists(newDir) + if err != nil { + t.Fatalf("ensureDirExists returned an error when directory already exists: %v", err) + } +} + +func Test_ensureFileExists(t *testing.T) { + tempDir, err := os.MkdirTemp("", "test_ensure_file_exists") + if err != nil { + t.Fatalf("failed to create temp directory: %v", err) + } + defer func() { + if err := os.RemoveAll(tempDir); err != nil { + t.Fatalf("failed to remove temp directory: %v", err) + } + }() + + newFilePath := filepath.Join(tempDir, "newdir", "testfile.txt") + + err = ensureFileExists(newFilePath) + if err != nil { + t.Fatalf("ensureFileExists returned an error: %v", err) + } + + if _, err := os.Stat(newFilePath); os.IsNotExist(err) { + t.Fatalf("expected file %s to be created", newFilePath) + } + + err = ensureFileExists(newFilePath) + if err != nil { + t.Fatalf("ensureFileExists returned an error when file already exists: %v", err) + } +} + +func Test_createSymlink(t *testing.T) { + tempDir, err := os.MkdirTemp("", "test_create_symlink") + if err != nil { + t.Fatalf("failed to create temp directory: %v", err) + } + defer func() { + if err := os.RemoveAll(tempDir); err != nil { + t.Fatalf("failed to remove temp directory: %v", err) + } + }() + + targetFile := filepath.Join(tempDir, "target.txt") + linkFile := filepath.Join(tempDir, "link.txt") + + err = os.WriteFile(targetFile, []byte("test data"), 0600) + if err != nil { + t.Fatalf("failed to create target file: %v", err) + } + + err = createSymlink(targetFile, linkFile) + if err != nil { + t.Fatalf("createSymlink returned an error: %v", err) + } + + linkInfo, err := os.Lstat(linkFile) + if err != nil { + t.Fatalf("failed to stat symlink: %v", err) + } + if linkInfo.Mode()&os.ModeSymlink == 0 { + t.Errorf("expected %s to be a symlink", linkFile) + } + + target, err := os.Readlink(linkFile) + if err != nil { + t.Fatalf("failed to read symlink: %v", err) + } + if target != targetFile { + t.Errorf("expected symlink to point to %s, got %s", targetFile, target) + } + + newTargetFile := filepath.Join(tempDir, "new_target.txt") + err = os.WriteFile(newTargetFile, []byte("new data"), 0600) + if err != nil { + t.Fatalf("failed to create new target file: %v", err) + } + + err = createSymlink(newTargetFile, linkFile) + if err != nil { + t.Fatalf("createSymlink returned an error when replacing symlink: %v", err) + } + + newTarget, err := os.Readlink(linkFile) + if err != nil { + t.Fatalf("failed to read new symlink: %v", err) + } + + if newTarget != newTargetFile { + t.Errorf("expected symlink to point to %s, got %s", newTargetFile, newTarget) + } +}