Skip to content

Commit

Permalink
refactor: move tls to a separate file
Browse files Browse the repository at this point in the history
Signed-off-by: Boris Glimcher <Boris.Glimcher@emc.com>
  • Loading branch information
glimchb committed Jun 25, 2024
1 parent 856a04e commit 1bb6e2c
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 122 deletions.
94 changes: 94 additions & 0 deletions sztp-agent/pkg/secureagent/tls.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
SPDX-License-Identifier: Apache-2.0
Copyright (C) 2022-2023 Intel Corporation
Copyright (c) 2022 Dell Inc, or its subsidiaries.
Copyright (C) 2022 Red Hat.
*/

// Package secureagent implements the secure agent
package secureagent

import (
"bytes"
"crypto/tls"
"crypto/x509"
"encoding/json"
"errors"
"io"
"log"
"net/http"
"os"
"strconv"
"strings"
)

func (a *Agent) doTLSRequest(input string, url string, empty bool) (*BootstrapServerPostOutput, error) {
var postResponse BootstrapServerPostOutput
var errorResponse BootstrapServerErrorOutput

log.Println("[DEBUG] Sending to: " + url)
log.Println("[DEBUG] Sending input: " + input)

body := strings.NewReader(input)
r, err := http.NewRequest(http.MethodPost, url, body)
if err != nil {
return nil, err
}

r.SetBasicAuth(a.GetSerialNumber(), a.GetDevicePassword())
r.Header.Add("Content-Type", a.GetContentTypeReq())

caCert, _ := os.ReadFile(a.GetBootstrapTrustAnchorCert())
caCertPool := x509.NewCertPool()
caCertPool.AppendCertsFromPEM(caCert)
cert, _ := tls.LoadX509KeyPair(a.GetDeviceEndEntityCert(), a.GetDevicePrivateKey())

client := &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{ //nolint:gosec
RootCAs: caCertPool,
Certificates: []tls.Certificate{cert},
},
},
}
res, err := client.Do(r)
if err != nil {
log.Println("Error doing the request", err.Error())
return nil, err
}
defer func() {
if err := res.Body.Close(); err != nil {
log.Println("Error when closing:", err)
}
}()

bodyBytes, err := io.ReadAll(res.Body)
if err != nil {
log.Println("Error reading the request", err.Error())
return nil, err
}

decoder := json.NewDecoder(bytes.NewReader(bodyBytes))
decoder.DisallowUnknownFields()
if !empty {
derr := decoder.Decode(&postResponse)
if derr != nil {
errdecoder := json.NewDecoder(bytes.NewReader(bodyBytes))
errdecoder.DisallowUnknownFields()
eerr := errdecoder.Decode(&errorResponse)
if eerr != nil {
log.Println("Received unknown response", string(bodyBytes))
return nil, derr
}
return nil, errors.New("[ERROR] Expected conveyed-information" +
", received error type=" + errorResponse.IetfRestconfErrors.Error[0].ErrorType +
", tag=" + errorResponse.IetfRestconfErrors.Error[0].ErrorTag +
", message=" + errorResponse.IetfRestconfErrors.Error[0].ErrorMessage)
}
log.Println(postResponse)
}
if res.StatusCode != http.StatusOK {
return nil, errors.New("[ERROR] Status code received: " + strconv.Itoa(res.StatusCode) + " ...but status code expected: " + strconv.Itoa(http.StatusOK))
}
return &postResponse, nil
}
53 changes: 53 additions & 0 deletions sztp-agent/pkg/secureagent/tls_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// SPDX-License-Identifier: Apache-2.0
// Copyright (C) 2022-2023 Red Hat.

// Package secureagent implements the secure agent
package secureagent

import (
"reflect"
"testing"
)

func TestAgent_doTLSRequest(t *testing.T) {
type fields struct {
BootstrapURL string
SerialNumber string
DevicePassword string
DevicePrivateKey string
DeviceEndEntityCert string
BootstrapTrustAnchorCert string
ContentTypeReq string
InputJSONContent string
DhcpLeaseFile string
}
var tests []struct {
name string
fields fields
want *BootstrapServerPostOutput
wantErr bool
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := &Agent{
BootstrapURL: tt.fields.BootstrapURL,
SerialNumber: tt.fields.SerialNumber,
DevicePassword: tt.fields.DevicePassword,
DevicePrivateKey: tt.fields.DevicePrivateKey,
DeviceEndEntityCert: tt.fields.DeviceEndEntityCert,
BootstrapTrustAnchorCert: tt.fields.BootstrapTrustAnchorCert,
ContentTypeReq: tt.fields.ContentTypeReq,
InputJSONContent: tt.fields.InputJSONContent,
DhcpLeaseFile: tt.fields.DhcpLeaseFile,
}
got, err := a.doTLSRequest(a.GetInputJSONContent(), a.GetBootstrapURL(), false)
if (err != nil) != tt.wantErr {
t.Errorf("doTLSRequest() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("doTLSRequest() got = %v, want %v", got, tt.want)
}
})
}
}
78 changes: 0 additions & 78 deletions sztp-agent/pkg/secureagent/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,10 @@ package secureagent

