Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added the compiler for logderivativesum #496

Merged
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions prover/protocol/column/column.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"reflect"

"github.com/consensys/linea-monorepo/prover/maths/common/smartvectors"
"github.com/consensys/linea-monorepo/prover/maths/field"
"github.com/consensys/linea-monorepo/prover/protocol/coin"
"github.com/consensys/linea-monorepo/prover/protocol/ifaces"
"github.com/consensys/linea-monorepo/prover/protocol/variables"
Expand Down Expand Up @@ -172,3 +173,18 @@ func ExprIsOnSameLengthHandles(board *symbolic.ExpressionBoard) int {

return length
}

// return the runtime assignments of a linear combination column
// that is computed on the fly from the columns stored in hs
func RandLinCombColAssignment(run ifaces.Runtime, coinVal field.Element, hs []ifaces.Column) smartvectors.SmartVector {
var colTableWit smartvectors.SmartVector
var witnessCollapsed smartvectors.SmartVector
x := field.One()
witnessCollapsed = smartvectors.NewConstant(field.Zero(), hs[0].Size())
for tableCol := range hs {
colTableWit = hs[tableCol].GetColAssignment(run)
witnessCollapsed = smartvectors.Add(witnessCollapsed, smartvectors.Mul(colTableWit, smartvectors.NewConstant(x, hs[0].Size())))
x.Mul(&x, &coinVal)
}
return witnessCollapsed
}
99 changes: 99 additions & 0 deletions prover/protocol/compiler/logderivative_sum.go/compile.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package logderiv
Soleimani193 marked this conversation as resolved.
Show resolved Hide resolved

import (
"fmt"

"github.com/consensys/gnark/frontend"
"github.com/consensys/linea-monorepo/prover/maths/field"
"github.com/consensys/linea-monorepo/prover/protocol/compiler/lookup"
"github.com/consensys/linea-monorepo/prover/protocol/ifaces"
"github.com/consensys/linea-monorepo/prover/protocol/query"
"github.com/consensys/linea-monorepo/prover/protocol/wizard"
)

// compile [query.LogDerivativeSum] query
func CompileLogDerivSum(comp *wizard.CompiledIOP) {

// Collect all the logDerivativeSum queries
for _, qName := range comp.QueriesParams.AllUnignoredKeys() {

// Filter out non other types of queries
logDeriv, ok := comp.QueriesParams.Data(qName).(query.LogDerivativeSum)
if !ok {
continue
}

// This ensures that the LogDerivativeSum query is not used again in the
// compilation process. We know that the query was already ignored at
// the beginning because we are iterating over the unignored keys.
comp.QueriesParams.MarkAsIgnored(qName)
// get the Numerator and Denominator from the input and prepare their compilation.
zEntries := logDeriv.Inputs
va := FinalEvaluationCheck{}
for _, entry := range zEntries {
zC := &lookup.ZCtx{
Round: entry.Round,
Size: entry.Size,
SigmaNumerator: entry.Numerator,
SigmaDenominator: entry.Denominator,
}

// z-packing compile; it imposes the correct accumulation over Numerator and Denominator.
zC.Compile(comp)
// prover step; Z assignments
zAssignmentTask := lookup.ZAssignmentTask(*zC)
comp.SubProvers.AppendToInner(zC.Round, func(run *wizard.ProverRuntime) {
zAssignmentTask.Run(run)
})
// collect all the zOpening for all the z columns
va.ZOpenings = append(va.ZOpenings, zC.ZOpenings...)
}

// verifer step
va.LogDerivSumID = qName
lastRound := comp.NumRounds() - 1
comp.RegisterVerifierAction(lastRound, &va)
}

}

type FinalEvaluationCheck struct {
// ZOpenings lists all the openings of all the zCtx
ZOpenings []query.LocalOpening
// query ID
LogDerivSumID ifaces.QueryID
}

// Run implements the [wizard.VerifierAction]
func (f *FinalEvaluationCheck) Run(run *wizard.VerifierRuntime) error {

// zSum stores the sum of the ending values of the zs as queried
// in the protocol via the local opening queries.
zSum := field.Zero()
for k := range f.ZOpenings {
temp := run.GetLocalPointEvalParams(f.ZOpenings[k].ID).Y
zSum.Add(&zSum, &temp)
}

claimedSum := run.GetLogDerivSumParams(f.LogDerivSumID).Sum
if zSum != claimedSum {
return fmt.Errorf("log-derivate-sum, the final evaluation check failed for %v,", f.LogDerivSumID)
}

return nil
}

// RunGnark implements the [wizard.VerifierAction]
func (f *FinalEvaluationCheck) RunGnark(api frontend.API, run *wizard.WizardVerifierCircuit) {

claimedSum := run.GetLogDerivSumParams(f.LogDerivSumID).Sum
// SigmaSKSum stores the sum of the ending values of the SigmaSs as queried
// in the protocol via the
zSum := frontend.Variable(field.Zero())
for k := range f.ZOpenings {
temp := run.GetLocalPointEvalParams(f.ZOpenings[k].ID).Y
zSum = api.Add(zSum, temp)
}

api.AssertIsEqual(zSum, claimedSum)
}
75 changes: 75 additions & 0 deletions prover/protocol/compiler/logderivative_sum.go/logderivsum_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package logderiv_test

