Skip to content

Commit

Permalink
Prover: fix compare initial and final values for the rolling hash in …
Browse files Browse the repository at this point in the history
…the pi-interconnection (#219)

Co-authored-by: Arya Tabaie <15056835+Tabaie@users.noreply.github.com>
  • Loading branch information
Tabaie and Tabaie authored Oct 22, 2024
1 parent ab8a897 commit 7a4d510
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 88 deletions.
11 changes: 11 additions & 0 deletions prover/circuits/internal/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -880,3 +880,14 @@ func InnerProd(api frontend.API, x, y []frontend.Variable) frontend.Variable {
}
return res
}

func SelectMany(api frontend.API, c frontend.Variable, ifSo, ifNot []frontend.Variable) []frontend.Variable {
if len(ifSo) != len(ifNot) {
panic("incompatible lengths")
}
res := make([]frontend.Variable, len(ifSo))
for i := range res {
res[i] = api.Select(c, ifSo[i], ifNot[i])
}
return res
}
28 changes: 17 additions & 11 deletions prover/circuits/pi-interconnection/assign.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ func (c *Compiled) Assign(r Request) (a Circuit, err error) {
maxNbL2MessageHashes := config.L2MsgMaxNbMerkle * merkleNbLeaves
l2MessageHashes := make([][32]byte, 0, maxNbL2MessageHashes)

finalRollingHashNum, finalRollingHash := aggregationFPI.InitialRollingHashNumber, aggregationFPI.InitialRollingHash

// Execution FPI
executionFPI := execution.FunctionalPublicInput{
FinalStateRootHash: aggregationFPI.InitialStateRootHash,
Expand Down Expand Up @@ -211,6 +213,11 @@ func (c *Compiled) Assign(r Request) (a Circuit, err error) {
executionFPI.L2MessageHashes = r.Executions[i].L2MsgHashes

l2MessageHashes = append(l2MessageHashes, r.Executions[i].L2MsgHashes...)

if r.Executions[i].FinalRollingHashNumber != 0 { // if the rolling hash is being updated, record the change
finalRollingHash = r.Executions[i].FinalRollingHash
finalRollingHashNum = r.Executions[i].FinalRollingHashNumber
}
}

a.ExecutionPublicInput[i] = executionFPI.Sum()
Expand All @@ -232,18 +239,17 @@ func (c *Compiled) Assign(r Request) (a Circuit, err error) {
executionFPI.FinalBlockNumber, aggregationFPI.FinalBlockNumber)
return
}
if executionFPI.FinalRollingHash != [32]byte{} {
if executionFPI.FinalRollingHash != aggregationFPI.FinalRollingHash {
err = fmt.Errorf("final rolling hashes do not match: execution=%x, aggregation=%x",
executionFPI.FinalRollingHash, aggregationFPI.FinalRollingHash)
return
}

if executionFPI.FinalRollingHashNumber != aggregationFPI.FinalRollingHashNumber {
err = fmt.Errorf("final rolling hash numbers do not match: execution=%v, aggregation=%v",
executionFPI.FinalRollingHashNumber, aggregationFPI.FinalRollingHashNumber)
return
}
if finalRollingHash != aggregationFPI.FinalRollingHash {
err = fmt.Errorf("final rolling hashes do not match: execution=%x, aggregation=%x",
executionFPI.FinalRollingHash, aggregationFPI.FinalRollingHash)
return
}

if finalRollingHashNum != aggregationFPI.FinalRollingHashNumber {
err = fmt.Errorf("final rolling hash numbers do not match: execution=%v, aggregation=%v",
executionFPI.FinalRollingHashNumber, aggregationFPI.FinalRollingHashNumber)
return
}

if len(l2MessageHashes) > maxNbL2MessageHashes {
Expand Down
59 changes: 26 additions & 33 deletions prover/circuits/pi-interconnection/circuit.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package pi_interconnection
import (
"errors"
"fmt"
"math"
"math/big"
"slices"

Expand Down Expand Up @@ -112,8 +113,11 @@ func (c *Circuit) Define(api frontend.API) error {

shnarfs := ComputeShnarfs(&hshK, c.ParentShnarf, shnarfParams)

rExecution := internal.NewRange(api, nbExecution, maxNbExecution)

initBlockNum, initHashNum, initHash := c.InitialBlockNumber, c.InitialRollingHashNumber, c.InitialRollingHash
initBlockTime, initState := c.InitialBlockTimestamp, c.InitialStateRootHash
finalRollingHash, finalRollingHashNum := c.InitialRollingHash, c.InitialRollingHashNumber
var l2MessagesByByte [32][]internal.VarSlice

execMaxNbL2Msg := len(c.ExecutionFPIQ[0].L2MessageHashes.Values)
Expand All @@ -130,9 +134,17 @@ func (c *Circuit) Define(api frontend.API) error {
for i, piq := range c.ExecutionFPIQ {
piq.RangeCheck(api)

comparator.IsLess(initBlockTime, piq.FinalBlockTimestamp)
comparator.IsLess(initBlockNum, piq.FinalBlockNumber)
comparator.IsLess(initHashNum, piq.FinalRollingHashNumber)
inRange := rExecution.InRange[i]
rollingHashNotUpdated := api.Select(inRange, api.IsZero(piq.FinalRollingHashNumber), 1) // padded input past nbExecutions is not required to be 0. So we multiply by inRange

newFinalRollingHashNum := api.Select(rollingHashNotUpdated, finalRollingHashNum, piq.FinalRollingHashNumber)

api.AssertIsEqual(comparator.IsLess(initBlockTime, api.Select(inRange, piq.FinalBlockTimestamp, uint64(math.MaxUint64))), 1) // don't compare if not updating
api.AssertIsEqual(comparator.IsLess(initBlockNum, api.Select(inRange, piq.FinalBlockNumber, uint64(math.MaxUint64))), 1)
api.AssertIsEqual(comparator.IsLess(finalRollingHashNum, api.Add(newFinalRollingHashNum, rollingHashNotUpdated)), 1) // if the rolling hash is updated, check that it has increased

finalRollingHashNum = newFinalRollingHashNum
copy(finalRollingHash[:], internal.SelectMany(api, rollingHashNotUpdated, finalRollingHash[:], piq.FinalRollingHash[:]))

pi := execution.FunctionalPublicInputSnark{
FunctionalPublicInputQSnark: piq,
Expand Down Expand Up @@ -162,51 +174,32 @@ func (c *Circuit) Define(api frontend.API) error {
}
}

merkleLeavesConcat := internal.Slice[[32]frontend.Variable]{Values: make([][32]frontend.Variable, c.L2MessageMaxNbMerkle*merkleNbLeaves)}
merkleLeavesConcat := internal.Var32Slice{Values: make([][32]frontend.Variable, c.L2MessageMaxNbMerkle*merkleNbLeaves)}
for i := 0; i < 32; i++ {
ithBytes := internal.Concat(api, len(merkleLeavesConcat.Values), l2MessagesByByte[i]...)
for j := range merkleLeavesConcat.Values {
merkleLeavesConcat.Values[j][i] = ithBytes.Values[j]
}
merkleLeavesConcat.Length = ithBytes.Length // same value regardless of i
}
rExecution := internal.NewRange(api, nbExecution, maxNbExecution)

twoPow8 := big.NewInt(256)
hi16B := func(block [32]frontend.Variable) frontend.Variable {
return compress.ReadNum(api, block[:16], twoPow8)
}
lo16B := func(block [32]frontend.Variable) frontend.Variable {
return compress.ReadNum(api, block[16:], twoPow8)
}

{ // if rolling hash values are present in the last execution, they must match those of aggregation
finalRollingHashFromExec := rExecution.LastArray32F(func(i int) [32]frontend.Variable { return c.ExecutionFPIQ[i].FinalRollingHash })
finalRollingHashNumFromExec := rExecution.LastF(func(i int) frontend.Variable { return c.ExecutionFPIQ[i].FinalRollingHashNumber })

h, l := hi16B(finalRollingHashFromExec), lo16B(finalRollingHashFromExec)

finalRollingHashPresent := api.Sub(1, api.Mul(api.IsZero(h), api.IsZero(l)))

internal.AssertEqualIf(api, finalRollingHashPresent, h, hi16B(c.FinalRollingHash))
internal.AssertEqualIf(api, finalRollingHashPresent, l, lo16B(c.FinalRollingHash))
internal.AssertEqualIf(api, finalRollingHashPresent, finalRollingHashNumFromExec, c.FinalRollingHashNumber)
}

pi := public_input.AggregationFPISnark{
AggregationFPIQSnark: c.AggregationFPIQSnark,
NbL2Messages: merkleLeavesConcat.Length,
L2MsgMerkleTreeRoots: make([][32]frontend.Variable, c.L2MessageMaxNbMerkle),
FinalBlockNumber: rExecution.LastF(func(i int) frontend.Variable { return c.ExecutionFPIQ[i].FinalBlockNumber }),
FinalBlockTimestamp: rExecution.LastF(func(i int) frontend.Variable { return c.ExecutionFPIQ[i].FinalBlockTimestamp }),
FinalShnarf: rDecompression.LastArray32(shnarfs),
L2MsgMerkleTreeDepth: c.L2MessageMerkleDepth,
AggregationFPIQSnark: c.AggregationFPIQSnark,
NbL2Messages: merkleLeavesConcat.Length,
L2MsgMerkleTreeRoots: make([][32]frontend.Variable, c.L2MessageMaxNbMerkle),
FinalBlockNumber: rExecution.LastF(func(i int) frontend.Variable { return c.ExecutionFPIQ[i].FinalBlockNumber }),
FinalBlockTimestamp: rExecution.LastF(func(i int) frontend.Variable { return c.ExecutionFPIQ[i].FinalBlockTimestamp }),
FinalShnarf: rDecompression.LastArray32(shnarfs),
FinalRollingHashNumber: finalRollingHashNum,
FinalRollingHash: finalRollingHash,
L2MsgMerkleTreeDepth: c.L2MessageMerkleDepth,
}

for i := range pi.L2MsgMerkleTreeRoots {
pi.L2MsgMerkleTreeRoots[i] = MerkleRootSnark(&hshK, merkleLeavesConcat.Values[i*merkleNbLeaves:(i+1)*merkleNbLeaves])
}

twoPow8 := big.NewInt(256)
// "open" aggregation public input
aggregationPIBytes := pi.Sum(api, &hshK)
api.AssertIsEqual(c.AggregationPublicInput[0], compress.ReadNum(api, aggregationPIBytes[:16], twoPow8))
Expand Down
31 changes: 5 additions & 26 deletions prover/circuits/pi-interconnection/e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ import (

// some of the execution data are faked
func TestSingleBlockBlob(t *testing.T) {
testPI(t, pitesting.AssignSingleBlockBlob(t), withSlack(0, 1, 2))
testPI(t, pitesting.AssignSingleBlockBlob(t), withSlack(0, 2))
}

func TestSingleBlobBlobE2E(t *testing.T) {
func TestSingleBlockBlobE2E(t *testing.T) {
req := pitesting.AssignSingleBlockBlob(t)
cfg := config.PublicInput{
MaxNbDecompression: len(req.Decompressions),
Expand Down Expand Up @@ -123,7 +123,7 @@ func TestTinyTwoBatchBlob(t *testing.T) {
},
}

testPI(t, req, withSlack(0, 1, 2))
testPI(t, req, withSlack(0, 2))
}

func TestTwoTwoBatchBlobs(t *testing.T) {
Expand Down Expand Up @@ -204,29 +204,7 @@ func TestTwoTwoBatchBlobs(t *testing.T) {
},
}

testPI(t, req, withSlack(0, 1, 2))
}

func TestEmpty(t *testing.T) {
const hexZeroBlock = "0x0000000000000000000000000000000000000000000000000000000000000000"

testPI(t, pi_interconnection.Request{
Aggregation: public_input.Aggregation{
FinalShnarf: hexZeroBlock,
ParentAggregationFinalShnarf: hexZeroBlock,
ParentStateRootHash: hexZeroBlock,
ParentAggregationLastBlockTimestamp: 0,
FinalTimestamp: 0,
LastFinalizedBlockNumber: 0,
FinalBlockNumber: 0,
LastFinalizedL1RollingHash: hexZeroBlock,
L1RollingHash: hexZeroBlock,
LastFinalizedL1RollingHashMessageNumber: 0,
L1RollingHashMessageNumber: 0,
L2MsgRootHashes: []string{},
L2MsgMerkleTreeDepth: 1,
},
})
testPI(t, req, withSlack(0, 2))
}

type testPIConfig struct {
Expand Down Expand Up @@ -269,6 +247,7 @@ func testPI(t *testing.T, req pi_interconnection.Request, options ...testPIOptio
ExecutionMaxNbMsg: 1 + slack[2],
L2MsgMerkleDepth: 5,
L2MsgMaxNbMerkle: 1 + slack[3],
MockKeccakWizard: true,
}

t.Run(fmt.Sprintf("slack profile %v", slack), func(t *testing.T) {
Expand Down
34 changes: 16 additions & 18 deletions prover/public-input/aggregation.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,15 @@ func (pi *AggregationFPI) ToSnarkType() AggregationFPISnark {
InitialRollingHashNumber: pi.InitialRollingHashNumber,
InitialStateRootHash: pi.InitialStateRootHash[:],

NbDecompression: pi.NbDecompression,
ChainID: pi.ChainID,
L2MessageServiceAddr: pi.L2MessageServiceAddr[:],
FinalRollingHashNumber: pi.FinalRollingHashNumber,
NbDecompression: pi.NbDecompression,
ChainID: pi.ChainID,
L2MessageServiceAddr: pi.L2MessageServiceAddr[:],
},
L2MsgMerkleTreeRoots: make([][32]frontend.Variable, len(pi.L2MsgMerkleTreeRoots)),
FinalBlockNumber: pi.FinalBlockNumber,
FinalBlockTimestamp: pi.FinalBlockTimestamp,
L2MsgMerkleTreeDepth: pi.L2MsgMerkleTreeDepth,
L2MsgMerkleTreeRoots: make([][32]frontend.Variable, len(pi.L2MsgMerkleTreeRoots)),
FinalBlockNumber: pi.FinalBlockNumber,
FinalBlockTimestamp: pi.FinalBlockTimestamp,
L2MsgMerkleTreeDepth: pi.L2MsgMerkleTreeDepth,
FinalRollingHashNumber: pi.FinalRollingHashNumber,
}

utils.Copy(s.FinalRollingHash[:], pi.FinalRollingHash[:])
Expand All @@ -154,23 +154,21 @@ type AggregationFPIQSnark struct {
InitialBlockTimestamp frontend.Variable
InitialRollingHash [32]frontend.Variable
InitialRollingHashNumber frontend.Variable
// Ideally, FinalRollingHash and FinalRollingHashNumber would be inferred from the executions
// but sometimes executions are missing those values
FinalRollingHash [32]frontend.Variable
FinalRollingHashNumber frontend.Variable
ChainID frontend.Variable // for now we're forcing all executions to have the same chain ID
L2MessageServiceAddr frontend.Variable // 20 bytes
ChainID frontend.Variable // for now we're forcing all executions to have the same chain ID
L2MessageServiceAddr frontend.Variable // 20 bytes
}

type AggregationFPISnark struct {
AggregationFPIQSnark
NbL2Messages frontend.Variable // TODO not used in hash. delete if not necessary
L2MsgMerkleTreeRoots [][32]frontend.Variable
// FinalStateRootHash frontend.Variable redundant: incorporated into final shnarf
FinalBlockNumber frontend.Variable
FinalBlockTimestamp frontend.Variable
FinalShnarf [32]frontend.Variable
L2MsgMerkleTreeDepth int
FinalBlockNumber frontend.Variable
FinalBlockTimestamp frontend.Variable
FinalShnarf [32]frontend.Variable
FinalRollingHash [32]frontend.Variable
FinalRollingHashNumber frontend.Variable
L2MsgMerkleTreeDepth int
}

// NewAggregationFPI does NOT set all fields, only the ones covered in public_input.Aggregation
Expand Down

0 comments on commit 7a4d510

Please sign in to comment.