From 1bb6e2c5a8c61e780cfea2d936caa15115ab64bb Mon Sep 17 00:00:00 2001 From: Boris Glimcher Date: Tue, 25 Jun 2024 19:03:10 +0300 Subject: [PATCH] refactor: move tls to a separate file Signed-off-by: Boris Glimcher --- sztp-agent/pkg/secureagent/tls.go | 94 ++++++++++++++++++++++++ sztp-agent/pkg/secureagent/tls_test.go | 53 +++++++++++++ sztp-agent/pkg/secureagent/utils.go | 78 -------------------- sztp-agent/pkg/secureagent/utils_test.go | 44 ----------- 4 files changed, 147 insertions(+), 122 deletions(-) create mode 100644 sztp-agent/pkg/secureagent/tls.go create mode 100644 sztp-agent/pkg/secureagent/tls_test.go diff --git a/sztp-agent/pkg/secureagent/tls.go b/sztp-agent/pkg/secureagent/tls.go new file mode 100644 index 00000000..44bdaf68 --- /dev/null +++ b/sztp-agent/pkg/secureagent/tls.go @@ -0,0 +1,94 @@ +/* +SPDX-License-Identifier: Apache-2.0 +Copyright (C) 2022-2023 Intel Corporation +Copyright (c) 2022 Dell Inc, or its subsidiaries. +Copyright (C) 2022 Red Hat. +*/ + +// Package secureagent implements the secure agent +package secureagent + +import ( + "bytes" + "crypto/tls" + "crypto/x509" + "encoding/json" + "errors" + "io" + "log" + "net/http" + "os" + "strconv" + "strings" +) + +func (a *Agent) doTLSRequest(input string, url string, empty bool) (*BootstrapServerPostOutput, error) { + var postResponse BootstrapServerPostOutput + var errorResponse BootstrapServerErrorOutput + + log.Println("[DEBUG] Sending to: " + url) + log.Println("[DEBUG] Sending input: " + input) + + body := strings.NewReader(input) + r, err := http.NewRequest(http.MethodPost, url, body) + if err != nil { + return nil, err + } + + r.SetBasicAuth(a.GetSerialNumber(), a.GetDevicePassword()) + r.Header.Add("Content-Type", a.GetContentTypeReq()) + + caCert, _ := os.ReadFile(a.GetBootstrapTrustAnchorCert()) + caCertPool := x509.NewCertPool() + caCertPool.AppendCertsFromPEM(caCert) + cert, _ := tls.LoadX509KeyPair(a.GetDeviceEndEntityCert(), a.GetDevicePrivateKey()) + + client := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ //nolint:gosec + RootCAs: caCertPool, + Certificates: []tls.Certificate{cert}, + }, + }, + } + res, err := client.Do(r) + if err != nil { + log.Println("Error doing the request", err.Error()) + return nil, err + } + defer func() { + if err := res.Body.Close(); err != nil { + log.Println("Error when closing:", err) + } + }() + + bodyBytes, err := io.ReadAll(res.Body) + if err != nil { + log.Println("Error reading the request", err.Error()) + return nil, err + } + + decoder := json.NewDecoder(bytes.NewReader(bodyBytes)) + decoder.DisallowUnknownFields() + if !empty { + derr := decoder.Decode(&postResponse) + if derr != nil { + errdecoder := json.NewDecoder(bytes.NewReader(bodyBytes)) + errdecoder.DisallowUnknownFields() + eerr := errdecoder.Decode(&errorResponse) + if eerr != nil { + log.Println("Received unknown response", string(bodyBytes)) + return nil, derr + } + return nil, errors.New("[ERROR] Expected conveyed-information" + + ", received error type=" + errorResponse.IetfRestconfErrors.Error[0].ErrorType + + ", tag=" + errorResponse.IetfRestconfErrors.Error[0].ErrorTag + + ", message=" + errorResponse.IetfRestconfErrors.Error[0].ErrorMessage) + } + log.Println(postResponse) + } + if res.StatusCode != http.StatusOK { + return nil, errors.New("[ERROR] Status code received: " + strconv.Itoa(res.StatusCode) + " ...but status code expected: " + strconv.Itoa(http.StatusOK)) + } + return &postResponse, nil +} diff --git a/sztp-agent/pkg/secureagent/tls_test.go b/sztp-agent/pkg/secureagent/tls_test.go new file mode 100644 index 00000000..b0c4c9ab --- /dev/null +++ b/sztp-agent/pkg/secureagent/tls_test.go @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (C) 2022-2023 Red Hat. + +// Package secureagent implements the secure agent +package secureagent + +import ( + "reflect" + "testing" +) + +func TestAgent_doTLSRequest(t *testing.T) { + type fields struct { + BootstrapURL string + SerialNumber string + DevicePassword string + DevicePrivateKey string + DeviceEndEntityCert string + BootstrapTrustAnchorCert string + ContentTypeReq string + InputJSONContent string + DhcpLeaseFile string + } + var tests []struct { + name string + fields fields + want *BootstrapServerPostOutput + wantErr bool + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &Agent{ + BootstrapURL: tt.fields.BootstrapURL, + SerialNumber: tt.fields.SerialNumber, + DevicePassword: tt.fields.DevicePassword, + DevicePrivateKey: tt.fields.DevicePrivateKey, + DeviceEndEntityCert: tt.fields.DeviceEndEntityCert, + BootstrapTrustAnchorCert: tt.fields.BootstrapTrustAnchorCert, + ContentTypeReq: tt.fields.ContentTypeReq, + InputJSONContent: tt.fields.InputJSONContent, + DhcpLeaseFile: tt.fields.DhcpLeaseFile, + } + got, err := a.doTLSRequest(a.GetInputJSONContent(), a.GetBootstrapURL(), false) + if (err != nil) != tt.wantErr { + t.Errorf("doTLSRequest() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("doTLSRequest() got = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/sztp-agent/pkg/secureagent/utils.go b/sztp-agent/pkg/secureagent/utils.go index 2ef18347..07a2a1ca 100644 --- a/sztp-agent/pkg/secureagent/utils.go +++ b/sztp-agent/pkg/secureagent/utils.go @@ -10,17 +10,10 @@ package secureagent import ( "bufio" - "bytes" - "crypto/tls" - "crypto/x509" "encoding/json" - "errors" - "io" "log" - "net/http" "os" "regexp" - "strconv" "strings" "github.com/jaypipes/ghw" @@ -49,77 +42,6 @@ func extractfromLine(line, regex string, index int) string { return re.FindAllString(line, -1)[index] } -func (a *Agent) doTLSRequest(input string, url string, empty bool) (*BootstrapServerPostOutput, error) { - var postResponse BootstrapServerPostOutput - var errorResponse BootstrapServerErrorOutput - - log.Println("[DEBUG] Sending to: " + url) - log.Println("[DEBUG] Sending input: " + input) - - body := strings.NewReader(input) - r, err := http.NewRequest(http.MethodPost, url, body) - if err != nil { - return nil, err - } - - r.SetBasicAuth(a.GetSerialNumber(), a.GetDevicePassword()) - r.Header.Add("Content-Type", a.GetContentTypeReq()) - - caCert, _ := os.ReadFile(a.GetBootstrapTrustAnchorCert()) - caCertPool := x509.NewCertPool() - caCertPool.AppendCertsFromPEM(caCert) - cert, _ := tls.LoadX509KeyPair(a.GetDeviceEndEntityCert(), a.GetDevicePrivateKey()) - - client := &http.Client{ - Transport: &http.Transport{ - TLSClientConfig: &tls.Config{ //nolint:gosec - RootCAs: caCertPool, - Certificates: []tls.Certificate{cert}, - }, - }, - } - res, err := client.Do(r) - if err != nil { - log.Println("Error doing the request", err.Error()) - return nil, err - } - defer func() { - if err := res.Body.Close(); err != nil { - log.Println("Error when closing:", err) - } - }() - - bodyBytes, err := io.ReadAll(res.Body) - if err != nil { - log.Println("Error reading the request", err.Error()) - return nil, err - } - - decoder := json.NewDecoder(bytes.NewReader(bodyBytes)) - decoder.DisallowUnknownFields() - if !empty { - derr := decoder.Decode(&postResponse) - if derr != nil { - errdecoder := json.NewDecoder(bytes.NewReader(bodyBytes)) - errdecoder.DisallowUnknownFields() - eerr := errdecoder.Decode(&errorResponse) - if eerr != nil { - log.Println("Received unknown response", string(bodyBytes)) - return nil, derr - } - return nil, errors.New("[ERROR] Expected conveyed-information" + - ", received error type=" + errorResponse.IetfRestconfErrors.Error[0].ErrorType + - ", tag=" + errorResponse.IetfRestconfErrors.Error[0].ErrorTag + - ", message=" + errorResponse.IetfRestconfErrors.Error[0].ErrorMessage) - } - log.Println(postResponse) - } - if res.StatusCode != http.StatusOK { - return nil, errors.New("[ERROR] Status code received: " + strconv.Itoa(res.StatusCode) + " ...but status code expected: " + strconv.Itoa(http.StatusOK)) - } - return &postResponse, nil -} - // GetSerialNumber returns the serial number of the device func GetSerialNumber(givenSerialNumber string) string { if givenSerialNumber != "" { diff --git a/sztp-agent/pkg/secureagent/utils_test.go b/sztp-agent/pkg/secureagent/utils_test.go index 8be75a78..c484b6aa 100644 --- a/sztp-agent/pkg/secureagent/utils_test.go +++ b/sztp-agent/pkg/secureagent/utils_test.go @@ -5,54 +5,10 @@ package secureagent import ( - "reflect" "strings" "testing" ) -func TestAgent_doTLSRequest(t *testing.T) { - type fields struct { - BootstrapURL string - SerialNumber string - DevicePassword string - DevicePrivateKey string - DeviceEndEntityCert string - BootstrapTrustAnchorCert string - ContentTypeReq string - InputJSONContent string - DhcpLeaseFile string - } - var tests []struct { - name string - fields fields - want *BootstrapServerPostOutput - wantErr bool - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - a := &Agent{ - BootstrapURL: tt.fields.BootstrapURL, - SerialNumber: tt.fields.SerialNumber, - DevicePassword: tt.fields.DevicePassword, - DevicePrivateKey: tt.fields.DevicePrivateKey, - DeviceEndEntityCert: tt.fields.DeviceEndEntityCert, - BootstrapTrustAnchorCert: tt.fields.BootstrapTrustAnchorCert, - ContentTypeReq: tt.fields.ContentTypeReq, - InputJSONContent: tt.fields.InputJSONContent, - DhcpLeaseFile: tt.fields.DhcpLeaseFile, - } - got, err := a.doTLSRequest(a.GetInputJSONContent(), a.GetBootstrapURL(), false) - if (err != nil) != tt.wantErr { - t.Errorf("doTLSRequest() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("doTLSRequest() got = %v, want %v", got, tt.want) - } - }) - } -} - func Test_extractfromLine(t *testing.T) { type args struct { line string