Skip to content

Commit

Permalink
shared randomness for limitless prover (#587)
Browse files Browse the repository at this point in the history
* generate lpp compilediop  

* generating coins from seed

*  coin generation on the verifier side

* testing the coin equalities among different module-segments

* adjusted inclusion compilation based on the seed
  • Loading branch information
Soleimani193 authored Jan 24, 2025
1 parent fb4ee00 commit 07ae16a
Show file tree
Hide file tree
Showing 12 changed files with 530 additions and 44 deletions.
1 change: 1 addition & 0 deletions prover/backend/execution/prove.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ func mustProveAndPass(
case config.ProverModeEncodeOnly:

profiling.ProfileTrace("encode-decode-no-circuit", true, false, func() {
//nolint:gosec // Ignoring weak randomness error
filepath := "/tmp/wizard-assignment/blob-" + strconv.Itoa(rand.Int()) + ".bin"

encodeOnlyZkEvm := zkevm.EncodeOnlyZkEvm(traces)
Expand Down
14 changes: 14 additions & 0 deletions prover/crypto/fiatshamir/fiatshamir.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,20 @@ func (fs *State) RandomField() field.Element {
return res
}

// RandomField generates and returns a single field element from the seed and the given name.
func (fs *State) RandomFieldFromSeed(seed field.Element, name string) field.Element {
challBytes := []byte(name)
seedBytes := seed.Bytes()
challBytes = append(challBytes, seedBytes[:]...)

var res field.Element
res.SetBytes(challBytes)

// increase the counter by one
fs.NumCoinGenerated++
return res
}

// RandomManyIntegers returns a list of challenge small integers. That is, a
// list of positive integer bounded by `upperBound`. The upperBound is strict
// and is restricted to being only be a power of two.
Expand Down
13 changes: 12 additions & 1 deletion prover/protocol/coin/coin.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"strconv"

"github.com/consensys/linea-monorepo/prover/crypto/fiatshamir"
"github.com/consensys/linea-monorepo/prover/maths/field"
"github.com/consensys/linea-monorepo/prover/utils"
)

Expand Down Expand Up @@ -61,6 +62,7 @@ type Type int
const (
Field Type = iota
IntegerVec
FieldFromSeed
)

// MarshalJSON implements [json.Marshaler] directly returning the Itoa of the
Expand Down Expand Up @@ -88,12 +90,17 @@ func (t *Type) UnmarshalJSON(b []byte) error {
/*
Sample a random coin, according to its `spec`
*/
func (info *Info) Sample(fs *fiatshamir.State) interface{} {
func (info *Info) Sample(fs *fiatshamir.State, seed ...field.Element) interface{} {
switch info.Type {
case Field:
return fs.RandomField()
case IntegerVec:
return fs.RandomManyIntegers(info.Size, info.UpperBound)
case FieldFromSeed:
if len(seed) == 0 {
panic("expected a SEED as the input")
}
return fs.RandomFieldFromSeed(seed[0], string(info.Name))
}
panic("Unreachable")
}
Expand All @@ -117,6 +124,10 @@ func NewInfo(name Name, type_ Type, round int, size ...int) Info {
if len(size) > 0 {
utils.Panic("size for Field")
}
case FieldFromSeed:
if len(size) > 0 {
utils.Panic("size for Field")
}
default:
panic("unreachable")
}
Expand Down
3 changes: 2 additions & 1 deletion prover/protocol/column/column.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ func EvalExprColumn(run ifaces.Runtime, board symbolic.ExpressionBoard) smartvec
metadata = board.ListVariableMetadata()
inputs = make([]smartvectors.SmartVector, len(metadata))
length = ExprIsOnSameLengthHandles(&board)
v field.Element
)

// Attempt to recover the size of the
Expand All @@ -119,7 +120,7 @@ func EvalExprColumn(run ifaces.Runtime, board symbolic.ExpressionBoard) smartvec
case ifaces.Column:
inputs[i] = m.GetColAssignment(run)
case coin.Info:
v := run.GetRandomCoinField(m.Name)
v = run.GetRandomCoinField(m.Name)
inputs[i] = smartvectors.NewConstant(v, length)
case ifaces.Accessor:
v := m.GetVal(run)
Expand Down
10 changes: 8 additions & 2 deletions prover/protocol/distributed/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,14 @@ func ReplaceExternalCoinsVerifCols(
utils.Panic("Coin %v is declared in round %v != 1", v.Name, v.Round)
}
if !moduleComp.Coins.Exists(v.Name) {
moduleComp.InsertCoin(v.Round, v.Name, coin.Field)
translationMap.InsertNew(v.String(), symbolic.NewVariable(v))

if !initialComp.Coins.Exists("SEED") {
utils.Panic("Expect to find a seed in the initialComp")
}
// register a local coin of type FieldFromSeed.
name := coin.Namef("%v_%v", v.Name, "FieldFromSeed")
localV := moduleComp.InsertCoin(v.Round, name, coin.FieldFromSeed)
translationMap.InsertNew(v.String(), symbolic.NewVariable(localV))
}
case ifaces.Column:
// create the local verfiercols and add them to the translationMap.
Expand Down
61 changes: 37 additions & 24 deletions prover/protocol/distributed/compiler/inclusion/inclusion_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@ import (
logderiv "github.com/consensys/linea-monorepo/prover/protocol/compiler/logderivativesum"
"github.com/consensys/linea-monorepo/prover/protocol/distributed"
"github.com/consensys/linea-monorepo/prover/protocol/distributed/compiler/inclusion"
"github.com/consensys/linea-monorepo/prover/protocol/distributed/lpp"
md "github.com/consensys/linea-monorepo/prover/protocol/distributed/namebaseddiscoverer"
"github.com/consensys/linea-monorepo/prover/protocol/ifaces"
"github.com/consensys/linea-monorepo/prover/protocol/wizard"
"github.com/stretchr/testify/require"
)

// It tests DistributedLogDerivSum.
func TestDistributedLogDerivSum(t *testing.T) {
func TestSeedGeneration(t *testing.T) {
const (
numSegModule0 = 2
numSegModule1 = 2
Expand All @@ -24,22 +25,25 @@ func TestDistributedLogDerivSum(t *testing.T) {

//initialComp
define := func(b *wizard.Builder) {
// columns from module0
col01 := b.CompiledIOP.InsertCommit(0, "module0.col1", 4)
col02 := b.CompiledIOP.InsertCommit(0, "module0.col2", 8)

// columns from module1
col10 := b.CompiledIOP.InsertCommit(0, "module1.col0", 8)
col11 := b.CompiledIOP.InsertCommit(0, "module1.col1", 16)
col12 := b.CompiledIOP.InsertCommit(0, "module1.col2", 4)
col13 := b.CompiledIOP.InsertCommit(0, "module1.col3", 4)
col14 := b.CompiledIOP.InsertCommit(0, "module1.col4", 16)
col15 := b.CompiledIOP.InsertCommit(0, "module1.col5", 16)

// columns from module2
col20 := b.CompiledIOP.InsertCommit(0, "module2.col0", 4)
col21 := b.CompiledIOP.InsertCommit(0, "module2.col1", 4)
col22 := b.CompiledIOP.InsertCommit(0, "module2.col2", 4)

var (
// columns from module0
col01 = b.CompiledIOP.InsertCommit(0, "module0.col1", 4)
col02 = b.CompiledIOP.InsertCommit(0, "module0.col2", 8)

// columns from module1
col10 = b.CompiledIOP.InsertCommit(0, "module1.col0", 8)
col11 = b.CompiledIOP.InsertCommit(0, "module1.col1", 16)
col12 = b.CompiledIOP.InsertCommit(0, "module1.col2", 4)
col13 = b.CompiledIOP.InsertCommit(0, "module1.col3", 4)
col14 = b.CompiledIOP.InsertCommit(0, "module1.col4", 16)
col15 = b.CompiledIOP.InsertCommit(0, "module1.col5", 16)

// columns from module2
col20 = b.CompiledIOP.InsertCommit(0, "module2.col0", 4)
col21 = b.CompiledIOP.InsertCommit(0, "module2.col1", 4)
col22 = b.CompiledIOP.InsertCommit(0, "module2.col2", 4)
)

// inclusion query: S \subset T , S in module0, T in module1.
b.CompiledIOP.InsertInclusion(0, "lookup0",
Expand Down Expand Up @@ -71,9 +75,10 @@ func TestDistributedLogDerivSum(t *testing.T) {
run.AssignColumn("module2.col2", smartvectors.ForTest(1, 1, 0, 1))
}

// in initialComp replace inclusion queries with a global LogDerivativeSum
// it also creates new columns relevant to the preparation such as multiplicity columns.
initialComp := wizard.Compile(define, distributed.IntoLogDerivativeSum)
// initial compiledIOP is the parent to LPPComp and all the SegmentModuleComp objects.
initialComp := wizard.Compile(define)
// apply the LPP relevant compilers and generate the seed for initialComp
lppComp := lpp.CompileLPPAndGetSeed(initialComp, distributed.IntoLogDerivativeSum)

// Initialize the period separating module discoverer
disc := &md.PeriodSeperatingModuleDiscoverer{}
Expand Down Expand Up @@ -105,6 +110,7 @@ func TestDistributedLogDerivSum(t *testing.T) {
})

// distribute the query LogDerivativeSum among modules.
// The seed is used to generate randomness for each moduleComp.
inclusion.DistributeLogDerivativeSum(initialComp, moduleComp0, "module0", disc, numSegModule0)
inclusion.DistributeLogDerivativeSum(initialComp, moduleComp1, "module1", disc, numSegModule1)
inclusion.DistributeLogDerivativeSum(initialComp, moduleComp2, "module2", disc, numSegModule2)
Expand All @@ -115,7 +121,14 @@ func TestDistributedLogDerivSum(t *testing.T) {
wizard.ContinueCompilation(moduleComp2, logderiv.CompileLogDerivSum, dummy.Compile)

// run the initial runtime
initialRuntime := wizard.RunProver(initialComp, prover)
initialRuntime := wizard.ProverOnlyFirstRound(initialComp, prover)

// compile and verify for lpp-Prover
lppProof := wizard.Prove(lppComp, func(run *wizard.ProverRuntime) {
run.ParentRuntime = initialRuntime
})
lppVerifierRuntime, valid := wizard.VerifierWithRuntime(lppComp, lppProof)
require.NoError(t, valid)

// Compile and prove for module0
for proverID := 0; proverID < numSegModule0; proverID++ {
Expand All @@ -124,7 +137,7 @@ func TestDistributedLogDerivSum(t *testing.T) {
// inputs for vertical splitting of the witness
run.ProverID = proverID
})
valid := wizard.Verify(moduleComp0, proof0)
valid := wizard.Verify(moduleComp0, proof0, lppVerifierRuntime)
require.NoError(t, valid)
}

Expand All @@ -135,7 +148,7 @@ func TestDistributedLogDerivSum(t *testing.T) {
// inputs for vertical splitting of the witness
run.ProverID = proverID
})
valid1 := wizard.Verify(moduleComp1, proof1)
valid1 := wizard.Verify(moduleComp1, proof1, lppVerifierRuntime)
require.NoError(t, valid1)
}

Expand All @@ -146,7 +159,7 @@ func TestDistributedLogDerivSum(t *testing.T) {
// inputs for vertical splitting of the witness
run.ProverID = proverID
})
valid2 := wizard.Verify(moduleComp2, proof2)
valid2 := wizard.Verify(moduleComp2, proof2, lppVerifierRuntime)
require.NoError(t, valid2)
}
}
162 changes: 162 additions & 0 deletions prover/protocol/distributed/lpp/lpp.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
package lpp

import (
"github.com/consensys/linea-monorepo/prover/protocol/coin"
"github.com/consensys/linea-monorepo/prover/protocol/ifaces"
"github.com/consensys/linea-monorepo/prover/protocol/query"
"github.com/consensys/linea-monorepo/prover/protocol/wizard"
)

// It applies the Compilation steps concerning the LPP queries over comp.
// It generates a LPP-CompiledIOP object internally, that is used for seed generation.
func CompileLPPAndGetSeed(comp *wizard.CompiledIOP, lppCompilers ...func(*wizard.CompiledIOP)) *wizard.CompiledIOP {

var (
lppComp = wizard.NewCompiledIOP()
oldColumns = []ifaces.Column{}
lppCols = []ifaces.Column{}
)

// get the LPP columns from comp.
lppCols = append(lppCols, getLPPColumns(comp)...)

for _, col := range comp.Columns.AllHandlesAtRound(0) {
oldColumns = append(oldColumns, col)
}

// applies lppCompiler; this would add a new round and probably new columns to the current round
// but no new column to the new round.
for _, lppCompiler := range lppCompilers {
lppCompiler(comp)

if comp.NumRounds() != 2 || comp.Columns.NumRounds() != 1 {
panic("we expect to have new round while no column is yet registered for the new round")
}

numRounds := comp.NumRounds()
comp.EqualizeRounds(numRounds)
}

// filter the new lpp columns.
for _, col := range comp.Columns.AllHandlesAtRound(0) {
isOld := false
for _, oldCol := range oldColumns {
if col.GetColID() == oldCol.GetColID() {
isOld = true
break
}
}
if !isOld {
// if it is not in the oldColumns it is a new lpp column.
lppCols = append(lppCols, col)
}
}

// add the LPP columns to the lppComp.
for _, col := range lppCols {
lppComp.InsertCommit(0, col.GetColID(), col.Size())
}

// register the seed, generated from LPP, in comp
// for the sake of the assignment it also should be registered in lppComp
lppComp.InsertCoin(1, "SEED", coin.Field)
comp.InsertCoin(1, "SEED", coin.Field)

// prepare and register prover actions.
lppProver := &lppProver{
cols: lppCols,
}

lppComp.RegisterProverAction(1, lppProver)

return lppComp

}

type lppProver struct {
cols []ifaces.Column
}

func (p *lppProver) Run(run *wizard.ProverRuntime) {

for _, col := range p.cols {
colWitness := run.ParentRuntime.GetColumn(col.GetColID())
run.AssignColumn(col.GetColID(), colWitness, col.Round())
}

// generate the seed based on LPP run time.
seed := run.GetRandomCoinField("SEED")

// pass the seed to the parent run time.
// note that the parent of LPP is also parent to all compiledIOP of segment-Modules.
// thus, this gives access to the seed for all segment-module-compiledIOPs.
run.ParentRuntime.Coins.InsertNew("SEED", seed)
}

// GetLPPComp take the and old CompiledIOP object.
// It creates a fresh CompiledIOP object holding only the LPP columns.
// old CompiledIOP includes the LPP queries and new LPP Columns includes the new columns generated at round 0,
// due to the application of a compilation step (i.e., during the preparation).
// for example : multiplicity columns, for inclusion query, are retrieved from new LPP columns.
func GetLPPComp(oldComp *wizard.CompiledIOP, newLPPCols []ifaces.Column) *wizard.CompiledIOP {

var (
// initialize LPPComp
lppComp = wizard.NewCompiledIOP()
lppCols = []ifaces.Column{}
)

// get the LPP columns
lppCols = append(lppCols, getLPPColumns(oldComp)...)
lppCols = append(lppCols, newLPPCols...)

for _, col := range lppCols {
lppComp.InsertCommit(0, col.GetColID(), col.Size())
}
return lppComp
}

// it extract LPP columns from the context of each LPP query.
func getLPPColumns(c *wizard.CompiledIOP) []ifaces.Column {

var (
lppColumns = []ifaces.Column{}
)

for _, qName := range c.QueriesNoParams.AllKeysAt(0) {
q := c.QueriesNoParams.Data(qName)
switch v := q.(type) {
case query.Inclusion:

for i := range v.Including {
lppColumns = append(lppColumns, v.Including[i]...)
}

lppColumns = append(lppColumns, v.Included...)

if v.IncludingFilter != nil {
lppColumns = append(lppColumns, v.IncludingFilter...)
}

if v.IncludedFilter != nil {
lppColumns = append(lppColumns, v.IncludedFilter)
}

case query.Permutation:
for i := range v.A {
lppColumns = append(lppColumns, v.A[i]...)
lppColumns = append(lppColumns, v.B[i]...)
}
case query.Projection:
lppColumns = append(lppColumns, v.ColumnsA...)
lppColumns = append(lppColumns, v.ColumnsB...)
lppColumns = append(lppColumns, v.FilterA, v.FilterB)

default:
//do noting
}

}

return lppColumns
}
Loading

0 comments on commit 07ae16a

Please sign in to comment.