From b7b38c4d02d6ecb3779e42c5f78ceaf7fd83ebab Mon Sep 17 00:00:00 2001 From: Jeet <113221510+JEETDESAI25@users.noreply.github.com> Date: Mon, 29 Dec 2025 03:02:57 -0800 Subject: [PATCH] Adding ECMR key exchange protocol. Closes #455 --- README.md | 1 + ecmr/client.go | 137 +++++++++++ ecmr/compat.go | 26 +++ ecmr/doc.go | 45 ++++ ecmr/ecmr.go | 42 ++++ ecmr/ecmr_test.go | 562 ++++++++++++++++++++++++++++++++++++++++++++++ ecmr/keys.go | 84 +++++++ ecmr/server.go | 44 ++++ 8 files changed, 941 insertions(+) create mode 100644 ecmr/client.go create mode 100644 ecmr/compat.go create mode 100644 ecmr/doc.go create mode 100644 ecmr/ecmr.go create mode 100644 ecmr/ecmr_test.go create mode 100644 ecmr/keys.go create mode 100644 ecmr/server.go diff --git a/README.md b/README.md index 897d6c42..2ff12d89 100644 --- a/README.md +++ b/README.md @@ -78,6 +78,7 @@ Alternatively, look at the [Cloudflare Go](https://github.com/cloudflare/go/tree - [OT](./ot/simot): Simplest Oblivious Transfer ([ia.cr/2015/267]). - [Threshold RSA](./tss/rsa) Signatures ([Shoup Eurocrypt 2000](https://www.iacr.org/archive/eurocrypt2000/1807/18070209-new.pdf)). - [Prio3](./vdaf/prio3) Verifiable Distributed Aggregation Function ([draft-irtf-cfrg-vdaf](https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/)). + - [ECMR](./ecmr): McCallum-Relyea key exchange for Tang/Clevis. ### Post-Quantum Cryptography diff --git a/ecmr/client.go b/ecmr/client.go new file mode 100644 index 00000000..3274bc57 --- /dev/null +++ b/ecmr/client.go @@ -0,0 +1,137 @@ +package ecmr + +import ( + "io" + + "github.com/cloudflare/circl/group" +) + +type Client struct{} + +func NewClient() *Client { + return &Client{} +} + +// Provision generates a new client key pair and computes the shared point +// with the server's public key. +func (c *Client) Provision(serverPub *PublicKey, rnd io.Reader) (*ProvisionResult, error) { + if serverPub == nil || serverPub.element == nil { + return nil, ErrNilKey + } + if rnd == nil { + return nil, ErrNilReader + } + + clientScalar := group.P521.RandomNonZeroScalar(rnd) + + clientPub := group.P521.NewElement().MulGen(clientScalar) + clientPubBytes, err := clientPub.MarshalBinary() + if err != nil { + return nil, ErrMalformedPoint + } + + sharedPoint := group.P521.NewElement().Mul(serverPub.element, clientScalar) + sharedPointBytes, err := sharedPoint.MarshalBinary() + if err != nil { + return nil, ErrMalformedPoint + } + + zeroScalar(clientScalar) + + return &ProvisionResult{ + ClientPublic: clientPubBytes, + SharedPoint: sharedPointBytes, + }, nil +} + +// CreateRecoveryRequest creates a blinded recovery request using the stored +// client public key and a fresh ephemeral scalar. +func (c *Client) CreateRecoveryRequest( + clientPublicBytes []byte, + serverPub *PublicKey, + rnd io.Reader, +) (*RecoveryRequest, *RecoveryState, error) { + if serverPub == nil || serverPub.element == nil { + return nil, nil, ErrNilKey + } + if rnd == nil { + return nil, nil, ErrNilReader + } + + if len(clientPublicBytes) != PublicKeySize { + return nil, nil, ErrMalformedPoint + } + + clientPub := group.P521.NewElement() + if err := clientPub.UnmarshalBinary(clientPublicBytes); err != nil { + return nil, nil, ErrMalformedPoint + } + if clientPub.IsIdentity() { + return nil, nil, ErrIdentityPoint + } + + ephemeral := group.P521.RandomNonZeroScalar(rnd) + + ephemeralPub := group.P521.NewElement().MulGen(ephemeral) + + blindedPoint := group.P521.NewElement().Add(clientPub, ephemeralPub) + blindedPointBytes, err := blindedPoint.MarshalBinary() + if err != nil { + zeroScalar(ephemeral) + return nil, nil, ErrMalformedPoint + } + + state := &RecoveryState{ + ephemeral: ephemeral, + serverPub: group.P521.NewElement().Set(serverPub.element), + } + + return &RecoveryRequest{BlindedPoint: blindedPointBytes}, state, nil +} + +// RecoverKey completes key recovery using the server's response. +// After calling this function, the RecoveryState is invalidated. +func (c *Client) RecoverKey( + state *RecoveryState, + response *RecoveryResponse, +) ([]byte, error) { + if state == nil || state.ephemeral == nil || state.serverPub == nil { + return nil, ErrNilKey + } + + defer func() { + zeroScalar(state.ephemeral) + state.ephemeral = nil + state.serverPub = nil + }() + + if response == nil || len(response.ProcessedPoint) != SharedPointSize { + return nil, ErrMalformedPoint + } + + serverResponse := group.P521.NewElement() + if err := serverResponse.UnmarshalBinary(response.ProcessedPoint); err != nil { + return nil, ErrMalformedPoint + } + if serverResponse.IsIdentity() { + return nil, ErrIdentityPoint + } + + blindingFactor := group.P521.NewElement().Mul(state.serverPub, state.ephemeral) + + negBlindingFactor := group.P521.NewElement().Neg(blindingFactor) + sharedPoint := group.P521.NewElement().Add(serverResponse, negBlindingFactor) + + sharedPointBytes, err := sharedPoint.MarshalBinary() + if err != nil { + return nil, ErrMalformedPoint + } + + return sharedPointBytes, nil +} + +func zeroScalar(s group.Scalar) { + if s != nil { + s.SetUint64(0) + } +} diff --git a/ecmr/compat.go b/ecmr/compat.go new file mode 100644 index 00000000..6a96d766 --- /dev/null +++ b/ecmr/compat.go @@ -0,0 +1,26 @@ +package ecmr + +import ( + "github.com/cloudflare/circl/group" +) + +// ExtractX extracts the x-coordinate from an uncompressed P-521 point. +// It validates the point is on-curve and not the identity before extracting. +func ExtractX(uncompressedPoint []byte) ([]byte, error) { + element := group.P521.NewElement() + if err := element.UnmarshalBinary(uncompressedPoint); err != nil { + return nil, ErrMalformedPoint + } + if element.IsIdentity() { + return nil, ErrIdentityPoint + } + + canonical, err := element.MarshalBinary() + if err != nil { + return nil, ErrMalformedPoint + } + + x := make([]byte, XCoordinateSize) + copy(x, canonical[1:1+XCoordinateSize]) + return x, nil +} diff --git a/ecmr/doc.go b/ecmr/doc.go new file mode 100644 index 00000000..d51ac245 --- /dev/null +++ b/ecmr/doc.go @@ -0,0 +1,45 @@ +// Package ecmr implements the McCallum-Relyea key exchange protocol for P-521. +// +// This protocol is used by Tang/Clevis for network-bound disk encryption (NBDE). +// It allows a client to derive a shared secret with a server's help, without +// the server ever learning the secret. +// +// # Timing Properties +// +// The scalar operations in this package (multiplication, addition) use CIRCL's +// group.P521, which delegates to Go's crypto/ecdh for constant-time scalar +// multiplication. +// +// IMPORTANT: Point serialization and validation are NOT constant-time due to +// limitations in the underlying group package: +// - MarshalBinary: calls big.Int.Mod and ecdsa.PublicKey.ECDH() +// - UnmarshalBinary: uses big.Int for coordinate parsing and curve checks +// +// Both operations may leak timing information about point coordinates. For +// Tang/Clevis deployments where the threat model is network-based key escrow, +// this is typically acceptable. Evaluate whether this meets your requirements. +// +// # Subgroup Membership +// +// P-521 is a prime-order curve (cofactor = 1). Every point validated as on-curve +// is automatically in the prime-order subgroup. No additional cofactor clearing +// or subgroup checks are needed. +// +// # Tang/Clevis Interoperability +// +// For Tang compatibility: +// 1. Call Provision or RecoverKey to get SharedPoint (133 bytes, uncompressed) +// 2. Extract x-coordinate: x, err := ecmr.ExtractX(sharedPoint) +// 3. Apply Concat KDF (RFC 7518 ยง4.6) with x as the shared secret +// +// ExtractX validates the point is on-curve before extracting, preventing +// corrupted stored state from producing invalid keys. Note that this validation +// uses variable-time operations (see Timing Properties above). +// +// # Supported Curves +// +// Only P-521 is supported. The API uses concrete types with no curve parameters. +// All key construction goes through GenerateKey or UnmarshalBinary, which +// exclusively use group.P521. Zero-value structs (e.g., &PublicKey{}) will fail +// at runtime with ErrNilKey. +package ecmr diff --git a/ecmr/ecmr.go b/ecmr/ecmr.go new file mode 100644 index 00000000..a8375125 --- /dev/null +++ b/ecmr/ecmr.go @@ -0,0 +1,42 @@ +package ecmr + +import ( + "errors" + + "github.com/cloudflare/circl/group" +) + +const ( + PublicKeySize = 133 + PrivateKeySize = 66 + SharedPointSize = 133 + UncompressedPointSize = 133 + XCoordinateSize = 66 +) + +var ( + ErrMalformedPoint = errors.New("ecmr: malformed point encoding") + ErrIdentityPoint = errors.New("ecmr: identity point not allowed") + ErrMalformedScalar = errors.New("ecmr: malformed scalar encoding") + ErrZeroScalar = errors.New("ecmr: zero scalar not allowed") + ErrNilReader = errors.New("ecmr: nil random reader") + ErrNilKey = errors.New("ecmr: nil or uninitialized key") +) + +type ProvisionResult struct { + ClientPublic []byte + SharedPoint []byte +} + +type RecoveryRequest struct { + BlindedPoint []byte +} + +type RecoveryResponse struct { + ProcessedPoint []byte +} + +type RecoveryState struct { + ephemeral group.Scalar + serverPub group.Element +} diff --git a/ecmr/ecmr_test.go b/ecmr/ecmr_test.go new file mode 100644 index 00000000..3965f90a --- /dev/null +++ b/ecmr/ecmr_test.go @@ -0,0 +1,562 @@ +package ecmr + +import ( + "bytes" + "crypto/rand" + "errors" + "testing" + + "github.com/cloudflare/circl/group" +) + +func TestProvisionAndRecover(t *testing.T) { + serverKey, err := GenerateKey(rand.Reader) + if err != nil { + t.Fatalf("GenerateKey failed: %v", err) + } + + server, err := NewServer(serverKey) + if err != nil { + t.Fatalf("NewServer failed: %v", err) + } + + client := NewClient() + + provisionResult, err := client.Provision(server.PublicKey(), rand.Reader) + if err != nil { + t.Fatalf("Provision failed: %v", err) + } + + if len(provisionResult.ClientPublic) != PublicKeySize { + t.Errorf("ClientPublic size = %d, want %d", len(provisionResult.ClientPublic), PublicKeySize) + } + if len(provisionResult.SharedPoint) != SharedPointSize { + t.Errorf("SharedPoint size = %d, want %d", len(provisionResult.SharedPoint), SharedPointSize) + } + + request, state, err := client.CreateRecoveryRequest( + provisionResult.ClientPublic, + server.PublicKey(), + rand.Reader, + ) + if err != nil { + t.Fatalf("CreateRecoveryRequest failed: %v", err) + } + + if len(request.BlindedPoint) != PublicKeySize { + t.Errorf("BlindedPoint size = %d, want %d", len(request.BlindedPoint), PublicKeySize) + } + + response, err := server.ProcessRecoveryRequest(request) + if err != nil { + t.Fatalf("ProcessRecoveryRequest failed: %v", err) + } + + if len(response.ProcessedPoint) != SharedPointSize { + t.Errorf("ProcessedPoint size = %d, want %d", len(response.ProcessedPoint), SharedPointSize) + } + + recoveredPoint, err := client.RecoverKey(state, response) + if err != nil { + t.Fatalf("RecoverKey failed: %v", err) + } + + if !bytes.Equal(recoveredPoint, provisionResult.SharedPoint) { + t.Error("Recovered point does not match original shared point") + } +} + +func TestGenerateKey(t *testing.T) { + key, err := GenerateKey(rand.Reader) + if err != nil { + t.Fatalf("GenerateKey failed: %v", err) + } + + if key.scalar == nil { + t.Error("Generated key has nil scalar") + } + + pub := key.Public() + if pub == nil || pub.element == nil { + t.Error("Public key is nil or has nil element") + } + + keyBytes, err := key.MarshalBinary() + if err != nil { + t.Fatalf("PrivateKey.MarshalBinary failed: %v", err) + } + if len(keyBytes) != PrivateKeySize { + t.Errorf("PrivateKey size = %d, want %d", len(keyBytes), PrivateKeySize) + } + + pubBytes, err := pub.MarshalBinary() + if err != nil { + t.Fatalf("PublicKey.MarshalBinary failed: %v", err) + } + if len(pubBytes) != PublicKeySize { + t.Errorf("PublicKey size = %d, want %d", len(pubBytes), PublicKeySize) + } +} + +func TestMarshalUnmarshal(t *testing.T) { + key, err := GenerateKey(rand.Reader) + if err != nil { + t.Fatalf("GenerateKey failed: %v", err) + } + + keyBytes, err := key.MarshalBinary() + if err != nil { + t.Fatalf("PrivateKey.MarshalBinary failed: %v", err) + } + + var key2 PrivateKey + err = key2.UnmarshalBinary(keyBytes) + if err != nil { + t.Fatalf("PrivateKey.UnmarshalBinary failed: %v", err) + } + + keyBytes2, err := key2.MarshalBinary() + if err != nil { + t.Fatalf("PrivateKey.MarshalBinary (2) failed: %v", err) + } + if !bytes.Equal(keyBytes, keyBytes2) { + t.Error("Private key round-trip failed") + } + + pub := key.Public() + pubBytes, err := pub.MarshalBinary() + if err != nil { + t.Fatalf("PublicKey.MarshalBinary failed: %v", err) + } + + var pub2 PublicKey + err = pub2.UnmarshalBinary(pubBytes) + if err != nil { + t.Fatalf("PublicKey.UnmarshalBinary failed: %v", err) + } + + pubBytes2, err := pub2.MarshalBinary() + if err != nil { + t.Fatalf("PublicKey.MarshalBinary (2) failed: %v", err) + } + if !bytes.Equal(pubBytes, pubBytes2) { + t.Error("Public key round-trip failed") + } +} + +func TestExtractX(t *testing.T) { + key, err := GenerateKey(rand.Reader) + if err != nil { + t.Fatalf("GenerateKey failed: %v", err) + } + + pubBytes, err := key.Public().MarshalBinary() + if err != nil { + t.Fatalf("PublicKey.MarshalBinary failed: %v", err) + } + + x, err := ExtractX(pubBytes) + if err != nil { + t.Fatalf("ExtractX failed: %v", err) + } + + if len(x) != XCoordinateSize { + t.Errorf("x-coordinate size = %d, want %d", len(x), XCoordinateSize) + } + + expectedX := pubBytes[1 : 1+XCoordinateSize] + if !bytes.Equal(x, expectedX) { + t.Error("ExtractX returned incorrect x-coordinate") + } +} + +func TestExtractXWrongLength(t *testing.T) { + tests := []struct { + name string + data []byte + }{ + {"empty", []byte{}}, + {"too short", make([]byte, 100)}, + {"too long", make([]byte, 200)}, + {"one byte", []byte{0x04}}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + _, err := ExtractX(tc.data) + if !errors.Is(err, ErrMalformedPoint) { + t.Errorf("ExtractX(%s) error = %v, want ErrMalformedPoint", tc.name, err) + } + }) + } +} + +func TestExtractXWrongPrefix(t *testing.T) { + data := make([]byte, UncompressedPointSize) + data[0] = 0x02 + _, err := ExtractX(data) + if !errors.Is(err, ErrMalformedPoint) { + t.Errorf("ExtractX(wrong prefix) error = %v, want ErrMalformedPoint", err) + } +} + +func TestExtractXIdentity(t *testing.T) { + identity := []byte{0x00} + + _, err := ExtractX(identity) + if !errors.Is(err, ErrIdentityPoint) { + t.Errorf("ExtractX(identity) error = %v, want ErrIdentityPoint", err) + } +} + +func TestExtractXOffCurve(t *testing.T) { + data := make([]byte, UncompressedPointSize) + data[0] = 0x04 + for i := 1; i < len(data); i++ { + data[i] = 0xFF + } + + _, err := ExtractX(data) + if !errors.Is(err, ErrMalformedPoint) { + t.Errorf("ExtractX(off-curve) error = %v, want ErrMalformedPoint", err) + } +} + +func TestNilKeyErrors(t *testing.T) { + validKey, _ := GenerateKey(rand.Reader) + validPub := validKey.Public() + validPubBytes, _ := validPub.MarshalBinary() + + t.Run("Provision nil serverPub", func(t *testing.T) { + _, err := NewClient().Provision(nil, rand.Reader) + assertError(t, err, ErrNilKey, ErrNilReader, ErrMalformedPoint) + }) + t.Run("Provision zero serverPub", func(t *testing.T) { + _, err := NewClient().Provision(&PublicKey{}, rand.Reader) + assertError(t, err, ErrNilKey, ErrNilReader, ErrMalformedPoint) + }) + t.Run("Provision nil reader", func(t *testing.T) { + _, err := NewClient().Provision(validPub, nil) + assertError(t, err, ErrNilKey, ErrNilReader, ErrMalformedPoint) + }) + t.Run("CreateRecoveryRequest nil serverPub", func(t *testing.T) { + _, _, err := NewClient().CreateRecoveryRequest(validPubBytes, nil, rand.Reader) + assertError(t, err, ErrNilKey, ErrNilReader, ErrMalformedPoint) + }) + t.Run("CreateRecoveryRequest zero serverPub", func(t *testing.T) { + _, _, err := NewClient().CreateRecoveryRequest(validPubBytes, &PublicKey{}, rand.Reader) + assertError(t, err, ErrNilKey, ErrNilReader, ErrMalformedPoint) + }) + t.Run("CreateRecoveryRequest nil reader", func(t *testing.T) { + _, _, err := NewClient().CreateRecoveryRequest(validPubBytes, validPub, nil) + assertError(t, err, ErrNilKey, ErrNilReader, ErrMalformedPoint) + }) + t.Run("RecoverKey nil state", func(t *testing.T) { + _, err := NewClient().RecoverKey(nil, &RecoveryResponse{ProcessedPoint: validPubBytes}) + assertError(t, err, ErrNilKey, ErrNilReader, ErrMalformedPoint) + }) + t.Run("RecoverKey zero state", func(t *testing.T) { + _, err := NewClient().RecoverKey(&RecoveryState{}, &RecoveryResponse{ProcessedPoint: validPubBytes}) + assertError(t, err, ErrNilKey, ErrNilReader, ErrMalformedPoint) + }) + t.Run("NewServer nil key", func(t *testing.T) { + _, err := NewServer(nil) + assertError(t, err, ErrNilKey, ErrNilReader, ErrMalformedPoint) + }) + t.Run("NewServer zero key", func(t *testing.T) { + _, err := NewServer(&PrivateKey{}) + assertError(t, err, ErrNilKey, ErrNilReader, ErrMalformedPoint) + }) + t.Run("ProcessRecoveryRequest nil request", func(t *testing.T) { + server, _ := NewServer(validKey) + _, err := server.ProcessRecoveryRequest(nil) + assertError(t, err, ErrNilKey, ErrNilReader, ErrMalformedPoint) + }) + t.Run("GenerateKey nil reader", func(t *testing.T) { + _, err := GenerateKey(nil) + assertError(t, err, ErrNilKey, ErrNilReader, ErrMalformedPoint) + }) + t.Run("PrivateKey.MarshalBinary zero value", func(t *testing.T) { + k := &PrivateKey{} + _, err := k.MarshalBinary() + assertError(t, err, ErrNilKey, ErrNilReader, ErrMalformedPoint) + }) + t.Run("PublicKey.MarshalBinary zero value", func(t *testing.T) { + k := &PublicKey{} + _, err := k.MarshalBinary() + assertError(t, err, ErrNilKey, ErrNilReader, ErrMalformedPoint) + }) +} + +func assertError(t *testing.T, err error, expected ...error) { + t.Helper() + if err == nil { + t.Error("expected error, got nil") + return + } + for _, e := range expected { + if errors.Is(err, e) { + return + } + } + t.Errorf("unexpected error type: %v", err) +} + +func TestZeroScalar(t *testing.T) { + zeroBytes := make([]byte, PrivateKeySize) + + var key PrivateKey + err := key.UnmarshalBinary(zeroBytes) + if !errors.Is(err, ErrZeroScalar) { + t.Errorf("UnmarshalBinary(zero) error = %v, want ErrZeroScalar", err) + } +} + +func TestMalformedBytes(t *testing.T) { + tests := []struct { + name string + fn func() error + }{ + { + name: "PrivateKey.UnmarshalBinary too short", + fn: func() error { + var k PrivateKey + return k.UnmarshalBinary(make([]byte, 10)) + }, + }, + { + name: "PrivateKey.UnmarshalBinary too long", + fn: func() error { + var k PrivateKey + return k.UnmarshalBinary(make([]byte, 100)) + }, + }, + { + name: "PublicKey.UnmarshalBinary too short", + fn: func() error { + var k PublicKey + return k.UnmarshalBinary(make([]byte, 10)) + }, + }, + { + name: "PublicKey.UnmarshalBinary too long", + fn: func() error { + var k PublicKey + return k.UnmarshalBinary(make([]byte, 200)) + }, + }, + { + name: "PublicKey.UnmarshalBinary wrong prefix", + fn: func() error { + var k PublicKey + data := make([]byte, PublicKeySize) + data[0] = 0x02 + return k.UnmarshalBinary(data) + }, + }, + { + name: "CreateRecoveryRequest wrong clientPublic size", + fn: func() error { + key, _ := GenerateKey(rand.Reader) + _, _, err := NewClient().CreateRecoveryRequest( + make([]byte, 10), + key.Public(), + rand.Reader, + ) + return err + }, + }, + { + name: "CreateRecoveryRequest wrong clientPublic prefix", + fn: func() error { + key, _ := GenerateKey(rand.Reader) + data := make([]byte, PublicKeySize) + data[0] = 0x02 + _, _, err := NewClient().CreateRecoveryRequest( + data, + key.Public(), + rand.Reader, + ) + return err + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := tc.fn() + if err == nil { + t.Errorf("%s: expected error, got nil", tc.name) + } + }) + } +} + +func TestMultipleRecoveries(t *testing.T) { + serverKey, _ := GenerateKey(rand.Reader) + server, _ := NewServer(serverKey) + client := NewClient() + + provisionResult, err := client.Provision(server.PublicKey(), rand.Reader) + if err != nil { + t.Fatalf("Provision failed: %v", err) + } + + for i := 0; i < 3; i++ { + request, state, err := client.CreateRecoveryRequest( + provisionResult.ClientPublic, + server.PublicKey(), + rand.Reader, + ) + if err != nil { + t.Fatalf("CreateRecoveryRequest %d failed: %v", i, err) + } + + response, err := server.ProcessRecoveryRequest(request) + if err != nil { + t.Fatalf("ProcessRecoveryRequest %d failed: %v", i, err) + } + + recoveredPoint, err := client.RecoverKey(state, response) + if err != nil { + t.Fatalf("RecoverKey %d failed: %v", i, err) + } + + if !bytes.Equal(recoveredPoint, provisionResult.SharedPoint) { + t.Errorf("Recovery %d: point mismatch", i) + } + } +} + +func TestRecoveryStateInvalidatedAfterUse(t *testing.T) { + serverKey, _ := GenerateKey(rand.Reader) + server, _ := NewServer(serverKey) + client := NewClient() + + provisionResult, _ := client.Provision(server.PublicKey(), rand.Reader) + + request, state, _ := client.CreateRecoveryRequest( + provisionResult.ClientPublic, + server.PublicKey(), + rand.Reader, + ) + response, _ := server.ProcessRecoveryRequest(request) + + _, err := client.RecoverKey(state, response) + if err != nil { + t.Fatalf("First RecoverKey failed: %v", err) + } + + _, err = client.RecoverKey(state, response) + if !errors.Is(err, ErrNilKey) { + t.Errorf("Reusing state: got %v, want ErrNilKey", err) + } +} + +func TestExtractXConsistency(t *testing.T) { + for i := 0; i < 10; i++ { + key, _ := GenerateKey(rand.Reader) + pubBytes, _ := key.Public().MarshalBinary() + + x, err := ExtractX(pubBytes) + if err != nil { + t.Fatalf("ExtractX failed: %v", err) + } + + expectedX := pubBytes[1 : 1+XCoordinateSize] + if !bytes.Equal(x, expectedX) { + t.Errorf("Iteration %d: x-coordinate mismatch", i) + } + } +} + +func BenchmarkProvision(b *testing.B) { + serverKey, _ := GenerateKey(rand.Reader) + server, _ := NewServer(serverKey) + client := NewClient() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := client.Provision(server.PublicKey(), rand.Reader) + if err != nil { + b.Fatalf("Provision failed: %v", err) + } + } +} + +func BenchmarkRecover(b *testing.B) { + serverKey, _ := GenerateKey(rand.Reader) + server, _ := NewServer(serverKey) + client := NewClient() + + provisionResult, _ := client.Provision(server.PublicKey(), rand.Reader) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + request, state, err := client.CreateRecoveryRequest( + provisionResult.ClientPublic, + server.PublicKey(), + rand.Reader, + ) + if err != nil { + b.Fatalf("CreateRecoveryRequest failed: %v", err) + } + + response, err := server.ProcessRecoveryRequest(request) + if err != nil { + b.Fatalf("ProcessRecoveryRequest failed: %v", err) + } + + _, err = client.RecoverKey(state, response) + if err != nil { + b.Fatalf("RecoverKey failed: %v", err) + } + } +} + +func BenchmarkExtractX(b *testing.B) { + key, _ := GenerateKey(rand.Reader) + pubBytes, _ := key.Public().MarshalBinary() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := ExtractX(pubBytes) + if err != nil { + b.Fatalf("ExtractX failed: %v", err) + } + } +} + +func BenchmarkGenerateKey(b *testing.B) { + for i := 0; i < b.N; i++ { + _, err := GenerateKey(rand.Reader) + if err != nil { + b.Fatalf("GenerateKey failed: %v", err) + } + } +} + +func TestIdentityPointRejection(t *testing.T) { + identity := group.P521.Identity() + identityBytes, _ := identity.MarshalBinary() + + t.Run("PublicKey.UnmarshalBinary rejects identity", func(t *testing.T) { + var k PublicKey + err := k.UnmarshalBinary(identityBytes) + if err == nil { + t.Error("Expected error for identity point") + } + }) + + t.Run("Server.ProcessRecoveryRequest validates input", func(t *testing.T) { + serverKey, _ := GenerateKey(rand.Reader) + server, _ := NewServer(serverKey) + + _, err := server.ProcessRecoveryRequest(&RecoveryRequest{ + BlindedPoint: identityBytes, + }) + if err == nil { + t.Error("Expected error for identity point in request") + } + }) +} diff --git a/ecmr/keys.go b/ecmr/keys.go new file mode 100644 index 00000000..1bf76c11 --- /dev/null +++ b/ecmr/keys.go @@ -0,0 +1,84 @@ +package ecmr + +import ( + "io" + + "github.com/cloudflare/circl/group" +) + +type PrivateKey struct { + scalar group.Scalar + pub *PublicKey +} + +type PublicKey struct { + element group.Element +} + +func GenerateKey(rnd io.Reader) (*PrivateKey, error) { + if rnd == nil { + return nil, ErrNilReader + } + + scalar := group.P521.RandomNonZeroScalar(rnd) + return &PrivateKey{scalar: scalar}, nil +} + +func (k *PrivateKey) Public() *PublicKey { + if k.pub == nil { + element := group.P521.NewElement().MulGen(k.scalar) + k.pub = &PublicKey{element: element} + } + return k.pub +} + +func (k *PrivateKey) MarshalBinary() ([]byte, error) { + if k.scalar == nil { + return nil, ErrNilKey + } + return k.scalar.MarshalBinary() +} + +func (k *PrivateKey) UnmarshalBinary(data []byte) error { + if len(data) != PrivateKeySize { + return ErrMalformedScalar + } + + scalar := group.P521.NewScalar() + if err := scalar.UnmarshalBinary(data); err != nil { + return ErrMalformedScalar + } + + if scalar.IsZero() { + return ErrZeroScalar + } + + k.scalar = scalar + k.pub = nil + return nil +} + +func (k *PublicKey) MarshalBinary() ([]byte, error) { + if k.element == nil { + return nil, ErrNilKey + } + return k.element.MarshalBinary() +} + +func (k *PublicKey) UnmarshalBinary(data []byte) error { + if len(data) != PublicKeySize { + return ErrMalformedPoint + } + + element := group.P521.NewElement() + if err := element.UnmarshalBinary(data); err != nil { + return ErrMalformedPoint + } + + if element.IsIdentity() { + return ErrIdentityPoint + } + + k.element = element + return nil +} diff --git a/ecmr/server.go b/ecmr/server.go new file mode 100644 index 00000000..0b17dbba --- /dev/null +++ b/ecmr/server.go @@ -0,0 +1,44 @@ +package ecmr + +import ( + "github.com/cloudflare/circl/group" +) + +type Server struct { + key *PrivateKey +} + +func NewServer(key *PrivateKey) (*Server, error) { + if key == nil || key.scalar == nil { + return nil, ErrNilKey + } + return &Server{key: key}, nil +} + +func (s *Server) PublicKey() *PublicKey { + return s.key.Public() +} + +// ProcessRecoveryRequest processes a client's recovery request and returns +// the server's response. +func (s *Server) ProcessRecoveryRequest(req *RecoveryRequest) (*RecoveryResponse, error) { + if req == nil || len(req.BlindedPoint) != PublicKeySize { + return nil, ErrMalformedPoint + } + + blindedPoint := group.P521.NewElement() + if err := blindedPoint.UnmarshalBinary(req.BlindedPoint); err != nil { + return nil, ErrMalformedPoint + } + if blindedPoint.IsIdentity() { + return nil, ErrIdentityPoint + } + + response := group.P521.NewElement().Mul(blindedPoint, s.key.scalar) + responseBytes, err := response.MarshalBinary() + if err != nil { + return nil, ErrMalformedPoint + } + + return &RecoveryResponse{ProcessedPoint: responseBytes}, nil +}