From 55f25e92e1f3a6dc783435c6561fed01e92ea502 Mon Sep 17 00:00:00 2001 From: Christopher Patton Date: Wed, 7 Aug 2024 14:05:21 -0700 Subject: [PATCH] Pass `[]Request` through `CreateReplies()` Replace the sequence of nonces with the sequence of `Request` structures that encapsulate them. --- cmd/testserver/main.go | 2 +- protocol/internal/cmd/gen_test_vectors.go | 16 +++++---- protocol/protocol.go | 24 +++++++------- protocol/protocol_test.go | 40 ++++++++++++++--------- 4 files changed, 48 insertions(+), 34 deletions(-) diff --git a/cmd/testserver/main.go b/cmd/testserver/main.go index 2a8243f..7205cb6 100644 --- a/cmd/testserver/main.go +++ b/cmd/testserver/main.go @@ -115,7 +115,7 @@ func handleRequest(requestBytes []byte, cert *protocol.Certificate, onlineSK ed2 } // Parse the request and create the response. - replies, err := protocol.CreateReplies(responseVer, [][]byte{req.Nonce}, time.Now(), radius, cert) + replies, err := protocol.CreateReplies(responseVer, []protocol.Request{*req}, time.Now(), radius, cert) if err != nil { return nil, err } diff --git a/protocol/internal/cmd/gen_test_vectors.go b/protocol/internal/cmd/gen_test_vectors.go index 39e201b..6267aa9 100644 --- a/protocol/internal/cmd/gen_test_vectors.go +++ b/protocol/internal/cmd/gen_test_vectors.go @@ -112,23 +112,27 @@ func main() { testVec.OnlineKey = ONLINE_KEY_HEX // Set the requests and replies. - nonces := make([][]byte, 0, numRequestsPerBatch) + requests := make([]protocol.Request, 0, numRequestsPerBatch) for i := 0; i < numRequestsPerBatch; i++ { - nonce, _, request, err := protocol.CreateRequest(clientVersionPref, r, nil, rootPublicKey) + _, _, reqBytes, err := protocol.CreateRequest(clientVersionPref, r, nil, rootPublicKey) if err != nil { panic(err) } - testVec.Requests = append(testVec.Requests, hex.EncodeToString(request)) - nonces = append(nonces, nonce[:]) + req, err := protocol.ParseRequest(reqBytes) + if err != nil { + panic(err) + } + testVec.Requests = append(testVec.Requests, hex.EncodeToString(reqBytes)) + requests = append(requests, *req) } - replies, err := protocol.CreateReplies(ver, nonces, testMidpoint, testRadius, onlineCert) + replies, err := protocol.CreateReplies(ver, requests, testMidpoint, testRadius, onlineCert) if err != nil { panic(err) } for i := 0; i < numRequestsPerBatch; i++ { - _, _, err = protocol.VerifyReply(clientVersionPref, replies[i], rootPublicKey, nonces[i]) + _, _, err = protocol.VerifyReply(clientVersionPref, replies[i], rootPublicKey, requests[i].Nonce) if err != nil { panic(err) } diff --git a/protocol/protocol.go b/protocol/protocol.go index cd4163c..c0da510 100644 --- a/protocol/protocol.go +++ b/protocol/protocol.go @@ -400,13 +400,13 @@ func hashNode(out *[maxNonceSize]byte, left, right []byte) { } // newTree creates a Merkle tree given one or more nonces. -func newTree(nonceSize int, nonces [][]byte) *tree { - if len(nonces) == 0 { +func newTree(nonceSize int, requests []Request) *tree { + if len(requests) == 0 { panic("newTree: passed empty slice") } levels := 1 - width := len(nonces) + width := len(requests) for width > 1 { width = (width + 1) / 2 levels++ @@ -416,15 +416,15 @@ func newTree(nonceSize int, nonces [][]byte) *tree { values: make([][][maxNonceSize]byte, 0, levels), } - leaves := make([][maxNonceSize]byte, ((len(nonces)+1)/2)*2) - for i, nonce := range nonces { + leaves := make([][maxNonceSize]byte, ((len(requests)+1)/2)*2) + for i, req := range requests { var leaf [maxNonceSize]byte - hashLeaf(&leaf, nonce) + hashLeaf(&leaf, req.Nonce) leaves[i] = leaf } // Fill any extra leaves with an existing leaf, to simplify analysis // that we are not inadvertently signing other messages. - for i := len(nonces); i < len(leaves); i++ { + for i := len(requests); i < len(leaves); i++ { leaves[i] = leaves[0] } ret.values = append(ret.values, leaves) @@ -558,15 +558,15 @@ func ParseRequest(bytes []byte) (req *Request, err error) { // // The same version is indicated in each reply. It's the callers responsibility // to ensure that each client supports this version. -func CreateReplies(ver Version, nonces [][]byte, midpoint time.Time, radius time.Duration, cert *Certificate) ([][]byte, error) { +func CreateReplies(ver Version, requests []Request, midpoint time.Time, radius time.Duration, cert *Certificate) ([][]byte, error) { versionIETF := ver != VersionGoogle nonceSize := nonceSize(versionIETF) - if len(nonces) == 0 { + if len(requests) == 0 { return nil, nil } - tree := newTree(nonceSize, nonces) + tree := newTree(nonceSize, requests) // Convert the midpoint and radius to their Roughtime representation. var midPointUint64 uint64 @@ -610,9 +610,9 @@ func CreateReplies(ver Version, nonces [][]byte, midpoint time.Time, radius time reply[tagVER] = encoded } - replies := make([][]byte, 0, len(nonces)) + replies := make([][]byte, 0, len(requests)) - for i := range nonces { + for i := range requests { var indexBytes [4]byte binary.LittleEndian.PutUint32(indexBytes[:], uint32(i)) reply[tagINDX] = indexBytes[:] diff --git a/protocol/protocol_test.go b/protocol/protocol_test.go index fbc5de4..e98cfdc 100644 --- a/protocol/protocol_test.go +++ b/protocol/protocol_test.go @@ -157,7 +157,7 @@ func TestRunTestVectors(t *testing.T) { panic(err) } - nonces := make([][]byte, 0) + requests := make([]Request, 0) advertisedVersions := make(map[Version]uint) for _, ver := range allVersions { advertisedVersions[ver] = 0 @@ -177,7 +177,7 @@ func TestRunTestVectors(t *testing.T) { advertisedVersions[ver] += 1 } - nonces = append(nonces, req.Nonce) + requests = append(requests, *req) } supportedVersions := make([]Version, 0, len(allVersions)) @@ -191,7 +191,7 @@ func TestRunTestVectors(t *testing.T) { t.Fatal(err) } - replies, err := CreateReplies(responseVer, nonces, testMidpoint, testRadius, onlineCert) + replies, err := CreateReplies(responseVer, requests, testMidpoint, testRadius, onlineCert) if err != nil { t.Fatal(err) } @@ -209,7 +209,7 @@ func TestRunTestVectors(t *testing.T) { } // Make sure the responses verify properly. - _, _, err = VerifyReply([]Version{responseVer}, replies[i], rootPublicKey, nonces[i]) + _, _, err = VerifyReply([]Version{responseVer}, replies[i], rootPublicKey, requests[i].Nonce) if err != nil { t.Error(err) } @@ -229,7 +229,7 @@ func TestRoundtrip(t *testing.T) { advertisedVersions[ver] = 0 } - nonces := make([][]byte, 0, numRequests) + requests := make([]Request, 0, numRequests) for i := 0; i < numRequests; i++ { nonceSent, _, request, err := CreateRequest([]Version{ver}, rand.Reader, nil, rootPublicKey) if err != nil { @@ -249,7 +249,7 @@ func TestRoundtrip(t *testing.T) { advertisedVersions[ver] += 1 } - nonces = append(nonces, req.Nonce) + requests = append(requests, *req) } supportedVersions := make([]Version, 0, len(allVersions)) @@ -263,17 +263,17 @@ func TestRoundtrip(t *testing.T) { t.Fatal(err) } - replies, err := CreateReplies(responseVer, nonces, testMidpoint, testRadius, cert) + replies, err := CreateReplies(responseVer, requests, testMidpoint, testRadius, cert) if err != nil { t.Fatal(err) } - if len(replies) != len(nonces) { - t.Fatalf("received %d replies for %d nonces", len(replies), len(nonces)) + if len(replies) != len(requests) { + t.Fatalf("received %d replies for %d nonces", len(replies), len(requests)) } for i, reply := range replies { - midpoint, radius, err := VerifyReply([]Version{responseVer}, reply, rootPublicKey, nonces[i]) + midpoint, radius, err := VerifyReply([]Version{responseVer}, reply, rootPublicKey, requests[i].Nonce) if err != nil { t.Errorf("error parsing reply #%d: %s", i, err) continue @@ -301,22 +301,32 @@ func TestChaining(t *testing.T) { for _, ver := range allVersions { t.Run(ver.String(), func(t *testing.T) { - nonce1, _, _, err := CreateRequest([]Version{ver}, rand.Reader, nil, rootPublicKeyA) + _, _, req1Bytes, err := CreateRequest([]Version{ver}, rand.Reader, nil, rootPublicKeyA) if err != nil { t.Fatal(err) } - replies1, err := CreateReplies(ver, [][]byte{nonce1[:]}, testMidpoint, testRadius, certA) + req1, err := ParseRequest(req1Bytes) if err != nil { t.Fatal(err) } - nonce2, blind2, _, err := CreateRequest([]Version{ver}, rand.Reader, replies1[0], rootPublicKeyB) + replies1, err := CreateReplies(ver, []Request{*req1}, testMidpoint, testRadius, certA) if err != nil { t.Fatal(err) } - replies2, err := CreateReplies(ver, [][]byte{nonce2[:]}, testMidpoint.Add(time.Duration(-10)*time.Second), testRadius, certB) + _, blind2, req2Bytes, err := CreateRequest([]Version{ver}, rand.Reader, replies1[0], rootPublicKeyB) + if err != nil { + t.Fatal(err) + } + + req2, err := ParseRequest(req2Bytes) + if err != nil { + t.Fatal(err) + } + + replies2, err := CreateReplies(ver, []Request{*req2}, testMidpoint.Add(time.Duration(-10)*time.Second), testRadius, certB) if err != nil { t.Fatal(err) } @@ -333,7 +343,7 @@ func TestChaining(t *testing.T) { } claim := []claimStep{ - {rootPublicKeyA, nonce1, replies1[0]}, + {rootPublicKeyA, req1.Nonce, replies1[0]}, {rootPublicKeyB, blind2, replies2[0]}, }