From 3b875fd8d65a1be8becb9eb8a3f9c1106a139ee1 Mon Sep 17 00:00:00 2001 From: AlexandreBelling Date: Mon, 6 Jan 2025 09:52:01 +0100 Subject: [PATCH] Prover/full recursion (#350) * (feat): Implements the full-recursion and test for a simple test * msg(sticker): better error in the sticker * test(full-recursion): adds a test for double full-recursion (overflowing memory) * fix: sort out the packages after rebasing * fix(pi): renaming of the public inputs * fix(hasher): adjust the code to using a [hash.StateStorer] * fix(pairing): pass the new format for fp12 elements * doc(plonk): adds more doc in plonk.alignment.go * doc(fs-hook): improves the documentation of the FiatShamirHook field. * docs(skipping): adds doc on the ByRoundRegister * feat(pubinp): move the zkevm public inputs to using the new public-input framework * doc(column-store): adds documentation for the more precise methods regarding the inclusion in the FS transcript. * clean(self-recursion): remove the self-recursion tuning file * doc(vortex): explain the separation between the verifier steps * doc(full-recursion): documents the prover and verifier actions * doc(columns): improve the documentation on the IncludeInProverFS --- .../large-tests/compile-batch-hasher/main.go | 2 +- .../circuits/blobdecompression/v1/circuit.go | 8 +- prover/circuits/execution/circuit.go | 8 - .../execution/pi_wizard_extraction.go | 42 ++- prover/circuits/pi-interconnection/circuit.go | 3 +- .../circuits/pi-interconnection/e2e_test.go | 2 + .../test_utils/test_utils.go | 1 + prover/crypto/fiatshamir/fiatshamir.go | 22 +- prover/crypto/fiatshamir/snark.go | 34 ++- prover/crypto/mimc/gkrmimc/helper.go | 29 +- prover/go.mod | 17 +- prover/go.sum | 34 ++- prover/protocol/column/store.go | 41 +++ .../compiler/dummy/dummy_prover_level.go | 2 - .../compiler/fullrecursion/actions.go | 207 ++++++++++++++ .../compiler/fullrecursion/circuit.go | 255 ++++++++++++++++++ .../compiler/fullrecursion/full_recursion.go | 230 ++++++++++++++++ .../fullrecursion/full_recursion_test.go | 105 ++++++++ prover/protocol/compiler/globalcs/compile.go | 2 +- .../protocol/compiler/globalcs/evaluation.go | 17 +- .../compiler/innerproduct/verifier.go | 9 + prover/protocol/compiler/lookup/compiler.go | 4 +- prover/protocol/compiler/lookup/verifier.go | 9 + .../protocol/compiler/permutation/compiler.go | 2 +- .../protocol/compiler/permutation/verifier.go | 21 +- .../compiler/selfrecursion/context.go | 4 +- .../compiler/splitter/sticker/sticker.go | 2 +- prover/protocol/compiler/vortex/compiler.go | 6 +- .../compiler/vortex/gnark_verifier.go | 6 +- prover/protocol/compiler/vortex/verifier.go | 5 - prover/protocol/dedicated/plonk/alignment.go | 42 ++- prover/protocol/dedicated/plonk/compile.go | 37 ++- .../dedicated/projection/projection.go | 15 +- prover/protocol/query/local_opening.go | 15 -- prover/protocol/wizard/actions.go | 27 ++ prover/protocol/wizard/builder.go | 14 +- prover/protocol/wizard/compiled.go | 26 +- prover/protocol/wizard/gnark_verifier.go | 213 ++++++++++----- prover/protocol/wizard/prover.go | 56 ++-- prover/protocol/wizard/public_input.go | 14 + prover/protocol/wizard/register.go | 61 ++++- prover/protocol/wizard/verifier.go | 77 ++++-- prover/zkevm/prover/ecpair/circuits.go | 115 +++++--- .../zkevm/prover/publicInput/public_input.go | 44 +++ 44 files changed, 1559 insertions(+), 326 deletions(-) create mode 100644 prover/protocol/compiler/fullrecursion/actions.go create mode 100644 prover/protocol/compiler/fullrecursion/circuit.go create mode 100644 prover/protocol/compiler/fullrecursion/full_recursion.go create mode 100644 prover/protocol/compiler/fullrecursion/full_recursion_test.go create mode 100644 prover/protocol/wizard/public_input.go diff --git a/prover/circuits/blobdecompression/large-tests/compile-batch-hasher/main.go b/prover/circuits/blobdecompression/large-tests/compile-batch-hasher/main.go index 9e20624a4..5dd6076bb 100644 --- a/prover/circuits/blobdecompression/large-tests/compile-batch-hasher/main.go +++ b/prover/circuits/blobdecompression/large-tests/compile-batch-hasher/main.go @@ -28,5 +28,5 @@ type circuit struct { func (c *circuit) Define(api frontend.API) error { hsh := gkrmimc.NewHasherFactory(api).NewHasher() - return v1.CheckBatchesSums(api, &hsh, c.NbBatches, c.BlobPayload[:], c.BatchEnds[:], c.ExpectedSums[:]) + return v1.CheckBatchesSums(api, hsh, c.NbBatches, c.BlobPayload[:], c.BatchEnds[:], c.ExpectedSums[:]) } diff --git a/prover/circuits/blobdecompression/v1/circuit.go b/prover/circuits/blobdecompression/v1/circuit.go index db864aea6..a04669002 100644 --- a/prover/circuits/blobdecompression/v1/circuit.go +++ b/prover/circuits/blobdecompression/v1/circuit.go @@ -4,11 +4,12 @@ import ( "bytes" "errors" "fmt" - "github.com/consensys/linea-monorepo/prover/lib/compressor/blob/dictionary" - "github.com/consensys/linea-monorepo/prover/lib/compressor/blob/encode" "hash" "math/big" + "github.com/consensys/linea-monorepo/prover/lib/compressor/blob/dictionary" + "github.com/consensys/linea-monorepo/prover/lib/compressor/blob/encode" + "github.com/consensys/gnark-crypto/ecc" fr377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" fr381 "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" @@ -192,8 +193,7 @@ func (i *FunctionalPublicInputSnark) Sum(api frontend.API, hsh snarkHash.FieldHa func (c Circuit) Define(api frontend.API) error { var hsh snarkHash.FieldHasher if c.UseGkrMiMC { - h := gkrmimc.NewHasherFactory(api).NewHasher() - hsh = &h + hsh = gkrmimc.NewHasherFactory(api).NewHasher() } else { if h, err := mimc.NewMiMC(api); err != nil { return err diff --git a/prover/circuits/execution/circuit.go b/prover/circuits/execution/circuit.go index 36f772c3d..3403911c6 100644 --- a/prover/circuits/execution/circuit.go +++ b/prover/circuits/execution/circuit.go @@ -12,7 +12,6 @@ import ( "github.com/consensys/linea-monorepo/prover/circuits" "github.com/consensys/linea-monorepo/prover/protocol/wizard" "github.com/consensys/linea-monorepo/prover/zkevm" - "github.com/consensys/linea-monorepo/prover/zkevm/prover/publicInput" "github.com/sirupsen/logrus" "github.com/consensys/gnark/std/hash/mimc" @@ -23,11 +22,6 @@ import ( type CircuitExecution struct { // The wizard verifier circuit WizardVerifier wizard.WizardVerifierCircuit `gnark:",secret"` - // The extractor is not part of the circuit per se, but hold informations - // that is used to extract the public inputs from the the WizardVerifier. - // The extractor only needs to be provided during the definition of the - // circuit and is omitted during the assignment of the circuit. - extractor publicInput.FunctionalInputExtractor `gnark:"-"` // The functional public inputs are the "actual" statement made by the // circuit. They are not part of the public input of the circuit for // a number of reasons involving efficiency and simplicity in the aggregation @@ -45,7 +39,6 @@ func Allocate(zkevm *zkevm.ZkEvm) CircuitExecution { } return CircuitExecution{ WizardVerifier: *wverifier, - extractor: zkevm.PublicInput.Extractor, FuncInputs: FunctionalPublicInputSnark{ FunctionalPublicInputQSnark: FunctionalPublicInputQSnark{ L2MessageHashes: L2MessageHashes{ @@ -90,7 +83,6 @@ func (c *CircuitExecution) Define(api frontend.API) error { api, &c.WizardVerifier, c.FuncInputs, - c.extractor, ) // Add missing public input check diff --git a/prover/circuits/execution/pi_wizard_extraction.go b/prover/circuits/execution/pi_wizard_extraction.go index 2f14948ca..e564320c4 100644 --- a/prover/circuits/execution/pi_wizard_extraction.go +++ b/prover/circuits/execution/pi_wizard_extraction.go @@ -16,13 +16,12 @@ func checkPublicInputs( api frontend.API, wvc *wizard.WizardVerifierCircuit, gnarkFuncInp FunctionalPublicInputSnark, - wizardFuncInp publicInput.FunctionalInputExtractor, ) { var ( lastRollingHash = internal.CombineBytesIntoElements(api, gnarkFuncInp.FinalRollingHashUpdate) firstRollingHash = internal.CombineBytesIntoElements(api, gnarkFuncInp.InitialRollingHashUpdate) - execDataHash = execDataHash(api, wvc, wizardFuncInp) + execDataHash = execDataHash(api, wvc) ) // As we have this issue, the execDataHash will not match what we have in the @@ -32,7 +31,7 @@ func checkPublicInputs( shouldBeEqual(api, execDataHash, gnarkFuncInp.DataChecksum) api.AssertIsEqual( - wvc.GetLocalPointEvalParams(wizardFuncInp.L2MessageHash.ID).Y, + wvc.GetPublicInput(api, publicInput.L2MessageHash), // TODO: this operation is done a second time when computing the final // public input which is wasteful although not dramatic (~8000 unused // constraints) @@ -40,62 +39,62 @@ func checkPublicInputs( ) api.AssertIsEqual( - wvc.GetLocalPointEvalParams(wizardFuncInp.InitialStateRootHash.ID).Y, + wvc.GetPublicInput(api, publicInput.InitialStateRootHash), gnarkFuncInp.InitialStateRootHash, ) api.AssertIsEqual( - wvc.GetLocalPointEvalParams(wizardFuncInp.InitialBlockNumber.ID).Y, + wvc.GetPublicInput(api, publicInput.InitialBlockNumber), gnarkFuncInp.InitialBlockNumber, ) api.AssertIsEqual( - wvc.GetLocalPointEvalParams(wizardFuncInp.InitialBlockTimestamp.ID).Y, + wvc.GetPublicInput(api, publicInput.InitialBlockTimestamp), gnarkFuncInp.InitialBlockTimestamp, ) api.AssertIsEqual( - wvc.GetLocalPointEvalParams(wizardFuncInp.FirstRollingHashUpdate[0].ID).Y, + wvc.GetPublicInput(api, publicInput.FirstRollingHashUpdate_0), firstRollingHash[0], ) api.AssertIsEqual( - wvc.GetLocalPointEvalParams(wizardFuncInp.FirstRollingHashUpdate[1].ID).Y, + wvc.GetPublicInput(api, publicInput.FirstRollingHashUpdate_1), firstRollingHash[1], ) api.AssertIsEqual( - wvc.GetLocalPointEvalParams(wizardFuncInp.FirstRollingHashUpdateNumber.ID).Y, + wvc.GetPublicInput(api, publicInput.FirstRollingHashUpdateNumber), gnarkFuncInp.FirstRollingHashUpdateNumber, ) api.AssertIsEqual( - wvc.GetLocalPointEvalParams(wizardFuncInp.FinalStateRootHash.ID).Y, + wvc.GetPublicInput(api, publicInput.FinalStateRootHash), gnarkFuncInp.FinalStateRootHash, ) api.AssertIsEqual( - wvc.GetLocalPointEvalParams(wizardFuncInp.FinalBlockNumber.ID).Y, + wvc.GetPublicInput(api, publicInput.FinalBlockNumber), gnarkFuncInp.FinalBlockNumber, ) api.AssertIsEqual( - wvc.GetLocalPointEvalParams(wizardFuncInp.FinalBlockTimestamp.ID).Y, + wvc.GetPublicInput(api, publicInput.FinalBlockTimestamp), gnarkFuncInp.FinalBlockTimestamp, ) api.AssertIsEqual( - wvc.GetLocalPointEvalParams(wizardFuncInp.LastRollingHashUpdate[0].ID).Y, + wvc.GetPublicInput(api, publicInput.LastRollingHashUpdate_0), lastRollingHash[0], ) api.AssertIsEqual( - wvc.GetLocalPointEvalParams(wizardFuncInp.LastRollingHashUpdate[1].ID).Y, + wvc.GetPublicInput(api, publicInput.LastRollingHashUpdate_1), lastRollingHash[1], ) api.AssertIsEqual( - wvc.GetLocalPointEvalParams(wizardFuncInp.LastRollingHashUpdateNumber.ID).Y, + wvc.GetPublicInput(api, publicInput.LastRollingHashNumberUpdate), gnarkFuncInp.LastRollingHashUpdateNumber, ) @@ -107,9 +106,9 @@ func checkPublicInputs( bridgeAddress = api.Add( api.Mul( twoPow128, - wizardFuncInp.L2MessageServiceAddrHi.GetFrontendVariable(api, wvc), + wvc.GetPublicInput(api, publicInput.L2MessageServiceAddrHi), ), - wizardFuncInp.L2MessageServiceAddrLo.GetFrontendVariable(api, wvc), + wvc.GetPublicInput(api, publicInput.L2MessageServiceAddrLo), ) ) @@ -119,10 +118,10 @@ func checkPublicInputs( // chainID) then the traces will return a chainID of zero. api.AssertIsEqual( api.Mul( - wvc.GetLocalPointEvalParams(wizardFuncInp.ChainID.ID).Y, + wvc.GetPublicInput(api, publicInput.ChainID), api.Sub( api.Div( - wvc.GetLocalPointEvalParams(wizardFuncInp.ChainID.ID).Y, + wvc.GetPublicInput(api, publicInput.ChainID), twoPow112, ), gnarkFuncInp.ChainID, @@ -141,7 +140,6 @@ func checkPublicInputs( func execDataHash( api frontend.API, wvc *wizard.WizardVerifierCircuit, - wFuncInp publicInput.FunctionalInputExtractor, ) frontend.Variable { hsh, err := mimc.NewMiMC(api) @@ -150,8 +148,8 @@ func execDataHash( } hsh.Write( - wvc.GetLocalPointEvalParams(wFuncInp.DataNbBytes.ID).Y, - wvc.GetLocalPointEvalParams(wFuncInp.DataChecksum.ID).Y, + wvc.GetPublicInput(api, publicInput.DataNbBytes), + wvc.GetPublicInput(api, publicInput.DataChecksum), ) return hsh.Sum() diff --git a/prover/circuits/pi-interconnection/circuit.go b/prover/circuits/pi-interconnection/circuit.go index eb6e2953d..a19573995 100644 --- a/prover/circuits/pi-interconnection/circuit.go +++ b/prover/circuits/pi-interconnection/circuit.go @@ -98,8 +98,7 @@ func (c *Circuit) Define(api frontend.API) error { hshM hash.FieldHasher ) if c.UseGkrMimc { - hsh := gkrmimc.NewHasherFactory(api).NewHasher() - hshM = &hsh + hshM = gkrmimc.NewHasherFactory(api).NewHasher() } else { if hsh, err := mimc.NewMiMC(api); err != nil { return err diff --git a/prover/circuits/pi-interconnection/e2e_test.go b/prover/circuits/pi-interconnection/e2e_test.go index 5f559ecda..ca53f69ae 100644 --- a/prover/circuits/pi-interconnection/e2e_test.go +++ b/prover/circuits/pi-interconnection/e2e_test.go @@ -112,6 +112,7 @@ func TestTinyTwoBatchBlob(t *testing.T) { req := pi_interconnection.Request{ Decompressions: []blobsubmission.Response{*blobResp}, Executions: execReq, + DictPath: "../../lib/compressor/compressor_dict.bin", Aggregation: public_input.Aggregation{ FinalShnarf: blobResp.ExpectedShnarf, ParentAggregationFinalShnarf: blobReq.PrevShnarf, @@ -208,6 +209,7 @@ func TestTwoTwoBatchBlobs(t *testing.T) { req := pi_interconnection.Request{ Decompressions: []blobsubmission.Response{*blobResp0, *blobResp1}, Executions: execReq, + DictPath: "../../lib/compressor/compressor_dict.bin", Aggregation: public_input.Aggregation{ FinalShnarf: blobResp1.ExpectedShnarf, ParentAggregationFinalShnarf: blobReq0.PrevShnarf, diff --git a/prover/circuits/pi-interconnection/test_utils/test_utils.go b/prover/circuits/pi-interconnection/test_utils/test_utils.go index c678f1673..36f1162ed 100644 --- a/prover/circuits/pi-interconnection/test_utils/test_utils.go +++ b/prover/circuits/pi-interconnection/test_utils/test_utils.go @@ -48,6 +48,7 @@ func AssignSingleBlockBlob(t require.TestingT) pi_interconnection.Request { merkleRoots := aggregation.PackInMiniTrees(test_utils.BlocksToHex(execReq.L2MessageHashes)) return pi_interconnection.Request{ + DictPath: "../../lib/compressor/compressor_dict.bin", Decompressions: []blobsubmission.Response{*blobResp}, Executions: []public_input.Execution{execReq}, Aggregation: public_input.Aggregation{ diff --git a/prover/crypto/fiatshamir/fiatshamir.go b/prover/crypto/fiatshamir/fiatshamir.go index c9752d7ca..3a52bdca9 100644 --- a/prover/crypto/fiatshamir/fiatshamir.go +++ b/prover/crypto/fiatshamir/fiatshamir.go @@ -1,9 +1,9 @@ package fiatshamir import ( - "hash" "math" + "github.com/consensys/gnark-crypto/hash" "github.com/consensys/linea-monorepo/prover/crypto/mimc" "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors" "github.com/consensys/linea-monorepo/prover/maths/field" @@ -31,7 +31,7 @@ import ( // // https://blog.trailofbits.com/2022/04/18/the-frozen-heart-vulnerability-in-plonk/ type State struct { - hasher hash.Hash + hasher hash.StateStorer TranscriptSize int NumCoinGenerated int } @@ -39,10 +39,26 @@ type State struct { // NewMiMCFiatShamir constructs a fresh and empty Fiat-Shamir state. func NewMiMCFiatShamir() *State { return &State{ - hasher: mimc.NewMiMC(), + hasher: mimc.NewMiMC().(hash.StateStorer), } } +// State returns the internal state of the Fiat-Shamir hasher. Only works for +// MiMC. +func (s *State) State() []field.Element { + _ = s.hasher.Sum(nil) + b := s.hasher.State() + f := new(field.Element).SetBytes(b) + return []field.Element{*f} +} + +// SetState sets the fiat-shamir state to the requested value +func (s *State) SetState(f []field.Element) { + _ = s.hasher.Sum(nil) + b := f[0].Bytes() + s.hasher.SetState(b[:]) +} + // Update the Fiat-Shamir state with a one or more of field elements. The // function as no-op if the caller supplies no field elements. func (fs *State) Update(vec ...field.Element) { diff --git a/prover/crypto/fiatshamir/snark.go b/prover/crypto/fiatshamir/snark.go index 198549ed5..7d26df39c 100644 --- a/prover/crypto/fiatshamir/snark.go +++ b/prover/crypto/fiatshamir/snark.go @@ -19,7 +19,7 @@ import ( // of the verifier of a protocol calling [State] as it allows having a very // similar code for both tasks. type GnarkFiatShamir struct { - hasher hash.FieldHasher + hasher hash.StateStorer // pointer to the gnark-API (also passed to the hasher but behind an // interface). This is needed to perform bit-decomposition. api frontend.API @@ -30,10 +30,10 @@ type GnarkFiatShamir struct { // used in the scope of a [frontend.Define] function. func NewGnarkFiatShamir(api frontend.API, factory *gkrmimc.HasherFactory) *GnarkFiatShamir { - var hasher hash.FieldHasher + var hasher hash.StateStorer if factory != nil { h := factory.NewHasher() - hasher = &h + hasher = h } else { h, err := mimc.NewMiMC(api) if err != nil { @@ -52,6 +52,34 @@ func NewGnarkFiatShamir(api frontend.API, factory *gkrmimc.HasherFactory) *Gnark } } +// SetState mutates the fiat-shamir state of +func (fs *GnarkFiatShamir) SetState(state []frontend.Variable) { + + switch hsh := fs.hasher.(type) { + case interface { + SetState([]frontend.Variable) error + }: + if err := hsh.SetState(state); err != nil { + panic(err) + } + default: + panic("unexpected hasher type") + } +} + +// State mutates the fiat-shamir state of +func (fs *GnarkFiatShamir) State() []frontend.Variable { + + switch hsh := fs.hasher.(type) { + case interface { + State() []frontend.Variable + }: + return hsh.State() + default: + panic("unexpected hasher type") + } +} + // Update updates the Fiat-Shamir state with a vector of frontend.Variable // representing field element each. func (fs *GnarkFiatShamir) Update(vec ...frontend.Variable) { diff --git a/prover/crypto/mimc/gkrmimc/helper.go b/prover/crypto/mimc/gkrmimc/helper.go index ef7977097..dd748d9c2 100644 --- a/prover/crypto/mimc/gkrmimc/helper.go +++ b/prover/crypto/mimc/gkrmimc/helper.go @@ -1,6 +1,7 @@ package gkrmimc import ( + "errors" "math/big" "github.com/consensys/gnark/frontend" @@ -70,8 +71,8 @@ type Hasher struct { // and will provide the same results for the same usage. // // However, the hasher should not be used in deferred gnark circuit execution. -func (f *HasherFactory) NewHasher() Hasher { - return Hasher{factory: f, state: frontend.Variable(0)} +func (f *HasherFactory) NewHasher() *Hasher { + return &Hasher{factory: f, state: frontend.Variable(0)} } // Writes fields elements into the hasher; implements [hash.FieldHasher] @@ -107,6 +108,30 @@ func (h *Hasher) Sum() frontend.Variable { return curr } +// SetState manually sets the state of the hasher to the provided value. In the +// case of MiMC only a single frontend variable is expected to represent the +// state. +func (h *Hasher) SetState(newState []frontend.Variable) error { + + if len(h.data) > 0 { + return errors.New("the hasher is not in an initial state") + } + + if len(newState) != 1 { + return errors.New("the MiMC hasher expects a single field element to represent the state") + } + + h.state = newState[0] + return nil +} + +// State returns the inner-state of the hasher. In the context of MiMC only a +// single field element is returned. +func (h *Hasher) State() []frontend.Variable { + _ = h.Sum() // to flush the hasher + return []frontend.Variable{h.state} +} + // compress calls returns a frontend.Variable holding the result of applying // the compression function of MiMC over state and block. The alleged returned // result is pushed on the stack of all the claims to verify. diff --git a/prover/go.mod b/prover/go.mod index 65b10c575..d955410d0 100644 --- a/prover/go.mod +++ b/prover/go.mod @@ -5,10 +5,10 @@ go 1.22.7 toolchain go1.23.0 require ( - github.com/consensys/bavard v0.1.22 + github.com/consensys/bavard v0.1.24 github.com/consensys/compress v0.2.5 - github.com/consensys/gnark v0.11.1-0.20240910135928-e8cb61d0be1d - github.com/consensys/gnark-crypto v0.14.1-0.20241007145620-e26bbdf97a4a + github.com/consensys/gnark v0.11.1-0.20241217141116-f3d91999250b + github.com/consensys/gnark-crypto v0.14.1-0.20241217134352-810063550bd4 github.com/consensys/go-corset v0.0.0-20241125005324-5cb0c289c021 github.com/crate-crypto/go-kzg-4844 v1.1.0 github.com/dlclark/regexp2 v1.11.2 @@ -24,9 +24,9 @@ require ( github.com/spf13/cobra v1.8.1 github.com/spf13/viper v1.19.0 github.com/stretchr/testify v1.9.0 - golang.org/x/crypto v0.26.0 + golang.org/x/crypto v0.31.0 golang.org/x/net v0.27.0 - golang.org/x/sync v0.8.0 + golang.org/x/sync v0.10.0 golang.org/x/time v0.5.0 ) @@ -63,8 +63,7 @@ require ( github.com/holiman/bloomfilter/v2 v2.0.3 // indirect github.com/holiman/uint256 v1.3.1 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect - github.com/ingonyama-zk/icicle v1.1.0 // indirect - github.com/ingonyama-zk/iciclegnark v0.1.0 // indirect + github.com/ingonyama-zk/icicle/v3 v3.1.1-0.20241118092657-fccdb2f0921b // indirect github.com/klauspost/compress v1.17.7 // indirect github.com/kr/pretty v0.3.1 // indirect github.com/kr/text v0.2.0 // indirect @@ -98,7 +97,7 @@ require ( github.com/x448/float16 v0.8.4 // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect go.uber.org/multierr v1.11.0 // indirect - golang.org/x/text v0.17.0 // indirect + golang.org/x/text v0.21.0 // indirect google.golang.org/protobuf v1.34.2 // indirect gopkg.in/ini.v1 v1.67.0 // indirect rsc.io/tmplfunc v0.0.3 // indirect @@ -111,6 +110,6 @@ require ( github.com/pkg/profile v1.7.0 github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect golang.org/x/exp v0.0.0-20240823005443-9b4947da3948 - golang.org/x/sys v0.25.0 // indirect + golang.org/x/sys v0.28.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/prover/go.sum b/prover/go.sum index d7d3ca9b4..29ee7f512 100644 --- a/prover/go.sum +++ b/prover/go.sum @@ -92,14 +92,14 @@ github.com/cockroachdb/redact v1.1.5 h1:u1PMllDkdFfPWaNGMyLD1+so+aq3uUItthCFqzwP github.com/cockroachdb/redact v1.1.5/go.mod h1:BVNblN9mBWFyMyqK1k3AAiSxhvhfK2oOZZ2lK+dpvRg= github.com/cockroachdb/tokenbucket v0.0.0-20230807174530-cc333fc44b06 h1:zuQyyAKVxetITBuuhv3BI9cMrmStnpT18zmgmTxunpo= github.com/cockroachdb/tokenbucket v0.0.0-20230807174530-cc333fc44b06/go.mod h1:7nc4anLGjupUW/PeY5qiNYsdNXj7zopG+eqsS7To5IQ= -github.com/consensys/bavard v0.1.22 h1:Uw2CGvbXSZWhqK59X0VG/zOjpTFuOMcPLStrp1ihI0A= -github.com/consensys/bavard v0.1.22/go.mod h1:k/zVjHHC4B+PQy1Pg7fgvG3ALicQw540Crag8qx+dZs= +github.com/consensys/bavard v0.1.24 h1:Lfe+bjYbpaoT7K5JTFoMi5wo9V4REGLvQQbHmatoN2I= +github.com/consensys/bavard v0.1.24/go.mod h1:k/zVjHHC4B+PQy1Pg7fgvG3ALicQw540Crag8qx+dZs= github.com/consensys/compress v0.2.5 h1:gJr1hKzbOD36JFsF1AN8lfXz1yevnJi1YolffY19Ntk= github.com/consensys/compress v0.2.5/go.mod h1:pyM+ZXiNUh7/0+AUjUf9RKUM6vSH7T/fsn5LLS0j1Tk= -github.com/consensys/gnark v0.11.1-0.20240910135928-e8cb61d0be1d h1:TmNupI1+K5/LOg1K0kqEhRf5sZwRtxXah5iTHQ6fJvw= -github.com/consensys/gnark v0.11.1-0.20240910135928-e8cb61d0be1d/go.mod h1:f9CH911SPCrbSZp5z9LYzJ3rZvI7mOUzzf48lCZO/5o= -github.com/consensys/gnark-crypto v0.14.1-0.20241007145620-e26bbdf97a4a h1:yUHuYq+v1C3maTwnntLYhTDmboq3scSo1PQIl375/sE= -github.com/consensys/gnark-crypto v0.14.1-0.20241007145620-e26bbdf97a4a/go.mod h1:F/hJyWBcTr1sWeifAKfEN3aVb3G4U5zheEC8IbWQun4= +github.com/consensys/gnark v0.11.1-0.20241217141116-f3d91999250b h1:isTN/YOs57bOt0JlJHJ8gF8C3CdETU2Z9ao4y8R6qms= +github.com/consensys/gnark v0.11.1-0.20241217141116-f3d91999250b/go.mod h1:8YNyW/+XsYiLRzROLaj/PSktYO4VAdv6YW1b1P3UsZk= +github.com/consensys/gnark-crypto v0.14.1-0.20241217134352-810063550bd4 h1:Kp6egjRqKZf4469dfAWqFe6gi3MRs4VvNHmTfEjUlS8= +github.com/consensys/gnark-crypto v0.14.1-0.20241217134352-810063550bd4/go.mod h1:GMPeN3dUSslNBYJsK3WTjIGd3l0ccfMbcEh/d5knFrc= github.com/consensys/go-corset v0.0.0-20241125005324-5cb0c289c021 h1:zAPMHjY72pXmjuyb/niQ816pd+B9RAmZoL/W/f5uJSU= github.com/consensys/go-corset v0.0.0-20241125005324-5cb0c289c021/go.mod h1:J64guTfpmfXl4Yk2D7lsWdYg0ilP+N8JWPudP7+sZpA= github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= @@ -290,10 +290,8 @@ github.com/icza/mighty v0.0.0-20180919140131-cfd07d671de6/go.mod h1:xQig96I1VNBD github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= -github.com/ingonyama-zk/icicle v1.1.0 h1:a2MUIaF+1i4JY2Lnb961ZMvaC8GFs9GqZgSnd9e95C8= -github.com/ingonyama-zk/icicle v1.1.0/go.mod h1:kAK8/EoN7fUEmakzgZIYdWy1a2rBnpCaZLqSHwZWxEk= -github.com/ingonyama-zk/iciclegnark v0.1.0 h1:88MkEghzjQBMjrYRJFxZ9oR9CTIpB8NG2zLeCJSvXKQ= -github.com/ingonyama-zk/iciclegnark v0.1.0/go.mod h1:wz6+IpyHKs6UhMMoQpNqz1VY+ddfKqC/gRwR/64W6WU= +github.com/ingonyama-zk/icicle/v3 v3.1.1-0.20241118092657-fccdb2f0921b h1:AvQTK7l0PTHODD06PVQX1Tn2o29sRIaKIDOvTJmKurY= +github.com/ingonyama-zk/icicle/v3 v3.1.1-0.20241118092657-fccdb2f0921b/go.mod h1:e0JHb27/P6WorCJS3YolbY5XffS4PGBuoW38OthLkDs= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/json-iterator/go v1.1.11/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= @@ -504,8 +502,8 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw= -golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54= +golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= +golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -615,8 +613,8 @@ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= -golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= +golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -678,8 +676,8 @@ golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= -golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= +golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= @@ -694,8 +692,8 @@ golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= -golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc= -golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= +golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= diff --git a/prover/protocol/column/store.go b/prover/protocol/column/store.go index cef1772e0..7e669f86f 100644 --- a/prover/protocol/column/store.go +++ b/prover/protocol/column/store.go @@ -44,6 +44,11 @@ type storedColumnInfo struct { ID ifaces.ColID // Status of the commitment Status Status + // IncludeInProverFS states the prover should include the column in his FS + // transcript. This is used for columns that are recursed using + // FullRecursion. This field is only meaningfull for [Ignored] columns as + // they are excluded by default. + IncludeInProverFS bool } // AddToRound constructs a [Natural], registers it in the [Store] and returns @@ -444,3 +449,39 @@ func assertCorrectStatusTransition(old, new Status) { utils.Panic("attempted the transition %v -> %v, which is forbidden", old.String(), new.String()) } } + +// IgnoreButKeepInProverTranscript marks a column as ignored but also asks that +// the column stays included in the FS transcript. This is used as part of +// full-recursion where the commitments to an inner-proofs should not be sent to +// the verifier but should still play a part in the FS transcript. +func (s *Store) IgnoreButKeepInProverTranscript(colName ifaces.ColID) { + in := s.info(colName) + in.Status = Ignored + in.IncludeInProverFS = true +} + +// IsIgnoredAndNotKeptInTranscript indicates whether the column can be ignored +// from the transcript and is used during the Fiat-Shamir randomness generation. +func (s *Store) IsIgnoredAndNotKeptInTranscript(colName ifaces.ColID) bool { + in := s.info(colName) + return in.Status == Ignored && !in.IncludeInProverFS +} + +// AllKeysProofsOrIgnoredButKeptInProverTranscript returns the list of the +// columns to be used as part of the FS transcript. +func (s *Store) AllKeysProofsOrIgnoredButKeptInProverTranscript(round int) []ifaces.ColID { + res := []ifaces.ColID{} + rnd := s.byRounds.MustGet(round) // precomputed are always at round zero + + for i, info := range rnd { + + ok := (info.Status == Proof) || (info.Status == Ignored && info.IncludeInProverFS) + if !ok { + continue + } + + res = append(res, rnd[i].ID) + } + + return res +} diff --git a/prover/protocol/compiler/dummy/dummy_prover_level.go b/prover/protocol/compiler/dummy/dummy_prover_level.go index 9054b203d..14ff60ba3 100644 --- a/prover/protocol/compiler/dummy/dummy_prover_level.go +++ b/prover/protocol/compiler/dummy/dummy_prover_level.go @@ -22,8 +22,6 @@ import ( // suitable for established unit-tests where we want to analyze the errors. func CompileAtProverLvl(comp *wizard.CompiledIOP) { - comp.DummyCompiled = true - /* Registers all declared commitments and query parameters as messages in the same round. This steps is only relevant diff --git a/prover/protocol/compiler/fullrecursion/actions.go b/prover/protocol/compiler/fullrecursion/actions.go new file mode 100644 index 000000000..44927aa99 --- /dev/null +++ b/prover/protocol/compiler/fullrecursion/actions.go @@ -0,0 +1,207 @@ +package fullrecursion + +import ( + "fmt" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/linea-monorepo/prover/maths/field" + "github.com/consensys/linea-monorepo/prover/protocol/wizard" +) + +// CircuitAssignment is an implementation of [wizard.ProverAction]. As such, it +// embodies the action of assigning the full-recursion Plonk circuit columns. +type CircuitAssignment fullRecursionCtx + +// ConsistencyCheck is an implementation of [wizard.VerifierAction]. As such it +// is responsible for checking that the public inputs of the full-recursion +// Plonk circuit are assigned to values that are consistent with (1) the public +// inputs of the wrapping wizard protocol and with the inputs of the +// self-recursion wizard. +type ConsistencyCheck struct { + fullRecursionCtx + isSkipped bool +} + +// ReplacementAssignment is a [wizard.ProverAction] implementation. It assigns +// the queries and columns that are "replaced" in the wizard. In essence, this +// concerns the main grail polynomial evaluation (the grail query) and the +// Merkle roots assignment. These have to be replaced so that they can be +// refered to by the self-recursion. Otherwise, they would be swallowed by the +// recursion Plonk circuit. +type ReplacementAssignment fullRecursionCtx + +// LocalOpeningAssignment assigns the local openings made over the Plonk PI. +// These are needed in order to (1) perform the consistency check (2) replace +// the "old" and recursed public inputs of the original wizard by new ones. +type LocalOpeningAssignment fullRecursionCtx + +// ResetFsActions is a [wizard.FsHook] responsible for tweaking the FS state as +// required by the self-recursion process. +type ResetFsActions struct { + fullRecursionCtx + isSkipped bool +} + +func (c CircuitAssignment) Run(run *wizard.ProverRuntime) { + c.PlonkInWizard.ProverAction.Run(run, WitnessAssigner(c)) +} + +func (c ReplacementAssignment) Run(run *wizard.ProverRuntime) { + params := run.GetUnivariateParams(c.PolyQuery.QueryID) + run.AssignUnivariate(c.PolyQueryReplacement.QueryID, params.X, params.Ys...) + + oldRoots := c.PcsCtx.Items.MerkleRoots + for i := range c.MerkleRootsReplacement { + + if c.PcsCtx.Items.MerkleRoots[i] == nil { + continue + } + + run.AssignColumn( + c.MerkleRootsReplacement[i].GetColID(), + oldRoots[i].GetColAssignment(run), + ) + } +} + +func (c LocalOpeningAssignment) Run(run *wizard.ProverRuntime) { + for i := range c.LocalOpenings { + run.AssignLocalPoint( + c.LocalOpenings[i].ID, + c.PlonkInWizard.PI.GetColAssignmentAt(run, i), + ) + } +} + +func (c *ConsistencyCheck) Run(run *wizard.VerifierRuntime) error { + + var ( + initialFsCirc = run.GetLocalPointEvalParams(c.LocalOpenings[0].ID).Y + initialFsRt = run.FiatShamirHistory[c.FirstRound+1][0][0] + piCursor = 2 + ) + + if initialFsCirc != initialFsRt { + return fmt.Errorf("full recursion: the initial FS do not match") + } + + for i := range c.NonEmptyMerkleRootPositions { + + var ( + pos = c.NonEmptyMerkleRootPositions[i] + fromRt = c.MerkleRootsReplacement[pos].GetColAssignmentAt(run, 0) + fromCirc = run.GetLocalPointEvalParams(c.LocalOpenings[piCursor+i].ID).Y + ) + + if fromRt != fromCirc { + return fmt.Errorf("full recursion: the commitment does not match (pos: %v)", i) + } + } + + piCursor += len(c.NonEmptyMerkleRootPositions) + + var ( + paramsRt = run.GetUnivariateParams(c.PolyQueryReplacement.QueryID) + xRt = paramsRt.X + xCirc = run.GetLocalPointEvalParams(c.LocalOpenings[piCursor].ID).Y + ) + + if xRt != xCirc { + return fmt.Errorf("full recursion: the Ys does not match") + } + + piCursor++ + + for i := range paramsRt.Ys { + + var ( + fromRt = paramsRt.Ys[i] + fromCirc = run.GetLocalPointEvalParams(c.LocalOpenings[piCursor+i].ID).Y + ) + + if fromRt != fromCirc { + return fmt.Errorf("full recursion: the Ys does not match (pos: %v)", i) + } + } + + // The public inputs do not need to be checked because they are redefined in + // term of the local openings directly. So checking it would amount to checking + // that the local openings are equal to themselves. + + return nil +} + +func (c *ConsistencyCheck) RunGnark(api frontend.API, run *wizard.WizardVerifierCircuit) { + + var ( + initialFsCirc = run.GetLocalPointEvalParams(c.LocalOpenings[0].ID).Y + initialFsRt = run.FiatShamirHistory[c.FirstRound+1][0][0] + piCursor = 2 + ) + + api.AssertIsEqual(initialFsCirc, initialFsRt) + + for i := range c.NonEmptyMerkleRootPositions { + + var ( + pos = c.NonEmptyMerkleRootPositions[i] + fromRt = c.MerkleRootsReplacement[pos].GetColAssignmentGnarkAt(run, 0) + fromCirc = run.GetLocalPointEvalParams(c.LocalOpenings[piCursor+i].ID).Y + ) + + api.AssertIsEqual(fromRt, fromCirc) + } + + piCursor += len(c.NonEmptyMerkleRootPositions) + + var ( + paramsRt = run.GetUnivariateParams(c.PolyQueryReplacement.QueryID) + xRt = paramsRt.X + xCirc = run.GetLocalPointEvalParams(c.LocalOpenings[piCursor].ID).Y + ) + + api.AssertIsEqual(xRt, xCirc) + + piCursor++ + + for i := range paramsRt.Ys { + + var ( + fromRt = paramsRt.Ys[i] + fromCirc = run.GetLocalPointEvalParams(c.LocalOpenings[piCursor+i].ID).Y + ) + + api.AssertIsEqual(fromRt, fromCirc) + } + + // The public inputs do not need to be checked because they are redefined in + // term of the local openings directly. So checking it would amount to checking + // that the local openings are equal to themselves. +} + +func (c *ConsistencyCheck) Skip() { + c.isSkipped = true +} + +func (c *ConsistencyCheck) IsSkipped() bool { + return c.isSkipped +} + +func (r *ResetFsActions) Run(run *wizard.VerifierRuntime) error { + finalFsCirc := run.GetLocalPointEvalParams(r.LocalOpenings[1].ID).Y + run.FS.SetState([]field.Element{finalFsCirc}) + return nil +} + +func (r *ResetFsActions) RunGnark(api frontend.API, run *wizard.WizardVerifierCircuit) { + finalFsCirc := run.GetLocalPointEvalParams(r.LocalOpenings[1].ID).Y + run.FS.SetState([]frontend.Variable{finalFsCirc}) +} + +func (r *ResetFsActions) Skip() { + r.isSkipped = true +} + +func (r *ResetFsActions) IsSkipped() bool { + return r.isSkipped +} diff --git a/prover/protocol/compiler/fullrecursion/circuit.go b/prover/protocol/compiler/fullrecursion/circuit.go new file mode 100644 index 000000000..2151dfa7f --- /dev/null +++ b/prover/protocol/compiler/fullrecursion/circuit.go @@ -0,0 +1,255 @@ +package fullrecursion + +import ( + "fmt" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/frontend" + "github.com/consensys/linea-monorepo/prover/crypto/fiatshamir" + "github.com/consensys/linea-monorepo/prover/crypto/mimc/gkrmimc" + "github.com/consensys/linea-monorepo/prover/protocol/coin" + "github.com/consensys/linea-monorepo/prover/protocol/query" + "github.com/consensys/linea-monorepo/prover/protocol/wizard" +) + +type gnarkCircuit struct { + InitialFsState frontend.Variable `gnark:",public"` + FinalFsState frontend.Variable `gnark:",public"` + Commitments []frontend.Variable `gnark:",public"` + X frontend.Variable `gnark:",public"` + Ys []frontend.Variable `gnark:",public"` + Pubs []frontend.Variable `gnark:",public"` + WizardVerifier *wizard.WizardVerifierCircuit + comp *wizard.CompiledIOP `gnark:"-"` + ctx *fullRecursionCtx `gnark:"-"` + withoutGkr bool `gnark:"-"` +} + +func allocateGnarkCircuit(comp *wizard.CompiledIOP, ctx *fullRecursionCtx) *gnarkCircuit { + + var ( + wizardVerifier = wizard.NewWizardVerifierCircuit() + ) + + for round := range ctx.Columns { + for _, col := range ctx.Columns[round] { + wizardVerifier.AllocColumn(col.GetColID(), col.Size()) + } + } + + for round := range ctx.QueryParams { + for _, qInfoIface := range ctx.QueryParams[round] { + switch qInfo := qInfoIface.(type) { + case query.UnivariateEval: + wizardVerifier.AllocUnivariateEval(qInfo.QueryID, qInfo) + case query.InnerProduct: + wizardVerifier.AllocInnerProduct(qInfo.ID, qInfo) + case query.LocalOpening: + wizardVerifier.AllocLocalOpening(qInfo.ID, qInfo) + } + } + } + + wizardVerifier.Spec = comp + + return &gnarkCircuit{ + ctx: ctx, + comp: comp, + WizardVerifier: wizardVerifier, + Commitments: make([]frontend.Variable, len(ctx.NonEmptyMerkleRootPositions)), + Ys: make([]frontend.Variable, len(ctx.PolyQuery.Pols)), + Pubs: make([]frontend.Variable, len(comp.PublicInputs)), + } + +} + +func (c *gnarkCircuit) Define(api frontend.API) error { + + w := c.WizardVerifier + + if c.withoutGkr { + w.FS = fiatshamir.NewGnarkFiatShamir(api, nil) + } else { + w.HasherFactory = gkrmimc.NewHasherFactory(api) + w.FS = fiatshamir.NewGnarkFiatShamir(api, w.HasherFactory) + } + + w.FiatShamirHistory = make([][2][]frontend.Variable, c.comp.NumRounds()) + + c.generateAllRandomCoins(api) + + for round := 0; round <= c.ctx.LastRound; round++ { + roundSteps := c.ctx.VerifierActions[round] + for _, step := range roundSteps { + step.RunGnark(api, w) + } + } + + for i := range c.Pubs { + api.AssertIsEqual(c.Pubs[i], c.ctx.PublicInputs[i].Acc.GetFrontendVariable(api, w)) + } + + polyParams := w.GetUnivariateParams(c.ctx.PolyQuery.Name()) + + api.AssertIsEqual(c.X, polyParams.X) + + for i := range polyParams.Ys { + api.AssertIsEqual(c.Ys[i], polyParams.Ys[i]) + } + + for i := range c.Commitments { + pos := c.ctx.NonEmptyMerkleRootPositions[i] + api.AssertIsEqual( + c.Commitments[i], + w.GetColumn(c.ctx.PcsCtx.Items.MerkleRoots[pos].GetColID())[0], + ) + } + + return nil +} + +// generateAllRandomCoins is as [VerifierRuntime.generateAllRandomCoins]. Note +// that the function does create constraints via the hasher factory that is +// inside of `c.FS`. +func (c *gnarkCircuit) generateAllRandomCoins(api frontend.API) { + + var ( + ctx = c.ctx + w = c.WizardVerifier + ) + + w.FS.SetState([]frontend.Variable{c.InitialFsState}) + + for currRound := 0; currRound <= c.ctx.LastRound; currRound++ { + + initialState := w.FS.State() + + if currRound > 0 { + + toUpdateFS := ctx.Columns[currRound-1] + for _, msg := range toUpdateFS { + val := w.GetColumn(msg.GetColID()) + w.FS.UpdateVec(val) + } + + queries := ctx.QueryParams[currRound-1] + for _, q := range queries { + params := w.GetParams(q.Name()) + params.UpdateFS(w.FS) + } + } + + for _, info := range ctx.Coins[currRound] { + switch info.Type { + case coin.Field: + value := w.FS.RandomField() + w.Coins.InsertNew(info.Name, value) + case coin.IntegerVec: + value := w.FS.RandomManyIntegers(info.Size, info.UpperBound) + w.Coins.InsertNew(info.Name, value) + } + } + + for _, fsHook := range ctx.FsHooks[currRound] { + fsHook.RunGnark(api, w) + } + + w.FiatShamirHistory[currRound] = [2][]frontend.Variable{ + initialState, + w.FS.State(), + } + } + + api.AssertIsEqual(w.FS.State()[0], c.FinalFsState) +} + +// AssignGnarkCircuit returns an assignment for the gnark circuit +func AssignGnarkCircuit(ctx *fullRecursionCtx, comp *wizard.CompiledIOP, run *wizard.ProverRuntime) *gnarkCircuit { + + var ( + wizardVerifier = wizard.NewWizardVerifierCircuit() + ) + + for round := range ctx.Columns { + for _, col := range ctx.Columns[round] { + wizardVerifier.AssignColumn(col.GetColID(), col.GetColAssignment(run)) + } + } + + for round := range ctx.QueryParams { + for _, qInfoIface := range ctx.QueryParams[round] { + switch qInfo := qInfoIface.(type) { + case query.UnivariateEval: + params := run.GetUnivariateParams(qInfo.QueryID) + wizardVerifier.AssignUnivariateEval(qInfo.QueryID, params) + case query.InnerProduct: + params := run.GetInnerProductParams(qInfo.ID) + wizardVerifier.AssignInnerProduct(qInfo.ID, params) + case query.LocalOpening: + params := run.GetLocalPointEvalParams(qInfo.ID) + wizardVerifier.AssignLocalOpening(qInfo.ID, params) + } + } + } + + c := &gnarkCircuit{ + ctx: ctx, + comp: comp, + WizardVerifier: wizardVerifier, + Pubs: make([]frontend.Variable, len(comp.PublicInputs)), + Commitments: make([]frontend.Variable, len(ctx.NonEmptyMerkleRootPositions)), + // It is important we start from the begining because of the case where + // we stack several FullRecursion. In that case, the FsHooks are going + // to automatically set the FsState to the correct value at first round. + InitialFsState: run.FiatShamirHistory[1][0][0], + FinalFsState: run.FiatShamirHistory[ctx.LastRound][1][0], + } + + polyParams := run.GetUnivariateParams(ctx.PolyQuery.QueryID).GnarkAssign() + c.X = polyParams.X + c.Ys = polyParams.Ys + + for i := range c.Pubs { + c.Pubs[i] = comp.PublicInputs[i].Acc.GetVal(run) + } + + for i := range c.Commitments { + pos := ctx.NonEmptyMerkleRootPositions[i] + c.Commitments[i] = ctx.PcsCtx.Items.MerkleRoots[pos].GetColAssignmentAt(run, 0) + } + + return c +} + +// WitnessAssign is an implementation of the [plonk.WitnessAssigner] and is used to +// generate the assignment of the fullRecursion circuit. +type WitnessAssigner fullRecursionCtx + +func (w WitnessAssigner) NumEffWitnesses(_ *wizard.ProverRuntime) int { + return 1 +} + +func (w WitnessAssigner) Assign(run *wizard.ProverRuntime, i int) (private, public witness.Witness, err error) { + + if i > 0 { + panic("only a single witness for the full-recursion") + } + + var ( + ctx = fullRecursionCtx(w) + assignment = AssignGnarkCircuit(&ctx, w.Comp, run) + ) + + witness, err := frontend.NewWitness(assignment, ecc.BLS12_377.ScalarField()) + if err != nil { + return nil, nil, fmt.Errorf("new witness: %W", err) + } + + pubWitness, err := witness.Public() + if err != nil { + return nil, nil, fmt.Errorf("public witness: %w", err) + } + + return witness, pubWitness, nil +} diff --git a/prover/protocol/compiler/fullrecursion/full_recursion.go b/prover/protocol/compiler/fullrecursion/full_recursion.go new file mode 100644 index 000000000..9c8017ae8 --- /dev/null +++ b/prover/protocol/compiler/fullrecursion/full_recursion.go @@ -0,0 +1,230 @@ +package fullrecursion + +import ( + "strconv" + + "github.com/consensys/linea-monorepo/prover/protocol/accessors" + "github.com/consensys/linea-monorepo/prover/protocol/coin" + "github.com/consensys/linea-monorepo/prover/protocol/column" + "github.com/consensys/linea-monorepo/prover/protocol/compiler/selfrecursion" + "github.com/consensys/linea-monorepo/prover/protocol/compiler/vortex" + "github.com/consensys/linea-monorepo/prover/protocol/dedicated/plonk" + "github.com/consensys/linea-monorepo/prover/protocol/ifaces" + "github.com/consensys/linea-monorepo/prover/protocol/query" + "github.com/consensys/linea-monorepo/prover/protocol/wizard" + "github.com/consensys/linea-monorepo/prover/utils" +) + +// FullRecursion "recurses" the wizard protocol by wrapping all the verifier +// steps in a Plonk-in-Wizard context as well as all the Proof columns. The +// Vortex PCS verification is done via self-recursion. +func FullRecursion(withoutGkr bool) func(comp *wizard.CompiledIOP) { + + return func(comp *wizard.CompiledIOP) { + var ( + ctx = captureCtx(comp) + c = allocateGnarkCircuit(comp, ctx) + numPI = len(c.ctx.NonEmptyMerkleRootPositions) + + len(c.Pubs) + + len(c.Ys) + + 3 // (1.) for X (2.) for the initial FS state (3.) for the final state + funcPiOffset = 3 + len(ctx.NonEmptyMerkleRootPositions) + len(ctx.PolyQuery.Pols) + ) + + selfrecursion.SelfRecurse(comp) + + piw := plonk.PlonkCheck(comp, "full-recursion-"+strconv.Itoa(comp.SelfRecursionCount), ctx.LastRound, c, 1) + + ctx.PlonkInWizard.PI = piw.ConcatenatedTinyPIs(utils.NextPowerOfTwo(numPI)) + ctx.PlonkInWizard.ProverAction = piw.GetPlonkProverAction() + + for i := 0; i < numPI; i++ { + + var ( + pi = ctx.PlonkInWizard.PI + lo = comp.InsertLocalOpening( + ctx.PlonkInWizard.PI.Round(), + ifaces.QueryIDf("%v_LO_%v", pi.String(), i), + column.Shift(pi, i), + ) + ) + + ctx.LocalOpenings = append(ctx.LocalOpenings, lo) + } + + for i := range comp.PublicInputs { + comp.PublicInputs[i].Acc = accessors.NewLocalOpeningAccessor( + ctx.LocalOpenings[funcPiOffset+i], + ctx.PlonkInWizard.PI.Round(), + ) + } + + comp.FiatShamirHooks.AppendToInner(ctx.LastRound, &ResetFsActions{fullRecursionCtx: *ctx}) + comp.RegisterProverAction(ctx.LastRound, CircuitAssignment(*ctx)) + comp.RegisterProverAction(ctx.LastRound, ReplacementAssignment(*ctx)) + comp.RegisterProverAction(ctx.PlonkInWizard.PI.Round(), LocalOpeningAssignment(*ctx)) + comp.RegisterVerifierAction(ctx.PlonkInWizard.PI.Round(), &ConsistencyCheck{fullRecursionCtx: *ctx}) + } +} + +// fullRecursionCtx holds compilation context informations about the wizard +// protocol being compiled by a FullRecursion routine. +type fullRecursionCtx struct { + // A pointer to the compiled-IOP over which the compilation step has run + Comp *wizard.CompiledIOP + // The Vortex compilation context + PcsCtx *vortex.Ctx + PublicInputs []wizard.PublicInput + PolyQuery query.UnivariateEval + PolyQueryReplacement query.UnivariateEval + MerkleRootsReplacement []ifaces.Column + NonEmptyMerkleRootPositions []int + FirstRound, LastRound int + QueryParams [][]ifaces.Query + Columns [][]ifaces.Column + VerifierActions [][]wizard.VerifierAction + Coins [][]coin.Info + FsHooks [][]wizard.VerifierAction + PlonkInWizard struct { + ProverAction plonk.PlonkInWizardProverAction + PI ifaces.Column + } + LocalOpenings []query.LocalOpening +} + +// captureCtx scans the content of comp to store the compilation infos of the +// CompiledIOP at the beginning of the compilation. +func captureCtx(comp *wizard.CompiledIOP) *fullRecursionCtx { + + var ( + polyQuery = comp.PcsCtxs.(*vortex.Ctx).Query + lastRound = comp.QueriesParams.Round(polyQuery.QueryID) + ctx = &fullRecursionCtx{ + Comp: comp, + PcsCtx: comp.PcsCtxs.(*vortex.Ctx), + PolyQuery: polyQuery, + LastRound: lastRound, + FirstRound: lastRound, + PublicInputs: append([]wizard.PublicInput{}, comp.PublicInputs...), + } + ) + + for round := 0; round <= lastRound; round++ { + + ctx.QueryParams = append(ctx.QueryParams, []ifaces.Query{}) + ctx.Columns = append(ctx.Columns, []ifaces.Column{}) + ctx.VerifierActions = append(ctx.VerifierActions, []wizard.VerifierAction{}) + ctx.Coins = append(ctx.Coins, []coin.Info{}) + ctx.FsHooks = append(ctx.FsHooks, []wizard.VerifierAction{}) + + for _, colName := range comp.Columns.AllKeysAt(round) { + + // filter the columns by status + var ( + status = comp.Columns.Status(colName) + col = comp.Columns.GetHandle(colName) + ) + + if !status.IsPublic() { + // the column is not public so it is not part of the proof + continue + } + + if status == column.VerifyingKey { + // these are constant columns + continue + } + + ctx.FirstRound = min(ctx.FirstRound, round) + ctx.Columns[round] = append(ctx.Columns[round], col) + comp.Columns.IgnoreButKeepInProverTranscript(colName) + } + + for _, qName := range comp.QueriesParams.AllKeysAt(round) { + + if comp.QueriesParams.IsSkippedFromVerifierTranscript(qName) { + continue + } + + // Not that we do not filter the already compiled queries + qInfo := comp.QueriesParams.Data(qName) + ctx.QueryParams[round] = append(ctx.QueryParams[round], qInfo) + comp.QueriesParams.MarkAsSkippedFromVerifierTranscript(qName) + } + + for _, cname := range comp.Coins.AllKeysAt(round) { + + if comp.Coins.IsSkippedFromVerifierTranscript(cname) { + continue + } + + coin := comp.Coins.Data(cname) + ctx.Coins[round] = append(ctx.Coins[round], coin) + comp.Coins.MarkAsSkippedFromVerifierTranscript(cname) + } + + verifierActions := comp.SubVerifiers.Inner() + + for i := range verifierActions[round] { + + va := verifierActions[round][i] + if va.IsSkipped() { + continue + } + + ctx.VerifierActions[round] = append(ctx.VerifierActions[round], va) + va.Skip() + } + + if comp.FiatShamirHooks.Len() > round { + resetFs := comp.FiatShamirHooks.Inner()[round] + for i := range resetFs { + + fsHook := resetFs[i] + if fsHook.IsSkipped() { + continue + } + + ctx.FsHooks[round] = append(ctx.VerifierActions[round], fsHook) + fsHook.Skip() + } + } + } + + comp.QueriesParams.MarkAsSkippedFromProverTranscript(polyQuery.QueryID) + + ctx.PcsCtx.IsSelfrecursed = true + + pcsCtxReplacement := *ctx.PcsCtx + pcsCtxReplacement.Items.MerkleRoots = make([]ifaces.Column, len(pcsCtxReplacement.Items.MerkleRoots)) + + for i := range pcsCtxReplacement.Items.MerkleRoots { + + if ctx.PcsCtx.Items.MerkleRoots[i] == nil { + continue + } + + ctx.NonEmptyMerkleRootPositions = append(ctx.NonEmptyMerkleRootPositions, i) + pcsCtxReplacement.Items.MerkleRoots[i] = comp.InsertProof( + ctx.LastRound, + ctx.PcsCtx.Items.MerkleRoots[i].GetColID()+"_REPLACEMENT", + 1, + ) + } + + ctx.MerkleRootsReplacement = pcsCtxReplacement.Items.MerkleRoots + + comp.PcsCtxs = &pcsCtxReplacement + newPolyQuery := comp.InsertUnivariate( + lastRound, + polyQuery.QueryID+"_REPLACEMENT", + polyQuery.Pols, + ) + + comp.QueriesParams.MarkAsIgnored(newPolyQuery.QueryID) + + ctx.PolyQueryReplacement = newPolyQuery + pcsCtxReplacement.Query = ctx.PolyQueryReplacement + + return ctx +} diff --git a/prover/protocol/compiler/fullrecursion/full_recursion_test.go b/prover/protocol/compiler/fullrecursion/full_recursion_test.go new file mode 100644 index 000000000..5fd62cd43 --- /dev/null +++ b/prover/protocol/compiler/fullrecursion/full_recursion_test.go @@ -0,0 +1,105 @@ +//go:build !fuzzlight + +package fullrecursion_test + +import ( + "fmt" + "testing" + + "github.com/consensys/linea-monorepo/prover/crypto/ringsis" + "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors" + "github.com/consensys/linea-monorepo/prover/protocol/compiler/dummy" + "github.com/consensys/linea-monorepo/prover/protocol/compiler/fullrecursion" + "github.com/consensys/linea-monorepo/prover/protocol/compiler/globalcs" + "github.com/consensys/linea-monorepo/prover/protocol/compiler/innerproduct" + "github.com/consensys/linea-monorepo/prover/protocol/compiler/localcs" + "github.com/consensys/linea-monorepo/prover/protocol/compiler/lookup" + "github.com/consensys/linea-monorepo/prover/protocol/compiler/mimc" + "github.com/consensys/linea-monorepo/prover/protocol/compiler/permutation" + "github.com/consensys/linea-monorepo/prover/protocol/compiler/specialqueries" + "github.com/consensys/linea-monorepo/prover/protocol/compiler/splitter" + "github.com/consensys/linea-monorepo/prover/protocol/compiler/splitter/sticker" + "github.com/consensys/linea-monorepo/prover/protocol/compiler/univariates" + "github.com/consensys/linea-monorepo/prover/protocol/compiler/vortex" + "github.com/consensys/linea-monorepo/prover/protocol/ifaces" + "github.com/consensys/linea-monorepo/prover/protocol/wizard" + "github.com/sirupsen/logrus" +) + +func TestLookup(t *testing.T) { + + logrus.SetLevel(logrus.FatalLevel) + + define := func(bui *wizard.Builder) { + + var ( + a = bui.RegisterCommit("A", 8) + b = bui.RegisterCommit("B", 8) + ) + + bui.Inclusion("Q", []ifaces.Column{a}, []ifaces.Column{b}) + } + + prove := func(run *wizard.ProverRuntime) { + run.AssignColumn("A", smartvectors.ForTest(1, 2, 3, 4, 5, 6, 7, 8)) + run.AssignColumn("B", smartvectors.ForTest(1, 2, 3, 4, 5, 6, 7, 8)) + } + + suites := [][]func(*wizard.CompiledIOP){ + { + lookup.CompileLogDerivative, + localcs.Compile, + globalcs.Compile, + univariates.CompileLocalOpening, + univariates.Naturalize, + univariates.MultiPointToSinglePoint(8), + vortex.Compile(2, vortex.ForceNumOpenedColumns(4), vortex.WithSISParams(&ringsis.StdParams)), + fullrecursion.FullRecursion(true), + dummy.CompileAtProverLvl, + }, + { + lookup.CompileLogDerivative, + localcs.Compile, + globalcs.Compile, + univariates.CompileLocalOpening, + univariates.Naturalize, + univariates.MultiPointToSinglePoint(8), + vortex.Compile(2, vortex.ForceNumOpenedColumns(4), vortex.WithSISParams(&ringsis.StdParams)), + fullrecursion.FullRecursion(true), + mimc.CompileMiMC, + specialqueries.RangeProof, + lookup.CompileLogDerivative, + specialqueries.CompileFixedPermutations, + permutation.CompileGrandProduct, + innerproduct.Compile, + sticker.Sticker(1<<8, 1<<16), + splitter.SplitColumns(1 << 16), + localcs.Compile, + globalcs.Compile, + univariates.CompileLocalOpening, + univariates.Naturalize, + univariates.MultiPointToSinglePoint(1 << 16), + vortex.Compile(2, vortex.ForceNumOpenedColumns(4), vortex.WithSISParams(&ringsis.StdParams)), + fullrecursion.FullRecursion(true), + dummy.CompileAtProverLvl, + }, + } + + for i, s := range suites { + + t.Run(fmt.Sprintf("case-%v", i), func(t *testing.T) { + + comp := wizard.Compile( + define, + s..., + ) + + proof := wizard.Prove(comp, prove) + + if err := wizard.Verify(comp, proof); err != nil { + t.Fatalf("verifier failed: %v", err) + } + }) + + } +} diff --git a/prover/protocol/compiler/globalcs/compile.go b/prover/protocol/compiler/globalcs/compile.go index 11dc45494..bf7f6b436 100644 --- a/prover/protocol/compiler/globalcs/compile.go +++ b/prover/protocol/compiler/globalcs/compile.go @@ -41,7 +41,7 @@ func Compile(comp *wizard.CompiledIOP) { comp.RegisterProverAction(quotientRound, "ientCtx) comp.RegisterProverAction(evaluationRound, evaluationProver(evaluationCtx)) - comp.RegisterVerifierAction(evaluationRound, evaluationVerifier(evaluationCtx)) + comp.RegisterVerifierAction(evaluationRound, &evaluationVerifier{evaluationCtx: evaluationCtx}) } diff --git a/prover/protocol/compiler/globalcs/evaluation.go b/prover/protocol/compiler/globalcs/evaluation.go index 4a9a03eeb..b612b78ec 100644 --- a/prover/protocol/compiler/globalcs/evaluation.go +++ b/prover/protocol/compiler/globalcs/evaluation.go @@ -37,7 +37,10 @@ type evaluationProver evaluationCtx // evaluationVerifier wraps [evaluationCtx] to implement the [wizard.VerifierAction] // interface. -type evaluationVerifier evaluationCtx +type evaluationVerifier struct { + evaluationCtx + skipped bool +} // declareUnivariateQueries declares the univariate queries over all the quotient // shares, making sure that the shares needing to be evaluated over the same @@ -162,7 +165,7 @@ func (pa evaluationProver) Run(run *wizard.ProverRuntime) { } // Run evaluate the constraint and checks that -func (ctx evaluationVerifier) Run(run *wizard.VerifierRuntime) error { +func (ctx *evaluationVerifier) Run(run *wizard.VerifierRuntime) error { var ( // Will be assigned to "X", the random point at which we check the constraint. @@ -236,7 +239,7 @@ func (ctx evaluationVerifier) Run(run *wizard.VerifierRuntime) error { } // Verifier step, evaluate the constraint and checks that -func (ctx evaluationVerifier) RunGnark(api frontend.API, c *wizard.WizardVerifierCircuit) { +func (ctx *evaluationVerifier) RunGnark(api frontend.API, c *wizard.WizardVerifierCircuit) { // Will be assigned to "X", the random point at which we check the constraint. r := c.GetRandomCoinField(ctx.EvalCoin.Name) @@ -463,3 +466,11 @@ func (ctx evaluationVerifier) recombineQuotientSharesEvaluationGnark(api fronten return recombinedYs } + +func (ctx *evaluationVerifier) Skip() { + ctx.skipped = true +} + +func (ctx *evaluationVerifier) IsSkipped() bool { + return ctx.skipped +} diff --git a/prover/protocol/compiler/innerproduct/verifier.go b/prover/protocol/compiler/innerproduct/verifier.go index ec01a9016..accc5acd0 100644 --- a/prover/protocol/compiler/innerproduct/verifier.go +++ b/prover/protocol/compiler/innerproduct/verifier.go @@ -20,6 +20,7 @@ type verifierForSize struct { SummationOpening query.LocalOpening // BatchOpening is the challenge used for the linear combination BatchOpening coin.Info + skipped bool } // Run implements [wizard.VerifierAction] @@ -87,3 +88,11 @@ func (v *verifierForSize) RunGnark(api frontend.API, run *wizard.WizardVerifierC api.AssertIsEqual(expected, actual) } + +func (v *verifierForSize) Skip() { + v.skipped = true +} + +func (v *verifierForSize) IsSkipped() bool { + return v.skipped +} diff --git a/prover/protocol/compiler/lookup/compiler.go b/prover/protocol/compiler/lookup/compiler.go index 6b4f8fea6..8bef9261e 100644 --- a/prover/protocol/compiler/lookup/compiler.go +++ b/prover/protocol/compiler/lookup/compiler.go @@ -36,7 +36,7 @@ func CompileLogDerivative(comp *wizard.CompiledIOP) { zCatalog = map[[2]int]*zCtx{} zEntries = [][2]int{} // verifier actions - va = finalEvaluationCheck{} + va = &finalEvaluationCheck{} ) // Skip the compilation phase if no lookup constraint is being used. Otherwise @@ -116,7 +116,7 @@ func CompileLogDerivative(comp *wizard.CompiledIOP) { } } - comp.RegisterVerifierAction(lastRound, &va) + comp.RegisterVerifierAction(lastRound, va) } // captureLookupTables inspects comp and look for Inclusion queries that are not diff --git a/prover/protocol/compiler/lookup/verifier.go b/prover/protocol/compiler/lookup/verifier.go index b5179574b..be3219d08 100644 --- a/prover/protocol/compiler/lookup/verifier.go +++ b/prover/protocol/compiler/lookup/verifier.go @@ -21,6 +21,7 @@ type finalEvaluationCheck struct { Name string // ZOpenings lists all the openings of all the zCtx ZOpenings []query.LocalOpening + skipped bool } // Run implements the [wizard.VerifierAction] @@ -54,3 +55,11 @@ func (f *finalEvaluationCheck) RunGnark(api frontend.API, run *wizard.WizardVeri api.AssertIsEqual(zSum, 0) } + +func (f *finalEvaluationCheck) Skip() { + f.skipped = true +} + +func (f *finalEvaluationCheck) IsSkipped() bool { + return f.skipped +} diff --git a/prover/protocol/compiler/permutation/compiler.go b/prover/protocol/compiler/permutation/compiler.go index 1c8320dba..7460ce9ea 100644 --- a/prover/protocol/compiler/permutation/compiler.go +++ b/prover/protocol/compiler/permutation/compiler.go @@ -49,7 +49,7 @@ func CompileGrandProduct(comp *wizard.CompiledIOP) { for round := range allProverActions { if len(allProverActions[round]) > 0 { comp.RegisterProverAction(round, allProverActions[round]) - comp.RegisterVerifierAction(round, VerifierCtx(allProverActions[round])) + comp.RegisterVerifierAction(round, &VerifierCtx{Ctxs: allProverActions[round]}) } } diff --git a/prover/protocol/compiler/permutation/verifier.go b/prover/protocol/compiler/permutation/verifier.go index 1fde7e51c..99924d329 100644 --- a/prover/protocol/compiler/permutation/verifier.go +++ b/prover/protocol/compiler/permutation/verifier.go @@ -11,15 +11,18 @@ import ( // The verifier gets all the query openings and multiple them together and // expect them to be one. It is represented by an array of ZCtx holding for // the same round. (we have the guarantee that they come from the same query). -type VerifierCtx []*ZCtx +type VerifierCtx struct { + Ctxs []*ZCtx + skipped bool +} // Run implements the [wizard.VerifierAction] interface and checks that the // product of the products given by the ZCtx is equal to one. -func (v VerifierCtx) Run(run *wizard.VerifierRuntime) error { +func (v *VerifierCtx) Run(run *wizard.VerifierRuntime) error { mustBeOne := field.One() - for _, zCtx := range v { + for _, zCtx := range v.Ctxs { for _, opening := range zCtx.ZOpenings { y := run.GetLocalPointEvalParams(opening.ID).Y mustBeOne.Mul(&mustBeOne, &y) @@ -35,11 +38,11 @@ func (v VerifierCtx) Run(run *wizard.VerifierRuntime) error { // Run implements the [wizard.VerifierAction] interface and is as // [VerifierCtx.Run] but in the context of a gnark circuit. -func (v VerifierCtx) RunGnark(api frontend.API, run *wizard.WizardVerifierCircuit) { +func (v *VerifierCtx) RunGnark(api frontend.API, run *wizard.WizardVerifierCircuit) { mustBeOne := frontend.Variable(1) - for _, zCtx := range v { + for _, zCtx := range v.Ctxs { for _, opening := range zCtx.ZOpenings { y := run.GetLocalPointEvalParams(opening.ID).Y mustBeOne = api.Mul(mustBeOne, y) @@ -48,3 +51,11 @@ func (v VerifierCtx) RunGnark(api frontend.API, run *wizard.WizardVerifierCircui api.AssertIsEqual(mustBeOne, frontend.Variable(1)) } + +func (v *VerifierCtx) Skip() { + v.skipped = true +} + +func (v *VerifierCtx) IsSkipped() bool { + return v.skipped +} diff --git a/prover/protocol/compiler/selfrecursion/context.go b/prover/protocol/compiler/selfrecursion/context.go index 13156815b..5226926f9 100644 --- a/prover/protocol/compiler/selfrecursion/context.go +++ b/prover/protocol/compiler/selfrecursion/context.go @@ -284,10 +284,10 @@ func NewSelfRecursionCxt(comp *wizard.CompiledIOP) SelfRecursionCtx { func assertVortexCompiled(comp *wizard.CompiledIOP) *vortex.Ctx { // When we compiled using Vortex, we annotated the compiledIOP // that the current protocol was a result of the - ctx := comp.CryptographicCompilerCtx + ctx := comp.PcsCtxs // Take ownership of the vortex context - comp.CryptographicCompilerCtx = nil + comp.PcsCtxs = nil // Check for non-nilness if ctx == nil { diff --git a/prover/protocol/compiler/splitter/sticker/sticker.go b/prover/protocol/compiler/splitter/sticker/sticker.go index 84b87aacc..e90ec7ff4 100644 --- a/prover/protocol/compiler/splitter/sticker/sticker.go +++ b/prover/protocol/compiler/splitter/sticker/sticker.go @@ -405,7 +405,7 @@ func (ctx *stickContext) compileFixedEvaluation() { // Filters out only the q, ok := ctx.comp.QueriesParams.Data(qName).(query.LocalOpening) if !ok { - utils.Panic("got an uncompilable query %v", qName) + utils.Panic("got an uncompilable query name=%v type=%T", qName, q) } // Assumption, the query is not over an interleaved column diff --git a/prover/protocol/compiler/vortex/compiler.go b/prover/protocol/compiler/vortex/compiler.go index 26c3c7868..7aa626bd3 100644 --- a/prover/protocol/compiler/vortex/compiler.go +++ b/prover/protocol/compiler/vortex/compiler.go @@ -61,7 +61,7 @@ func Compile(blowUpFactor int, options ...VortexOp) func(*wizard.CompiledIOP) { lastRound := comp.NumRounds() - 1 // Stores a pointer to the cryptographic compiler of Vortex - comp.CryptographicCompilerCtx = &ctx + comp.PcsCtxs = &ctx // Converts the precomputed as verifying key (e.g. send // them to the verifier) in the offline phase if the @@ -86,6 +86,10 @@ func Compile(blowUpFactor int, options ...VortexOp) func(*wizard.CompiledIOP) { // Registers the prover and verifier steps comp.SubProvers.AppendToInner(lastRound+1, ctx.ComputeLinearComb) comp.SubProvers.AppendToInner(lastRound+2, ctx.OpenSelectedColumns) + // This is separated from GnarkVerify because, when doing full-recursion + // , we want to recurse this verifier step but not [ctx.Verify] which is + // already handled by the self-recursion mechanism. + comp.InsertVerifier(lastRound, ctx.explicitPublicEvaluation, ctx.gnarkExplicitPublicEvaluation) comp.InsertVerifier(lastRound+2, ctx.Verify, ctx.GnarkVerify) } } diff --git a/prover/protocol/compiler/vortex/gnark_verifier.go b/prover/protocol/compiler/vortex/gnark_verifier.go index cf217bcec..430a15604 100644 --- a/prover/protocol/compiler/vortex/gnark_verifier.go +++ b/prover/protocol/compiler/vortex/gnark_verifier.go @@ -15,8 +15,6 @@ import ( ) func (ctx *Ctx) GnarkVerify(api frontend.API, vr *wizard.WizardVerifierCircuit) { - // Evaluate explicitly the public columns - ctx.gnarkExplicitPublicEvaluation(api, vr) // The skip verification flag may be on, if the current vortex // context get self-recursed. In this case, the verifier does @@ -64,7 +62,7 @@ func (ctx *Ctx) GnarkVerify(api frontend.API, vr *wizard.WizardVerifierCircuit) // function that will defer the hashing to gkr factoryHasherFunc := func(_ frontend.API) (hash.FieldHasher, error) { h := vr.HasherFactory.NewHasher() - return &h, nil + return h, nil } packedMProofs := vr.GetColumn(ctx.MerkleProofName()) @@ -93,7 +91,7 @@ func (ctx *Ctx) GnarkVerify(api frontend.API, vr *wizard.WizardVerifierCircuit) } // returns the Ys as a vector -func (ctx *Ctx) gnarkGetYs(api frontend.API, vr *wizard.WizardVerifierCircuit) (ys [][]frontend.Variable) { +func (ctx *Ctx) gnarkGetYs(_ frontend.API, vr *wizard.WizardVerifierCircuit) (ys [][]frontend.Variable) { query := ctx.Query params := vr.GetUnivariateParams(ctx.Query.QueryID) diff --git a/prover/protocol/compiler/vortex/verifier.go b/prover/protocol/compiler/vortex/verifier.go index 50ffc5205..1b27c1779 100644 --- a/prover/protocol/compiler/vortex/verifier.go +++ b/prover/protocol/compiler/vortex/verifier.go @@ -15,11 +15,6 @@ import ( func (ctx *Ctx) Verify(vr *wizard.VerifierRuntime) error { - // Evaluate explicitly the public columns - if err := ctx.explicitPublicEvaluation(vr); err != nil { - return err - } - // The skip verification flag may be on, if the current vortex // context get self-recursed. In this case, the verifier does // not need to do anything diff --git a/prover/protocol/dedicated/plonk/alignment.go b/prover/protocol/dedicated/plonk/alignment.go index 975667373..59ef0deae 100644 --- a/prover/protocol/dedicated/plonk/alignment.go +++ b/prover/protocol/dedicated/plonk/alignment.go @@ -171,6 +171,8 @@ func (ci *CircuitAlignmentInput) Assign(run *wizard.ProverRuntime, i int) (priva return ci.witnesses[i], ci.witnesses[i], nil } +// NumEffWitnesses returns the effective number of Plonk witnesses that are +// collected from the assignment of the AlignmentModule. func (ci *CircuitAlignmentInput) NumEffWitnesses(run *wizard.ProverRuntime) int { ci.prepareWitnesses(run) return ci.numEffWitnesses @@ -258,6 +260,9 @@ func DefineAlignment(comp *wizard.CompiledIOP, toAlign *CircuitAlignmentInput) * return res } +// csIsActive adds the cosntraints ensuring that the [Alignment.IsActive] column +// is well-formed. Namely, that this is a sequence of 1s followed by a sequence +// of 0s. func (a *Alignment) csIsActive(comp *wizard.CompiledIOP) { // IsActive is binary column comp.InsertGlobal(a.Round, ifaces.QueryIDf("%v_IS_ACTIVE_BINARY", a.Name), symbolic.Mul(a.IsActive, symbolic.Sub(a.IsActive, 1))) @@ -265,10 +270,16 @@ func (a *Alignment) csIsActive(comp *wizard.CompiledIOP) { comp.InsertGlobal(a.Round, ifaces.QueryIDf("%v_IS_ACTIVE_SWITCH", a.Name), symbolic.Sub(a.IsActive, symbolic.Mul(a.IsActive, column.Shift(a.IsActive, -1)))) } +// csProjection ensures the data in the [Alignment.Data] column is the same as +// the data provided by the [Alignment.CircuitInput]. func (a *Alignment) csProjection(comp *wizard.CompiledIOP) { projection.InsertProjection(comp, ifaces.QueryIDf("%v_PROJECTION", a.Name), []ifaces.Column{a.DataToCircuit}, []ifaces.Column{a.CircuitInput}, a.DataToCircuitMask, a.ActualCircuitInputMask) } +// csProjectionSelector constraints that the projection selection +// [Alignment.ActualCircuitInputMask] is well-formed. This ensures that the +// imported data are correctly imported "in-front" of the public inputs of the +// Plonk. func (a *Alignment) csProjectionSelector(comp *wizard.CompiledIOP) { // ACTUAL_PI_MASK = IS_ACTIVE * STATIC_PI_MASK comp.InsertGlobal(a.Round, ifaces.QueryIDf("%v_ACTUAL_SUBSET", a.Name), symbolic.Sub(a.ActualCircuitInputMask, symbolic.Mul(a.IsActive, a.FullCircuitInputMask))) @@ -281,6 +292,8 @@ func (a *Alignment) Assign(run *wizard.ProverRuntime) { a.assignCircMaskOpenings(run) } +// assignMasks assigns the [Alignment.IsActive] and the [Alignment.ActualCircuitInputMask] +// into `run`. func (a *Alignment) assignMasks(run *wizard.ProverRuntime) { // we want to assign IS_ACTIVE and ACTUAL_MASK columns. We can construct // them at the same time from the precomputed mask and selector. @@ -320,7 +333,7 @@ func (a *Alignment) assignMasks(run *wizard.ProverRuntime) { run.AssignColumn(a.ActualCircuitInputMask.GetColID(), smartvectors.NewRegular(actualCircMaskAssignment)) } -// assignCircMaskOpenings assigns the openings queries over [actualCircMaskAssignment] +// assignCircMaskOpenings assigns the openings queries over the actualCircMaskAssignment func (a *Alignment) assignCircMaskOpenings(run *wizard.ProverRuntime) { for i := range a.circMaskOpenings { v := a.circMaskOpenings[i].Pol.GetColAssignmentAt(run, 0) @@ -328,7 +341,8 @@ func (a *Alignment) assignCircMaskOpenings(run *wizard.ProverRuntime) { } } -// getCircuitMaskValue returns the +// getCircuitMaskValue returns the static assignment of the precomputed columns +// to be assigned to [Alignment.FullCircuitInputMask]. func getCircuitMaskValue(nbPublicInputPerCircuit, nbCircuitInstance int) smartvectors.SmartVector { var ( @@ -345,7 +359,8 @@ func getCircuitMaskValue(nbPublicInputPerCircuit, nbCircuitInstance int) smartve return smartvectors.NewRegular(maskValue) } -// check the activators are well-set w.r.t to the circuit mask column +// checkActivators adds the constraints checking the activators are well-set w.r.t +// to the circuit mask column. See [compilationCtx.Columns.Activators]. func (ci *Alignment) checkActivators(comp *wizard.CompiledIOP) { var ( @@ -366,12 +381,17 @@ func (ci *Alignment) checkActivators(comp *wizard.CompiledIOP) { ci.circMaskOpenings = openings - comp.RegisterVerifierAction(ci.Round, checkActivatorAndMask(*ci)) + comp.RegisterVerifierAction(ci.Round, &checkActivatorAndMask{Alignment: *ci}) } -type checkActivatorAndMask Alignment +// checkActivatorAndMask is an implementation of [wizard.VerifierAction] and is +// used to embody the verifier checks added by [checkActivators]. +type checkActivatorAndMask struct { + Alignment + skipped bool +} -func (c checkActivatorAndMask) Run(run *wizard.VerifierRuntime) error { +func (c *checkActivatorAndMask) Run(run *wizard.VerifierRuntime) error { for i := range c.circMaskOpenings { var ( localOpening = run.GetLocalPointEvalParams(c.circMaskOpenings[i].ID) @@ -390,7 +410,7 @@ func (c checkActivatorAndMask) Run(run *wizard.VerifierRuntime) error { return nil } -func (c checkActivatorAndMask) RunGnark(api frontend.API, run *wizard.WizardVerifierCircuit) { +func (c *checkActivatorAndMask) RunGnark(api frontend.API, run *wizard.WizardVerifierCircuit) { for i := range c.circMaskOpenings { var ( valOpened = run.GetLocalPointEvalParams(c.circMaskOpenings[i].ID).Y @@ -400,3 +420,11 @@ func (c checkActivatorAndMask) RunGnark(api frontend.API, run *wizard.WizardVeri api.AssertIsEqual(valOpened, valActiv) } } + +func (c *checkActivatorAndMask) Skip() { + c.skipped = true +} + +func (c *checkActivatorAndMask) IsSkipped() bool { + return c.skipped +} diff --git a/prover/protocol/dedicated/plonk/compile.go b/prover/protocol/dedicated/plonk/compile.go index ae0fcccc2..3e9182fbe 100644 --- a/prover/protocol/dedicated/plonk/compile.go +++ b/prover/protocol/dedicated/plonk/compile.go @@ -68,7 +68,7 @@ func PlonkCheck( comp.RegisterProverAction(round+1, lroCommitProverAction{compilationCtx: ctx, proverStateLock: &sync.Mutex{}}) } - comp.RegisterVerifierAction(round, checkingActivators(ctx.Columns.Activators)) + comp.RegisterVerifierAction(round, &checkingActivators{Cols: ctx.Columns.Activators}) return ctx } @@ -299,20 +299,23 @@ func (ctx *compilationCtx) addCopyConstraint() { // checkingActivators implements the [wizard.VerifierAction] interface and // checks that the [Activators] columns are correctly assigned -type checkingActivators []ifaces.Column +type checkingActivators struct { + Cols []ifaces.Column + skipped bool +} -var _ wizard.VerifierAction = checkingActivators{} +var _ wizard.VerifierAction = &checkingActivators{} -func (ca checkingActivators) Run(run *wizard.VerifierRuntime) error { - for i := range ca { +func (ca *checkingActivators) Run(run *wizard.VerifierRuntime) error { + for i := range ca.Cols { - curr := ca[i].GetColAssignmentAt(run, 0) + curr := ca.Cols[i].GetColAssignmentAt(run, 0) if !curr.IsOne() && !curr.IsZero() { return fmt.Errorf("error the activators must be 0 or 1") } - if i+1 < len(ca) { - next := ca[i+1].GetColAssignmentAt(run, 0) + if i+1 < len(ca.Cols) { + next := ca.Cols[i+1].GetColAssignmentAt(run, 0) if curr.IsZero() && !next.IsZero() { return fmt.Errorf("the activators must never go from 0 to 1") } @@ -322,15 +325,23 @@ func (ca checkingActivators) Run(run *wizard.VerifierRuntime) error { return nil } -func (ca checkingActivators) RunGnark(api frontend.API, run *wizard.WizardVerifierCircuit) { - for i := range ca { +func (ca *checkingActivators) RunGnark(api frontend.API, run *wizard.WizardVerifierCircuit) { + for i := range ca.Cols { - curr := ca[i].GetColAssignmentGnarkAt(run, 0) + curr := ca.Cols[i].GetColAssignmentGnarkAt(run, 0) api.AssertIsBoolean(curr) - if i+1 < len(ca) { - next := ca[i+1].GetColAssignmentGnarkAt(run, 0) + if i+1 < len(ca.Cols) { + next := ca.Cols[i+1].GetColAssignmentGnarkAt(run, 0) api.AssertIsEqual(next, api.Mul(curr, next)) } } } + +func (ca *checkingActivators) Skip() { + ca.skipped = true +} + +func (ca *checkingActivators) IsSkipped() bool { + return ca.skipped +} diff --git a/prover/protocol/dedicated/projection/projection.go b/prover/protocol/dedicated/projection/projection.go index efa137fb4..149b358e7 100644 --- a/prover/protocol/dedicated/projection/projection.go +++ b/prover/protocol/dedicated/projection/projection.go @@ -64,6 +64,7 @@ type projectionProverAction struct { type projectionVerifierAction struct { Name ifaces.QueryID HornerA0, HornerB0 query.LocalOpening + skipped bool } // InsertProjection applies a projection query between sets (columnsA, filterA) @@ -212,7 +213,7 @@ func InsertProjection( pa.HornerB0 = comp.InsertLocalOpening(round, ifaces.QueryIDf("%v_HORNER_B0", queryName), pa.HornerB) comp.RegisterProverAction(round, pa) - comp.RegisterVerifierAction(round, projectionVerifierAction{HornerA0: pa.HornerA0, HornerB0: pa.HornerB0, Name: queryName}) + comp.RegisterVerifierAction(round, &projectionVerifierAction{HornerA0: pa.HornerA0, HornerB0: pa.HornerB0, Name: queryName}) } // Run implements the [wizard.ProverAction] interface. @@ -314,7 +315,7 @@ func (pa projectionProverAction) Run(run *wizard.ProverRuntime) { } // Run implements the [wizard.VerifierAction] interface. -func (va projectionVerifierAction) Run(run *wizard.VerifierRuntime) error { +func (va *projectionVerifierAction) Run(run *wizard.VerifierRuntime) error { var ( a = run.GetLocalPointEvalParams(va.HornerA0.ID).Y @@ -329,7 +330,7 @@ func (va projectionVerifierAction) Run(run *wizard.VerifierRuntime) error { } // RunGnark implements the [wizard.VerifierAction] interface. -func (va projectionVerifierAction) RunGnark(api frontend.API, run *wizard.WizardVerifierCircuit) { +func (va *projectionVerifierAction) RunGnark(api frontend.API, run *wizard.WizardVerifierCircuit) { var ( a = run.GetLocalPointEvalParams(va.HornerA0.ID).Y @@ -339,6 +340,14 @@ func (va projectionVerifierAction) RunGnark(api frontend.API, run *wizard.Wizard api.AssertIsEqual(a, b) } +func (va *projectionVerifierAction) Skip() { + va.skipped = true +} + +func (va *projectionVerifierAction) IsSkipped() bool { + return va.skipped +} + // cmptHorner computes a random Horner accumulation of the filtered elements // starting from the last entry down to the first entry. The final value is // stored in the last entry of the returned slice. diff --git a/prover/protocol/query/local_opening.go b/prover/protocol/query/local_opening.go index 463948c04..bb77b5376 100644 --- a/prover/protocol/query/local_opening.go +++ b/prover/protocol/query/local_opening.go @@ -6,7 +6,6 @@ import ( "github.com/consensys/gnark/frontend" "github.com/consensys/linea-monorepo/prover/crypto/fiatshamir" "github.com/consensys/linea-monorepo/prover/maths/field" - "github.com/consensys/linea-monorepo/prover/protocol/column" "github.com/consensys/linea-monorepo/prover/protocol/ifaces" "github.com/consensys/linea-monorepo/prover/utils" ) @@ -30,20 +29,6 @@ func (lop LocalOpeningParams) UpdateFS(fs *fiatshamir.State) { // Constructs a new local opening query func NewLocalOpening(id ifaces.QueryID, pol ifaces.Column) LocalOpening { - // For simplicity, we enforce the `pol` to be either Natural or Shifted(Natural) - // Allegedly, this does not block any-case - switch h := pol.(type) { - case column.Natural: - // allowed - case column.Shifted: - if _, ok := h.Parent.(column.Natural); !ok { - utils.Panic("Unsupported handle should only be a shifted or a natural %v", pol) - } - // allowed - default: - utils.Panic("Unsupported handle should only be a shifted %v", pol) - } - if len(pol.GetColID()) == 0 { utils.Panic("Assigned a polynomial name with an empty length") } diff --git a/prover/protocol/wizard/actions.go b/prover/protocol/wizard/actions.go index 630e51415..398017fba 100644 --- a/prover/protocol/wizard/actions.go +++ b/prover/protocol/wizard/actions.go @@ -14,6 +14,10 @@ type ProverAction interface { // protocol. Usually, this is used to represent verifier checks. They can be // registered via [CompiledIOP.RegisterVerifierAction]. type VerifierAction interface { + // Skip indicates that the verifier action can be skipped + Skip() + // IsSkipped returns whether the current VerifierAction is skipped + IsSkipped() bool // Run executes the VerifierAction over a [VerifierRuntime] it returns an // error. Run(*VerifierRuntime) error @@ -21,3 +25,26 @@ type VerifierAction interface { // error the function enforces the passing of the verifier's checks. RunGnark(frontend.API, *WizardVerifierCircuit) } + +// genVerifierAction represents a verifier action represented by closures +type genVerifierAction struct { + skipped bool + run func(*VerifierRuntime) error + runGnark func(frontend.API, *WizardVerifierCircuit) +} + +func (gva *genVerifierAction) Run(run *VerifierRuntime) error { + return gva.run(run) +} + +func (gva *genVerifierAction) RunGnark(api frontend.API, run *WizardVerifierCircuit) { + gva.runGnark(api, run) +} + +func (gva *genVerifierAction) Skip() { + gva.skipped = true +} + +func (gva *genVerifierAction) IsSkipped() bool { + return gva.skipped +} diff --git a/prover/protocol/wizard/builder.go b/prover/protocol/wizard/builder.go index 9bab4cfb6..baf08b2ee 100644 --- a/prover/protocol/wizard/builder.go +++ b/prover/protocol/wizard/builder.go @@ -294,16 +294,8 @@ func (b *Builder) equalizeRounds(numRounds int) { /* Check and reserve for the verifiers */ - if comp.subVerifiers.Len() > numRounds { - utils.Panic("Bug : numRounds is %v but %v rounds are registered for the verifier. %v", numRounds, comp.subVerifiers.Len(), helpMsg) + if comp.SubVerifiers.Len() > numRounds { + utils.Panic("Bug : numRounds is %v but %v rounds are registered for the verifier. %v", numRounds, comp.SubVerifiers.Len(), helpMsg) } - comp.subVerifiers.Reserve(numRounds) - - /* - Check and reserve for the gnark verifiers - */ - if comp.gnarkSubVerifiers.Len() > numRounds { - utils.Panic("Bug : numRounds is %v but %v rounds are registered for the gnark verifier. %v", numRounds, comp.gnarkSubVerifiers.Len(), helpMsg) - } - comp.gnarkSubVerifiers.Reserve(numRounds) + comp.SubVerifiers.Reserve(numRounds) } diff --git a/prover/protocol/wizard/compiled.go b/prover/protocol/wizard/compiled.go index 6cf5c37a3..eab3ecc2b 100644 --- a/prover/protocol/wizard/compiled.go +++ b/prover/protocol/wizard/compiled.go @@ -71,28 +71,26 @@ type CompiledIOP struct { // manual checks that the verifier has to perform. This is useful when a check // cannot be represented in term of query but, when possible, queries should // always be preferred to express a relation that the witness must satisfy. - subVerifiers collection.VecVec[VerifierStep] + SubVerifiers collection.VecVec[VerifierAction] - // gnarkSubVerifiers does the same as [gnarkSubVerifiers] but in a gnark - // circuit. Whenever, the user add a subVerifier function into the compiled - // IOP, he should also provide an equivalent gnark function that does - // exactly the same thing, but in a gnark circuit. This used when - // instantiating a gnark verifier for the sub-protocol. - gnarkSubVerifiers collection.VecVec[GnarkVerifierStep] + // FiatShamirHooks is an action that is run during the FS sampling. Compared + // to a normal verifier action it has the possibility to interact with the + // Fiat-Shamir state. + FiatShamirHooks collection.VecVec[VerifierAction] // Precomputed stores the assignments of all the Precomputed and VerifierKey // polynomials. It is assigned directly when registering a precomputed // column. Precomputed collection.Mapping[ifaces.ColID, ifaces.ColAssignment] - // CryptographicCompilerCtx stores the compilation context of the last used + // PcsCtxs stores the compilation context of the last used // cryptographic compiler. Specifically, it is aimed to store the last // Vortex compilation context (see [github.com/consensys/linea-monorepo/prover/protocol/compiler]) that was used. And // its purpose is to provide the Vortex context to the self-recursion // compilation context; see [github.com/consensys/linea-monorepo/prover/protocol/compiler/selfrecursion]. This allows // the self-recursion context to learn about the columns to use and the // Vortex parameters. - CryptographicCompilerCtx any + PcsCtxs any // DummyCompiled that can be set internally by the compilation, when we are // using the [github.com/consensys/linea-monorepo/prover/protocol/compiler/dummy.Compile] compilation step. This steps @@ -120,6 +118,10 @@ type CompiledIOP struct { // // For efficiency reasons, the fiatShamirSetup is derived using SHA2. fiatShamirSetup field.Element + + // FunctionalPublic inputs lists the queries representing a public inputs + // and their identifiers + PublicInputs []PublicInput } // NumRounds returns the total number of prover interactions with the verifier @@ -478,8 +480,10 @@ func (c *CompiledIOP) InsertPublicInput(round int, name ifaces.ColID, size int) // passing `nil` is fine. func (c *CompiledIOP) InsertVerifier(round int, ver VerifierStep, gnarkVer GnarkVerifierStep) { c.assertConsistentRound(round) - c.gnarkSubVerifiers.AppendToInner(round, gnarkVer) - c.subVerifiers.AppendToInner(round, ver) + c.SubVerifiers.AppendToInner(round, &genVerifierAction{ + run: ver, + runGnark: gnarkVer, + }) } // InsertRange registers [query.Range] in the CompiledIOP. Namely, it ensures diff --git a/prover/protocol/wizard/gnark_verifier.go b/prover/protocol/wizard/gnark_verifier.go index 82f37c718..40e48fcff 100644 --- a/prover/protocol/wizard/gnark_verifier.go +++ b/prover/protocol/wizard/gnark_verifier.go @@ -5,6 +5,7 @@ import ( "github.com/consensys/linea-monorepo/prover/crypto/fiatshamir" "github.com/consensys/linea-monorepo/prover/crypto/mimc/gkrmimc" "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors" + "github.com/consensys/linea-monorepo/prover/maths/field" "github.com/consensys/linea-monorepo/prover/protocol/coin" "github.com/consensys/linea-monorepo/prover/protocol/column" "github.com/consensys/linea-monorepo/prover/protocol/ifaces" @@ -80,6 +81,11 @@ type WizardVerifierCircuit struct { // hashes but also the MiMC Vortex column hashes that we use for the // last round of the self-recursion. HasherFactory *gkrmimc.HasherFactory `gnark:"-"` + + // FiatShamirHistory tracks the fiat-shamir state at the beginning of every + // round. The first entry is the initial state, the final entry is the final + // state. + FiatShamirHistory [][2][]frontend.Variable `gnark:"-"` } // AllocateWizardCircuit allocates the inner-slices of the verifier struct from a precompiled IOP. It @@ -89,7 +95,7 @@ type WizardVerifierCircuit struct { // the circuit. func AllocateWizardCircuit(comp *CompiledIOP) (*WizardVerifierCircuit, error) { - res := newWizardVerifierCircuit() + res := NewWizardVerifierCircuit() for i, colName := range comp.Columns.AllKeys() { // filter the columns by status @@ -109,9 +115,7 @@ func AllocateWizardCircuit(comp *CompiledIOP) (*WizardVerifierCircuit, error) { size := comp.Columns.GetSize(colName) // Allocates the column in the circuit and indexes it - colID := len(res.Columns) - res.Columns = append(res.Columns, gnarkutil.AllocateSlice(size)) - res.columnsIDs.InsertNew(colName, colID) + res.AllocColumn(colName, size) } /* @@ -129,17 +133,11 @@ func AllocateWizardCircuit(comp *CompiledIOP) (*WizardVerifierCircuit, error) { switch qInfo := qInfoIface.(type) { case query.UnivariateEval: - // Note that nil is the default value for frontend.Variable - res.univariateParamsIDs.InsertNew(qName, len(res.UnivariateParams)) - res.UnivariateParams = append(res.UnivariateParams, qInfo.GnarkAllocate()) + res.AllocUnivariateEval(qName, qInfo) case query.InnerProduct: - // Note that nil is the default value for frontend.Variable - res.innerProductIDs.InsertNew(qName, len(res.InnerProductParams)) - res.InnerProductParams = append(res.InnerProductParams, qInfo.GnarkAllocate()) + res.AllocInnerProduct(qName, qInfo) case query.LocalOpening: - // Note that nil is the default value for frontend.Variable - res.localOpeningIDs.InsertNew(qName, len(res.LocalOpeningParams)) - res.LocalOpeningParams = append(res.LocalOpeningParams, query.GnarkLocalOpeningParams{}) + res.AllocLocalOpening(qName, qInfo) } } @@ -154,67 +152,55 @@ func (c *WizardVerifierCircuit) Verify(api frontend.API) { c.HasherFactory = gkrmimc.NewHasherFactory(api) c.FS = fiatshamir.NewGnarkFiatShamir(api, c.HasherFactory) c.FS.Update(c.Spec.fiatShamirSetup) + c.FiatShamirHistory = make([][2][]frontend.Variable, c.Spec.NumRounds()) c.generateAllRandomCoins(api) - logrus.Tracef("Generated the coins") - - for _, roundSteps := range c.Spec.gnarkSubVerifiers.Inner() { + for _, roundSteps := range c.Spec.SubVerifiers.Inner() { for _, step := range roundSteps { - step(api, c) + if !step.IsSkipped() { + step.RunGnark(api, c) + } } } } // generateAllRandomCoins is as [VerifierRuntime.generateAllRandomCoins]. Note that the function // does create constraints via the hasher factory that is inside of `c.FS`. -func (c *WizardVerifierCircuit) generateAllRandomCoins(_ frontend.API) { +func (c *WizardVerifierCircuit) generateAllRandomCoins(api frontend.API) { for currRound := 0; currRound < c.Spec.NumRounds(); currRound++ { - if currRound > 0 { - /* - Sanity-check : Make sure all issued random coin have been - "consumed" by all the verifiers steps, in the round we are - "closing" - */ - toBeConsumed := c.Spec.Coins.AllKeysAt(currRound - 1) - c.Coins.Exists(toBeConsumed...) - if !c.Spec.DummyCompiled { + initialState := c.FS.State() - // Make sure that all messages have been written and use them - // to update the FS state. Note that we do not need to update - // FS using the last round of the prover because he is always - // the last one to "talk" in the protocol. - toUpdateFS := c.Spec.Columns.AllKeysProofAt(currRound - 1) - for _, msg := range toUpdateFS { - - msgID := c.columnsIDs.MustGet(msg) - msgContent := c.Columns[msgID] - - logrus.Tracef("VERIFIER CIRCUIT : Updating the FS oracle with a message - %v", msg) - c.FS.UpdateVec(msgContent) - } + if currRound > 0 { - toUpdateFS = c.Spec.Columns.AllKeysPublicInputAt(currRound - 1) - for _, msg := range toUpdateFS { + // Make sure that all messages have been written and use them + // to update the FS state. Note that we do not need to update + // FS using the last round of the prover because he is always + // the last one to "talk" in the protocol. + toUpdateFS := c.Spec.Columns.AllKeysProofAt(currRound - 1) + for _, msg := range toUpdateFS { + msgContent := c.GetColumn(msg) + c.FS.UpdateVec(msgContent) + } - msgID := c.columnsIDs.MustGet(msg) - msgContent := c.Columns[msgID] + toUpdateFS = c.Spec.Columns.AllKeysPublicInputAt(currRound - 1) + for _, msg := range toUpdateFS { + msgContent := c.GetColumn(msg) + c.FS.UpdateVec(msgContent) + } - logrus.Tracef("VERIFIER CIRCUIT : Updating the FS oracle with public input - %v", msg) - c.FS.UpdateVec(msgContent) + /* + Also include the prover's allegations for all evaluations + */ + queries := c.Spec.QueriesParams.AllKeysAt(currRound - 1) + for _, qName := range queries { + if c.Spec.QueriesParams.IsSkippedFromVerifierTranscript(qName) { + continue } - /* - Also include the prover's allegations for all evaluations - */ - queries := c.Spec.QueriesParams.AllKeysAt(currRound - 1) - for _, qName := range queries { - // Implicitly, this will panic whenever we start supporting - // a new type of query params - params := c.GetParams(qName) - params.UpdateFS(c.FS) - } + params := c.GetParams(qName) + params.UpdateFS(c.FS) } } @@ -223,8 +209,11 @@ func (c *WizardVerifierCircuit) generateAllRandomCoins(_ frontend.API) { */ toCompute := c.Spec.Coins.AllKeysAt(currRound) for _, coinName := range toCompute { + if c.Spec.Coins.IsSkippedFromVerifierTranscript(coinName) { + continue + } + info := c.Spec.Coins.Data(coinName) - logrus.Tracef("VERIFIER CIRCUIT : Generate a random coin - %v", coinName) switch info.Type { case coin.Field: value := c.FS.RandomField() @@ -234,6 +223,22 @@ func (c *WizardVerifierCircuit) generateAllRandomCoins(_ frontend.API) { c.Coins.InsertNew(coinName, value) } } + + if c.Spec.FiatShamirHooks.Len() > currRound { + fsHooks := c.Spec.FiatShamirHooks.MustGet(currRound) + for i := range fsHooks { + if fsHooks[i].IsSkipped() { + continue + } + + fsHooks[i].RunGnark(api, c) + } + } + + c.FiatShamirHistory[currRound] = [2][]frontend.Variable{ + initialState, + c.FS.State(), + } } } @@ -346,9 +351,9 @@ func (c *WizardVerifierCircuit) GetColumnAt(name ifaces.ColID, pos int) frontend return c.GetColumn(name)[pos] } -// newWizardVerifierCircuit creates an empty wizard verifier circuit. +// NewWizardVerifierCircuit creates an empty wizard verifier circuit. // Initializes the underlying structs and collections. -func newWizardVerifierCircuit() *WizardVerifierCircuit { +func NewWizardVerifierCircuit() *WizardVerifierCircuit { res := &WizardVerifierCircuit{} res.columnsIDs = collection.NewMapping[ifaces.ColID, int]() res.univariateParamsIDs = collection.NewMapping[ifaces.QueryID, int]() @@ -367,7 +372,7 @@ func newWizardVerifierCircuit() *WizardVerifierCircuit { // gnark assignment circuit involving the verification of Wizard proof. func GetWizardVerifierCircuitAssignment(comp *CompiledIOP, proof Proof) *WizardVerifierCircuit { - res := newWizardVerifierCircuit() + res := NewWizardVerifierCircuit() /* Assigns the messages. Note that the iteration order is made @@ -411,17 +416,11 @@ func GetWizardVerifierCircuitAssignment(comp *CompiledIOP, proof Proof) *WizardV switch params := paramsIface.(type) { case query.UnivariateEvalParams: - res.univariateParamsIDs.InsertNew(qName, len(res.UnivariateParams)) - res.UnivariateParams = append(res.UnivariateParams, params.GnarkAssign()) - + res.AssignUnivariateEval(qName, params) case query.InnerProductParams: - res.innerProductIDs.InsertNew(qName, len(res.InnerProductParams)) - res.InnerProductParams = append(res.InnerProductParams, params.GnarkAssign()) - + res.AssignInnerProduct(qName, params) case query.LocalOpeningParams: - res.localOpeningIDs.InsertNew(qName, len(res.LocalOpeningParams)) - res.LocalOpeningParams = append(res.LocalOpeningParams, params.GnarkAssign()) - + res.AssignLocalOpening(qName, params) default: utils.Panic("unknow type %T", params) } @@ -445,3 +444,79 @@ func (c *WizardVerifierCircuit) GetParams(id ifaces.QueryID) ifaces.GnarkQueryPa } panic("unreachable") } + +// AllocColumn inserts a column in the Wizard verifier circuit and is meant +// to be called at allocation time. +func (c *WizardVerifierCircuit) AllocColumn(id ifaces.ColID, size int) []frontend.Variable { + column := make([]frontend.Variable, size) + c.columnsIDs.InsertNew(id, len(c.Columns)) + c.Columns = append(c.Columns, column) + return column +} + +// AssignColumn assigns a column in the Wizard verifier circuit +func (c *WizardVerifierCircuit) AssignColumn(id ifaces.ColID, sv smartvectors.SmartVector) { + column := smartvectors.IntoGnarkAssignment(sv) + c.columnsIDs.InsertNew(id, len(c.Columns)) + c.Columns = append(c.Columns, column) +} + +// AllocUnivariableEval inserts a slot for a univariate query opening in the +// witness of the verifier circuit. +func (c *WizardVerifierCircuit) AllocUnivariateEval(qName ifaces.QueryID, qInfo query.UnivariateEval) { + // Note that nil is the default value for frontend.Variable + c.univariateParamsIDs.InsertNew(qName, len(c.UnivariateParams)) + c.UnivariateParams = append(c.UnivariateParams, qInfo.GnarkAllocate()) +} + +// AllocInnerProduct inserts a slot for an inner-product query opening in the +// witness of the verifier circuit. +func (c *WizardVerifierCircuit) AllocInnerProduct(qName ifaces.QueryID, qInfo query.InnerProduct) { + // Note that nil is the default value for frontend.Variable + c.innerProductIDs.InsertNew(qName, len(c.InnerProductParams)) + c.InnerProductParams = append(c.InnerProductParams, qInfo.GnarkAllocate()) +} + +// AllocLocalOpening inserts a slot for a local position opening in the witness +// of the verifier circuit. +func (c *WizardVerifierCircuit) AllocLocalOpening(qName ifaces.QueryID, qInfo query.LocalOpening) { + // Note that nil is the default value for frontend.Variable + c.localOpeningIDs.InsertNew(qName, len(c.LocalOpeningParams)) + c.LocalOpeningParams = append(c.LocalOpeningParams, query.GnarkLocalOpeningParams{}) +} + +// AssignUnivariableEval inserts a slot for a univariate query opening in the +// witness of the verifier circuit. +func (c *WizardVerifierCircuit) AssignUnivariateEval(qName ifaces.QueryID, params query.UnivariateEvalParams) { + // Note that nil is the default value for frontend.Variable + c.univariateParamsIDs.InsertNew(qName, len(c.UnivariateParams)) + c.UnivariateParams = append(c.UnivariateParams, params.GnarkAssign()) +} + +// AssignInnerProduct inserts a slot for an inner-product query opening in the +// witness of the verifier circuit. +func (c *WizardVerifierCircuit) AssignInnerProduct(qName ifaces.QueryID, params query.InnerProductParams) { + // Note that nil is the default value for frontend.Variable + c.innerProductIDs.InsertNew(qName, len(c.InnerProductParams)) + c.InnerProductParams = append(c.InnerProductParams, params.GnarkAssign()) +} + +// AssignLocalOpening inserts a slot for a local position opening in the witness +// of the verifier circuit. +func (c *WizardVerifierCircuit) AssignLocalOpening(qName ifaces.QueryID, params query.LocalOpeningParams) { + // Note that nil is the default value for frontend.Variable + c.localOpeningIDs.InsertNew(qName, len(c.LocalOpeningParams)) + c.LocalOpeningParams = append(c.LocalOpeningParams, params.GnarkAssign()) +} + +// GetPublicInput returns a public input value from its name +func (c *WizardVerifierCircuit) GetPublicInput(api frontend.API, name string) frontend.Variable { + allPubs := c.Spec.PublicInputs + for i := range allPubs { + if allPubs[i].Name == name { + return allPubs[i].Acc.GetFrontendVariable(api, c) + } + } + utils.Panic("could not find public input nb %v", name) + return field.Element{} +} diff --git a/prover/protocol/wizard/prover.go b/prover/protocol/wizard/prover.go index caa4df4aa..df61fd31b 100644 --- a/prover/protocol/wizard/prover.go +++ b/prover/protocol/wizard/prover.go @@ -12,7 +12,6 @@ import ( "github.com/consensys/linea-monorepo/prover/protocol/query" "github.com/consensys/linea-monorepo/prover/utils" "github.com/consensys/linea-monorepo/prover/utils/collection" - "github.com/sirupsen/logrus" ) // ProverStep represents an operation to be performed by the prover of a @@ -120,6 +119,11 @@ type ProverRuntime struct { // lock is global lock so that the assignment maps are thread safes lock *sync.Mutex + + // FiatShamirHistory tracks the fiat-shamir state at the beginning of every + // round. The first entry is the initial state, the final entry is the final + // state. + FiatShamirHistory [][2][]field.Element } // Prove is the top-level function that runs the Prover on the user's side. It @@ -204,14 +208,20 @@ func (c *CompiledIOP) createProver() ProverRuntime { // Instantiates an empty Assignment (but link it to the CompiledIOP) runtime := ProverRuntime{ - Spec: c, - Columns: collection.NewMapping[ifaces.ColID, ifaces.ColAssignment](), - QueriesParams: collection.NewMapping[ifaces.QueryID, ifaces.QueryParams](), - Coins: collection.NewMapping[coin.Name, interface{}](), - State: collection.NewMapping[string, interface{}](), - FS: fs, - currRound: 0, - lock: &sync.Mutex{}, + Spec: c, + Columns: collection.NewMapping[ifaces.ColID, ifaces.ColAssignment](), + QueriesParams: collection.NewMapping[ifaces.QueryID, ifaces.QueryParams](), + Coins: collection.NewMapping[coin.Name, interface{}](), + State: collection.NewMapping[string, interface{}](), + FS: fs, + currRound: 0, + lock: &sync.Mutex{}, + FiatShamirHistory: make([][2][]field.Element, c.NumRounds()), + } + + runtime.FiatShamirHistory[0] = [2][]field.Element{ + fs.State(), + fs.State(), } // Pass the precomputed polynomials @@ -452,6 +462,8 @@ func (run *ProverRuntime) getRandomCoinGeneric(name coin.Name, requestedType coi // parameters. This makes all the new coins available in the prover runtime. func (run *ProverRuntime) goNextRound() { + initialState := run.FS.State() + /* Make sure all issued random coin have been "consumed" by all the prover steps, in the round we are closing. An error occuring here is more likely @@ -473,11 +485,6 @@ func (run *ProverRuntime) goNextRound() { toBeParametrized := run.Spec.QueriesParams.AllKeysAt(run.currRound) run.QueriesParams.MustExists(toBeParametrized...) - // Counts the transcript size of the round and the number of field - // element generated. - initialTranscriptSize := run.FS.TranscriptSize - initialNumCoinsGenerated := run.FS.NumCoinGenerated - if !run.Spec.DummyCompiled { /* @@ -486,13 +493,11 @@ func (run *ProverRuntime) goNextRound() { FS using the last round of the prover because he is always the last one to "talk" in the protocol. */ - start := run.FS.TranscriptSize - msgsToFS := run.Spec.Columns.AllKeysProofAt(run.currRound) + msgsToFS := run.Spec.Columns.AllKeysProofsOrIgnoredButKeptInProverTranscript(run.currRound) for _, msgName := range msgsToFS { instance := run.GetMessage(msgName) run.FS.UpdateSV(instance) } - logrus.Debugf("Fiat-shamir round %v - %v proof elements in the transcript", run.currRound, run.FS.TranscriptSize-start) /* Make sure that all messages have been written and use them @@ -500,26 +505,26 @@ func (run *ProverRuntime) goNextRound() { FS using the last round of the prover because he is always the last one to "talk" in the protocol. */ - start = run.FS.TranscriptSize msgsToFS = run.Spec.Columns.AllKeysPublicInputAt(run.currRound) for _, msgName := range msgsToFS { instance := run.GetMessage(msgName) run.FS.UpdateSV(instance) } - logrus.Debugf("Fiat-shamir round %v - %v public inputs in the transcript", run.currRound, run.FS.TranscriptSize-start) /* Also include the prover's allegations for all evaluations */ - start = run.FS.TranscriptSize paramsToFS := run.Spec.QueriesParams.AllKeysAt(run.currRound) for _, qName := range paramsToFS { + if run.Spec.QueriesParams.IsSkippedFromProverTranscript(qName) { + continue + } + // Implicitly, this will panic whenever we start supporting // a new type of query params params := run.QueriesParams.MustGet(qName) params.UpdateFS(run.FS) } - logrus.Debugf("Fiat-shamir round %v - %v query params in the transcript", run.currRound, run.FS.TranscriptSize-start) } // Increment the number of rounds @@ -537,9 +542,12 @@ func (run *ProverRuntime) goNextRound() { run.Coins.InsertNew(coin, value) } - logrus.Debugf("Ran Fiat-Shamir for round %v, transcript size %v (field element), generated %v field elements, total-transcript %v, total-generated %v", - run.currRound, run.FS.TranscriptSize-initialTranscriptSize, run.FS.NumCoinGenerated-initialNumCoinsGenerated, run.FS.TranscriptSize, run.FS.NumCoinGenerated, - ) + finalState := run.FS.State() + + run.FiatShamirHistory[run.currRound] = [2][]field.Element{ + initialState, + finalState, + } } // runProverSteps runs all the [ProverStep] specified in the underlying diff --git a/prover/protocol/wizard/public_input.go b/prover/protocol/wizard/public_input.go new file mode 100644 index 000000000..30b8c9328 --- /dev/null +++ b/prover/protocol/wizard/public_input.go @@ -0,0 +1,14 @@ +package wizard + +import "github.com/consensys/linea-monorepo/prover/protocol/ifaces" + +// PublicInput represents a public input in a wizard protocol. Public inputs +// are materialized with a functional identifier and a local opening query. +// The identifier is what ultimately identifies the public input as the query +// may be mutated by compilation (if we use the FullRecursion compiler), therefore +// it would unsafe to use the ID of the query to identify the public input in +// the circuit. +type PublicInput struct { + Name string + Acc ifaces.Accessor +} diff --git a/prover/protocol/wizard/register.go b/prover/protocol/wizard/register.go index 45d84273c..6efdcbdb9 100644 --- a/prover/protocol/wizard/register.go +++ b/prover/protocol/wizard/register.go @@ -5,11 +5,10 @@ import ( "github.com/consensys/linea-monorepo/prover/utils/collection" ) -/* -In a nutshell, an item is an abstract type that -accounts for the fact that CompiledProtocol -registers various things for different rounds -*/ +// ByRoundRegister is a an abstract data-structure used to register the +// [column.Natural], [coin.Info] and [ifaces.Query] etc... Each item is added +// at a particular round. The structure additionally records compilation +// informations about the objects stored in the register. type ByRoundRegister[ID comparable, DATA any] struct { // All the data for each key mapping collection.Mapping[ID, DATA] @@ -19,6 +18,17 @@ type ByRoundRegister[ID comparable, DATA any] struct { byRoundsIndex collection.Mapping[ID, int] // Marks an entry as ignorable (but does not delete it) ignored collection.Set[ID] + // skippedFromVerifierTranscript marks an entry as "skipped from verifier + // transcript from the FS transcript for the verifier. This means that the + // verifier will not use this value. However, the value can still be used + // by the prover. The reason for this field is to work around subtle issues + // while dealing with recursion. + skippedFromVerifierTranscript collection.Set[ID] + // skippedFromProverTranscript marks an entry as "skipped from prover + // transcript" this means that neither the prover nor the verifier will use + // this value to update the transcript. The reason for this field is to work + // around subtle issues while dealing with recursion. + skippedFromProverTranscript collection.Set[ID] } /* @@ -26,10 +36,12 @@ Construct a new round register */ func NewRegister[ID comparable, DATA any]() ByRoundRegister[ID, DATA] { return ByRoundRegister[ID, DATA]{ - mapping: collection.NewMapping[ID, DATA](), - byRounds: collection.NewVecVec[ID](), - byRoundsIndex: collection.NewMapping[ID, int](), - ignored: collection.NewSet[ID](), + mapping: collection.NewMapping[ID, DATA](), + byRounds: collection.NewVecVec[ID](), + byRoundsIndex: collection.NewMapping[ID, int](), + ignored: collection.NewSet[ID](), + skippedFromVerifierTranscript: collection.NewSet[ID](), + skippedFromProverTranscript: collection.NewSet[ID](), } } @@ -166,3 +178,34 @@ func (r *ByRoundRegister[ID, DATA]) IsIgnored(id ID) bool { r.mapping.MustExists(id) return r.ignored.Exists(id) } + +// MarkAsSkippedFromVerifierTranscript marks an entry as skipped from the transcript +// of the verifier. Panic if the key is missing from the register. Returns true if +// the item was already ignored. +func (r *ByRoundRegister[ID, DATA]) MarkAsSkippedFromVerifierTranscript(id ID) bool { + r.mapping.MustExists(id) + return r.skippedFromVerifierTranscript.Insert(id) +} + +// IsSkippedFromVerifierTranscript returns if the entry is skipped from the +// transcript. Panics if the entry is missing from the map. +func (r *ByRoundRegister[ID, DATA]) IsSkippedFromVerifierTranscript(id ID) bool { + r.mapping.MustExists(id) + return r.skippedFromVerifierTranscript.Exists(id) +} + +// MarkAsSkippedFromProverTranscript marks an entry as skipped from the transcript +// of the verifier. Panic if the key is missing from the register. Returns true +// if the item was already ignored. +func (r *ByRoundRegister[ID, DATA]) MarkAsSkippedFromProverTranscript(id ID) bool { + r.mapping.MustExists(id) + r.skippedFromVerifierTranscript.Insert(id) + return r.skippedFromProverTranscript.Insert(id) +} + +// IsSkippedFromProverTranscript returns if the entry is skipped from the +// transcript. Panics if the entry is missing from the map. +func (r *ByRoundRegister[ID, DATA]) IsSkippedFromProverTranscript(id ID) bool { + r.mapping.MustExists(id) + return r.skippedFromProverTranscript.Exists(id) +} diff --git a/prover/protocol/wizard/verifier.go b/prover/protocol/wizard/verifier.go index 1a355796e..65b8d83bf 100644 --- a/prover/protocol/wizard/verifier.go +++ b/prover/protocol/wizard/verifier.go @@ -9,7 +9,6 @@ import ( "github.com/consensys/linea-monorepo/prover/protocol/query" "github.com/consensys/linea-monorepo/prover/utils" "github.com/consensys/linea-monorepo/prover/utils/collection" - "github.com/sirupsen/logrus" ) // Proof generically represents a proof obtained from the wizard. This object does not @@ -69,6 +68,11 @@ type VerifierRuntime struct { // the verifer end up having different state or the same message being // included a second time. Use it externally at your own risks. FS *fiatshamir.State + + // FiatShamirHistory tracks the fiat-shamir state at the beginning of every + // round. The first entry is the initial state, the final entry is the final + // state. + FiatShamirHistory [][2][]field.Element } // Verify verifies a wizard proof. The caller specifies a [CompiledIOP] that @@ -91,10 +95,12 @@ func Verify(c *CompiledIOP, proof Proof) error { any */ errs := []error{} - for _, roundSteps := range runtime.Spec.subVerifiers.Inner() { + for _, roundSteps := range runtime.Spec.SubVerifiers.Inner() { for _, step := range roundSteps { - if err := step(&runtime); err != nil { - errs = append(errs, err) + if !step.IsSkipped() { + if err := step.Run(&runtime); err != nil { + errs = append(errs, err) + } } } } @@ -117,11 +123,12 @@ func (c *CompiledIOP) createVerifier(proof Proof) VerifierRuntime { Instantiate an empty assigment for the verifier */ runtime := VerifierRuntime{ - Spec: c, - Coins: collection.NewMapping[coin.Name, interface{}](), - Columns: proof.Messages, - QueriesParams: proof.QueriesParams, - FS: fiatshamir.NewMiMCFiatShamir(), + Spec: c, + Coins: collection.NewMapping[coin.Name, interface{}](), + Columns: proof.Messages, + QueriesParams: proof.QueriesParams, + FS: fiatshamir.NewMiMCFiatShamir(), + FiatShamirHistory: make([][2][]field.Element, c.NumRounds()), } runtime.FS.Update(c.fiatShamirSetup) @@ -146,14 +153,10 @@ func (c *CompiledIOP) createVerifier(proof Proof) VerifierRuntime { func (run *VerifierRuntime) generateAllRandomCoins() { for currRound := 0; currRound < run.Spec.NumRounds(); currRound++ { + + initialState := run.FS.State() + if currRound > 0 { - /* - Sanity-check : Make sure all issued random coin have been - "consumed" by all the verifiers steps, in the round we are - "closing" - */ - toBeConsumed := run.Spec.Coins.AllKeysAt(currRound - 1) - run.Coins.MustExists(toBeConsumed...) if !run.Spec.DummyCompiled { @@ -166,14 +169,12 @@ func (run *VerifierRuntime) generateAllRandomCoins() { msgsToFS := run.Spec.Columns.AllKeysProofAt(currRound - 1) for _, msgName := range msgsToFS { instance := run.GetColumn(msgName) - logrus.Tracef("VERIFIER : Update fiat-shamir with proof message %v", msgName) run.FS.UpdateSV(instance) } msgsToFS = run.Spec.Columns.AllKeysPublicInputAt(currRound - 1) for _, msgName := range msgsToFS { instance := run.GetColumn(msgName) - logrus.Tracef("VERIFIER : Update fiat-shamir with public input %v", msgName) run.FS.UpdateSV(instance) } @@ -182,9 +183,10 @@ func (run *VerifierRuntime) generateAllRandomCoins() { */ queries := run.Spec.QueriesParams.AllKeysAt(currRound - 1) for _, qName := range queries { - // Implicitly, this will panic whenever we start supporting - // a new type of query params - logrus.Tracef("VERIFIER : Update fiat-shamir with query parameters %v", qName) + if run.Spec.QueriesParams.IsSkippedFromVerifierTranscript(qName) { + continue + } + params := run.QueriesParams.MustGet(qName) params.UpdateFS(run.FS) } @@ -197,11 +199,30 @@ func (run *VerifierRuntime) generateAllRandomCoins() { */ toCompute := run.Spec.Coins.AllKeysAt(currRound) for _, coin := range toCompute { - logrus.Tracef("VERIFIER : Generate coin %v", coin) + if run.Spec.Coins.IsSkippedFromVerifierTranscript(coin) { + continue + } + info := run.Spec.Coins.Data(coin) value := info.Sample(run.FS) run.Coins.InsertNew(coin, value) } + + if run.Spec.FiatShamirHooks.Len() > currRound { + fsHooks := run.Spec.FiatShamirHooks.MustGet(currRound) + for i := range fsHooks { + if fsHooks[i].IsSkipped() { + continue + } + + fsHooks[i].Run(run) + } + } + + run.FiatShamirHistory[currRound] = [2][]field.Element{ + initialState, + run.FS.State(), + } } } @@ -358,3 +379,15 @@ func (run VerifierRuntime) GetColumnAt(name ifaces.ColID, pos int) field.Element func (run *VerifierRuntime) GetParams(name ifaces.QueryID) ifaces.QueryParams { return run.QueriesParams.MustGet(name) } + +// GetPublicInput returns a public input from its name +func (run *VerifierRuntime) GetPublicInput(name string) field.Element { + allPubs := run.Spec.PublicInputs + for i := range allPubs { + if allPubs[i].Name == name { + return allPubs[i].Acc.GetVal(run) + } + } + utils.Panic("could not find public input nb %v", name) + return field.Element{} +} diff --git a/prover/zkevm/prover/ecpair/circuits.go b/prover/zkevm/prover/ecpair/circuits.go index ca35d00b6..0eb122669 100644 --- a/prover/zkevm/prover/ecpair/circuits.go +++ b/prover/zkevm/prover/ecpair/circuits.go @@ -2,6 +2,7 @@ package ecpair import ( "fmt" + "math/big" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/algebra/emulated/fields_bn254" @@ -9,10 +10,16 @@ import ( "github.com/consensys/gnark/std/evmprecompiles" "github.com/consensys/gnark/std/math/bitslice" "github.com/consensys/gnark/std/math/emulated" + "github.com/consensys/gnark/std/math/emulated/emparams" ) var fpParams sw_bn254.BaseField +type ( + fpField = emulated.Field[emparams.BN254Fp] + fpElement = emulated.Element[emparams.BN254Fp] +) + // G1ElementWizard represents G1 element as Wizard limbs (2 limbs of 128 bits) type G1ElementWizard struct { P [nbG1Limbs]frontend.Variable @@ -126,50 +133,22 @@ func (c *GtElementWizard) ToGtElement(api frontend.API, fp *emulated.Field[sw_bn C1B2YLimbs[2], C1B2YLimbs[3] = bitslice.Partition(api, c.T[22], 64, bitslice.WithNbDigits(128)) C1B2YLimbs[0], C1B2YLimbs[1] = bitslice.Partition(api, c.T[23], 64, bitslice.WithNbDigits(128)) - C0B0X := fp.NewElement(C0B0XLimbs) - C0B0Y := fp.NewElement(C0B0YLimbs) - C0B1X := fp.NewElement(C0B1XLimbs) - C0B1Y := fp.NewElement(C0B1YLimbs) - C0B2X := fp.NewElement(C0B2XLimbs) - C0B2Y := fp.NewElement(C0B2YLimbs) - C1B0X := fp.NewElement(C1B0XLimbs) - C1B0Y := fp.NewElement(C1B0YLimbs) - C1B1X := fp.NewElement(C1B1XLimbs) - C1B1Y := fp.NewElement(C1B1YLimbs) - C1B2X := fp.NewElement(C1B2XLimbs) - C1B2Y := fp.NewElement(C1B2YLimbs) - - T := sw_bn254.GTEl{ - C0: fields_bn254.E6{ - B0: fields_bn254.E2{ - A0: *C0B0X, - A1: *C0B0Y, - }, - B1: fields_bn254.E2{ - A0: *C0B1X, - A1: *C0B1Y, - }, - B2: fields_bn254.E2{ - A0: *C0B2X, - A1: *C0B2Y, - }, - }, - C1: fields_bn254.E6{ - B0: fields_bn254.E2{ - A0: *C1B0X, - A1: *C1B0Y, - }, - B1: fields_bn254.E2{ - A0: *C1B1X, - A1: *C1B1Y, - }, - B2: fields_bn254.E2{ - A0: *C1B2X, - A1: *C1B2Y, - }, - }, + e12Tower := [12]*fpElement{ + fp.NewElement(C0B0XLimbs), + fp.NewElement(C0B0YLimbs), + fp.NewElement(C0B1XLimbs), + fp.NewElement(C0B1YLimbs), + fp.NewElement(C0B2XLimbs), + fp.NewElement(C0B2YLimbs), + fp.NewElement(C1B0XLimbs), + fp.NewElement(C1B0YLimbs), + fp.NewElement(C1B1XLimbs), + fp.NewElement(C1B1YLimbs), + fp.NewElement(C1B2XLimbs), + fp.NewElement(C1B2YLimbs), } - return T + + return intoGtNoTower(fp, e12Tower) } // MultiG2GroupcheckCircuit is a circuit that checks multiple G2 group @@ -309,3 +288,53 @@ func (c *MillerLoopFinalExpInstance) Check(api frontend.API, fp *emulated.Field[ return evmprecompiles.ECPairMillerLoopAndFinalExpCheck(api, &prev, &P, &Q, c.Expected[1]) } + +// intoGtNoTower converts an E12 element as in the outputs of the pairing +// precompile on Ethereum into a non-tower representation of the same E12 +// element. +func intoGtNoTower(api *fpField, coordinates [12]*fpElement) sw_bn254.GTEl { + + var ( + C0B0X = coordinates[0] + C0B0Y = coordinates[1] + C0B1X = coordinates[2] + C0B1Y = coordinates[3] + C0B2X = coordinates[4] + C0B2Y = coordinates[5] + C1B0X = coordinates[6] + C1B0Y = coordinates[7] + C1B1X = coordinates[8] + C1B1Y = coordinates[9] + C1B2X = coordinates[10] + C1B2Y = coordinates[11] + ) + + var t *fpElement + t = api.MulConst(C0B0Y, big.NewInt(9)) + c0 := api.Sub(C0B0X, t) + t = api.MulConst(C1B0Y, big.NewInt(9)) + c1 := api.Sub(C1B0X, t) + t = api.MulConst(C0B1Y, big.NewInt(9)) + c2 := api.Sub(C0B1X, t) + t = api.MulConst(C1B1Y, big.NewInt(9)) + c3 := api.Sub(C1B1X, t) + t = api.MulConst(C0B2Y, big.NewInt(9)) + c4 := api.Sub(C0B2X, t) + t = api.MulConst(C1B2Y, big.NewInt(9)) + c5 := api.Sub(C1B2X, t) + + return sw_bn254.GTEl{ + A0: *c0, + A1: *c1, + A2: *c2, + A3: *c3, + A4: *c4, + A5: *c5, + A6: *C0B0Y, + A7: *C1B0Y, + A8: *C0B1Y, + A9: *C1B1Y, + A10: *C0B2Y, + A11: *C1B2Y, + } +} diff --git a/prover/zkevm/prover/publicInput/public_input.go b/prover/zkevm/prover/publicInput/public_input.go index 7ab94cd5d..555e8577a 100644 --- a/prover/zkevm/prover/publicInput/public_input.go +++ b/prover/zkevm/prover/publicInput/public_input.go @@ -17,6 +17,28 @@ import ( "github.com/ethereum/go-ethereum/common" ) +var ( + DataNbBytes = "DataNbBytes" + DataChecksum = "DataChecksum" + L2MessageHash = "L2MessageHash" + InitialStateRootHash = "InitialStateRootHash" + FinalStateRootHash = "FinalStateRootHash" + InitialBlockNumber = "InitialBlockNumber" + FinalBlockNumber = "FinalBlockNumber" + InitialBlockTimestamp = "InitialBlockTimestamp" + FinalBlockTimestamp = "FinalBlockTimestamp" + FirstRollingHashUpdate_0 = "FirstRollingHashUpdate_0" + FirstRollingHashUpdate_1 = "FirstRollingHashUpdate_1" + LastRollingHashUpdate_0 = "LastRollingHashUpdate_0" + LastRollingHashUpdate_1 = "LastRollingHashUpdate_1" + FirstRollingHashUpdateNumber = "FirstRollingHashUpdateNumber" + LastRollingHashNumberUpdate = "LastRollingHashNumberUpdate" + ChainID = "ChainID" + NBytesChainID = "NBytesChainID" + L2MessageServiceAddrHi = "L2MessageServiceAddrHi" + L2MessageServiceAddrLo = "L2MessageServiceAddrLo" +) + // PublicInput collects a number of submodules responsible for collecting the // wizard witness data holding the public inputs of the execution circuit. type PublicInput struct { @@ -302,4 +324,26 @@ func (pi *PublicInput) generateExtractor(comp *wizard.CompiledIOP) { L2MessageServiceAddrHi: accessors.NewFromPublicColumn(pi.Aux.logSelectors.L2BridgeAddressColHI, 0), L2MessageServiceAddrLo: accessors.NewFromPublicColumn(pi.Aux.logSelectors.L2BridgeAddressColLo, 0), } + + comp.PublicInputs = append(comp.PublicInputs, + wizard.PublicInput{Name: "DataNbBytes", Acc: accessors.NewLocalOpeningAccessor(pi.Extractor.DataNbBytes, 0)}, + wizard.PublicInput{Name: "DataChecksum", Acc: accessors.NewLocalOpeningAccessor(pi.Extractor.DataChecksum, 0)}, + wizard.PublicInput{Name: "L2MessageHash", Acc: accessors.NewLocalOpeningAccessor(pi.Extractor.L2MessageHash, 0)}, + wizard.PublicInput{Name: "InitialStateRootHash", Acc: accessors.NewLocalOpeningAccessor(pi.Extractor.InitialStateRootHash, 0)}, + wizard.PublicInput{Name: "FinalStateRootHash", Acc: accessors.NewLocalOpeningAccessor(pi.Extractor.FinalStateRootHash, 0)}, + wizard.PublicInput{Name: "InitialBlockNumber", Acc: accessors.NewLocalOpeningAccessor(pi.Extractor.InitialBlockNumber, 0)}, + wizard.PublicInput{Name: "FinalBlockNumber", Acc: accessors.NewLocalOpeningAccessor(pi.Extractor.FinalBlockNumber, 0)}, + wizard.PublicInput{Name: "InitialBlockTimestamp", Acc: accessors.NewLocalOpeningAccessor(pi.Extractor.InitialBlockTimestamp, 0)}, + wizard.PublicInput{Name: "FinalBlockTimestamp", Acc: accessors.NewLocalOpeningAccessor(pi.Extractor.FinalBlockTimestamp, 0)}, + wizard.PublicInput{Name: "FirstRollingHashUpdate[0]", Acc: accessors.NewLocalOpeningAccessor(pi.Extractor.FirstRollingHashUpdate[0], 0)}, + wizard.PublicInput{Name: "FirstRollingHashUpdate[1]", Acc: accessors.NewLocalOpeningAccessor(pi.Extractor.FirstRollingHashUpdate[1], 0)}, + wizard.PublicInput{Name: "LastRollingHashUpdate[0]", Acc: accessors.NewLocalOpeningAccessor(pi.Extractor.LastRollingHashUpdate[0], 0)}, + wizard.PublicInput{Name: "LastRollingHashUpdate[1]", Acc: accessors.NewLocalOpeningAccessor(pi.Extractor.LastRollingHashUpdate[1], 0)}, + wizard.PublicInput{Name: "FirstRollingHashUpdateNumber", Acc: accessors.NewLocalOpeningAccessor(pi.Extractor.FirstRollingHashUpdateNumber, 0)}, + wizard.PublicInput{Name: "LastRollingHashNumberUpdate", Acc: accessors.NewLocalOpeningAccessor(pi.Extractor.LastRollingHashUpdateNumber, 0)}, + wizard.PublicInput{Name: "ChainID", Acc: accessors.NewLocalOpeningAccessor(pi.Extractor.ChainID, 0)}, + wizard.PublicInput{Name: "NBytesChainID", Acc: accessors.NewLocalOpeningAccessor(pi.Extractor.NBytesChainID, 0)}, + wizard.PublicInput{Name: "L2MessageServiceAddrHi", Acc: pi.Extractor.L2MessageServiceAddrHi}, + wizard.PublicInput{Name: "L2MessageServiceAddrLo", Acc: pi.Extractor.L2MessageServiceAddrLo}, + ) }