Skip to content

Commit

Permalink
Adapt tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dvob committed Jul 27, 2024
1 parent 4deac76 commit 3013b95
Show file tree
Hide file tree
Showing 7 changed files with 128 additions and 116 deletions.
7 changes: 2 additions & 5 deletions cmd/pcert/cobra.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,17 @@ package main
import (
"errors"
"fmt"
"os"
"strings"

"github.com/spf13/cobra"

Check failure on line 8 in cmd/pcert/cobra.go

View workflow job for this annotation

GitHub Actions / test

github.com/spf13/cobra@v1.8.1: replacement directory ../cobra does not exist
"github.com/spf13/pflag"
)

func WithEnv(c *cobra.Command) *cobra.Command {
func WithEnv(c *cobra.Command, args []string, getEnv func(name string) (string, bool)) *cobra.Command {
if c.HasParent() {
c = c.Root()
}

args := os.Args[1:]

var (
cmd *cobra.Command
err error
Expand All @@ -41,7 +38,7 @@ func WithEnv(c *cobra.Command) *cobra.Command {
optName := strings.ToUpper(f.Name)
optName = strings.ReplaceAll(optName, "-", "_")
varName := envVarPrefix + optName
if val, ok := os.LookupEnv(varName); ok {
if val, ok := getEnv(varName); ok {
err := f.Value.Set(val)
if err != nil {
errs = append(errs, fmt.Errorf("invalid environment variable '%s': %w", varName, err))
Expand Down
13 changes: 8 additions & 5 deletions cmd/pcert/create2.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (

type createCommand struct {
Out io.Writer
In io.Writer
In io.Reader

CertificateOutputLocation string
KeyOutputLocation string
Expand All @@ -41,8 +41,6 @@ func getKeyRelativeToCert(certPath string) string {

func newCreate2Cmd() *cobra.Command {
createCommand := &createCommand{
Out: os.Stdout,
In: os.Stdin,
CertificateOutputLocation: "",
KeyOutputLocation: "",
SignCertificateLocation: "",
Expand All @@ -63,6 +61,8 @@ pcert create tls.crt
`,
Args: cobra.MaximumNArgs(2),
RunE: func(cmd *cobra.Command, args []string) error {
createCommand.In = cmd.InOrStdin()
createCommand.Out = cmd.OutOrStdout()
// default key output file relative to certificate
if len(args) == 1 && args[0] != "-" {
createCommand.CertificateOutputLocation = args[0]
Expand Down Expand Up @@ -103,7 +103,7 @@ pcert create tls.crt
if createCommand.SignCertificateLocation != "" {
slog.Info("process signer")
if createCommand.SignCertificateLocation == "-" {
stdin, err = io.ReadAll(os.Stdin)
stdin, err = io.ReadAll(createCommand.In)
if err != nil {
return err
}
Expand Down Expand Up @@ -157,7 +157,10 @@ pcert create tls.crt
}

if createCommand.CertificateOutputLocation == "" || createCommand.CertificateOutputLocation == "-" {
createCommand.Out.Write(certPEM)
_, err := createCommand.Out.Write(certPEM)
if err != nil {
return err
}
} else {
err := os.WriteFile(createCommand.CertificateOutputLocation, certPEM, 0664)
if err != nil {
Expand Down
125 changes: 62 additions & 63 deletions cmd/pcert/create_test.go
Original file line number Diff line number Diff line change
@@ -1,65 +1,66 @@
package main

import (
"bytes"
"crypto/x509"
"crypto/x509/pkix"
"fmt"
"io"
"os"
"testing"
"time"

"github.com/dvob/pcert"
)

func runCmd(args []string, env map[string]string) error {
os.Clearenv()
for k, v := range env {
os.Setenv(k, v)
}
func runCmd(args []string, env map[string]string) (io.WriteCloser, *bytes.Buffer, *bytes.Buffer, error) {
stdout := &bytes.Buffer{}
stderr := &bytes.Buffer{}
stdinReader, stdinWriter := io.Pipe()
cmd := newRootCmd()
cmd.SetArgs(args)
return cmd.Execute()
cmd.SetIn(stdinReader)
cmd.SetOut(stdout)
cmd.SetErr(stderr)
cmd = WithEnv(cmd, args, func(name string) (string, bool) {
if env == nil {
return "", false
}
val, ok := env[name]
return val, ok
})

return stdinWriter, stdout, stderr, cmd.Execute()
}

func runCreateAndLoad(name string, args []string, env map[string]string) (*x509.Certificate, error) {
defer os.Remove(name + ".crt")
defer os.Remove(name + ".key")
fullArgs := []string{"create", name}
fullArgs = append(fullArgs, args...)
err := runCmd(fullArgs, env)
func runAndLoad(args []string, env map[string]string) (*x509.Certificate, error) {
_, stdout, stderr, err := runCmd(args, env)
if err != nil {
return nil, err
}

cert, err := pcert.Load(name + ".crt")
return cert, err
}
if stderr.Len() != 0 {
return nil, fmt.Errorf("stderr not empty '%s'", stderr.String())
}

func Test_create(t *testing.T) {
name := "foo1"
cert, err := runCreateAndLoad("foo1", []string{}, nil)
cert, err := pcert.Parse(stdout.Bytes())
if err != nil {
t.Error(err)
return
return nil, fmt.Errorf("could not read certificate from standard output: %s", err)
}

if cert.Subject.CommonName != name {
t.Errorf("common name no set correctly: got: %s, want: %s", cert.Subject.CommonName, name)
}
return cert, err
}

func Test_create_subject(t *testing.T) {
cn := "myCommonName"
cert, err := runCreateAndLoad("foo2", []string{
"--subject",
"CN=" + cn,
}, nil)
func Test_create(t *testing.T) {
name := "foo1"
cert, err := runAndLoad([]string{"create", "--subject", "/CN=" + name}, nil)
if err != nil {
t.Error(err)
t.Fatal(err)
return
}

if cert.Subject.CommonName != cn {
t.Errorf("common name no set correctly: got: %s, want: %s", cert.Subject.CommonName, cn)
if cert.Subject.CommonName != name {
t.Fatalf("common name no set correctly: got: %s, want: %s", cert.Subject.CommonName, name)
}
}

Expand All @@ -71,7 +72,8 @@ func Test_create_subject_multiple(t *testing.T) {
Organization: []string{"Snakeoil Ltd."},
OrganizationalUnit: []string{"Group 1", "Group 2"},
}
cert, err := runCreateAndLoad("subject2", []string{
cert, err := runAndLoad([]string{
"create",
"--subject",
"CN=Bla bla bla/C=CH/L=Bern",
"--subject",
Expand All @@ -80,12 +82,12 @@ func Test_create_subject_multiple(t *testing.T) {
"OU=Group 1/OU=Group 2",
}, nil)
if err != nil {
t.Error(err)
t.Fatal(err)
return
}

if subject.String() != cert.Subject.String() {
t.Errorf("subject no set correctly:\n got: %s\nwant: %s", cert.Subject, subject)
t.Fatalf("subject no set correctly:\n got: %s\nwant: %s", cert.Subject, subject)
}
}

Expand All @@ -100,54 +102,57 @@ func Test_create_subject_combined_with_environment(t *testing.T) {
Organization: []string{"Snakeoil Ltd."},
OrganizationalUnit: []string{"Group 1", "Group 2"},
}
cert, err := runCreateAndLoad("subject3", []string{
cert, err := runAndLoad([]string{
"create",
"--subject",
"CN=Bla bla bla",
"--subject",
"OU=Group 1/OU=Group 2",
}, env)
if err != nil {
t.Error(err)
t.Fatal(err)
return
}

if subject.String() != cert.Subject.String() {
t.Errorf("subject no set correctly:\n got: %s\nwant: %s", cert.Subject, subject)
t.Fatalf("subject no set correctly:\n got: %s\nwant: %s", cert.Subject, subject)
}
}

func Test_create_not_before(t *testing.T) {
notBefore := time.Date(2020, 10, 27, 12, 0, 0, 0, time.FixedZone("UTC+1", 60*60))
cert, err := runCreateAndLoad("foo3", []string{
cert, err := runAndLoad([]string{
"create",
"--not-before",
"2020-10-27T12:00:00+01:00",
}, nil)
if err != nil {
t.Error(err)
t.Fatal(err)
return
}

if !cert.NotBefore.Equal(notBefore) {
t.Errorf("not before not set correctly: got: %s, want: %s", cert.NotBefore, notBefore)
t.Fatalf("not before not set correctly: got: %s, want: %s", cert.NotBefore, notBefore)
}

notAfter := notBefore.Add(pcert.DefaultValidityPeriod)
if !cert.NotAfter.Equal(notAfter) {
t.Errorf("not after not set correctly: got: %s, want: %s", cert.NotAfter, notAfter)
t.Fatalf("not after not set correctly: got: %s, want: %s", cert.NotAfter, notAfter)
}
}

func Test_create_not_before_and_not_after(t *testing.T) {
notBefore := time.Date(2020, 12, 30, 12, 0, 0, 0, time.FixedZone("UTC+1", 60*60))
notAfter := time.Date(2022, 12, 30, 12, 0, 0, 0, time.FixedZone("UTC+1", 60*60))
cert, err := runCreateAndLoad("foo4", []string{
cert, err := runAndLoad([]string{
"create",
"--not-before",
"2020-12-30T12:00:00+01:00",
"--not-after",
"2022-12-30T12:00:00+01:00",
}, nil)
if err != nil {
t.Error(err)
t.Fatal(err)
return
}

Expand All @@ -162,13 +167,13 @@ func Test_create_not_before_and_not_after(t *testing.T) {

func Test_create_with_expiry(t *testing.T) {
now := time.Now().Round(time.Minute)
cert, err := runCreateAndLoad("foo4", []string{
cert, err := runAndLoad([]string{
"create",
"--expiry",
"3y",
}, nil)
if err != nil {
t.Error(err)
return
t.Fatal(err)
}

actualNotBefore := cert.NotBefore.Round(time.Minute)
Expand All @@ -185,14 +190,15 @@ func Test_create_with_expiry(t *testing.T) {

func Test_create_not_before_with_expiry(t *testing.T) {
notBefore := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)
cert, err := runCreateAndLoad("foo4", []string{
cert, err := runAndLoad([]string{
"create",
"--not-before",
"2020-01-01T00:00:00Z",
"--expiry",
"90d",
}, nil)
if err != nil {
t.Error(err)
t.Fatal(err)
return
}

Expand All @@ -207,30 +213,23 @@ func Test_create_not_before_with_expiry(t *testing.T) {
}

func Test_create_output_parameter(t *testing.T) {
name := "foo2"
certFile := "mycert_foo2"
keyFile := "mykey_foo2"
defer os.Remove(certFile)
defer os.Remove(keyFile)
err := runCmd([]string{
defer os.Remove("tls.crt")
defer os.Remove("tls.key")
_, _, _, err := runCmd([]string{
"create",
name,
"--cert",
certFile,
"--key",
keyFile,
"tls.crt",
}, nil)
if err != nil {
t.Error(err)
t.Fatal(err)
return
}

_, err = pcert.Load(certFile)
_, err = pcert.Load("tls.crt")
if err != nil {
t.Errorf("could not load certificate: %s", err)
}

_, err = pcert.LoadKey(keyFile)
_, err = pcert.LoadKey("tls.key")
if err != nil {
t.Errorf("could not load key: %s", err)
}
Expand Down
3 changes: 3 additions & 0 deletions cmd/pcert/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,11 @@ func BindCertificateOptionsFlags(fs *pflag.FlagSet, co *pcert.CertificateOptions
fs.IPSliceVar(&co.IPAddresses, "ip", []net.IP{}, "IP subject alternative name.")
fs.Var(newURISliceValue(&co.URIs), "uri", "URI subject alternative name.")
fs.Var(newSignAlgValue(&co.SignatureAlgorithm), "sign-alg", "Signature Algorithm. See 'pcert list' for available algorithms.")

fs.Var(newTimeValue(&co.NotBefore), "not-before", fmt.Sprintf("Not valid before time in RFC3339 format (e.g. '%s').", time.Now().UTC().Format(time.RFC3339)))
fs.Var(newTimeValue(&co.NotAfter), "not-after", fmt.Sprintf("Not valid after time in RFC3339 format (e.g. '%s').", time.Now().Add(time.Hour*24*60).UTC().Format(time.RFC3339)))
fs.Var(newDurationValue(&co.Expiry), "expiry", "Validity period of the certificate. If --not-after is set this option has no effect.")

fs.Var(newSubjectValue(&co.Subject), "subject", "Subject in the form '/C=CH/O=My Org/OU=My Team'.")

//fs.BoolVar(&co.BasicConstraintsValid, "basic-constraints", cert.BasicConstraintsValid, "Add basic constraints extension.")
Expand Down
2 changes: 1 addition & 1 deletion cmd/pcert/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ var (
)

func main() {
err := WithEnv(newRootCmd()).Execute()
err := WithEnv(newRootCmd(), os.Args[1:], os.LookupEnv).Execute()
if err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
Expand Down
Loading

0 comments on commit 3013b95

Please sign in to comment.