import (
"testing"

"github.com/consensys/linea-monorepo/prover/maths/common/smartvectors"
"github.com/consensys/linea-monorepo/prover/maths/field"
"github.com/consensys/linea-monorepo/prover/protocol/compiler/dummy"
logderiv "github.com/consensys/linea-monorepo/prover/protocol/compiler/logderivative_sum.go"
"github.com/consensys/linea-monorepo/prover/protocol/ifaces"
"github.com/consensys/linea-monorepo/prover/protocol/query"
"github.com/consensys/linea-monorepo/prover/protocol/wizard"
"github.com/consensys/linea-monorepo/prover/symbolic"
"github.com/stretchr/testify/require"
)

// It tests that the given expression for the LogDerivativeSum adds up to the given parameter.
func TestLogDerivSum(t *testing.T) {

define := func(b *wizard.Builder) {
var (
comp = b.CompiledIOP
)

p0 := b.RegisterCommit("Num_0", 4)
p1 := b.RegisterCommit("Num_1", 4)
p2 := b.RegisterCommit("Num_2", 4)

q0 := b.RegisterCommit("Den_0", 4)
q1 := b.RegisterCommit("Den_1", 4)
q2 := b.RegisterCommit("Den_2", 4)

numerators := []*symbolic.Expression{
symbolic.Mul(p0, -1),
ifaces.ColumnAsVariable(p1),
symbolic.Mul(p2, p0, 2),
}

denominators := []*symbolic.Expression{
ifaces.ColumnAsVariable(q0),
ifaces.ColumnAsVariable(q1),
ifaces.ColumnAsVariable(q2),
}

key := [2]int{0, 4}
zCat1 := map[[2]int]*query.LogDerivativeSumInput{}
zCat1[key] = &query.LogDerivativeSumInput{
Round: 0,
Size: 4,
Numerator: numerators,
Denominator: denominators,
}
comp.InsertLogDerivativeSum(0, "LogDerivSum_Test", zCat1)

}

prover := func(run *wizard.ProverRuntime) {

run.AssignColumn("Num_0", smartvectors.ForTest(1, 1, 1, 1))
run.AssignColumn("Num_1", smartvectors.ForTest(2, 3, 7, 9))
run.AssignColumn("Num_2", smartvectors.ForTest(5, 6, 1, 1))

run.AssignColumn("Den_0", smartvectors.ForTest(1, 1, 1, 1))
run.AssignColumn("Den_1", smartvectors.ForTest(2, 3, 7, 9))
run.AssignColumn("Den_2", smartvectors.ForTest(5, 6, 1, 1))

run.AssignLogDerivSum("LogDerivSum_Test", field.NewElement(8))

}

compiled := wizard.Compile(define, logderiv.CompileLogDerivSum, dummy.Compile)
proof := wizard.Prove(compiled, prover)
valid := wizard.Verify(compiled, proof)
require.NoError(t, valid)
}
4 changes: 2 additions & 2 deletions prover/protocol/compiler/lookup/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,11 @@ func CompileLogDerivative(comp *wizard.CompiledIOP) {
for _, entry := range zEntries {
zC := zCatalog[entry]
// z-packing compile
zC.compile(comp)
zC.Compile(comp)
// entry[0]:round, entry[1]: size
// the round that Gamma was registered.
round := entry[0]
proverActions[round].pushZAssignment(zAssignmentTask(*zC))
proverActions[round].pushZAssignment(ZAssignmentTask(*zC))
va.ZOpenings = append(va.ZOpenings, zC.ZOpenings...)
va.Name = zC.Name
}
Expand Down
12 changes: 6 additions & 6 deletions prover/protocol/compiler/lookup/prover.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ type proverTaskAtRound struct {

// ZAssignmentTasks lists all the tasks consisting of assigning the
// columns SigmaS and SigmaT for the given round.
ZAssignmentTasks []zAssignmentTask
ZAssignmentTasks []ZAssignmentTask
}

// Run implements the [wizard.ProverAction interface]. The tasks will spawn
Expand Down Expand Up @@ -94,7 +94,7 @@ func (p proverTaskAtRound) Run(run *wizard.ProverRuntime) {
wg.Done()
}()

p.ZAssignmentTasks[i].run(run)
p.ZAssignmentTasks[i].Run(run)
}(i)
}

Expand All @@ -111,7 +111,7 @@ func (p *proverTaskAtRound) pushMAssignment(m MAssignmentTask) {
}

