Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: unify http client creation #448

Merged
merged 2 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion sztp-agent/cmd/daemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ func Daemon() *cobra.Command {
return fmt.Errorf("must not be folder: %q", filePath)
}
}
a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert)
client := secureagent.NewHTTPClient(bootstrapTrustAnchorCert, deviceEndEntityCert, devicePrivateKey)
a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, &client)
return a.RunCommandDaemon()
},
}
Expand Down
3 changes: 2 additions & 1 deletion sztp-agent/cmd/disable.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ func Disable() *cobra.Command {
Use: "disable",
Short: "Run the disable command",
RunE: func(_ *cobra.Command, _ []string) error {
a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert)
client := secureagent.NewHTTPClient(bootstrapTrustAnchorCert, deviceEndEntityCert, devicePrivateKey)
a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, &client)
return a.RunCommandDisable()
},
}
Expand Down
3 changes: 2 additions & 1 deletion sztp-agent/cmd/enable.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ func Enable() *cobra.Command {
Use: "enable",
Short: "Run the enable command",
RunE: func(_ *cobra.Command, _ []string) error {
a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert)
client := secureagent.NewHTTPClient(bootstrapTrustAnchorCert, deviceEndEntityCert, devicePrivateKey)
a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, &client)
return a.RunCommandEnable()
},
}
Expand Down
3 changes: 2 additions & 1 deletion sztp-agent/cmd/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ func Run() *cobra.Command {
return fmt.Errorf("must not be folder: %q", filePath)
}
}
a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert)
client := secureagent.NewHTTPClient(bootstrapTrustAnchorCert, deviceEndEntityCert, devicePrivateKey)
a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, &client)
return a.RunCommand()
},
}
Expand Down
3 changes: 2 additions & 1 deletion sztp-agent/cmd/status.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ func Status() *cobra.Command {
Use: "status",
Short: "Run the status command",
RunE: func(_ *cobra.Command, _ []string) error {
a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert)
client := secureagent.NewHTTPClient(bootstrapTrustAnchorCert, deviceEndEntityCert, devicePrivateKey)
a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, &client)
return a.RunCommandStatus()
},
}
Expand Down
14 changes: 12 additions & 2 deletions sztp-agent/pkg/secureagent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ Copyright (C) 2022 Red Hat.
// Package secureagent implements the secure agent
package secureagent

import (
"net/http"
)

const (
CONTENT_TYPE_YANG = "application/yang-data+json"
OS_RELEASE_FILE = "/etc/os-release"
Expand Down Expand Up @@ -68,6 +72,11 @@ type BootstrapServerErrorOutput struct {
} `json:"ietf-restconf:errors"`
}

type HttpClient interface {
glimchb marked this conversation as resolved.
Show resolved Hide resolved
Get(uri string) (*http.Response, error)
Do(req *http.Request) (*http.Response, error)
}

// Agent is the basic structure to define an agent instance
type Agent struct {
InputBootstrapURL string // Bootstrap complete URL given by USER
Expand All @@ -83,10 +92,10 @@ type Agent struct {
ProgressJSON ProgressJSON // ProgressJson structure
BootstrapServerOnboardingInfo BootstrapServerOnboardingInfo // BootstrapServerOnboardingInfo structure
BootstrapServerRedirectInfo BootstrapServerRedirectInfo // BootstrapServerRedirectInfo structure

HttpClient HttpClient
}

func NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert string) *Agent {
func NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert string, httpClient HttpClient) *Agent {
return &Agent{
InputBootstrapURL: bootstrapURL,
BootstrapURL: "",
Expand All @@ -101,6 +110,7 @@ func NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, deviceP
ProgressJSON: ProgressJSON{},
BootstrapServerRedirectInfo: BootstrapServerRedirectInfo{},
BootstrapServerOnboardingInfo: BootstrapServerOnboardingInfo{},
HttpClient: httpClient,
}
}

Expand Down
5 changes: 4 additions & 1 deletion sztp-agent/pkg/secureagent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Copyright (C) 2022 Red Hat.
package secureagent

