diff --git a/pkg/cmd/init_test.go b/pkg/cmd/init_test.go index 97d870af..4dd26cae 100644 --- a/pkg/cmd/init_test.go +++ b/pkg/cmd/init_test.go @@ -2,6 +2,9 @@ package cmd import ( "fmt" + "net" + "net/http" + "net/http/httptest" "os" "path/filepath" "testing" @@ -20,7 +23,7 @@ var initCmdTestCases = []struct { // https://github.com/go-survey/survey/issues/394 // This flag is used to enable Pipe-based, and not PTY-based, tests. pipe bool - beforeTest func(t *testing.T, conjurrcInTmpDir string) + beforeTest func(t *testing.T, conjurrcInTmpDir string) func() assert func(t *testing.T, conjurrcInTmpDir string, stdout string) }{ { @@ -94,8 +97,9 @@ appliance_url: http://conjur }, }, pipe: true, - beforeTest: func(t *testing.T, conjurrcInTmpDir string) { + beforeTest: func(t *testing.T, conjurrcInTmpDir string) func() { os.WriteFile(conjurrcInTmpDir, []byte("something"), 0644) + return nil }, assert: func(t *testing.T, conjurrcInTmpDir string, stdout string) { // Assert that file is not overwritten @@ -113,8 +117,9 @@ appliance_url: http://conjur }, }, pipe: true, - beforeTest: func(t *testing.T, conjurrcInTmpDir string) { + beforeTest: func(t *testing.T, conjurrcInTmpDir string) func() { os.WriteFile(conjurrcInTmpDir, []byte("something"), 0644) + return nil }, assert: func(t *testing.T, conjurrcInTmpDir string, stdout string) { // Assert that file is overwritten @@ -142,8 +147,9 @@ credential_storage: file { name: "force overwrite", args: []string{"init", "-u=http://host", "-a=yet-another-test-account", "--force", "-i"}, - beforeTest: func(t *testing.T, conjurrcInTmpDir string) { + beforeTest: func(t *testing.T, conjurrcInTmpDir string) func() { os.WriteFile(conjurrcInTmpDir, []byte("something"), 0644) + return nil }, assert: func(t *testing.T, conjurrcInTmpDir string, stdout string) { // Assert that file is overwritten @@ -205,16 +211,20 @@ appliance_url: http://host }, { name: "fails for self-signed certificate", - args: []string{"init", "-u=https://self-signed.badssl.com", "-a=test-account"}, + args: []string{"init", "-u=https://localhost:8080", "-a=test-account"}, + beforeTest: func(t *testing.T, conjurrcInTmpDir string) func() { + return startSelfSignedServer(t, 8080) + }, assert: func(t *testing.T, conjurrcInTmpDir string, stdout string) { assert.Contains(t, stdout, "Unable to retrieve and validate certificate") + assert.Contains(t, stdout, "x509") assert.Contains(t, stdout, "If you're attempting to use a self-signed certificate, re-run the init command with the `--self-signed` flag") assertFetchCertFailed(t, conjurrcInTmpDir) }, }, { name: "succeeds for self-signed certificate with --self-signed flag", - args: []string{"init", "-u=https://self-signed.badssl.com", "-a=test-account", "--self-signed"}, + args: []string{"init", "-u=https://localhost:8080", "-a=test-account", "--self-signed"}, promptResponses: []promptResponse{ { prompt: "Trust this certificate?", @@ -222,6 +232,9 @@ appliance_url: http://host }, }, pipe: true, + beforeTest: func(t *testing.T, conjurrcInTmpDir string) func() { + return startSelfSignedServer(t, 8080) + }, assert: func(t *testing.T, conjurrcInTmpDir string, stdout string) { assert.Contains(t, stdout, "Warning: Using self-signed certificates is not recommended and could lead to exposure of sensitive data") assertCertWritten(t, conjurrcInTmpDir, stdout) @@ -307,7 +320,10 @@ func TestInitCmd(t *testing.T) { conjurrcInTmpDir := tempDir + "/.conjurrc" if tc.beforeTest != nil { - tc.beforeTest(t, conjurrcInTmpDir) + cleanup := tc.beforeTest(t, conjurrcInTmpDir) + if cleanup != nil { + defer cleanup() + } } // --file default to conjurrcInTmpDir. It can always be overwritten in each test case @@ -389,3 +405,18 @@ func assertCertWritten(t *testing.T, conjurrcInTmpDir string, stdout string) { data, _ = os.ReadFile(expectedCertPath) assert.Contains(t, string(data), "-----BEGIN CERTIFICATE-----") } + +func startSelfSignedServer(t *testing.T, port int) func() { + server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "Hello, client") + })) + l, err := net.Listen("tcp", "localhost:8080") + if err != nil { + assert.NoError(t, err, "unabled to start test server") + } + + server.Listener = l + server.StartTLS() + + return func() { server.Close() } +} diff --git a/pkg/utils/tls.go b/pkg/utils/tls.go index 34a8c252..c6c3d5ec 100644 --- a/pkg/utils/tls.go +++ b/pkg/utils/tls.go @@ -1,7 +1,6 @@ package utils import ( - "bytes" "crypto/sha1" "crypto/tls" "crypto/x509" @@ -65,13 +64,6 @@ func GetServerCert(host string, allowSelfSigned bool) (ServerCert, error) { } func getSha1Fingerprint(cert []byte) string { - sum := sha1.Sum(cert) - var buf bytes.Buffer - for i, f := range sum { - if i > 0 { - buf.WriteString(":") - } - fmt.Fprintf(&buf, "%02X", f) - } - return buf.String() + sha1sum := sha1.Sum(cert) + return strings.ToUpper(fmt.Sprintf("%x", sha1sum)) } diff --git a/pkg/utils/tls_test.go b/pkg/utils/tls_test.go index 93db748d..5564b0fc 100644 --- a/pkg/utils/tls_test.go +++ b/pkg/utils/tls_test.go @@ -1,87 +1,36 @@ package utils import ( - "strings" + "fmt" + "net" + "net/http" + "net/http/httptest" "testing" "github.com/stretchr/testify/assert" ) func TestGetServerCert(t *testing.T) { - // Note: this will change every few years when GitHub renews their TLS certificate - githubFingerprint := "A3:B5:9E:5F:E8:84:EE:1F:34:D9:8E:EF:85:8E:3F:B6:62:AC:10:4A" - githubCert := `-----BEGIN CERTIFICATE----- -MIIFajCCBPGgAwIBAgIQDNCovsYyz+ZF7KCpsIT7HDAKBggqhkjOPQQDAzBWMQsw -CQYDVQQGEwJVUzEVMBMGA1UEChMMRGlnaUNlcnQgSW5jMTAwLgYDVQQDEydEaWdp -Q2VydCBUTFMgSHlicmlkIEVDQyBTSEEzODQgMjAyMCBDQTEwHhcNMjMwMjE0MDAw -MDAwWhcNMjQwMzE0MjM1OTU5WjBmMQswCQYDVQQGEwJVUzETMBEGA1UECBMKQ2Fs -aWZvcm5pYTEWMBQGA1UEBxMNU2FuIEZyYW5jaXNjbzEVMBMGA1UEChMMR2l0SHVi -LCBJbmMuMRMwEQYDVQQDEwpnaXRodWIuY29tMFkwEwYHKoZIzj0CAQYIKoZIzj0D -AQcDQgAEo6QDRgPfRlFWy8k5qyLN52xZlnqToPu5QByQMog2xgl2nFD1Vfd2Xmgg -nO4i7YMMFTAQQUReMqyQodWq8uVDs6OCA48wggOLMB8GA1UdIwQYMBaAFAq8CCkX -jKU5bXoOzjPHLrPt+8N6MB0GA1UdDgQWBBTHByd4hfKdM8lMXlZ9XNaOcmfr3jAl -BgNVHREEHjAcggpnaXRodWIuY29tgg53d3cuZ2l0aHViLmNvbTAOBgNVHQ8BAf8E -BAMCB4AwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsGAQUFBwMCMIGbBgNVHR8EgZMw -gZAwRqBEoEKGQGh0dHA6Ly9jcmwzLmRpZ2ljZXJ0LmNvbS9EaWdpQ2VydFRMU0h5 -YnJpZEVDQ1NIQTM4NDIwMjBDQTEtMS5jcmwwRqBEoEKGQGh0dHA6Ly9jcmw0LmRp -Z2ljZXJ0LmNvbS9EaWdpQ2VydFRMU0h5YnJpZEVDQ1NIQTM4NDIwMjBDQTEtMS5j -cmwwPgYDVR0gBDcwNTAzBgZngQwBAgIwKTAnBggrBgEFBQcCARYbaHR0cDovL3d3 -dy5kaWdpY2VydC5jb20vQ1BTMIGFBggrBgEFBQcBAQR5MHcwJAYIKwYBBQUHMAGG -GGh0dHA6Ly9vY3NwLmRpZ2ljZXJ0LmNvbTBPBggrBgEFBQcwAoZDaHR0cDovL2Nh -Y2VydHMuZGlnaWNlcnQuY29tL0RpZ2lDZXJ0VExTSHlicmlkRUNDU0hBMzg0MjAy -MENBMS0xLmNydDAJBgNVHRMEAjAAMIIBgAYKKwYBBAHWeQIEAgSCAXAEggFsAWoA -dwDuzdBk1dsazsVct520zROiModGfLzs3sNRSFlGcR+1mwAAAYZQ3Rv6AAAEAwBI -MEYCIQDkFq7T4iy6gp+pefJLxpRS7U3gh8xQymmxtI8FdzqU6wIhALWfw/nLD63Q -YPIwG3EFchINvWUfB6mcU0t2lRIEpr8uAHYASLDja9qmRzQP5WoC+p0w6xxSActW -3SyB2bu/qznYhHMAAAGGUN0cKwAABAMARzBFAiAePGAyfiBR9dbhr31N9ZfESC5G -V2uGBTcyTyUENrH3twIhAPwJfsB8A4MmNr2nW+sdE1n2YiCObW+3DTHr2/UR7lvU -AHcAO1N3dT4tuYBOizBbBv5AO2fYT8P0x70ADS1yb+H61BcAAAGGUN0cOgAABAMA -SDBGAiEAzOBr9OZ0+6OSZyFTiywN64PysN0FLeLRyL5jmEsYrDYCIQDu0jtgWiMI -KU6CM0dKcqUWLkaFE23c2iWAhYAHqrFRRzAKBggqhkjOPQQDAwNnADBkAjAE3A3U -3jSZCpwfqOHBdlxi9ASgKTU+wg0qw3FqtfQ31OwLYFdxh0MlNk/HwkjRSWgCMFbQ -vMkXEPvNvv4t30K6xtpG26qmZ+6OiISBIIXMljWnsiYR1gyZnTzIg3AQSw4Vmw== ------END CERTIFICATE-----` - - // Note: this will change whenever certs are renewed for self-signed.badssl.com - selfSignedFingerprint := "42:B0:D7:0D:41:C3:7C:E7:09:9F:55:97:56:BC:51:E5:D0:34:24:51" - - t.Run("Returns the right certificate from github.com", func(t *testing.T) { - cert, err := GetServerCert("github.com", false) - assert.NoError(t, err) - - assert.Equal(t, githubCert, strings.TrimSpace(cert.Cert)) - assert.Equal(t, githubFingerprint, cert.Fingerprint) - }) - - t.Run("Returns the right certificate from github.com:443", func(t *testing.T) { - cert, err := GetServerCert("github.com:443", false) - assert.NoError(t, err) - - assert.Equal(t, githubCert, strings.TrimSpace(cert.Cert)) - assert.Equal(t, githubFingerprint, cert.Fingerprint) - }) - - t.Run("Returns the right certificate from github.com when self-signed is allowed", func(t *testing.T) { - // Ensure that allowing self-signed certs doesn't break support for normal certs - cert, err := GetServerCert("github.com", true) - assert.NoError(t, err) - - assert.Equal(t, githubCert, strings.TrimSpace(cert.Cert)) - assert.Equal(t, githubFingerprint, cert.Fingerprint) - }) - t.Run("Returns an error when the server doesn't exist", func(t *testing.T) { _, err := GetServerCert("example.com:444", false) assert.Error(t, err) }) t.Run("Returns an error for self-signed certificates", func(t *testing.T) { - _, err := GetServerCert("self-signed.badssl.com", false) + server := startSelfSignedServer(t, 8080) + defer server.Close() + + _, err := GetServerCert("localhost:8080", false) assert.Error(t, err) }) t.Run("Returns the right certificate for self-signed certificates when allowed", func(t *testing.T) { - cert, err := GetServerCert("self-signed.badssl.com", true) + server := startSelfSignedServer(t, 8080) + defer server.Close() + + selfSignedFingerprint := getSha1Fingerprint(server.Certificate().Raw) + + cert, err := GetServerCert("localhost:8080", true) assert.NoError(t, err) assert.Equal(t, selfSignedFingerprint, cert.Fingerprint) @@ -97,3 +46,18 @@ vMkXEPvNvv4t30K6xtpG26qmZ+6OiISBIIXMljWnsiYR1gyZnTzIg3AQSw4Vmw== assert.Error(t, err) }) } + +func startSelfSignedServer(t *testing.T, port int) *httptest.Server { + server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "Hello, client") + })) + l, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + assert.NoError(t, err, "unabled to start test server") + } + + server.Listener = l + server.StartTLS() + + return server +}