Skip to content

Commit

Permalink
Prover/full recursion (#350)
Browse files Browse the repository at this point in the history
* (feat): Implements the full-recursion and test for a simple test
* msg(sticker): better error in the sticker
* test(full-recursion): adds a test for double full-recursion (overflowing memory)
* fix: sort out the packages after rebasing
* fix(pi): renaming of the public inputs
* fix(hasher): adjust the code to using a [hash.StateStorer]
* fix(pairing): pass the new format for fp12 elements
* doc(plonk): adds more doc in plonk.alignment.go
* doc(fs-hook): improves the documentation of the FiatShamirHook field.
* docs(skipping): adds doc on the ByRoundRegister
* feat(pubinp): move the zkevm public inputs to using the new public-input framework
* doc(column-store): adds documentation for the more precise methods regarding the inclusion in the FS transcript.
* clean(self-recursion): remove the self-recursion tuning file
* doc(vortex): explain the separation between the verifier steps
* doc(full-recursion): documents the prover and verifier actions
* doc(columns): improve the documentation on the IncludeInProverFS
  • Loading branch information
AlexandreBelling authored Jan 6, 2025
1 parent 9dc4304 commit 3b875fd
Show file tree
Hide file tree
Showing 44 changed files with 1,559 additions and 326 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,5 @@ type circuit struct {

func (c *circuit) Define(api frontend.API) error {
hsh := gkrmimc.NewHasherFactory(api).NewHasher()
return v1.CheckBatchesSums(api, &hsh, c.NbBatches, c.BlobPayload[:], c.BatchEnds[:], c.ExpectedSums[:])
return v1.CheckBatchesSums(api, hsh, c.NbBatches, c.BlobPayload[:], c.BatchEnds[:], c.ExpectedSums[:])
}
8 changes: 4 additions & 4 deletions prover/circuits/blobdecompression/v1/circuit.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ import (
"bytes"
"errors"
"fmt"
"github.com/consensys/linea-monorepo/prover/lib/compressor/blob/dictionary"
"github.com/consensys/linea-monorepo/prover/lib/compressor/blob/encode"
"hash"
"math/big"

"github.com/consensys/linea-monorepo/prover/lib/compressor/blob/dictionary"
"github.com/consensys/linea-monorepo/prover/lib/compressor/blob/encode"

"github.com/consensys/gnark-crypto/ecc"
fr377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
fr381 "github.com/consensys/gnark-crypto/ecc/bls12-381/fr"
Expand Down Expand Up @@ -192,8 +193,7 @@ func (i *FunctionalPublicInputSnark) Sum(api frontend.API, hsh snarkHash.FieldHa
func (c Circuit) Define(api frontend.API) error {
var hsh snarkHash.FieldHasher
if c.UseGkrMiMC {
h := gkrmimc.NewHasherFactory(api).NewHasher()
hsh = &h
hsh = gkrmimc.NewHasherFactory(api).NewHasher()
} else {
if h, err := mimc.NewMiMC(api); err != nil {
return err
Expand Down
8 changes: 0 additions & 8 deletions prover/circuits/execution/circuit.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"github.com/consensys/linea-monorepo/prover/circuits"
"github.com/consensys/linea-monorepo/prover/protocol/wizard"
"github.com/consensys/linea-monorepo/prover/zkevm"
"github.com/consensys/linea-monorepo/prover/zkevm/prover/publicInput"
"github.com/sirupsen/logrus"

"github.com/consensys/gnark/std/hash/mimc"
Expand All @@ -23,11 +22,6 @@ import (
type CircuitExecution struct {
// The wizard verifier circuit
WizardVerifier wizard.WizardVerifierCircuit `gnark:",secret"`
// The extractor is not part of the circuit per se, but hold informations
// that is used to extract the public inputs from the the WizardVerifier.
// The extractor only needs to be provided during the definition of the
// circuit and is omitted during the assignment of the circuit.
extractor publicInput.FunctionalInputExtractor `gnark:"-"`
// The functional public inputs are the "actual" statement made by the
// circuit. They are not part of the public input of the circuit for
// a number of reasons involving efficiency and simplicity in the aggregation
Expand All @@ -45,7 +39,6 @@ func Allocate(zkevm *zkevm.ZkEvm) CircuitExecution {
}
return CircuitExecution{
WizardVerifier: *wverifier,
extractor: zkevm.PublicInput.Extractor,
FuncInputs: FunctionalPublicInputSnark{
FunctionalPublicInputQSnark: FunctionalPublicInputQSnark{
L2MessageHashes: L2MessageHashes{
Expand Down Expand Up @@ -90,7 +83,6 @@ func (c *CircuitExecution) Define(api frontend.API) error {
api,
&c.WizardVerifier,
c.FuncInputs,
c.extractor,
)

// Add missing public input check
Expand Down
42 changes: 20 additions & 22 deletions prover/circuits/execution/pi_wizard_extraction.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,12 @@ func checkPublicInputs(
api frontend.API,
wvc *wizard.WizardVerifierCircuit,
gnarkFuncInp FunctionalPublicInputSnark,
wizardFuncInp publicInput.FunctionalInputExtractor,
) {

var (
lastRollingHash = internal.CombineBytesIntoElements(api, gnarkFuncInp.FinalRollingHashUpdate)
firstRollingHash = internal.CombineBytesIntoElements(api, gnarkFuncInp.InitialRollingHashUpdate)
execDataHash = execDataHash(api, wvc, wizardFuncInp)
execDataHash = execDataHash(api, wvc)
)

// As we have this issue, the execDataHash will not match what we have in the
Expand All @@ -32,70 +31,70 @@ func checkPublicInputs(
shouldBeEqual(api, execDataHash, gnarkFuncInp.DataChecksum)

api.AssertIsEqual(
wvc.GetLocalPointEvalParams(wizardFuncInp.L2MessageHash.ID).Y,
wvc.GetPublicInput(api, publicInput.L2MessageHash),
// TODO: this operation is done a second time when computing the final
// public input which is wasteful although not dramatic (~8000 unused
// constraints)
gnarkFuncInp.L2MessageHashes.CheckSumMiMC(api),
)

api.AssertIsEqual(
wvc.GetLocalPointEvalParams(wizardFuncInp.InitialStateRootHash.ID).Y,
wvc.GetPublicInput(api, publicInput.InitialStateRootHash),
gnarkFuncInp.InitialStateRootHash,
)

api.AssertIsEqual(
wvc.GetLocalPointEvalParams(wizardFuncInp.InitialBlockNumber.ID).Y,
wvc.GetPublicInput(api, publicInput.InitialBlockNumber),
gnarkFuncInp.InitialBlockNumber,
)

api.AssertIsEqual(
wvc.GetLocalPointEvalParams(wizardFuncInp.InitialBlockTimestamp.ID).Y,
wvc.GetPublicInput(api, publicInput.InitialBlockTimestamp),
gnarkFuncInp.InitialBlockTimestamp,
)

api.AssertIsEqual(
wvc.GetLocalPointEvalParams(wizardFuncInp.FirstRollingHashUpdate[0].ID).Y,
wvc.GetPublicInput(api, publicInput.FirstRollingHashUpdate_0),
firstRollingHash[0],
)

api.AssertIsEqual(
wvc.GetLocalPointEvalParams(wizardFuncInp.FirstRollingHashUpdate[1].ID).Y,
wvc.GetPublicInput(api, publicInput.FirstRollingHashUpdate_1),
firstRollingHash[1],
)

api.AssertIsEqual(
wvc.GetLocalPointEvalParams(wizardFuncInp.FirstRollingHashUpdateNumber.ID).Y,
wvc.GetPublicInput(api, publicInput.FirstRollingHashUpdateNumber),
gnarkFuncInp.FirstRollingHashUpdateNumber,
)

api.AssertIsEqual(
wvc.GetLocalPointEvalParams(wizardFuncInp.FinalStateRootHash.ID).Y,
wvc.GetPublicInput(api, publicInput.FinalStateRootHash),
gnarkFuncInp.FinalStateRootHash,
)

api.AssertIsEqual(
wvc.GetLocalPointEvalParams(wizardFuncInp.FinalBlockNumber.ID).Y,
wvc.GetPublicInput(api, publicInput.FinalBlockNumber),
gnarkFuncInp.FinalBlockNumber,
)

api.AssertIsEqual(
wvc.GetLocalPointEvalParams(wizardFuncInp.FinalBlockTimestamp.ID).Y,
wvc.GetPublicInput(api, publicInput.FinalBlockTimestamp),
gnarkFuncInp.FinalBlockTimestamp,
)

api.AssertIsEqual(
wvc.GetLocalPointEvalParams(wizardFuncInp.LastRollingHashUpdate[0].ID).Y,
wvc.GetPublicInput(api, publicInput.LastRollingHashUpdate_0),
lastRollingHash[0],
)

api.AssertIsEqual(
wvc.GetLocalPointEvalParams(wizardFuncInp.LastRollingHashUpdate[1].ID).Y,
wvc.GetPublicInput(api, publicInput.LastRollingHashUpdate_1),
lastRollingHash[1],
)

api.AssertIsEqual(
wvc.GetLocalPointEvalParams(wizardFuncInp.LastRollingHashUpdateNumber.ID).Y,
wvc.GetPublicInput(api, publicInput.LastRollingHashNumberUpdate),
gnarkFuncInp.LastRollingHashUpdateNumber,
)

Expand All @@ -107,9 +106,9 @@ func checkPublicInputs(
bridgeAddress = api.Add(
api.Mul(
twoPow128,
wizardFuncInp.L2MessageServiceAddrHi.GetFrontendVariable(api, wvc),
wvc.GetPublicInput(api, publicInput.L2MessageServiceAddrHi),
),
wizardFuncInp.L2MessageServiceAddrLo.GetFrontendVariable(api, wvc),
wvc.GetPublicInput(api, publicInput.L2MessageServiceAddrLo),
)
)

Expand All @@ -119,10 +118,10 @@ func checkPublicInputs(
// chainID) then the traces will return a chainID of zero.
api.AssertIsEqual(
api.Mul(
wvc.GetLocalPointEvalParams(wizardFuncInp.ChainID.ID).Y,
wvc.GetPublicInput(api, publicInput.ChainID),
api.Sub(
api.Div(
wvc.GetLocalPointEvalParams(wizardFuncInp.ChainID.ID).Y,
wvc.GetPublicInput(api, publicInput.ChainID),
twoPow112,
),
gnarkFuncInp.ChainID,
Expand All @@ -141,7 +140,6 @@ func checkPublicInputs(
func execDataHash(
api frontend.API,
wvc *wizard.WizardVerifierCircuit,
wFuncInp publicInput.FunctionalInputExtractor,
) frontend.Variable {

hsh, err := mimc.NewMiMC(api)
Expand All @@ -150,8 +148,8 @@ func execDataHash(
}

hsh.Write(
wvc.GetLocalPointEvalParams(wFuncInp.DataNbBytes.ID).Y,
wvc.GetLocalPointEvalParams(wFuncInp.DataChecksum.ID).Y,
wvc.GetPublicInput(api, publicInput.DataNbBytes),
wvc.GetPublicInput(api, publicInput.DataChecksum),
)

return hsh.Sum()
Expand Down
3 changes: 1 addition & 2 deletions prover/circuits/pi-interconnection/circuit.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,7 @@ func (c *Circuit) Define(api frontend.API) error {
hshM hash.FieldHasher
)
if c.UseGkrMimc {
hsh := gkrmimc.NewHasherFactory(api).NewHasher()
hshM = &hsh
hshM = gkrmimc.NewHasherFactory(api).NewHasher()
} else {
if hsh, err := mimc.NewMiMC(api); err != nil {
return err
Expand Down
2 changes: 2 additions & 0 deletions prover/circuits/pi-interconnection/e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ func TestTinyTwoBatchBlob(t *testing.T) {
req := pi_interconnection.Request{
Decompressions: []blobsubmission.Response{*blobResp},
Executions: execReq,
DictPath: "../../lib/compressor/compressor_dict.bin",
Aggregation: public_input.Aggregation{
FinalShnarf: blobResp.ExpectedShnarf,
ParentAggregationFinalShnarf: blobReq.PrevShnarf,
Expand Down Expand Up @@ -208,6 +209,7 @@ func TestTwoTwoBatchBlobs(t *testing.T) {
req := pi_interconnection.Request{
Decompressions: []blobsubmission.Response{*blobResp0, *blobResp1},
Executions: execReq,
DictPath: "../../lib/compressor/compressor_dict.bin",
Aggregation: public_input.Aggregation{
FinalShnarf: blobResp1.ExpectedShnarf,
ParentAggregationFinalShnarf: blobReq0.PrevShnarf,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ func AssignSingleBlockBlob(t require.TestingT) pi_interconnection.Request {
merkleRoots := aggregation.PackInMiniTrees(test_utils.BlocksToHex(execReq.L2MessageHashes))

return pi_interconnection.Request{
DictPath: "../../lib/compressor/compressor_dict.bin",
Decompressions: []blobsubmission.Response{*blobResp},
Executions: []public_input.Execution{execReq},
Aggregation: public_input.Aggregation{
Expand Down
22 changes: 19 additions & 3 deletions prover/crypto/fiatshamir/fiatshamir.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package fiatshamir

import (
"hash"
"math"

"github.com/consensys/gnark-crypto/hash"
"github.com/consensys/linea-monorepo/prover/crypto/mimc"
"github.com/consensys/linea-monorepo/prover/maths/common/smartvectors"
"github.com/consensys/linea-monorepo/prover/maths/field"
Expand Down Expand Up @@ -31,18 +31,34 @@ import (
//
// https://blog.trailofbits.com/2022/04/18/the-frozen-heart-vulnerability-in-plonk/
type State struct {
hasher hash.Hash
hasher hash.StateStorer
TranscriptSize int
NumCoinGenerated int
}

// NewMiMCFiatShamir constructs a fresh and empty Fiat-Shamir state.
func NewMiMCFiatShamir() *State {
return &State{
hasher: mimc.NewMiMC(),
hasher: mimc.NewMiMC().(hash.StateStorer),
}
}

// State returns the internal state of the Fiat-Shamir hasher. Only works for
// MiMC.
func (s *State) State() []field.Element {
_ = s.hasher.Sum(nil)
b := s.hasher.State()
f := new(field.Element).SetBytes(b)
return []field.Element{*f}
}

// SetState sets the fiat-shamir state to the requested value
func (s *State) SetState(f []field.Element) {
_ = s.hasher.Sum(nil)
b := f[0].Bytes()
s.hasher.SetState(b[:])
}

// Update the Fiat-Shamir state with a one or more of field elements. The
// function as no-op if the caller supplies no field elements.
func (fs *State) Update(vec ...field.Element) {
Expand Down
34 changes: 31 additions & 3 deletions prover/crypto/fiatshamir/snark.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import (
// of the verifier of a protocol calling [State] as it allows having a very
// similar code for both tasks.
type GnarkFiatShamir struct {
hasher hash.FieldHasher
hasher hash.StateStorer
// pointer to the gnark-API (also passed to the hasher but behind an
// interface). This is needed to perform bit-decomposition.
api frontend.API
Expand All @@ -30,10 +30,10 @@ type GnarkFiatShamir struct {
// used in the scope of a [frontend.Define] function.
func NewGnarkFiatShamir(api frontend.API, factory *gkrmimc.HasherFactory) *GnarkFiatShamir {

var hasher hash.FieldHasher
var hasher hash.StateStorer
if factory != nil {
h := factory.NewHasher()
hasher = &h
hasher = h
} else {
h, err := mimc.NewMiMC(api)
if err != nil {
Expand All @@ -52,6 +52,34 @@ func NewGnarkFiatShamir(api frontend.API, factory *gkrmimc.HasherFactory) *Gnark
}
}

// SetState mutates the fiat-shamir state of
func (fs *GnarkFiatShamir) SetState(state []frontend.Variable) {

switch hsh := fs.hasher.(type) {
case interface {
SetState([]frontend.Variable) error
}:
if err := hsh.SetState(state); err != nil {
panic(err)
}
default:
panic("unexpected hasher type")
}
}

// State mutates the fiat-shamir state of
func (fs *GnarkFiatShamir) State() []frontend.Variable {

switch hsh := fs.hasher.(type) {
case interface {
State() []frontend.Variable
}:
return hsh.State()
default:
panic("unexpected hasher type")
}
}

// Update updates the Fiat-Shamir state with a vector of frontend.Variable
// representing field element each.
func (fs *GnarkFiatShamir) Update(vec ...frontend.Variable) {
Expand Down
29 changes: 27 additions & 2 deletions prover/crypto/mimc/gkrmimc/helper.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package gkrmimc

import (
"errors"
"math/big"

"github.com/consensys/gnark/frontend"
Expand Down Expand Up @@ -70,8 +71,8 @@ type Hasher struct {
// and will provide the same results for the same usage.
//
// However, the hasher should not be used in deferred gnark circuit execution.
func (f *HasherFactory) NewHasher() Hasher {
return Hasher{factory: f, state: frontend.Variable(0)}
func (f *HasherFactory) NewHasher() *Hasher {
return &Hasher{factory: f, state: frontend.Variable(0)}
}

// Writes fields elements into the hasher; implements [hash.FieldHasher]
Expand Down Expand Up @@ -107,6 +108,30 @@ func (h *Hasher) Sum() frontend.Variable {
return curr
}

// SetState manually sets the state of the hasher to the provided value. In the
// case of MiMC only a single frontend variable is expected to represent the
// state.
func (h *Hasher) SetState(newState []frontend.Variable) error {

if len(h.data) > 0 {
return errors.New("the hasher is not in an initial state")
}

if len(newState) != 1 {
return errors.New("the MiMC hasher expects a single field element to represent the state")
}

h.state = newState[0]
return nil
}

// State returns the inner-state of the hasher. In the context of MiMC only a
// single field element is returned.
func (h *Hasher) State() []frontend.Variable {
_ = h.Sum() // to flush the hasher
return []frontend.Variable{h.state}
}

// compress calls returns a frontend.Variable holding the result of applying
// the compression function of MiMC over state and block. The alleged returned
// result is pushed on the stack of all the claims to verify.
Expand Down
Loading

0 comments on commit 3b875fd

Please sign in to comment.