diff --git a/sztp-agent/pkg/secureagent/configuration.go b/sztp-agent/pkg/secureagent/configuration.go index 00a569a..95b35b6 100644 --- a/sztp-agent/pkg/secureagent/configuration.go +++ b/sztp-agent/pkg/secureagent/configuration.go @@ -10,7 +10,7 @@ import ( func (a *Agent) copyConfigurationFile() error { log.Println("[INFO] Starting the Copy Configuration.") _ = a.doReportProgress(ProgressTypeConfigInitiated, "Configuration Initiated") - _ = a.UpdateAndSaveStatus("config", true, "") + _ = a.updateAndSaveStatus("config", true, "") // Copy the configuration file to the device file, err := os.Create(ARTIFACTS_PATH + a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.InfoTimestampReference + "-config") if err != nil { @@ -37,7 +37,7 @@ func (a *Agent) copyConfigurationFile() error { } log.Println("[INFO] Configuration file copied successfully") _ = a.doReportProgress(ProgressTypeConfigComplete, "Configuration Complete") - _ = a.UpdateAndSaveStatus("config", false, "") + _ = a.updateAndSaveStatus("config", false, "") return nil } @@ -58,7 +58,7 @@ func (a *Agent) launchScriptsConfiguration(typeOf string) error { } log.Println("[INFO] Starting the " + scriptName + "-configuration.") _ = a.doReportProgress(reportStart, "Report starting") - _ = a.UpdateAndSaveStatus(scriptName+"-script", true, "") + _ = a.updateAndSaveStatus(scriptName+"-script", true, "") // nolint:gosec file, err := os.Create(ARTIFACTS_PATH + a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.InfoTimestampReference + scriptName + "configuration.sh") if err != nil { @@ -92,7 +92,7 @@ func (a *Agent) launchScriptsConfiguration(typeOf string) error { } log.Println(string(out)) // remove it _ = a.doReportProgress(reportEnd, "Report end") - _ = a.UpdateAndSaveStatus(scriptName+"-script", false, "") + _ = a.updateAndSaveStatus(scriptName+"-script", false, "") log.Println("[INFO] " + scriptName + "-Configuration script executed successfully") return nil } diff --git a/sztp-agent/pkg/secureagent/daemon.go b/sztp-agent/pkg/secureagent/daemon.go index a829a32..066e571 100644 --- a/sztp-agent/pkg/secureagent/daemon.go +++ b/sztp-agent/pkg/secureagent/daemon.go @@ -33,7 +33,7 @@ const ( // RunCommandDaemon runs the command in the background func (a *Agent) RunCommandDaemon() error { - if err := a.PrepareStatus(); err != nil { + if err := a.prepareStatus(); err != nil { log.Println("failed to prepare status: ", err) return err } @@ -50,7 +50,7 @@ func (a *Agent) RunCommandDaemon() error { } func (a *Agent) performBootstrapSequence() error { - _ = a.UpdateAndSaveStatus("bootstrap", true, "") + _ = a.updateAndSaveStatus("bootstrap", true, "") var err error err = a.discoverBootstrapURLs() if err != nil { @@ -81,7 +81,7 @@ func (a *Agent) performBootstrapSequence() error { return err } _ = a.doReportProgress(ProgressTypeBootstrapComplete, "Bootstrap Complete") - _ = a.UpdateAndSaveStatus("bootstrap", false, "") + _ = a.updateAndSaveStatus("bootstrap", false, "") return nil } @@ -148,7 +148,7 @@ func (a *Agent) doRequestBootstrapServerOnboardingInfo() error { } log.Println("[INFO] Response retrieved successfully") _ = a.doReportProgress(ProgressTypeBootstrapInitiated, "Bootstrap Initiated") - _ = a.UpdateAndSaveStatus("bootstrap", true, "") + _ = a.updateAndSaveStatus("bootstrap", true, "") crypto := res.IetfSztpBootstrapServerOutput.ConveyedInformation newVal, err := base64.StdEncoding.DecodeString(crypto) if err != nil { diff --git a/sztp-agent/pkg/secureagent/image.go b/sztp-agent/pkg/secureagent/image.go index a918bd5..4466698 100644 --- a/sztp-agent/pkg/secureagent/image.go +++ b/sztp-agent/pkg/secureagent/image.go @@ -23,7 +23,7 @@ import ( func (a *Agent) downloadAndValidateImage() error { log.Printf("[INFO] Starting the Download Image: %v", a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.BootImage.DownloadURI) _ = a.doReportProgress(ProgressTypeBootImageInitiated, "BootImage Initiated") - _ = a.UpdateAndSaveStatus("boot-image", true, "") + _ = a.updateAndSaveStatus("boot-image", true, "") // Download the image from DownloadURI and save it to a file a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.InfoTimestampReference = fmt.Sprintf("%8d", time.Now().Unix()) for i, item := range a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.BootImage.DownloadURI { @@ -79,7 +79,7 @@ func (a *Agent) downloadAndValidateImage() error { } log.Println("[INFO] Checksum verified successfully") _ = a.doReportProgress(ProgressTypeBootImageComplete, "BootImage Complete") - _ = a.UpdateAndSaveStatus("boot-image", false, "") + _ = a.updateAndSaveStatus("boot-image", false, "") return nil default: return errors.New("unsupported hash algorithm") diff --git a/sztp-agent/pkg/secureagent/run.go b/sztp-agent/pkg/secureagent/run.go index 978fc03..09dbf8c 100644 --- a/sztp-agent/pkg/secureagent/run.go +++ b/sztp-agent/pkg/secureagent/run.go @@ -13,7 +13,7 @@ import "log" // RunCommand runs the command in the background func (a *Agent) RunCommand() error { log.Println("runCommand started") - if err := a.PrepareStatus(); err != nil { + if err := a.prepareStatus(); err != nil { log.Println("failed to prepare status: ", err) return err } diff --git a/sztp-agent/pkg/secureagent/status.go b/sztp-agent/pkg/secureagent/status.go index 9dcc15d..656471c 100644 --- a/sztp-agent/pkg/secureagent/status.go +++ b/sztp-agent/pkg/secureagent/status.go @@ -59,7 +59,7 @@ func (a *Agent) loadStatusFile() (*Status, error) { return &status, nil } -func (a *Agent) UpdateAndSaveStatus(stage string, isStart bool, errMsg string) error { +func (a *Agent) updateAndSaveStatus(stage string, isStart bool, errMsg string) error { status, err := a.loadStatusFile() if err != nil { fmt.Println("Creating a new status file.") @@ -143,57 +143,6 @@ func (a *Agent) saveResult(result *Result) error { return saveToFile(result, a.GetResultFilePath()) } -// EnsureDirExists checks if a directory exists, and creates it if it doesn't. -func EnsureDirExists(dir string) error { - if _, err := os.Stat(dir); os.IsNotExist(err) { - err := os.MkdirAll(dir, 0755) // Create the directory with appropriate permissions - if err != nil { - return fmt.Errorf("failed to create directory %s: %v", dir, err) - } - } - return nil -} - -// EnsureFile ensures that a file exists; creates it if it does not. -func EnsureFileExists(filePath string) error { - // Ensure the directory exists - dir := filepath.Dir(filePath) - if err := EnsureDirExists(dir); err != nil { - return err - } - - // Check if the file already exists - if _, err := os.Stat(filePath); os.IsNotExist(err) { - // File does not exist, create it - file, err := os.Create(filePath) - if err != nil { - return fmt.Errorf("failed to create file %s: %v", filePath, err) - } - defer file.Close() - fmt.Printf("File %s created successfully.\n", filePath) - } else { - fmt.Printf("File %s already exists.\n", filePath) - } - return nil -} - -// CreateSymlink creates a symlink for a file from target to link location. -func CreateSymlink(targetFile, linkFile string) error { - // Ensure the directory for the symlink exists - linkDir := filepath.Dir(linkFile) - if err := EnsureDirExists(linkDir); err != nil { - return err - } - - // Remove any existing symlink - if _, err := os.Lstat(linkFile); err == nil { - os.Remove(linkFile) - } - - // Create a new symlink - return os.Symlink(targetFile, linkFile) -} - // RunCommandStatus runs the command in the background func (a *Agent) RunCommandStatus() error { log.Println("RunCommandStatus") @@ -207,19 +156,19 @@ func (a *Agent) RunCommandStatus() error { return nil } -func (a *Agent) PrepareStatus() error { +func (a *Agent) prepareStatus() error { log.Println("prepareStatus") // Ensure /run/sztp directory exists - if err := EnsureDirExists(a.GetSymLinkDir()); err != nil { + if err := ensureDirExists(a.GetSymLinkDir()); err != nil { fmt.Printf("Failed to create directory %s: %v\n", a.GetSymLinkDir(), err) return err } - if err := EnsureFileExists(a.GetStatusFilePath()); err != nil { + if err := ensureFileExists(a.GetStatusFilePath()); err != nil { return err } - if err := EnsureFileExists(a.GetResultFilePath()); err != nil { + if err := ensureFileExists(a.GetResultFilePath()); err != nil { return err } @@ -227,18 +176,18 @@ func (a *Agent) PrepareStatus() error { resultSymlinkPath := filepath.Join(a.GetSymLinkDir(), "result.json") // Create symlinks for status.json and result.json - if err := CreateSymlink(a.GetStatusFilePath(), statusSymlinkPath); err != nil { + if err := createSymlink(a.GetStatusFilePath(), statusSymlinkPath); err != nil { fmt.Printf("Failed to create symlink for status.json: %v\n", err) return err } - if err := CreateSymlink(a.GetResultFilePath(), resultSymlinkPath); err != nil { + if err := createSymlink(a.GetResultFilePath(), resultSymlinkPath); err != nil { fmt.Printf("Failed to create symlink for result.json: %v\n", err) return err } fmt.Println("Symlinks created successfully.") - if err := a.UpdateAndSaveStatus("init", true, ""); err != nil { + if err := a.updateAndSaveStatus("init", true, ""); err != nil { return err } diff --git a/sztp-agent/pkg/secureagent/status_test.go b/sztp-agent/pkg/secureagent/status_test.go index ece4925..d16d14c 100644 --- a/sztp-agent/pkg/secureagent/status_test.go +++ b/sztp-agent/pkg/secureagent/status_test.go @@ -69,7 +69,7 @@ func TestAgent_RunCommandStatus(t *testing.T) { ResultFilePath: tt.fields.ResultFilePath, SymLinkDir: tt.fields.SymLinkDir, } - a.PrepareStatus() + a.prepareStatus() if err := a.RunCommandStatus(); (err != nil) != tt.wantErr { t.Errorf("RunCommandStatus() error = %v, wantErr %v", err, tt.wantErr) } diff --git a/sztp-agent/pkg/secureagent/utils.go b/sztp-agent/pkg/secureagent/utils.go index 8ed6506..0eddb2f 100644 --- a/sztp-agent/pkg/secureagent/utils.go +++ b/sztp-agent/pkg/secureagent/utils.go @@ -106,3 +106,54 @@ func saveToFile(data interface{}, filePath string) error { // Atomic move of temp file to replace the original. return os.Rename(tempPath, filePath) } + +// EnsureDirExists checks if a directory exists, and creates it if it doesn't. +func ensureDirExists(dir string) error { + if _, err := os.Stat(dir); os.IsNotExist(err) { + err := os.MkdirAll(dir, 0755) // Create the directory with appropriate permissions + if err != nil { + return fmt.Errorf("failed to create directory %s: %v", dir, err) + } + } + return nil +} + +// EnsureFile ensures that a file exists; creates it if it does not. +func ensureFileExists(filePath string) error { + // Ensure the directory exists + dir := filepath.Dir(filePath) + if err := ensureDirExists(dir); err != nil { + return err + } + + // Check if the file already exists + if _, err := os.Stat(filePath); os.IsNotExist(err) { + // File does not exist, create it + file, err := os.Create(filePath) + if err != nil { + return fmt.Errorf("failed to create file %s: %v", filePath, err) + } + defer file.Close() + fmt.Printf("File %s created successfully.\n", filePath) + } else { + fmt.Printf("File %s already exists.\n", filePath) + } + return nil +} + +// CreateSymlink creates a symlink for a file from target to link location. +func createSymlink(targetFile, linkFile string) error { + // Ensure the directory for the symlink exists + linkDir := filepath.Dir(linkFile) + if err := ensureDirExists(linkDir); err != nil { + return err + } + + // Remove any existing symlink + if _, err := os.Lstat(linkFile); err == nil { + os.Remove(linkFile) + } + + // Create a new symlink + return os.Symlink(targetFile, linkFile) +} diff --git a/sztp-agent/pkg/secureagent/utils_test.go b/sztp-agent/pkg/secureagent/utils_test.go index bf32fec..42e134f 100644 --- a/sztp-agent/pkg/secureagent/utils_test.go +++ b/sztp-agent/pkg/secureagent/utils_test.go @@ -5,7 +5,9 @@ package secureagent import ( + "encoding/json" "os" + "path/filepath" "testing" ) @@ -76,3 +78,151 @@ func Test_calculateSHA256File(t *testing.T) { t.Errorf("Checksum did not match %s %s", checksum, expected) } } + +func Test_saveToFile(t *testing.T) { + tempDir, err := os.MkdirTemp("", "test_save_to_file") + if err != nil { + t.Fatalf("failed to create temp directory: %v", err) + } + defer os.RemoveAll(tempDir) + + filePath := filepath.Join(tempDir, "test.json") + data := map[string]string{"key": "value"} + + err = saveToFile(data, filePath) + if err != nil { + t.Fatalf("saveToFile returned an error: %v", err) + } + + _, err = os.Stat(filePath) + if os.IsNotExist(err) { + t.Fatalf("file %s was not created", filePath) + } + + file, err := os.Open(filePath) + if err != nil { + t.Fatalf("failed to open the file: %v", err) + } + defer file.Close() + + var readData map[string]string + decoder := json.NewDecoder(file) + err = decoder.Decode(&readData) + if err != nil { + t.Fatalf("failed to decode JSON data: %v", err) + } + + if readData["key"] != "value" { + t.Errorf("expected 'key' to be 'value', got %s", readData["key"]) + } +} + +func TestEnsureDirExists(t *testing.T) { + tempDir, err := os.MkdirTemp("", "test_ensure_dir_exists") + if err != nil { + t.Fatalf("failed to create temp directory: %v", err) + } + defer os.RemoveAll(tempDir) + + newDir := filepath.Join(tempDir, "newdir") + + if _, err := os.Stat(newDir); !os.IsNotExist(err) { + t.Fatalf("expected directory %s to not exist", newDir) + } + + err = ensureDirExists(newDir) + if err != nil { + t.Fatalf("ensureDirExists returned an error: %v", err) + } + + if _, err := os.Stat(newDir); os.IsNotExist(err) { + t.Fatalf("expected directory %s to be created", newDir) + } + + err = ensureDirExists(newDir) + if err != nil { + t.Fatalf("ensureDirExists returned an error when directory already exists: %v", err) + } +} + +func TestEnsureFileExists(t *testing.T) { + tempDir, err := os.MkdirTemp("", "test_ensure_file_exists") + if err != nil { + t.Fatalf("failed to create temp directory: %v", err) + } + defer os.RemoveAll(tempDir) + + newFilePath := filepath.Join(tempDir, "newdir", "testfile.txt") + + err = ensureFileExists(newFilePath) + if err != nil { + t.Fatalf("ensureFileExists returned an error: %v", err) + } + + if _, err := os.Stat(newFilePath); os.IsNotExist(err) { + t.Fatalf("expected file %s to be created", newFilePath) + } + + err = ensureFileExists(newFilePath) + if err != nil { + t.Fatalf("ensureFileExists returned an error when file already exists: %v", err) + } +} + +func TestCreateSymlink(t *testing.T) { + tempDir, err := os.MkdirTemp("", "test_create_symlink") + if err != nil { + t.Fatalf("failed to create temp directory: %v", err) + } + defer os.RemoveAll(tempDir) + + targetFile := filepath.Join(tempDir, "target.txt") + linkFile := filepath.Join(tempDir, "link.txt") + + err = os.WriteFile(targetFile, []byte("test data"), 0644) + if err != nil { + t.Fatalf("failed to create target file: %v", err) + } + + err = createSymlink(targetFile, linkFile) + if err != nil { + t.Fatalf("createSymlink returned an error: %v", err) + } + + linkInfo, err := os.Lstat(linkFile) + t.Logf("linkInfo: %v", linkInfo) /// + if err != nil { + t.Fatalf("failed to stat symlink: %v", err) + } + if linkInfo.Mode()&os.ModeSymlink == 0 { + t.Errorf("expected %s to be a symlink", linkFile) + } + + target, err := os.Readlink(linkFile) + if err != nil { + t.Fatalf("failed to read symlink: %v", err) + } + if target != targetFile { + t.Errorf("expected symlink to point to %s, got %s", targetFile, target) + } + + newTargetFile := filepath.Join(tempDir, "new_target.txt") + err = os.WriteFile(newTargetFile, []byte("new data"), 0644) + if err != nil { + t.Fatalf("failed to create new target file: %v", err) + } + + err = createSymlink(newTargetFile, linkFile) + if err != nil { + t.Fatalf("createSymlink returned an error when replacing symlink: %v", err) + } + + newTarget, err := os.Readlink(linkFile) + if err != nil { + t.Fatalf("failed to read new symlink: %v", err) + } + + if newTarget != newTargetFile { + t.Errorf("expected symlink to point to %s, got %s", newTargetFile, newTarget) + } +}