Skip to content

Commit

Permalink
Pass []Request through CreateReplies()
Browse files Browse the repository at this point in the history
Replace the sequence of nonces with the sequence of `Request` structures
that encapsulate them.
  • Loading branch information
cjpatton committed Aug 8, 2024
1 parent c68cef3 commit 55f25e9
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 34 deletions.
2 changes: 1 addition & 1 deletion cmd/testserver/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ func handleRequest(requestBytes []byte, cert *protocol.Certificate, onlineSK ed2
}

// Parse the request and create the response.
replies, err := protocol.CreateReplies(responseVer, [][]byte{req.Nonce}, time.Now(), radius, cert)
replies, err := protocol.CreateReplies(responseVer, []protocol.Request{*req}, time.Now(), radius, cert)
if err != nil {
return nil, err
}
Expand Down
16 changes: 10 additions & 6 deletions protocol/internal/cmd/gen_test_vectors.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,23 +112,27 @@ func main() {
testVec.OnlineKey = ONLINE_KEY_HEX

// Set the requests and replies.
nonces := make([][]byte, 0, numRequestsPerBatch)
requests := make([]protocol.Request, 0, numRequestsPerBatch)
for i := 0; i < numRequestsPerBatch; i++ {
nonce, _, request, err := protocol.CreateRequest(clientVersionPref, r, nil, rootPublicKey)
_, _, reqBytes, err := protocol.CreateRequest(clientVersionPref, r, nil, rootPublicKey)
if err != nil {
panic(err)
}
testVec.Requests = append(testVec.Requests, hex.EncodeToString(request))
nonces = append(nonces, nonce[:])
req, err := protocol.ParseRequest(reqBytes)
if err != nil {
panic(err)
}
testVec.Requests = append(testVec.Requests, hex.EncodeToString(reqBytes))
requests = append(requests, *req)
}

replies, err := protocol.CreateReplies(ver, nonces, testMidpoint, testRadius, onlineCert)
replies, err := protocol.CreateReplies(ver, requests, testMidpoint, testRadius, onlineCert)
if err != nil {
panic(err)
}

for i := 0; i < numRequestsPerBatch; i++ {
_, _, err = protocol.VerifyReply(clientVersionPref, replies[i], rootPublicKey, nonces[i])
_, _, err = protocol.VerifyReply(clientVersionPref, replies[i], rootPublicKey, requests[i].Nonce)
if err != nil {
panic(err)
}
Expand Down
24 changes: 12 additions & 12 deletions protocol/protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -400,13 +400,13 @@ func hashNode(out *[maxNonceSize]byte, left, right []byte) {
}

// newTree creates a Merkle tree given one or more nonces.
func newTree(nonceSize int, nonces [][]byte) *tree {
if len(nonces) == 0 {
func newTree(nonceSize int, requests []Request) *tree {
if len(requests) == 0 {
panic("newTree: passed empty slice")
}

levels := 1
width := len(nonces)
width := len(requests)
for width > 1 {
width = (width + 1) / 2
levels++
Expand All @@ -416,15 +416,15 @@ func newTree(nonceSize int, nonces [][]byte) *tree {
values: make([][][maxNonceSize]byte, 0, levels),
}

leaves := make([][maxNonceSize]byte, ((len(nonces)+1)/2)*2)
for i, nonce := range nonces {
leaves := make([][maxNonceSize]byte, ((len(requests)+1)/2)*2)
for i, req := range requests {
var leaf [maxNonceSize]byte
hashLeaf(&leaf, nonce)
hashLeaf(&leaf, req.Nonce)
leaves[i] = leaf
}
// Fill any extra leaves with an existing leaf, to simplify analysis
// that we are not inadvertently signing other messages.
for i := len(nonces); i < len(leaves); i++ {
for i := len(requests); i < len(leaves); i++ {
leaves[i] = leaves[0]
}
ret.values = append(ret.values, leaves)
Expand Down Expand Up @@ -558,15 +558,15 @@ func ParseRequest(bytes []byte) (req *Request, err error) {
//
// The same version is indicated in each reply. It's the callers responsibility
// to ensure that each client supports this version.
func CreateReplies(ver Version, nonces [][]byte, midpoint time.Time, radius time.Duration, cert *Certificate) ([][]byte, error) {
func CreateReplies(ver Version, requests []Request, midpoint time.Time, radius time.Duration, cert *Certificate) ([][]byte, error) {
versionIETF := ver != VersionGoogle
nonceSize := nonceSize(versionIETF)

if len(nonces) == 0 {
if len(requests) == 0 {
return nil, nil
}

tree := newTree(nonceSize, nonces)
tree := newTree(nonceSize, requests)

// Convert the midpoint and radius to their Roughtime representation.
var midPointUint64 uint64
Expand Down Expand Up @@ -610,9 +610,9 @@ func CreateReplies(ver Version, nonces [][]byte, midpoint time.Time, radius time
reply[tagVER] = encoded
}

replies := make([][]byte, 0, len(nonces))
replies := make([][]byte, 0, len(requests))

for i := range nonces {
for i := range requests {
var indexBytes [4]byte
binary.LittleEndian.PutUint32(indexBytes[:], uint32(i))
reply[tagINDX] = indexBytes[:]
Expand Down
40 changes: 25 additions & 15 deletions protocol/protocol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ func TestRunTestVectors(t *testing.T) {
panic(err)
}

nonces := make([][]byte, 0)
requests := make([]Request, 0)
advertisedVersions := make(map[Version]uint)
for _, ver := range allVersions {
advertisedVersions[ver] = 0
Expand All @@ -177,7 +177,7 @@ func TestRunTestVectors(t *testing.T) {
advertisedVersions[ver] += 1
}

nonces = append(nonces, req.Nonce)
requests = append(requests, *req)
}

supportedVersions := make([]Version, 0, len(allVersions))
Expand All @@ -191,7 +191,7 @@ func TestRunTestVectors(t *testing.T) {
t.Fatal(err)
}

replies, err := CreateReplies(responseVer, nonces, testMidpoint, testRadius, onlineCert)
replies, err := CreateReplies(responseVer, requests, testMidpoint, testRadius, onlineCert)
if err != nil {
t.Fatal(err)
}
Expand All @@ -209,7 +209,7 @@ func TestRunTestVectors(t *testing.T) {
}

// Make sure the responses verify properly.
_, _, err = VerifyReply([]Version{responseVer}, replies[i], rootPublicKey, nonces[i])
_, _, err = VerifyReply([]Version{responseVer}, replies[i], rootPublicKey, requests[i].Nonce)
if err != nil {
t.Error(err)
}
Expand All @@ -229,7 +229,7 @@ func TestRoundtrip(t *testing.T) {
advertisedVersions[ver] = 0
}

nonces := make([][]byte, 0, numRequests)
requests := make([]Request, 0, numRequests)
for i := 0; i < numRequests; i++ {
nonceSent, _, request, err := CreateRequest([]Version{ver}, rand.Reader, nil, rootPublicKey)
if err != nil {
Expand All @@ -249,7 +249,7 @@ func TestRoundtrip(t *testing.T) {
advertisedVersions[ver] += 1
}

nonces = append(nonces, req.Nonce)
requests = append(requests, *req)
}

supportedVersions := make([]Version, 0, len(allVersions))
Expand All @@ -263,17 +263,17 @@ func TestRoundtrip(t *testing.T) {
t.Fatal(err)
}

replies, err := CreateReplies(responseVer, nonces, testMidpoint, testRadius, cert)
replies, err := CreateReplies(responseVer, requests, testMidpoint, testRadius, cert)
if err != nil {
t.Fatal(err)
}

if len(replies) != len(nonces) {
t.Fatalf("received %d replies for %d nonces", len(replies), len(nonces))
if len(replies) != len(requests) {
t.Fatalf("received %d replies for %d nonces", len(replies), len(requests))
}

for i, reply := range replies {
midpoint, radius, err := VerifyReply([]Version{responseVer}, reply, rootPublicKey, nonces[i])
midpoint, radius, err := VerifyReply([]Version{responseVer}, reply, rootPublicKey, requests[i].Nonce)
if err != nil {
t.Errorf("error parsing reply #%d: %s", i, err)
continue
Expand Down Expand Up @@ -301,22 +301,32 @@ func TestChaining(t *testing.T) {

for _, ver := range allVersions {
t.Run(ver.String(), func(t *testing.T) {
nonce1, _, _, err := CreateRequest([]Version{ver}, rand.Reader, nil, rootPublicKeyA)
_, _, req1Bytes, err := CreateRequest([]Version{ver}, rand.Reader, nil, rootPublicKeyA)
if err != nil {
t.Fatal(err)
}

replies1, err := CreateReplies(ver, [][]byte{nonce1[:]}, testMidpoint, testRadius, certA)
req1, err := ParseRequest(req1Bytes)
if err != nil {
t.Fatal(err)
}

nonce2, blind2, _, err := CreateRequest([]Version{ver}, rand.Reader, replies1[0], rootPublicKeyB)
replies1, err := CreateReplies(ver, []Request{*req1}, testMidpoint, testRadius, certA)
if err != nil {
t.Fatal(err)
}

replies2, err := CreateReplies(ver, [][]byte{nonce2[:]}, testMidpoint.Add(time.Duration(-10)*time.Second), testRadius, certB)
_, blind2, req2Bytes, err := CreateRequest([]Version{ver}, rand.Reader, replies1[0], rootPublicKeyB)
if err != nil {
t.Fatal(err)
}

req2, err := ParseRequest(req2Bytes)
if err != nil {
t.Fatal(err)
}

replies2, err := CreateReplies(ver, []Request{*req2}, testMidpoint.Add(time.Duration(-10)*time.Second), testRadius, certB)
if err != nil {
t.Fatal(err)
}
Expand All @@ -333,7 +343,7 @@ func TestChaining(t *testing.T) {
}

claim := []claimStep{
{rootPublicKeyA, nonce1, replies1[0]},
{rootPublicKeyA, req1.Nonce, replies1[0]},
{rootPublicKeyB, blind2, replies2[0]},
}

Expand Down

0 comments on commit 55f25e9

Please sign in to comment.