import (
"bufio"
"bytes"
"crypto/tls"
"crypto/x509"
"encoding/json"
"errors"
"io"
"log"
"net/http"
"os"
"regexp"
"strconv"
"strings"

"github.com/jaypipes/ghw"
Expand Down Expand Up @@ -49,77 +42,6 @@ func extractfromLine(line, regex string, index int) string {
return re.FindAllString(line, -1)[index]
}

func (a *Agent) doTLSRequest(input string, url string, empty bool) (*BootstrapServerPostOutput, error) {
var postResponse BootstrapServerPostOutput
var errorResponse BootstrapServerErrorOutput

log.Println("[DEBUG] Sending to: " + url)
log.Println("[DEBUG] Sending input: " + input)

body := strings.NewReader(input)
r, err := http.NewRequest(http.MethodPost, url, body)
if err != nil {
return nil, err
}

r.SetBasicAuth(a.GetSerialNumber(), a.GetDevicePassword())
r.Header.Add("Content-Type", a.GetContentTypeReq())

caCert, _ := os.ReadFile(a.GetBootstrapTrustAnchorCert())
caCertPool := x509.NewCertPool()
caCertPool.AppendCertsFromPEM(caCert)
cert, _ := tls.LoadX509KeyPair(a.GetDeviceEndEntityCert(), a.GetDevicePrivateKey())

client := &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{ //nolint:gosec
RootCAs: caCertPool,
Certificates: []tls.Certificate{cert},
},
},
}
res, err := client.Do(r)
if err != nil {
log.Println("Error doing the request", err.Error())
return nil, err
}
defer func() {
if err := res.Body.Close(); err != nil {
log.Println("Error when closing:", err)
}
}()

bodyBytes, err := io.ReadAll(res.Body)
if err != nil {
log.Println("Error reading the request", err.Error())
return nil, err
}

decoder := json.NewDecoder(bytes.NewReader(bodyBytes))
decoder.DisallowUnknownFields()
if !empty {
derr := decoder.Decode(&postResponse)
if derr != nil {
errdecoder := json.NewDecoder(bytes.NewReader(bodyBytes))
errdecoder.DisallowUnknownFields()
eerr := errdecoder.Decode(&errorResponse)
if eerr != nil {
log.Println("Received unknown response", string(bodyBytes))
return nil, derr
}
return nil, errors.New("[ERROR] Expected conveyed-information" +
", received error type=" + errorResponse.IetfRestconfErrors.Error[0].ErrorType +
", tag=" + errorResponse.IetfRestconfErrors.Error[0].ErrorTag +
", message=" + errorResponse.IetfRestconfErrors.Error[0].ErrorMessage)
}
log.Println(postResponse)
}
if res.StatusCode != http.StatusOK {
return nil, errors.New("[ERROR] Status code received: " + strconv.Itoa(res.StatusCode) + " ...but status code expected: " + strconv.Itoa(http.StatusOK))
}
return &postResponse, nil
}

// GetSerialNumber returns the serial number of the device
func GetSerialNumber(givenSerialNumber string) string {
if givenSerialNumber != "" {
Expand Down
44 changes: 0 additions & 44 deletions sztp-agent/pkg/secureagent/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,54 +5,10 @@
package secureagent

import (
"reflect"
"strings"
"testing"
)

func TestAgent_doTLSRequest(t *testing.T) {
type fields struct {
BootstrapURL string
SerialNumber string
DevicePassword string
DevicePrivateKey string
DeviceEndEntityCert string
BootstrapTrustAnchorCert string
ContentTypeReq string
InputJSONContent string
DhcpLeaseFile string
}
var tests []struct {
name string
fields fields
want *BootstrapServerPostOutput
wantErr bool
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := &Agent{
BootstrapURL: tt.fields.BootstrapURL,
SerialNumber: tt.fields.SerialNumber,
DevicePassword: tt.fields.DevicePassword,
DevicePrivateKey: tt.fields.DevicePrivateKey,
DeviceEndEntityCert: tt.fields.DeviceEndEntityCert,
BootstrapTrustAnchorCert: tt.fields.BootstrapTrustAnchorCert,
ContentTypeReq: tt.fields.ContentTypeReq,
InputJSONContent: tt.fields.InputJSONContent,
DhcpLeaseFile: tt.fields.DhcpLeaseFile,
}
got, err := a.doTLSRequest(a.GetInputJSONContent(), a.GetBootstrapURL(), false)
if (err != nil) != tt.wantErr {
t.Errorf("doTLSRequest() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("doTLSRequest() got = %v, want %v", got, tt.want)
}
})
}
}

func Test_extractfromLine(t *testing.T) {
type args struct {
line string
Expand Down

0 comments on commit 1bb6e2c

Please sign in to comment.