Skip to content

Commit

Permalink
Use httptest package for self-signed cert tests
Browse files Browse the repository at this point in the history
  • Loading branch information
doodlesbykumbi authored and szh committed Apr 24, 2023
1 parent 51565a7 commit 5b9382d
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 82 deletions.
45 changes: 38 additions & 7 deletions pkg/cmd/init_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ package cmd

import (
"fmt"
"net"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
Expand All @@ -20,7 +23,7 @@ var initCmdTestCases = []struct {
// https://github.com/go-survey/survey/issues/394
// This flag is used to enable Pipe-based, and not PTY-based, tests.
pipe bool
beforeTest func(t *testing.T, conjurrcInTmpDir string)
beforeTest func(t *testing.T, conjurrcInTmpDir string) func()
assert func(t *testing.T, conjurrcInTmpDir string, stdout string)
}{
{
Expand Down Expand Up @@ -94,8 +97,9 @@ appliance_url: http://conjur
},
},
pipe: true,
beforeTest: func(t *testing.T, conjurrcInTmpDir string) {
beforeTest: func(t *testing.T, conjurrcInTmpDir string) func() {
os.WriteFile(conjurrcInTmpDir, []byte("something"), 0644)
return nil
},
assert: func(t *testing.T, conjurrcInTmpDir string, stdout string) {
// Assert that file is not overwritten
Expand All @@ -113,8 +117,9 @@ appliance_url: http://conjur
},
},
pipe: true,
beforeTest: func(t *testing.T, conjurrcInTmpDir string) {
beforeTest: func(t *testing.T, conjurrcInTmpDir string) func() {
os.WriteFile(conjurrcInTmpDir, []byte("something"), 0644)
return nil
},
assert: func(t *testing.T, conjurrcInTmpDir string, stdout string) {
// Assert that file is overwritten
Expand Down Expand Up @@ -142,8 +147,9 @@ credential_storage: file
{
name: "force overwrite",
args: []string{"init", "-u=http://host", "-a=yet-another-test-account", "--force", "-i"},
beforeTest: func(t *testing.T, conjurrcInTmpDir string) {
beforeTest: func(t *testing.T, conjurrcInTmpDir string) func() {
os.WriteFile(conjurrcInTmpDir, []byte("something"), 0644)
return nil
},
assert: func(t *testing.T, conjurrcInTmpDir string, stdout string) {
// Assert that file is overwritten
Expand Down Expand Up @@ -205,23 +211,30 @@ appliance_url: http://host
},
{
name: "fails for self-signed certificate",
args: []string{"init", "-u=https://self-signed.badssl.com", "-a=test-account"},
args: []string{"init", "-u=https://localhost:8080", "-a=test-account"},
beforeTest: func(t *testing.T, conjurrcInTmpDir string) func() {
return startSelfSignedServer(t, 8080)
},
assert: func(t *testing.T, conjurrcInTmpDir string, stdout string) {
assert.Contains(t, stdout, "Unable to retrieve and validate certificate")
assert.Contains(t, stdout, "x509")
assert.Contains(t, stdout, "If you're attempting to use a self-signed certificate, re-run the init command with the `--self-signed` flag")
assertFetchCertFailed(t, conjurrcInTmpDir)
},
},
{
name: "succeeds for self-signed certificate with --self-signed flag",
args: []string{"init", "-u=https://self-signed.badssl.com", "-a=test-account", "--self-signed"},
args: []string{"init", "-u=https://localhost:8080", "-a=test-account", "--self-signed"},
promptResponses: []promptResponse{
{
prompt: "Trust this certificate?",
response: "y",
},
},
pipe: true,
beforeTest: func(t *testing.T, conjurrcInTmpDir string) func() {
return startSelfSignedServer(t, 8080)
},
assert: func(t *testing.T, conjurrcInTmpDir string, stdout string) {
assert.Contains(t, stdout, "Warning: Using self-signed certificates is not recommended and could lead to exposure of sensitive data")
assertCertWritten(t, conjurrcInTmpDir, stdout)
Expand Down Expand Up @@ -307,7 +320,10 @@ func TestInitCmd(t *testing.T) {
conjurrcInTmpDir := tempDir + "/.conjurrc"

if tc.beforeTest != nil {
tc.beforeTest(t, conjurrcInTmpDir)
cleanup := tc.beforeTest(t, conjurrcInTmpDir)
if cleanup != nil {
defer cleanup()
}
}

// --file default to conjurrcInTmpDir. It can always be overwritten in each test case
Expand Down Expand Up @@ -389,3 +405,18 @@ func assertCertWritten(t *testing.T, conjurrcInTmpDir string, stdout string) {
data, _ = os.ReadFile(expectedCertPath)
assert.Contains(t, string(data), "-----BEGIN CERTIFICATE-----")
}

func startSelfSignedServer(t *testing.T, port int) func() {
server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello, client")
}))
l, err := net.Listen("tcp", "localhost:8080")
if err != nil {
assert.NoError(t, err, "unabled to start test server")
}

server.Listener = l
server.StartTLS()

return func() { server.Close() }
}
12 changes: 2 additions & 10 deletions pkg/utils/tls.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package utils

import (
"bytes"
"crypto/sha1"
"crypto/tls"
"crypto/x509"
Expand Down Expand Up @@ -65,13 +64,6 @@ func GetServerCert(host string, allowSelfSigned bool) (ServerCert, error) {
}

func getSha1Fingerprint(cert []byte) string {
sum := sha1.Sum(cert)
var buf bytes.Buffer
for i, f := range sum {
if i > 0 {
buf.WriteString(":")
}
fmt.Fprintf(&buf, "%02X", f)
}
return buf.String()
sha1sum := sha1.Sum(cert)
return strings.ToUpper(fmt.Sprintf("%x", sha1sum))
}
94 changes: 29 additions & 65 deletions pkg/utils/tls_test.go
Original file line number Diff line number Diff line change
@@ -1,87 +1,36 @@
package utils

import (
"strings"
"fmt"
"net"
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
)

func TestGetServerCert(t *testing.T) {
// Note: this will change every few years when GitHub renews their TLS certificate
githubFingerprint := "A3:B5:9E:5F:E8:84:EE:1F:34:D9:8E:EF:85:8E:3F:B6:62:AC:10:4A"
githubCert := `-----BEGIN CERTIFICATE-----
MIIFajCCBPGgAwIBAgIQDNCovsYyz+ZF7KCpsIT7HDAKBggqhkjOPQQDAzBWMQsw
CQYDVQQGEwJVUzEVMBMGA1UEChMMRGlnaUNlcnQgSW5jMTAwLgYDVQQDEydEaWdp
Q2VydCBUTFMgSHlicmlkIEVDQyBTSEEzODQgMjAyMCBDQTEwHhcNMjMwMjE0MDAw
MDAwWhcNMjQwMzE0MjM1OTU5WjBmMQswCQYDVQQGEwJVUzETMBEGA1UECBMKQ2Fs
aWZvcm5pYTEWMBQGA1UEBxMNU2FuIEZyYW5jaXNjbzEVMBMGA1UEChMMR2l0SHVi
LCBJbmMuMRMwEQYDVQQDEwpnaXRodWIuY29tMFkwEwYHKoZIzj0CAQYIKoZIzj0D
AQcDQgAEo6QDRgPfRlFWy8k5qyLN52xZlnqToPu5QByQMog2xgl2nFD1Vfd2Xmgg
nO4i7YMMFTAQQUReMqyQodWq8uVDs6OCA48wggOLMB8GA1UdIwQYMBaAFAq8CCkX
jKU5bXoOzjPHLrPt+8N6MB0GA1UdDgQWBBTHByd4hfKdM8lMXlZ9XNaOcmfr3jAl
BgNVHREEHjAcggpnaXRodWIuY29tgg53d3cuZ2l0aHViLmNvbTAOBgNVHQ8BAf8E
BAMCB4AwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsGAQUFBwMCMIGbBgNVHR8EgZMw
gZAwRqBEoEKGQGh0dHA6Ly9jcmwzLmRpZ2ljZXJ0LmNvbS9EaWdpQ2VydFRMU0h5
YnJpZEVDQ1NIQTM4NDIwMjBDQTEtMS5jcmwwRqBEoEKGQGh0dHA6Ly9jcmw0LmRp
Z2ljZXJ0LmNvbS9EaWdpQ2VydFRMU0h5YnJpZEVDQ1NIQTM4NDIwMjBDQTEtMS5j
cmwwPgYDVR0gBDcwNTAzBgZngQwBAgIwKTAnBggrBgEFBQcCARYbaHR0cDovL3d3
dy5kaWdpY2VydC5jb20vQ1BTMIGFBggrBgEFBQcBAQR5MHcwJAYIKwYBBQUHMAGG
GGh0dHA6Ly9vY3NwLmRpZ2ljZXJ0LmNvbTBPBggrBgEFBQcwAoZDaHR0cDovL2Nh
Y2VydHMuZGlnaWNlcnQuY29tL0RpZ2lDZXJ0VExTSHlicmlkRUNDU0hBMzg0MjAy
MENBMS0xLmNydDAJBgNVHRMEAjAAMIIBgAYKKwYBBAHWeQIEAgSCAXAEggFsAWoA
dwDuzdBk1dsazsVct520zROiModGfLzs3sNRSFlGcR+1mwAAAYZQ3Rv6AAAEAwBI
MEYCIQDkFq7T4iy6gp+pefJLxpRS7U3gh8xQymmxtI8FdzqU6wIhALWfw/nLD63Q
YPIwG3EFchINvWUfB6mcU0t2lRIEpr8uAHYASLDja9qmRzQP5WoC+p0w6xxSActW
3SyB2bu/qznYhHMAAAGGUN0cKwAABAMARzBFAiAePGAyfiBR9dbhr31N9ZfESC5G
V2uGBTcyTyUENrH3twIhAPwJfsB8A4MmNr2nW+sdE1n2YiCObW+3DTHr2/UR7lvU
AHcAO1N3dT4tuYBOizBbBv5AO2fYT8P0x70ADS1yb+H61BcAAAGGUN0cOgAABAMA
SDBGAiEAzOBr9OZ0+6OSZyFTiywN64PysN0FLeLRyL5jmEsYrDYCIQDu0jtgWiMI
KU6CM0dKcqUWLkaFE23c2iWAhYAHqrFRRzAKBggqhkjOPQQDAwNnADBkAjAE3A3U
3jSZCpwfqOHBdlxi9ASgKTU+wg0qw3FqtfQ31OwLYFdxh0MlNk/HwkjRSWgCMFbQ
vMkXEPvNvv4t30K6xtpG26qmZ+6OiISBIIXMljWnsiYR1gyZnTzIg3AQSw4Vmw==
-----END CERTIFICATE-----`

// Note: this will change whenever certs are renewed for self-signed.badssl.com
selfSignedFingerprint := "42:B0:D7:0D:41:C3:7C:E7:09:9F:55:97:56:BC:51:E5:D0:34:24:51"

t.Run("Returns the right certificate from github.com", func(t *testing.T) {
cert, err := GetServerCert("github.com", false)
assert.NoError(t, err)

assert.Equal(t, githubCert, strings.TrimSpace(cert.Cert))
assert.Equal(t, githubFingerprint, cert.Fingerprint)
})

t.Run("Returns the right certificate from github.com:443", func(t *testing.T) {
cert, err := GetServerCert("github.com:443", false)
assert.NoError(t, err)

assert.Equal(t, githubCert, strings.TrimSpace(cert.Cert))
assert.Equal(t, githubFingerprint, cert.Fingerprint)
})

t.Run("Returns the right certificate from github.com when self-signed is allowed", func(t *testing.T) {
// Ensure that allowing self-signed certs doesn't break support for normal certs
cert, err := GetServerCert("github.com", true)
assert.NoError(t, err)

assert.Equal(t, githubCert, strings.TrimSpace(cert.Cert))
assert.Equal(t, githubFingerprint, cert.Fingerprint)
})

t.Run("Returns an error when the server doesn't exist", func(t *testing.T) {
_, err := GetServerCert("example.com:444", false)
assert.Error(t, err)
})

t.Run("Returns an error for self-signed certificates", func(t *testing.T) {
_, err := GetServerCert("self-signed.badssl.com", false)
server := startSelfSignedServer(t, 8080)
defer server.Close()

_, err := GetServerCert("localhost:8080", false)
assert.Error(t, err)
})

t.Run("Returns the right certificate for self-signed certificates when allowed", func(t *testing.T) {
cert, err := GetServerCert("self-signed.badssl.com", true)
server := startSelfSignedServer(t, 8080)
defer server.Close()

selfSignedFingerprint := getSha1Fingerprint(server.Certificate().Raw)

cert, err := GetServerCert("localhost:8080", true)
assert.NoError(t, err)

assert.Equal(t, selfSignedFingerprint, cert.Fingerprint)
Expand All @@ -97,3 +46,18 @@ vMkXEPvNvv4t30K6xtpG26qmZ+6OiISBIIXMljWnsiYR1gyZnTzIg3AQSw4Vmw==
assert.Error(t, err)
})
}

func startSelfSignedServer(t *testing.T, port int) *httptest.Server {
server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello, client")
}))
l, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", port))
if err != nil {
assert.NoError(t, err, "unabled to start test server")
}

server.Listener = l
server.StartTLS()

return server
}

0 comments on commit 5b9382d

Please sign in to comment.