// pushZAssignment appends an [sigmaAssignmentTask] to the list of tasks
func (p *proverTaskAtRound) pushZAssignment(s zAssignmentTask) {
func (p *proverTaskAtRound) pushZAssignment(s ZAssignmentTask) {
p.ZAssignmentTasks = append(p.ZAssignmentTasks, s)
}

Expand Down Expand Up @@ -298,12 +298,12 @@ func (a MAssignmentTask) Run(run *wizard.ProverRuntime) {

}

// zAssignmentTask represents a prover task of assignming the columns
// ZAssignmentTask represents a prover task of assignming the columns
// SigmaS and SigmaT for a specific lookup table.
// sigmaAssignment
type zAssignmentTask ZCtx
type ZAssignmentTask ZCtx

func (z zAssignmentTask) run(run *wizard.ProverRuntime) {
func (z ZAssignmentTask) Run(run *wizard.ProverRuntime) {
parallel.Execute(len(z.ZDenominatorBoarded), func(start, stop int) {
for frag := start; frag < stop; frag++ {

Expand Down
6 changes: 3 additions & 3 deletions prover/protocol/compiler/lookup/z_packing.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ type ZCtx struct {
Name string
}

// check permutation and see how/where compile is called (see how to constracut z there)
// check permutation and see how/where Compile is called (see how to constracut z there)
// when constructing z, check if z is T or S
// and change T -> -M, S -> +Filter
// S or T -> ({S,T} + X)
// compile should be called inside CompileGrandSum
func (z *ZCtx) compile(comp *wizard.CompiledIOP) {
// Compile should be called inside CompileGrandSum
func (z *ZCtx) Compile(comp *wizard.CompiledIOP) {

var (
numZs = utils.DivCeil(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ type DistributionInputs struct {
}

// GetShareOfLogDerivativeSum extracts the share of the given modules from the given LogDerivativeSum query.
// It insert a new LogDerivativeSum for the extracted share.
// It inserts a new LogDerivativeSum for the extracted share.
func GetShareOfLogDerivativeSum(in DistributionInputs) {
var (
initialComp = in.InitialComp
Expand Down
2 changes: 1 addition & 1 deletion prover/protocol/distributed/preparation.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func IntoLogDerivativeSum(comp *wizard.CompiledIOP) {
checkTable = mainLookupCtx.CheckedTables[lookupTableName]
round = mainLookupCtx.Rounds[lookupTableName]
includedFilters = mainLookupCtx.IncludedFilters[lookupTableName]
// collapse multiColumns to single Columns
// collapse multiColumns to single Columns and commit to M.
tableCtx = lookup.CompileLookupTable(comp, round, lookupTable, checkTable, includedFilters)
)

Expand Down
6 changes: 3 additions & 3 deletions prover/protocol/query/gnark_params.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ func (p LocalOpeningParams) GnarkAssign() GnarkLocalOpeningParams {
}

type GnarkLogDerivSumParams struct {
Y frontend.Variable
Sum frontend.Variable
}

func (p LogDerivSumParams) GnarkAssign() GnarkLogDerivSumParams {
return GnarkLogDerivSumParams{Y: p.Sum}
return GnarkLogDerivSumParams{Sum: p.Sum}
}

// A gnark circuit version of InnerProductParams
Expand Down Expand Up @@ -64,7 +64,7 @@ func (p GnarkLocalOpeningParams) UpdateFS(fs *fiatshamir.GnarkFiatShamir) {

// Update the fiat-shamir state with the the present parameters
func (p GnarkLogDerivSumParams) UpdateFS(fs *fiatshamir.GnarkFiatShamir) {
fs.Update(p.Y)
fs.Update(p.Sum)
}

// Update the fiat-shamir state with the the present parameters
Expand Down
9 changes: 1 addition & 8 deletions prover/protocol/query/logderiv_sum.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,7 @@ type LogDerivativeSumInput struct {
// D_{i,j} is the i-th element of the underlying column of j-th Denominator
type LogDerivativeSum struct {
Inputs map[[2]int]*LogDerivativeSumInput

ZNumeratorBoarded, ZDenominatorBoarded map[[2]int][]sym.ExpressionBoard

Zs map[[2]int][]ifaces.Column
// ZOpenings are the opening queries to the end of each Z.
ZOpenings map[[2]int][]LocalOpening

ID ifaces.QueryID
ID ifaces.QueryID
}

// the result of the global Sum
Expand Down
6 changes: 3 additions & 3 deletions prover/protocol/wizard/gnark_verifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,9 +319,9 @@ func (c *WizardVerifierCircuit) GetLocalPointEvalParams(name ifaces.QueryID) que
// GetLogDerivSumParams returns the parameters for the requested
// [query.LogDerivativeSum] query. Its work mirrors the function
// [VerifierRuntime.GetLogDerivSumParams]
func (c *WizardVerifierCircuit) GetLogDerivSumParams(name ifaces.QueryID) query.GnarkLocalOpeningParams {
qID := c.localOpeningIDs.MustGet(name)
return c.LocalOpeningParams[qID]
func (c *WizardVerifierCircuit) GetLogDerivSumParams(name ifaces.QueryID) query.GnarkLogDerivSumParams {
qID := c.logDerivSumIDs.MustGet(name)
return c.LogDerivSumParams[qID]
}

// GetColumns returns the gnark assignment of a column in a gnark circuit. It
Expand Down
Loading