import (
"net/http"
"reflect"
"testing"
)
Expand Down Expand Up @@ -829,6 +830,7 @@ func TestNewAgent(t *testing.T) {
deviceEndEntityCert string
bootstrapTrustAnchorCert string
}
client := http.Client{}
tests := []struct {
name string
args args
Expand Down Expand Up @@ -856,12 +858,13 @@ func TestNewAgent(t *testing.T) {
ContentTypeReq: "application/yang-data+json",
InputJSONContent: generateInputJSONContent(),
DhcpLeaseFile: "TestDhcpLeaseFile",
HttpClient: &client,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := NewAgent(tt.args.bootstrapURL, tt.args.serialNumber, tt.args.dhcpLeaseFile, tt.args.devicePassword, tt.args.devicePrivateKey, tt.args.deviceEndEntityCert, tt.args.bootstrapTrustAnchorCert); !reflect.DeepEqual(got, tt.want) {
if got := NewAgent(tt.args.bootstrapURL, tt.args.serialNumber, tt.args.dhcpLeaseFile, tt.args.devicePassword, tt.args.devicePrivateKey, tt.args.deviceEndEntityCert, tt.args.bootstrapTrustAnchorCert, &client); !reflect.DeepEqual(got, tt.want) {
t.Errorf("NewAgent() = %v, want %v", got, tt.want)
}
})
Expand Down
3 changes: 3 additions & 0 deletions sztp-agent/pkg/secureagent/configuration_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package secureagent

import (
"net/http"
"testing"
)

Expand Down Expand Up @@ -151,6 +152,7 @@ func TestAgent_copyConfigurationFile(t *testing.T) {
ProgressJSON: tt.fields.ProgressJSON,
BootstrapServerOnboardingInfo: tt.fields.BootstrapServerOnboardingInfo,
BootstrapServerRedirectInfo: tt.fields.BootstrapServerRedirectInfo,
HttpClient: &http.Client{},
}
if err := a.copyConfigurationFile(); (err != nil) != tt.wantErr {
t.Errorf("copyConfigurationFile() error = %v, wantErr %v", err, tt.wantErr)
Expand Down Expand Up @@ -368,6 +370,7 @@ func TestAgent_launchScriptsConfiguration(t *testing.T) {
ProgressJSON: tt.fields.ProgressJSON,
BootstrapServerOnboardingInfo: tt.fields.BootstrapServerOnboardingInfo,
BootstrapServerRedirectInfo: tt.fields.BootstrapServerRedirectInfo,
HttpClient: &http.Client{},
}
if err := a.launchScriptsConfiguration(tt.args.typeOf); (err != nil) != tt.wantErr {
t.Errorf("launchScriptsConfiguration() error = %v, wantErr %v", err, tt.wantErr)
Expand Down
1 change: 1 addition & 0 deletions sztp-agent/pkg/secureagent/daemon_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ func TestAgent_doReqBootstrap(t *testing.T) {
ContentTypeReq: tt.fields.ContentTypeReq,
InputJSONContent: tt.fields.InputJSONContent,
DhcpLeaseFile: tt.fields.DhcpLeaseFile,
HttpClient: &http.Client{},
}
if err := a.doRequestBootstrapServerOnboardingInfo(); (err != nil) != tt.wantErr {
t.Errorf("doRequestBootstrapServer() error = %v, wantErr %v", err, tt.wantErr)
Expand Down
25 changes: 1 addition & 24 deletions sztp-agent/pkg/secureagent/image.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,10 @@ Copyright (C) 2022 Red Hat.
package secureagent

import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
"log"
"net/http"
"os"
"path/filepath"
"strconv"
Expand All @@ -37,27 +34,7 @@ func (a *Agent) downloadAndValidateImage() error {
return err
}

caCert, _ := os.ReadFile(a.GetBootstrapTrustAnchorCert())
caCertPool := x509.NewCertPool()
caCertPool.AppendCertsFromPEM(caCert)
cert, _ := tls.LoadX509KeyPair(a.GetDeviceEndEntityCert(), a.GetDevicePrivateKey())

check := http.Client{
CheckRedirect: func(r *http.Request, _ []*http.Request) error {
r.URL.Opaque = r.URL.Path
return nil
},
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
//nolint:gosec
InsecureSkipVerify: true, // TODO: remove skip verify
RootCAs: caCertPool,
Certificates: []tls.Certificate{cert},
},
},
}

response, err := check.Get(item)
response, err := a.HttpClient.Get(item)
if err != nil {
return err
}
Expand Down
3 changes: 2 additions & 1 deletion sztp-agent/pkg/secureagent/image_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
//nolint:funlen
func TestAgent_downloadAndValidateImage(t *testing.T) {
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/imageOK" {
if r.URL.Path == "/imageOK" || r.URL.Path == "/report-progress" {
w.WriteHeader(200)
} else {
w.WriteHeader(400)
Expand Down Expand Up @@ -309,6 +309,7 @@ func TestAgent_downloadAndValidateImage(t *testing.T) {
ProgressJSON: tt.fields.ProgressJSON,
BootstrapServerOnboardingInfo: tt.fields.BootstrapServerOnboardingInfo,
BootstrapServerRedirectInfo: tt.fields.BootstrapServerRedirectInfo,
HttpClient: &http.Client{},
}
if err := a.downloadAndValidateImage(); (err != nil) != tt.wantErr {
t.Errorf("downloadAndValidateImage() error = %v, wantErr %v", err, tt.wantErr)
Expand Down
1 change: 1 addition & 0 deletions sztp-agent/pkg/secureagent/progress_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ func TestAgent_doReportProgress(t *testing.T) {
InputJSONContent: tt.fields.InputJSONContent,
DhcpLeaseFile: tt.fields.DhcpLeaseFile,
ProgressJSON: tt.fields.ProgressJSON,
HttpClient: &http.Client{},
}
if err := a.doReportProgress(ProgressTypeBootstrapInitiated, "Bootstrap Initiated"); (err != nil) != tt.wantErr {
t.Errorf("doReportProgress() error = %v, wantErr %v", err, tt.wantErr)
Expand Down
40 changes: 26 additions & 14 deletions sztp-agent/pkg/secureagent/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,35 @@
"log"
"net/http"
"os"
"path/filepath"
"strconv"
"strings"
)

// NewHTTPClient instantiate a new HTTP Client
func NewHTTPClient(bootstrapTrustAnchorCert string, deviceEndEntityCert string, devicePrivateKey string) http.Client {
certPath := filepath.Clean(bootstrapTrustAnchorCert)
caCert, _ := os.ReadFile(certPath)
caCertPool := x509.NewCertPool()
caCertPool.AppendCertsFromPEM(caCert)
cert, _ := tls.LoadX509KeyPair(deviceEndEntityCert, devicePrivateKey)
client := http.Client{
CheckRedirect: func(r *http.Request, _ []*http.Request) error {
r.URL.Opaque = r.URL.Path
return nil
},
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
//nolint:gosec
InsecureSkipVerify: true, // TODO: remove skip verify
Dismissed Show dismissed Hide dismissed
RootCAs: caCertPool,
Certificates: []tls.Certificate{cert},
},
},
}
return client
}

func (a *Agent) doTLSRequest(input string, url string, empty bool) (*BootstrapServerPostOutput, error) {
var postResponse BootstrapServerPostOutput
var errorResponse BootstrapServerErrorOutput
Expand All @@ -38,20 +63,7 @@
r.SetBasicAuth(a.GetSerialNumber(), a.GetDevicePassword())
r.Header.Add("Content-Type", a.GetContentTypeReq())

caCert, _ := os.ReadFile(a.GetBootstrapTrustAnchorCert())
caCertPool := x509.NewCertPool()
caCertPool.AppendCertsFromPEM(caCert)
cert, _ := tls.LoadX509KeyPair(a.GetDeviceEndEntityCert(), a.GetDevicePrivateKey())

client := &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{ //nolint:gosec
RootCAs: caCertPool,
Certificates: []tls.Certificate{cert},
},
},
}
res, err := client.Do(r)
res, err := a.HttpClient.Do(r)
if err != nil {
log.Println("Error doing the request", err.Error())
return nil, err
Expand Down
Loading