From 7a4d510111e293756da643223096527816bdaac6 Mon Sep 17 00:00:00 2001 From: Arya Tabaie Date: Tue, 22 Oct 2024 04:17:08 -0500 Subject: [PATCH] Prover: fix compare initial and final values for the rolling hash in the pi-interconnection (#219) Co-authored-by: Arya Tabaie <15056835+Tabaie@users.noreply.github.com> --- prover/circuits/internal/utils.go | 11 ++++ prover/circuits/pi-interconnection/assign.go | 28 +++++---- prover/circuits/pi-interconnection/circuit.go | 59 ++++++++----------- .../circuits/pi-interconnection/e2e_test.go | 31 ++-------- prover/public-input/aggregation.go | 34 +++++------ 5 files changed, 75 insertions(+), 88 deletions(-) diff --git a/prover/circuits/internal/utils.go b/prover/circuits/internal/utils.go index 3bd8d8cc4..2360c5a57 100644 --- a/prover/circuits/internal/utils.go +++ b/prover/circuits/internal/utils.go @@ -880,3 +880,14 @@ func InnerProd(api frontend.API, x, y []frontend.Variable) frontend.Variable { } return res } + +func SelectMany(api frontend.API, c frontend.Variable, ifSo, ifNot []frontend.Variable) []frontend.Variable { + if len(ifSo) != len(ifNot) { + panic("incompatible lengths") + } + res := make([]frontend.Variable, len(ifSo)) + for i := range res { + res[i] = api.Select(c, ifSo[i], ifNot[i]) + } + return res +} diff --git a/prover/circuits/pi-interconnection/assign.go b/prover/circuits/pi-interconnection/assign.go index ece51d38c..6af1b4167 100644 --- a/prover/circuits/pi-interconnection/assign.go +++ b/prover/circuits/pi-interconnection/assign.go @@ -179,6 +179,8 @@ func (c *Compiled) Assign(r Request) (a Circuit, err error) { maxNbL2MessageHashes := config.L2MsgMaxNbMerkle * merkleNbLeaves l2MessageHashes := make([][32]byte, 0, maxNbL2MessageHashes) + finalRollingHashNum, finalRollingHash := aggregationFPI.InitialRollingHashNumber, aggregationFPI.InitialRollingHash + // Execution FPI executionFPI := execution.FunctionalPublicInput{ FinalStateRootHash: aggregationFPI.InitialStateRootHash, @@ -211,6 +213,11 @@ func (c *Compiled) Assign(r Request) (a Circuit, err error) { executionFPI.L2MessageHashes = r.Executions[i].L2MsgHashes l2MessageHashes = append(l2MessageHashes, r.Executions[i].L2MsgHashes...) + + if r.Executions[i].FinalRollingHashNumber != 0 { // if the rolling hash is being updated, record the change + finalRollingHash = r.Executions[i].FinalRollingHash + finalRollingHashNum = r.Executions[i].FinalRollingHashNumber + } } a.ExecutionPublicInput[i] = executionFPI.Sum() @@ -232,18 +239,17 @@ func (c *Compiled) Assign(r Request) (a Circuit, err error) { executionFPI.FinalBlockNumber, aggregationFPI.FinalBlockNumber) return } - if executionFPI.FinalRollingHash != [32]byte{} { - if executionFPI.FinalRollingHash != aggregationFPI.FinalRollingHash { - err = fmt.Errorf("final rolling hashes do not match: execution=%x, aggregation=%x", - executionFPI.FinalRollingHash, aggregationFPI.FinalRollingHash) - return - } - if executionFPI.FinalRollingHashNumber != aggregationFPI.FinalRollingHashNumber { - err = fmt.Errorf("final rolling hash numbers do not match: execution=%v, aggregation=%v", - executionFPI.FinalRollingHashNumber, aggregationFPI.FinalRollingHashNumber) - return - } + if finalRollingHash != aggregationFPI.FinalRollingHash { + err = fmt.Errorf("final rolling hashes do not match: execution=%x, aggregation=%x", + executionFPI.FinalRollingHash, aggregationFPI.FinalRollingHash) + return + } + + if finalRollingHashNum != aggregationFPI.FinalRollingHashNumber { + err = fmt.Errorf("final rolling hash numbers do not match: execution=%v, aggregation=%v", + executionFPI.FinalRollingHashNumber, aggregationFPI.FinalRollingHashNumber) + return } if len(l2MessageHashes) > maxNbL2MessageHashes { diff --git a/prover/circuits/pi-interconnection/circuit.go b/prover/circuits/pi-interconnection/circuit.go index 28ad9007e..cffd57077 100644 --- a/prover/circuits/pi-interconnection/circuit.go +++ b/prover/circuits/pi-interconnection/circuit.go @@ -3,6 +3,7 @@ package pi_interconnection import ( "errors" "fmt" + "math" "math/big" "slices" @@ -112,8 +113,11 @@ func (c *Circuit) Define(api frontend.API) error { shnarfs := ComputeShnarfs(&hshK, c.ParentShnarf, shnarfParams) + rExecution := internal.NewRange(api, nbExecution, maxNbExecution) + initBlockNum, initHashNum, initHash := c.InitialBlockNumber, c.InitialRollingHashNumber, c.InitialRollingHash initBlockTime, initState := c.InitialBlockTimestamp, c.InitialStateRootHash + finalRollingHash, finalRollingHashNum := c.InitialRollingHash, c.InitialRollingHashNumber var l2MessagesByByte [32][]internal.VarSlice execMaxNbL2Msg := len(c.ExecutionFPIQ[0].L2MessageHashes.Values) @@ -130,9 +134,17 @@ func (c *Circuit) Define(api frontend.API) error { for i, piq := range c.ExecutionFPIQ { piq.RangeCheck(api) - comparator.IsLess(initBlockTime, piq.FinalBlockTimestamp) - comparator.IsLess(initBlockNum, piq.FinalBlockNumber) - comparator.IsLess(initHashNum, piq.FinalRollingHashNumber) + inRange := rExecution.InRange[i] + rollingHashNotUpdated := api.Select(inRange, api.IsZero(piq.FinalRollingHashNumber), 1) // padded input past nbExecutions is not required to be 0. So we multiply by inRange + + newFinalRollingHashNum := api.Select(rollingHashNotUpdated, finalRollingHashNum, piq.FinalRollingHashNumber) + + api.AssertIsEqual(comparator.IsLess(initBlockTime, api.Select(inRange, piq.FinalBlockTimestamp, uint64(math.MaxUint64))), 1) // don't compare if not updating + api.AssertIsEqual(comparator.IsLess(initBlockNum, api.Select(inRange, piq.FinalBlockNumber, uint64(math.MaxUint64))), 1) + api.AssertIsEqual(comparator.IsLess(finalRollingHashNum, api.Add(newFinalRollingHashNum, rollingHashNotUpdated)), 1) // if the rolling hash is updated, check that it has increased + + finalRollingHashNum = newFinalRollingHashNum + copy(finalRollingHash[:], internal.SelectMany(api, rollingHashNotUpdated, finalRollingHash[:], piq.FinalRollingHash[:])) pi := execution.FunctionalPublicInputSnark{ FunctionalPublicInputQSnark: piq, @@ -162,7 +174,7 @@ func (c *Circuit) Define(api frontend.API) error { } } - merkleLeavesConcat := internal.Slice[[32]frontend.Variable]{Values: make([][32]frontend.Variable, c.L2MessageMaxNbMerkle*merkleNbLeaves)} + merkleLeavesConcat := internal.Var32Slice{Values: make([][32]frontend.Variable, c.L2MessageMaxNbMerkle*merkleNbLeaves)} for i := 0; i < 32; i++ { ithBytes := internal.Concat(api, len(merkleLeavesConcat.Values), l2MessagesByByte[i]...) for j := range merkleLeavesConcat.Values { @@ -170,43 +182,24 @@ func (c *Circuit) Define(api frontend.API) error { } merkleLeavesConcat.Length = ithBytes.Length // same value regardless of i } - rExecution := internal.NewRange(api, nbExecution, maxNbExecution) - - twoPow8 := big.NewInt(256) - hi16B := func(block [32]frontend.Variable) frontend.Variable { - return compress.ReadNum(api, block[:16], twoPow8) - } - lo16B := func(block [32]frontend.Variable) frontend.Variable { - return compress.ReadNum(api, block[16:], twoPow8) - } - - { // if rolling hash values are present in the last execution, they must match those of aggregation - finalRollingHashFromExec := rExecution.LastArray32F(func(i int) [32]frontend.Variable { return c.ExecutionFPIQ[i].FinalRollingHash }) - finalRollingHashNumFromExec := rExecution.LastF(func(i int) frontend.Variable { return c.ExecutionFPIQ[i].FinalRollingHashNumber }) - - h, l := hi16B(finalRollingHashFromExec), lo16B(finalRollingHashFromExec) - - finalRollingHashPresent := api.Sub(1, api.Mul(api.IsZero(h), api.IsZero(l))) - - internal.AssertEqualIf(api, finalRollingHashPresent, h, hi16B(c.FinalRollingHash)) - internal.AssertEqualIf(api, finalRollingHashPresent, l, lo16B(c.FinalRollingHash)) - internal.AssertEqualIf(api, finalRollingHashPresent, finalRollingHashNumFromExec, c.FinalRollingHashNumber) - } pi := public_input.AggregationFPISnark{ - AggregationFPIQSnark: c.AggregationFPIQSnark, - NbL2Messages: merkleLeavesConcat.Length, - L2MsgMerkleTreeRoots: make([][32]frontend.Variable, c.L2MessageMaxNbMerkle), - FinalBlockNumber: rExecution.LastF(func(i int) frontend.Variable { return c.ExecutionFPIQ[i].FinalBlockNumber }), - FinalBlockTimestamp: rExecution.LastF(func(i int) frontend.Variable { return c.ExecutionFPIQ[i].FinalBlockTimestamp }), - FinalShnarf: rDecompression.LastArray32(shnarfs), - L2MsgMerkleTreeDepth: c.L2MessageMerkleDepth, + AggregationFPIQSnark: c.AggregationFPIQSnark, + NbL2Messages: merkleLeavesConcat.Length, + L2MsgMerkleTreeRoots: make([][32]frontend.Variable, c.L2MessageMaxNbMerkle), + FinalBlockNumber: rExecution.LastF(func(i int) frontend.Variable { return c.ExecutionFPIQ[i].FinalBlockNumber }), + FinalBlockTimestamp: rExecution.LastF(func(i int) frontend.Variable { return c.ExecutionFPIQ[i].FinalBlockTimestamp }), + FinalShnarf: rDecompression.LastArray32(shnarfs), + FinalRollingHashNumber: finalRollingHashNum, + FinalRollingHash: finalRollingHash, + L2MsgMerkleTreeDepth: c.L2MessageMerkleDepth, } for i := range pi.L2MsgMerkleTreeRoots { pi.L2MsgMerkleTreeRoots[i] = MerkleRootSnark(&hshK, merkleLeavesConcat.Values[i*merkleNbLeaves:(i+1)*merkleNbLeaves]) } + twoPow8 := big.NewInt(256) // "open" aggregation public input aggregationPIBytes := pi.Sum(api, &hshK) api.AssertIsEqual(c.AggregationPublicInput[0], compress.ReadNum(api, aggregationPIBytes[:16], twoPow8)) diff --git a/prover/circuits/pi-interconnection/e2e_test.go b/prover/circuits/pi-interconnection/e2e_test.go index 763000c28..7b65bb37c 100644 --- a/prover/circuits/pi-interconnection/e2e_test.go +++ b/prover/circuits/pi-interconnection/e2e_test.go @@ -30,10 +30,10 @@ import ( // some of the execution data are faked func TestSingleBlockBlob(t *testing.T) { - testPI(t, pitesting.AssignSingleBlockBlob(t), withSlack(0, 1, 2)) + testPI(t, pitesting.AssignSingleBlockBlob(t), withSlack(0, 2)) } -func TestSingleBlobBlobE2E(t *testing.T) { +func TestSingleBlockBlobE2E(t *testing.T) { req := pitesting.AssignSingleBlockBlob(t) cfg := config.PublicInput{ MaxNbDecompression: len(req.Decompressions), @@ -123,7 +123,7 @@ func TestTinyTwoBatchBlob(t *testing.T) { }, } - testPI(t, req, withSlack(0, 1, 2)) + testPI(t, req, withSlack(0, 2)) } func TestTwoTwoBatchBlobs(t *testing.T) { @@ -204,29 +204,7 @@ func TestTwoTwoBatchBlobs(t *testing.T) { }, } - testPI(t, req, withSlack(0, 1, 2)) -} - -func TestEmpty(t *testing.T) { - const hexZeroBlock = "0x0000000000000000000000000000000000000000000000000000000000000000" - - testPI(t, pi_interconnection.Request{ - Aggregation: public_input.Aggregation{ - FinalShnarf: hexZeroBlock, - ParentAggregationFinalShnarf: hexZeroBlock, - ParentStateRootHash: hexZeroBlock, - ParentAggregationLastBlockTimestamp: 0, - FinalTimestamp: 0, - LastFinalizedBlockNumber: 0, - FinalBlockNumber: 0, - LastFinalizedL1RollingHash: hexZeroBlock, - L1RollingHash: hexZeroBlock, - LastFinalizedL1RollingHashMessageNumber: 0, - L1RollingHashMessageNumber: 0, - L2MsgRootHashes: []string{}, - L2MsgMerkleTreeDepth: 1, - }, - }) + testPI(t, req, withSlack(0, 2)) } type testPIConfig struct { @@ -269,6 +247,7 @@ func testPI(t *testing.T, req pi_interconnection.Request, options ...testPIOptio ExecutionMaxNbMsg: 1 + slack[2], L2MsgMerkleDepth: 5, L2MsgMaxNbMerkle: 1 + slack[3], + MockKeccakWizard: true, } t.Run(fmt.Sprintf("slack profile %v", slack), func(t *testing.T) { diff --git a/prover/public-input/aggregation.go b/prover/public-input/aggregation.go index 5fc3b61c4..0aa4ee5e2 100644 --- a/prover/public-input/aggregation.go +++ b/prover/public-input/aggregation.go @@ -123,15 +123,15 @@ func (pi *AggregationFPI) ToSnarkType() AggregationFPISnark { InitialRollingHashNumber: pi.InitialRollingHashNumber, InitialStateRootHash: pi.InitialStateRootHash[:], - NbDecompression: pi.NbDecompression, - ChainID: pi.ChainID, - L2MessageServiceAddr: pi.L2MessageServiceAddr[:], - FinalRollingHashNumber: pi.FinalRollingHashNumber, + NbDecompression: pi.NbDecompression, + ChainID: pi.ChainID, + L2MessageServiceAddr: pi.L2MessageServiceAddr[:], }, - L2MsgMerkleTreeRoots: make([][32]frontend.Variable, len(pi.L2MsgMerkleTreeRoots)), - FinalBlockNumber: pi.FinalBlockNumber, - FinalBlockTimestamp: pi.FinalBlockTimestamp, - L2MsgMerkleTreeDepth: pi.L2MsgMerkleTreeDepth, + L2MsgMerkleTreeRoots: make([][32]frontend.Variable, len(pi.L2MsgMerkleTreeRoots)), + FinalBlockNumber: pi.FinalBlockNumber, + FinalBlockTimestamp: pi.FinalBlockTimestamp, + L2MsgMerkleTreeDepth: pi.L2MsgMerkleTreeDepth, + FinalRollingHashNumber: pi.FinalRollingHashNumber, } utils.Copy(s.FinalRollingHash[:], pi.FinalRollingHash[:]) @@ -154,12 +154,8 @@ type AggregationFPIQSnark struct { InitialBlockTimestamp frontend.Variable InitialRollingHash [32]frontend.Variable InitialRollingHashNumber frontend.Variable - // Ideally, FinalRollingHash and FinalRollingHashNumber would be inferred from the executions - // but sometimes executions are missing those values - FinalRollingHash [32]frontend.Variable - FinalRollingHashNumber frontend.Variable - ChainID frontend.Variable // for now we're forcing all executions to have the same chain ID - L2MessageServiceAddr frontend.Variable // 20 bytes + ChainID frontend.Variable // for now we're forcing all executions to have the same chain ID + L2MessageServiceAddr frontend.Variable // 20 bytes } type AggregationFPISnark struct { @@ -167,10 +163,12 @@ type AggregationFPISnark struct { NbL2Messages frontend.Variable // TODO not used in hash. delete if not necessary L2MsgMerkleTreeRoots [][32]frontend.Variable // FinalStateRootHash frontend.Variable redundant: incorporated into final shnarf - FinalBlockNumber frontend.Variable - FinalBlockTimestamp frontend.Variable - FinalShnarf [32]frontend.Variable - L2MsgMerkleTreeDepth int + FinalBlockNumber frontend.Variable + FinalBlockTimestamp frontend.Variable + FinalShnarf [32]frontend.Variable + FinalRollingHash [32]frontend.Variable + FinalRollingHashNumber frontend.Variable + L2MsgMerkleTreeDepth int } // NewAggregationFPI does NOT set all fields, only the ones covered in public_input.Aggregation