diff --git a/prover/protocol/column/column.go b/prover/protocol/column/column.go index 906bf9102..18af77a92 100644 --- a/prover/protocol/column/column.go +++ b/prover/protocol/column/column.go @@ -4,6 +4,7 @@ import ( "reflect" "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors" + "github.com/consensys/linea-monorepo/prover/maths/field" "github.com/consensys/linea-monorepo/prover/protocol/coin" "github.com/consensys/linea-monorepo/prover/protocol/ifaces" "github.com/consensys/linea-monorepo/prover/protocol/variables" @@ -172,3 +173,18 @@ func ExprIsOnSameLengthHandles(board *symbolic.ExpressionBoard) int { return length } + +// return the runtime assignments of a linear combination column +// that is computed on the fly from the columns stored in hs +func RandLinCombColAssignment(run ifaces.Runtime, coinVal field.Element, hs []ifaces.Column) smartvectors.SmartVector { + var colTableWit smartvectors.SmartVector + var witnessCollapsed smartvectors.SmartVector + x := field.One() + witnessCollapsed = smartvectors.NewConstant(field.Zero(), hs[0].Size()) + for tableCol := range hs { + colTableWit = hs[tableCol].GetColAssignment(run) + witnessCollapsed = smartvectors.Add(witnessCollapsed, smartvectors.Mul(colTableWit, smartvectors.NewConstant(x, hs[0].Size()))) + x.Mul(&x, &coinVal) + } + return witnessCollapsed +} diff --git a/prover/protocol/compiler/logderivativesum/compile.go b/prover/protocol/compiler/logderivativesum/compile.go new file mode 100644 index 000000000..edac1f402 --- /dev/null +++ b/prover/protocol/compiler/logderivativesum/compile.go @@ -0,0 +1,101 @@ +package logderiv + +import ( + "fmt" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/linea-monorepo/prover/maths/field" + "github.com/consensys/linea-monorepo/prover/protocol/compiler/lookup" + "github.com/consensys/linea-monorepo/prover/protocol/ifaces" + "github.com/consensys/linea-monorepo/prover/protocol/query" + "github.com/consensys/linea-monorepo/prover/protocol/wizard" +) + +// compile [query.LogDerivativeSum] query +func CompileLogDerivSum(comp *wizard.CompiledIOP) { + + // Collect all the logDerivativeSum queries + for _, qName := range comp.QueriesParams.AllUnignoredKeys() { + + // Filter out non other types of queries + logDeriv, ok := comp.QueriesParams.Data(qName).(query.LogDerivativeSum) + if !ok { + continue + } + + // This ensures that the LogDerivativeSum query is not used again in the + // compilation process. We know that the query was already ignored at + // the beginning because we are iterating over the unignored keys. + comp.QueriesParams.MarkAsIgnored(qName) + // get the Numerator and Denominator from the input and prepare their compilation. + zEntries := logDeriv.Inputs + va := FinalEvaluationCheck{} + for _, entry := range zEntries { + zC := &lookup.ZCtx{ + Round: entry.Round, + Size: entry.Size, + SigmaNumerator: entry.Numerator, + SigmaDenominator: entry.Denominator, + } + + // z-packing compile; it imposes the correct accumulation over Numerator and Denominator. + zC.Compile(comp) + // prover step; Z assignments + zAssignmentTask := lookup.ZAssignmentTask(*zC) + comp.SubProvers.AppendToInner(zC.Round, func(run *wizard.ProverRuntime) { + zAssignmentTask.Run(run) + }) + // collect all the zOpening for all the z columns + va.ZOpenings = append(va.ZOpenings, zC.ZOpenings...) + } + + // verifer step + va.LogDerivSumID = qName + lastRound := comp.NumRounds() - 1 + comp.RegisterVerifierAction(lastRound, &va) + } + +} + +type FinalEvaluationCheck struct { + // ZOpenings lists all the openings of all the zCtx + ZOpenings []query.LocalOpening + // query ID + LogDerivSumID ifaces.QueryID +} + +// Run implements the [wizard.VerifierAction] +func (f *FinalEvaluationCheck) Run(run *wizard.VerifierRuntime) error { + + // zSum stores the sum of the ending values of the zs as queried + // in the protocol via the local opening queries. + zSum := field.Zero() + for k := range f.ZOpenings { + temp := run.GetLocalPointEvalParams(f.ZOpenings[k].ID).Y + zSum.Add(&zSum, &temp) + } + + claimedSum := run.GetLogDerivSumParams(f.LogDerivSumID).Sum + if zSum != claimedSum { + return fmt.Errorf("log-derivate-sum; the final evaluation check failed for %v\n"+ + "given %v but calculated %v,", + f.LogDerivSumID, claimedSum.String(), zSum.String()) + } + + return nil +} + +// RunGnark implements the [wizard.VerifierAction] +func (f *FinalEvaluationCheck) RunGnark(api frontend.API, run *wizard.WizardVerifierCircuit) { + + claimedSum := run.GetLogDerivSumParams(f.LogDerivSumID).Sum + // SigmaSKSum stores the sum of the ending values of the SigmaSs as queried + // in the protocol via the + zSum := frontend.Variable(field.Zero()) + for k := range f.ZOpenings { + temp := run.GetLocalPointEvalParams(f.ZOpenings[k].ID).Y + zSum = api.Add(zSum, temp) + } + + api.AssertIsEqual(zSum, claimedSum) +} diff --git a/prover/protocol/compiler/logderivativesum/logderivsum_test.go b/prover/protocol/compiler/logderivativesum/logderivsum_test.go new file mode 100644 index 000000000..0bcf350da --- /dev/null +++ b/prover/protocol/compiler/logderivativesum/logderivsum_test.go @@ -0,0 +1,75 @@ +package logderiv_test + +import ( + "testing" + + "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors" + "github.com/consensys/linea-monorepo/prover/maths/field" + "github.com/consensys/linea-monorepo/prover/protocol/compiler/dummy" + logderiv "github.com/consensys/linea-monorepo/prover/protocol/compiler/logderivativesum" + "github.com/consensys/linea-monorepo/prover/protocol/ifaces" + "github.com/consensys/linea-monorepo/prover/protocol/query" + "github.com/consensys/linea-monorepo/prover/protocol/wizard" + "github.com/consensys/linea-monorepo/prover/symbolic" + "github.com/stretchr/testify/require" +) + +// It tests that the given expression for the LogDerivativeSum adds up to the given parameter. +func TestLogDerivSum(t *testing.T) { + + define := func(b *wizard.Builder) { + var ( + comp = b.CompiledIOP + ) + + p0 := b.RegisterCommit("Num_0", 4) + p1 := b.RegisterCommit("Num_1", 4) + p2 := b.RegisterCommit("Num_2", 4) + + q0 := b.RegisterCommit("Den_0", 4) + q1 := b.RegisterCommit("Den_1", 4) + q2 := b.RegisterCommit("Den_2", 4) + + numerators := []*symbolic.Expression{ + symbolic.Mul(p0, -1), + ifaces.ColumnAsVariable(p1), + symbolic.Mul(p2, p0, 2), + } + + denominators := []*symbolic.Expression{ + ifaces.ColumnAsVariable(q0), + ifaces.ColumnAsVariable(q1), + ifaces.ColumnAsVariable(q2), + } + + key := [2]int{0, 4} + zCat1 := map[[2]int]*query.LogDerivativeSumInput{} + zCat1[key] = &query.LogDerivativeSumInput{ + Round: 0, + Size: 4, + Numerator: numerators, + Denominator: denominators, + } + comp.InsertLogDerivativeSum(0, "LogDerivSum_Test", zCat1) + + } + + prover := func(run *wizard.ProverRuntime) { + + run.AssignColumn("Num_0", smartvectors.ForTest(1, 1, 1, 1)) + run.AssignColumn("Num_1", smartvectors.ForTest(2, 3, 7, 9)) + run.AssignColumn("Num_2", smartvectors.ForTest(5, 6, 1, 1)) + + run.AssignColumn("Den_0", smartvectors.ForTest(1, 1, 1, 1)) + run.AssignColumn("Den_1", smartvectors.ForTest(2, 3, 7, 9)) + run.AssignColumn("Den_2", smartvectors.ForTest(5, 6, 1, 1)) + + run.AssignLogDerivSum("LogDerivSum_Test", field.NewElement(8)) + + } + + compiled := wizard.Compile(define, logderiv.CompileLogDerivSum, dummy.Compile) + proof := wizard.Prove(compiled, prover) + valid := wizard.Verify(compiled, proof) + require.NoError(t, valid) +} diff --git a/prover/protocol/compiler/lookup/compiler.go b/prover/protocol/compiler/lookup/compiler.go index a39838b48..772dce3ae 100644 --- a/prover/protocol/compiler/lookup/compiler.go +++ b/prover/protocol/compiler/lookup/compiler.go @@ -99,11 +99,11 @@ func CompileLogDerivative(comp *wizard.CompiledIOP) { for _, entry := range zEntries { zC := zCatalog[entry] // z-packing compile - zC.compile(comp) + zC.Compile(comp) // entry[0]:round, entry[1]: size // the round that Gamma was registered. round := entry[0] - proverActions[round].pushZAssignment(zAssignmentTask(*zC)) + proverActions[round].pushZAssignment(ZAssignmentTask(*zC)) va.ZOpenings = append(va.ZOpenings, zC.ZOpenings...) va.Name = zC.Name } diff --git a/prover/protocol/compiler/lookup/prover.go b/prover/protocol/compiler/lookup/prover.go index 57bd5d612..dfa7d382a 100644 --- a/prover/protocol/compiler/lookup/prover.go +++ b/prover/protocol/compiler/lookup/prover.go @@ -32,7 +32,7 @@ type proverTaskAtRound struct { // ZAssignmentTasks lists all the tasks consisting of assigning the // columns SigmaS and SigmaT for the given round. - ZAssignmentTasks []zAssignmentTask + ZAssignmentTasks []ZAssignmentTask } // Run implements the [wizard.ProverAction interface]. The tasks will spawn @@ -94,7 +94,7 @@ func (p proverTaskAtRound) Run(run *wizard.ProverRuntime) { wg.Done() }() - p.ZAssignmentTasks[i].run(run) + p.ZAssignmentTasks[i].Run(run) }(i) } @@ -111,7 +111,7 @@ func (p *proverTaskAtRound) pushMAssignment(m MAssignmentTask) { } // pushZAssignment appends an [sigmaAssignmentTask] to the list of tasks -func (p *proverTaskAtRound) pushZAssignment(s zAssignmentTask) { +func (p *proverTaskAtRound) pushZAssignment(s ZAssignmentTask) { p.ZAssignmentTasks = append(p.ZAssignmentTasks, s) } @@ -298,12 +298,12 @@ func (a MAssignmentTask) Run(run *wizard.ProverRuntime) { } -// zAssignmentTask represents a prover task of assignming the columns +// ZAssignmentTask represents a prover task of assignming the columns // SigmaS and SigmaT for a specific lookup table. // sigmaAssignment -type zAssignmentTask ZCtx +type ZAssignmentTask ZCtx -func (z zAssignmentTask) run(run *wizard.ProverRuntime) { +func (z ZAssignmentTask) Run(run *wizard.ProverRuntime) { parallel.Execute(len(z.ZDenominatorBoarded), func(start, stop int) { for frag := start; frag < stop; frag++ { diff --git a/prover/protocol/compiler/lookup/z_packing.go b/prover/protocol/compiler/lookup/z_packing.go index 2838f4834..7d62954bd 100644 --- a/prover/protocol/compiler/lookup/z_packing.go +++ b/prover/protocol/compiler/lookup/z_packing.go @@ -39,12 +39,12 @@ type ZCtx struct { Name string } -// check permutation and see how/where compile is called (see how to constracut z there) +// check permutation and see how/where Compile is called (see how to constracut z there) // when constructing z, check if z is T or S // and change T -> -M, S -> +Filter // S or T -> ({S,T} + X) -// compile should be called inside CompileGrandSum -func (z *ZCtx) compile(comp *wizard.CompiledIOP) { +// Compile should be called inside CompileGrandSum +func (z *ZCtx) Compile(comp *wizard.CompiledIOP) { var ( numZs = utils.DivCeil( diff --git a/prover/protocol/distributed/compiler/inclusion/inclusion.go b/prover/protocol/distributed/compiler/inclusion/inclusion.go index f2685ff34..73086fccb 100644 --- a/prover/protocol/distributed/compiler/inclusion/inclusion.go +++ b/prover/protocol/distributed/compiler/inclusion/inclusion.go @@ -25,7 +25,7 @@ type DistributionInputs struct { } // GetShareOfLogDerivativeSum extracts the share of the given modules from the given LogDerivativeSum query. -// It insert a new LogDerivativeSum for the extracted share. +// It inserts a new LogDerivativeSum for the extracted share. func GetShareOfLogDerivativeSum(in DistributionInputs) { var ( initialComp = in.InitialComp diff --git a/prover/protocol/distributed/preparation.go b/prover/protocol/distributed/preparation.go index 9bd55cdae..e70526eac 100644 --- a/prover/protocol/distributed/preparation.go +++ b/prover/protocol/distributed/preparation.go @@ -54,7 +54,7 @@ func IntoLogDerivativeSum(comp *wizard.CompiledIOP) { checkTable = mainLookupCtx.CheckedTables[lookupTableName] round = mainLookupCtx.Rounds[lookupTableName] includedFilters = mainLookupCtx.IncludedFilters[lookupTableName] - // collapse multiColumns to single Columns + // collapse multiColumns to single Columns and commit to M. tableCtx = lookup.CompileLookupTable(comp, round, lookupTable, checkTable, includedFilters) ) diff --git a/prover/protocol/query/gnark_params.go b/prover/protocol/query/gnark_params.go index fc2ddabbf..903102d8e 100644 --- a/prover/protocol/query/gnark_params.go +++ b/prover/protocol/query/gnark_params.go @@ -16,11 +16,11 @@ func (p LocalOpeningParams) GnarkAssign() GnarkLocalOpeningParams { } type GnarkLogDerivSumParams struct { - Y frontend.Variable + Sum frontend.Variable } func (p LogDerivSumParams) GnarkAssign() GnarkLogDerivSumParams { - return GnarkLogDerivSumParams{Y: p.Sum} + return GnarkLogDerivSumParams{Sum: p.Sum} } // A gnark circuit version of InnerProductParams @@ -64,7 +64,7 @@ func (p GnarkLocalOpeningParams) UpdateFS(fs *fiatshamir.GnarkFiatShamir) { // Update the fiat-shamir state with the the present parameters func (p GnarkLogDerivSumParams) UpdateFS(fs *fiatshamir.GnarkFiatShamir) { - fs.Update(p.Y) + fs.Update(p.Sum) } // Update the fiat-shamir state with the the present parameters diff --git a/prover/protocol/query/logderiv_sum.go b/prover/protocol/query/logderiv_sum.go index 9734cdc7b..6a5cb5d4f 100644 --- a/prover/protocol/query/logderiv_sum.go +++ b/prover/protocol/query/logderiv_sum.go @@ -28,14 +28,7 @@ type LogDerivativeSumInput struct { // D_{i,j} is the i-th element of the underlying column of j-th Denominator type LogDerivativeSum struct { Inputs map[[2]int]*LogDerivativeSumInput - - ZNumeratorBoarded, ZDenominatorBoarded map[[2]int][]sym.ExpressionBoard - - Zs map[[2]int][]ifaces.Column - // ZOpenings are the opening queries to the end of each Z. - ZOpenings map[[2]int][]LocalOpening - - ID ifaces.QueryID + ID ifaces.QueryID } // the result of the global Sum diff --git a/prover/protocol/wizard/gnark_verifier.go b/prover/protocol/wizard/gnark_verifier.go index efe06be28..10d035cb2 100644 --- a/prover/protocol/wizard/gnark_verifier.go +++ b/prover/protocol/wizard/gnark_verifier.go @@ -319,9 +319,9 @@ func (c *WizardVerifierCircuit) GetLocalPointEvalParams(name ifaces.QueryID) que // GetLogDerivSumParams returns the parameters for the requested // [query.LogDerivativeSum] query. Its work mirrors the function // [VerifierRuntime.GetLogDerivSumParams] -func (c *WizardVerifierCircuit) GetLogDerivSumParams(name ifaces.QueryID) query.GnarkLocalOpeningParams { - qID := c.localOpeningIDs.MustGet(name) - return c.LocalOpeningParams[qID] +func (c *WizardVerifierCircuit) GetLogDerivSumParams(name ifaces.QueryID) query.GnarkLogDerivSumParams { + qID := c.logDerivSumIDs.MustGet(name) + return c.LogDerivSumParams[qID] } // GetColumns returns the gnark assignment of a column in a gnark circuit. It