Skip to content

Commit

Permalink
Get tests passing
Browse files Browse the repository at this point in the history
Code reorgnization allows dnapitest to import the marshal/unmarshal
functions from the keys package.
  • Loading branch information
johnmaguire committed Dec 4, 2024
1 parent fdf9490 commit ebd0b25
Show file tree
Hide file tree
Showing 9 changed files with 103 additions and 82 deletions.
78 changes: 44 additions & 34 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"context"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rand"
"encoding/json"
"fmt"
"io"
Expand All @@ -15,6 +16,7 @@ import (
"sync/atomic"
"time"

"github.com/DefinedNet/dnapi/keys"
"github.com/DefinedNet/dnapi/message"
"github.com/sirupsen/logrus"
)
Expand Down Expand Up @@ -93,22 +95,22 @@ type EnrollMeta struct {
// On success it returns the Nebula config generated by the server, a Nebula private key PEM to be inserted into the
// config (see api.InsertConfigPrivateKey), credentials to be used in DNClient API requests, and a meta object
// containing organization info.
func (c *Client) Enroll(ctx context.Context, logger logrus.FieldLogger, code string) ([]byte, []byte, *Credentials, *EnrollMeta, error) {
func (c *Client) Enroll(ctx context.Context, logger logrus.FieldLogger, code string) ([]byte, []byte, *keys.Credentials, *EnrollMeta, error) {
logger.WithFields(logrus.Fields{"server": c.dnServer}).Debug("Making enrollment request to API")

// Generate keys for the enrollment request
keys, err := newKeys()
// Generate newKeys for the enrollment request
newKeys, err := keys.New()
if err != nil {
return nil, nil, nil, nil, err
}

// Make a request to the API with the enrollment code
jv, err := json.Marshal(message.EnrollRequest{
Code: code,
NebulaPubkeyX25519: keys.nebulaX25519PublicKeyPEM,
HostPubkeyEd25519: keys.hostEd25519PublicKeyPEM,
NebulaPubkeyP256: keys.nebulaP256PublicKeyPEM,
HostPubkeyP256: keys.hostP256PublicKeyPEM,
NebulaPubkeyX25519: newKeys.NebulaX25519PublicKeyPEM,
HostPubkeyEd25519: newKeys.HostEd25519PublicKeyPEM,
NebulaPubkeyP256: newKeys.NebulaP256PublicKeyPEM,
HostPubkeyP256: newKeys.HostP256PublicKeyPEM,
Timestamp: time.Now(),
})
if err != nil {
Expand Down Expand Up @@ -163,24 +165,24 @@ func (c *Client) Enroll(ctx context.Context, logger logrus.FieldLogger, code str

// Determine the private keys to save based on the network curve type
var privkeyPEM []byte
var privkey PrivateKey
var privkey keys.PrivateKey
switch r.Data.Network.Curve {
case message.NetworkCurve25519:
privkeyPEM = keys.nebulaX25519PrivateKeyPEM
privkey = Ed25519PrivateKey{keys.hostEd25519PrivateKey}
privkeyPEM = newKeys.NebulaX25519PrivateKeyPEM
privkey = keys.Ed25519PrivateKey{newKeys.HostEd25519PrivateKey}
case message.NetworkCurveP256:
privkeyPEM = keys.nebulaP256PrivateKeyPEM
privkey = P256PrivateKey{keys.hostP256PrivateKey}
privkeyPEM = newKeys.NebulaP256PrivateKeyPEM
privkey = keys.P256PrivateKey{newKeys.HostP256PrivateKey}
default:
return nil, nil, nil, nil, &APIError{e: fmt.Errorf("unsupported curve type: %s", r.Data.Network.Curve), ReqID: reqID}
}

trustedKeys, err := TrustedKeysFromPEM(r.Data.TrustedKeys)
trustedKeys, err := keys.TrustedKeysFromPEM(r.Data.TrustedKeys)
if err != nil {
return nil, nil, nil, nil, &APIError{e: fmt.Errorf("failed to load trusted keys from bundle: %s", err), ReqID: reqID}
}

creds := &Credentials{
creds := &keys.Credentials{
HostID: r.Data.HostID,
PrivateKey: privkey,
Counter: r.Data.Counter,
Expand All @@ -190,7 +192,7 @@ func (c *Client) Enroll(ctx context.Context, logger logrus.FieldLogger, code str
}

// CheckForUpdate sends a signed message to the DNClient API to learn if there is a new configuration available.
func (c *Client) CheckForUpdate(ctx context.Context, creds Credentials) (bool, error) {
func (c *Client) CheckForUpdate(ctx context.Context, creds keys.Credentials) (bool, error) {
respBody, err := c.postDNClient(ctx, message.CheckForUpdate, nil, creds.HostID, creds.Counter, creds.PrivateKey)
if err != nil {
return false, fmt.Errorf("failed to post message to dnclient api: %w", err)
Expand All @@ -205,7 +207,7 @@ func (c *Client) CheckForUpdate(ctx context.Context, creds Credentials) (bool, e

// LongPollWait sends a signed message to a DNClient API endpoint that will block, returning only
// if there is an action the client should take before the timeout (config updates, debug commands)
func (c *Client) LongPollWait(ctx context.Context, creds Credentials, supportedActions []string) (*message.LongPollWaitResponse, error) {
func (c *Client) LongPollWait(ctx context.Context, creds keys.Credentials, supportedActions []string) (*message.LongPollWaitResponse, error) {
value, err := json.Marshal(message.LongPollWaitRequest{
SupportedActions: supportedActions,
})
Expand All @@ -230,12 +232,12 @@ func (c *Client) LongPollWait(ctx context.Context, creds Credentials, supportedA
// is returned along with the new Nebula private key PEM and new DNClient API credentials.
//
// See dnapi.InsertConfigPrivateKey for how to insert the new Nebula private key into the configuration.
func (c *Client) DoUpdate(ctx context.Context, creds Credentials) ([]byte, []byte, *Credentials, error) {
func (c *Client) DoUpdate(ctx context.Context, creds keys.Credentials) ([]byte, []byte, *keys.Credentials, error) {
// Rotate keys
var nebulaPrivkeyPEM []byte // ECDH
var hostPrivkey PrivateKey // ECDSA
var nebulaPrivkeyPEM []byte // ECDH
var hostPrivkey keys.PrivateKey // ECDSA

newKeys, err := newKeys()
newKeys, err := keys.New()
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to generate new keys: %s", err)
}
Expand All @@ -247,15 +249,15 @@ func (c *Client) DoUpdate(ctx context.Context, creds Credentials) ([]byte, []byt
// Set the correct keypair based on the current private key type
switch creds.PrivateKey.Unwrap().(type) {
case ed25519.PrivateKey:
hostPrivkey = Ed25519PrivateKey{newKeys.hostEd25519PrivateKey}
nebulaPrivkeyPEM = newKeys.nebulaX25519PrivateKeyPEM
msg.EdPubkeyPEM = newKeys.hostEd25519PublicKeyPEM
msg.DHPubkeyPEM = newKeys.nebulaX25519PublicKeyPEM
hostPrivkey = keys.Ed25519PrivateKey{newKeys.HostEd25519PrivateKey}
nebulaPrivkeyPEM = newKeys.NebulaX25519PrivateKeyPEM
msg.EdPubkeyPEM = newKeys.HostEd25519PublicKeyPEM
msg.DHPubkeyPEM = newKeys.NebulaX25519PublicKeyPEM
case *ecdsa.PrivateKey:
hostPrivkey = P256PrivateKey{newKeys.hostP256PrivateKey}
nebulaPrivkeyPEM = newKeys.nebulaP256PrivateKeyPEM
msg.P256HostPubkeyPEM = newKeys.hostP256PublicKeyPEM
msg.P256NebulaPubkeyPEM = newKeys.nebulaP256PublicKeyPEM
hostPrivkey = keys.P256PrivateKey{newKeys.HostP256PrivateKey}
nebulaPrivkeyPEM = newKeys.NebulaP256PrivateKeyPEM
msg.P256HostPubkeyPEM = newKeys.HostP256PublicKeyPEM
msg.P256NebulaPubkeyPEM = newKeys.NebulaP256PublicKeyPEM
}

blob, err := json.Marshal(msg)
Expand Down Expand Up @@ -303,12 +305,12 @@ func (c *Client) DoUpdate(ctx context.Context, creds Credentials) ([]byte, []byt
return nil, nil, nil, fmt.Errorf("counter in request (%d) should be less than counter in response (%d)", creds.Counter, result.Counter)
}

trustedKeys, err := TrustedKeysFromPEM(result.TrustedKeys)
trustedKeys, err := keys.TrustedKeysFromPEM(result.TrustedKeys)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to load trusted keys from bundle: %s", err)
}

newCreds := &Credentials{
newCreds := &keys.Credentials{
HostID: creds.HostID,
Counter: result.Counter,
PrivateKey: hostPrivkey,
Expand All @@ -318,7 +320,7 @@ func (c *Client) DoUpdate(ctx context.Context, creds Credentials) ([]byte, []byt
return result.Config, nebulaPrivkeyPEM, newCreds, nil
}

func (c *Client) CommandResponse(ctx context.Context, creds Credentials, responseToken string, response any) error {
func (c *Client) CommandResponse(ctx context.Context, creds keys.Credentials, responseToken string, response any) error {
value, err := json.Marshal(message.CommandResponseRequest{
ResponseToken: responseToken,
Response: response,
Expand All @@ -331,7 +333,7 @@ func (c *Client) CommandResponse(ctx context.Context, creds Credentials, respons
return err
}

func (c *Client) StreamCommandResponse(ctx context.Context, creds Credentials, responseToken string) (*StreamController, error) {
func (c *Client) StreamCommandResponse(ctx context.Context, creds keys.Credentials, responseToken string) (*StreamController, error) {
value, err := json.Marshal(message.CommandResponseRequest{
ResponseToken: responseToken,
})
Expand All @@ -344,7 +346,7 @@ func (c *Client) StreamCommandResponse(ctx context.Context, creds Credentials, r

// streamingPostDNClient wraps and signs the given dnclientRequestWrapper message, and makes a streaming API call.
// On success, it returns a StreamController to interact with the request. On error, the error is returned.
func (c *Client) streamingPostDNClient(ctx context.Context, reqType string, value []byte, hostID string, counter uint, privkey PrivateKey) (*StreamController, error) {
func (c *Client) streamingPostDNClient(ctx context.Context, reqType string, value []byte, hostID string, counter uint, privkey keys.PrivateKey) (*StreamController, error) {
pr, pw := io.Pipe()

postBody, err := SignRequestV1(reqType, value, hostID, counter, privkey)
Expand Down Expand Up @@ -402,7 +404,7 @@ func (c *Client) streamingPostDNClient(ctx context.Context, reqType string, valu

// postDNClient wraps and signs the given dnclientRequestWrapper message, and makes the API call.
// On success, it returns the response message body. On error, the error is returned.
func (c *Client) postDNClient(ctx context.Context, reqType string, value []byte, hostID string, counter uint, privkey PrivateKey) ([]byte, error) {
func (c *Client) postDNClient(ctx context.Context, reqType string, value []byte, hostID string, counter uint, privkey keys.PrivateKey) ([]byte, error) {
postBody, err := SignRequestV1(reqType, value, hostID, counter, privkey)
if err != nil {
return nil, err
Expand Down Expand Up @@ -506,3 +508,11 @@ func (t *uaTransport) RoundTrip(req *http.Request) (*http.Response, error) {
req.Header.Set("User-Agent", t.useragent)
return t.T.RoundTrip(req)
}

func nonce() []byte {
nonce := make([]byte, 16)
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
panic(err)
}
return nonce
}
16 changes: 9 additions & 7 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (

"github.com/DefinedNet/dnapi/dnapitest"
"github.com/DefinedNet/dnapi/internal/testutil"
"github.com/DefinedNet/dnapi/keys"
"github.com/DefinedNet/dnapi/message"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
Expand Down Expand Up @@ -89,7 +90,7 @@ func TestEnroll(t *testing.T) {

assert.Equal(t, hostID, creds.HostID)
assert.Equal(t, counter, creds.Counter)
assert.Equal(t, []ed25519.PublicKey{ca.Details.PublicKey}, creds.TrustedKeys)
assert.Equal(t, []keys.TrustedKey{keys.Ed25519TrustedKey{ca.Details.PublicKey}}, creds.TrustedKeys)
assert.NotEmpty(t, creds.PrivateKey)
assert.NotEmpty(t, pkey)

Expand Down Expand Up @@ -196,7 +197,8 @@ func TestDoUpdate(t *testing.T) {
require.NoError(t, err)

// convert privkey to private key
pubkey := cert.MarshalEd25519PublicKey(creds.PrivateKey.Unwrap().(ed25519.PrivateKey).Public().(ed25519.PublicKey))
pubkey, err := keys.MarshalEd25519HostPublicKey(creds.PrivateKey.Unwrap().(ed25519.PrivateKey).Public().(ed25519.PublicKey))
require.NoError(t, err)

// make sure all credential values were set
assert.NotEmpty(t, creds.HostID)
Expand All @@ -214,11 +216,11 @@ func TestDoUpdate(t *testing.T) {
})

// Create a new, invalid requesting authentication key
_, invalidPrivKey, err := newEd25519Keypair()
nk, err := keys.New()
require.NoError(t, err)
invalidCreds := Credentials{
invalidCreds := keys.Credentials{
HostID: creds.HostID,
PrivateKey: Ed25519PrivateKey{invalidPrivKey},
PrivateKey: keys.Ed25519PrivateKey{nk.HostEd25519PrivateKey},
Counter: creds.Counter,
TrustedKeys: creds.TrustedKeys,
}
Expand All @@ -241,7 +243,7 @@ func TestDoUpdate(t *testing.T) {
}
rawRes := jsonMarshal(newConfigResponse)

_, newPrivkey, err := newEd25519Keypair()
nk, err := keys.New()
require.NoError(t, err)

// XXX the mock server will update the ed pubkey for us, but this is problematic because
Expand All @@ -253,7 +255,7 @@ func TestDoUpdate(t *testing.T) {
Data: message.SignedResponse{
Version: 1,
Message: rawRes,
Signature: ed25519.Sign(newPrivkey, rawRes),
Signature: ed25519.Sign(nk.HostEd25519PrivateKey, rawRes),
},
})
})
Expand Down
3 changes: 2 additions & 1 deletion dnapitest/dnapitest.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"net/http/httptest"
"time"

"github.com/DefinedNet/dnapi/keys"
"github.com/DefinedNet/dnapi/message"
"github.com/slackhq/nebula/cert"
"gopkg.in/yaml.v2"
Expand Down Expand Up @@ -95,7 +96,7 @@ func (s *Server) handlerEnroll(w http.ResponseWriter, r *http.Request) {

func (s *Server) SetEdPubkey(edPubkeyPEM []byte) error {
// hard failure, return
edPubkey, rest, err := cert.UnmarshalEd25519PublicKey(edPubkeyPEM)
edPubkey, rest, err := keys.UnmarshalEd25519HostPublicKey(edPubkeyPEM)
if err != nil {
return fmt.Errorf("failed to unmarshal ed pubkey: %w", err)
}
Expand Down
9 changes: 9 additions & 0 deletions keys/credentials.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package keys

// Credentials contains information necessary to make requests against the DNClient API.
type Credentials struct {
HostID string
PrivateKey PrivateKey
Counter uint
TrustedKeys []TrustedKey
}
50 changes: 21 additions & 29 deletions crypto.go → keys/crypto.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package dnapi
package keys

import (
"crypto/ecdsa"
Expand Down Expand Up @@ -70,7 +70,7 @@ func (k P256PrivateKey) MarshalPEM() ([]byte, error) {
return MarshalP256HostPrivateKey(k.PrivateKey)
}

// keys contains a set of P256 and X25519/Ed25519 keys. Only one set is used,
// Keys contains a set of P256 and X25519/Ed25519 keys. Only one set is used,
// depending on the network the host is enrolled in. At the time of enrollment
// clients do not know which curve the network uses, so both keys must be
// generated.
Expand All @@ -79,21 +79,21 @@ func (k P256PrivateKey) MarshalPEM() ([]byte, error) {
// DN API and the private Nebula key is written to disk and parsed by the
// Nebula library. The host private key is not marshalled to PEM here because
// we will need it to sign requests.
type keys struct {
type Keys struct {
// 25519 Curve
nebulaX25519PublicKeyPEM []byte // ECDH (Nebula)
nebulaX25519PrivateKeyPEM []byte // ECDH (Nebula)
hostEd25519PublicKeyPEM []byte // EdDSA (DN API)
hostEd25519PrivateKey ed25519.PrivateKey // EdDSA (DN API)
NebulaX25519PublicKeyPEM []byte // ECDH (Nebula)
NebulaX25519PrivateKeyPEM []byte // ECDH (Nebula)
HostEd25519PublicKeyPEM []byte // EdDSA (DN API)
HostEd25519PrivateKey ed25519.PrivateKey // EdDSA (DN API)

// P256 Curve
nebulaP256PublicKeyPEM []byte // ECDH (Nebula)
nebulaP256PrivateKeyPEM []byte // ECDH (Nebula)
hostP256PublicKeyPEM []byte // ECDSA (DN API)
hostP256PrivateKey *ecdsa.PrivateKey // ECDSA (DN API)
NebulaP256PublicKeyPEM []byte // ECDH (Nebula)
NebulaP256PrivateKeyPEM []byte // ECDH (Nebula)
HostP256PublicKeyPEM []byte // ECDSA (DN API)
HostP256PrivateKey *ecdsa.PrivateKey // ECDSA (DN API)
}

func newKeys() (*keys, error) {
func New() (*Keys, error) {
x25519PublicKeyPEM, x25519PrivateKeyPEM, ed25519PublicKey, ed25519PrivateKey, err := newKeys25519()
if err != nil {
return nil, err
Expand All @@ -114,15 +114,15 @@ func newKeys() (*keys, error) {
return nil, err
}

return &keys{
nebulaX25519PublicKeyPEM: x25519PublicKeyPEM,
nebulaX25519PrivateKeyPEM: x25519PrivateKeyPEM,
hostEd25519PublicKeyPEM: ed25519PublicKeyPEM,
hostEd25519PrivateKey: ed25519PrivateKey,
nebulaP256PublicKeyPEM: ecdhP256PublicKeyPEM,
nebulaP256PrivateKeyPEM: ecdhP256PrivateKeyPEM,
hostP256PublicKeyPEM: ecdsaP256PublicKeyPEM,
hostP256PrivateKey: ecdsaP256PrivateKey,
return &Keys{
NebulaX25519PublicKeyPEM: x25519PublicKeyPEM,
NebulaX25519PrivateKeyPEM: x25519PrivateKeyPEM,
HostEd25519PublicKeyPEM: ed25519PublicKeyPEM,
HostEd25519PrivateKey: ed25519PrivateKey,
NebulaP256PublicKeyPEM: ecdhP256PublicKeyPEM,
NebulaP256PrivateKeyPEM: ecdhP256PrivateKeyPEM,
HostP256PublicKeyPEM: ecdsaP256PublicKeyPEM,
HostP256PrivateKey: ecdsaP256PrivateKey,
}, nil
}

Expand Down Expand Up @@ -218,11 +218,3 @@ func newNebulaP256KeypairPEM() ([]byte, []byte, error) {

return pubkey, privkey, nil
}

func nonce() []byte {
nonce := make([]byte, 16)
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
panic(err)
}
return nonce
}
14 changes: 14 additions & 0 deletions keys/crypto_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package keys

import (
"testing"

"github.com/stretchr/testify/require"
)

func TestCrypto(t *testing.T) {
keys, err := New()
require.NoError(t, err)

t.Logf("ed25519 host pubkey: %s", keys.HostEd25519PublicKeyPEM)
}
2 changes: 1 addition & 1 deletion pem.go → keys/pem.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package dnapi
package keys

import (
"crypto/ecdsa"
Expand Down
Loading

0 comments on commit ebd0b25

Please sign in to comment.